Source code for slapo.utils.common

# Copyright, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Common utilities used in schedule."""
import importlib
from functools import lru_cache
from types import FunctionType

from torch import nn
from torch import fx

    "fwd_pre": "_forward_pre_hooks",
    "fwd_post": "_forward_hooks",
    "bwd_post": "_backward_hooks",

def get_hooks(mod):
    """Get the hooks of a module.

    mod : torch.nn.Module
        The module.

        A dictionary of hooks.
    hooks = {"fwd_pre": [], "fwd_post": [], "bwd_post": []}
    for hook in mod._forward_hooks.values():

    for hook in mod._forward_pre_hooks.values():

    for hook in mod._backward_hooks.values():

    return hooks

def has_hook(mod, hook_type):
    return len(getattr(mod, HOOK_TYPE_TO_ATTR[hook_type])) > 0

def transfer_hooks(old_mod, new_mod, hook_types=None):
    """Transfer the hooks from old_mod to new_mod.

    old_mod : torch.nn.Module
        The old module.
    new_mod : torch.nn.Module
        The new module.
    hook_types : Optional[List[str]]
        The types of hooks to transfer. If None, transfer all hooks.
    if hook_types is None:
        hook_types = ["fwd_pre", "fwd_post", "bwd_post"]

    for hook_attr in [HOOK_TYPE_TO_ATTR[hook_type] for hook_type in hook_types]:
        setattr(new_mod, hook_attr, getattr(old_mod, hook_attr))

def transfer_hooks_for_fusion(sch, subgraphs, new_mod):
    """Transfer hooks of modules in the subgraph to be fused.
    For example, the fwd_pre hook of the first module in the subgraph will become
    the fwd_pre hook of the fused module.
    Note that if middle modules have hooks, we will throw errors because we cannot
    keep these hooks in the fused module.

    sch : Schedule
        The parent schedule.
    subgraphs : List[List[Tuple[Node, Node]]]
        The fused subgraphs that need to be transferred hooks.
    new_mod : torch.nn.Module
        The new module that will be created after fusion.
    hook_types = HOOK_TYPE_TO_ATTR.keys()
    if len(subgraphs) > 1:
        # Since horizontal fusion needs to combine the hooks together,
        # we cannot support it for now.
        for i, sublst in enumerate(subgraphs):
            for _, node in sublst:
                if node.op != "call_module":
                old_mod = sch.get_module(
                for hook in hook_types:
                    if has_hook(old_mod, hook) > 0:
                        raise RuntimeError(
                            "Cannot use horizontal fusion since module "
                            f"{} has a {hook} hook"
        ops = subgraphs[0]
        for i, (_, node) in enumerate(ops):
            if node.op == "call_module":
                old_mod = sch.get_module(
                if i == 0:  # the first node
                    if has_hook(old_mod, "fwd_post"):
                        raise RuntimeError(
                            f"Cannot transfer hooks from {} to the "
                            f"new module since {} has a fwd_post hook"
                    transfer_hooks(old_mod, new_mod, ["fwd_pre", "bwd_post"])
                elif i == len(ops) - 1:  # the last node
                    if has_hook(old_mod, "fwd_pre") or has_hook(old_mod, "bwd_post"):
                        raise RuntimeError(
                            f"Cannot transfer hooks from {} to the new "
                            f"module since {} has a fwd_pre/bwd_post hook"
                    transfer_hooks(old_mod, new_mod, ["fwd_post"])
                elif any(has_hook(old_mod, x) for x in hook_types):
                    raise RuntimeError(
                        f"Cannot transfer hooks from {} to the new module "
                        f"since {} is in the middle of the subgraph"

def transfor_param_tags(sch, param, new_param):
    for param_tag_name in sch.get_top_schedule().metadata.param_tags:
        if hasattr(param, param_tag_name):
            setattr(new_param, param_tag_name, getattr(param, param_tag_name))

def is_lambda_function(obj):
    return isinstance(obj, FunctionType) and obj.__name__ == "<lambda>"

[docs]def is_module_list(module, name="", parent=None): """A module list will become nn.Module or fx.GraphModule after tracing, but we still want to treat it as a module list in the schedule. """ if isinstance(module, nn.Sequential): return False if isinstance(module, nn.ModuleList): return True if ( isinstance(module, fx.GraphModule) and parent is not None and isinstance(parent.mod, fx.GraphModule) ): # If the module and its parent are both traced, we can check # the caller in the parent. If there is a caller that directly # calls this module, then this is not a module list. for node in parent.mod.graph.nodes: if node.op == "call_module" and == name: return False # If all above cannot work, we could only chacke if its children are indexed by # sequential integers, and treat it as a module list if so. child_names = [name for name, _ in module.named_children()] if not child_names: return False try: child_names = [int(n) for n in child_names] return child_names == list(range(len(child_names))) except ValueError: return False
@lru_cache() def importlib_or_none(name): """Import the module if available, otherwise return None.""" try: return importlib.import_module(name) except ModuleNotFoundError: return None