# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Common utilities used in schedule."""
import importlib
from functools import lru_cache
from types import FunctionType
from torch import nn
from torch import fx
HOOK_TYPE_TO_ATTR = {
"fwd_pre": "_forward_pre_hooks",
"fwd_post": "_forward_hooks",
"bwd_post": "_backward_hooks",
}
def get_hooks(mod):
"""Get the hooks of a module.
Parameters
----------
mod : torch.nn.Module
The module.
Returns
-------
dict
A dictionary of hooks.
"""
hooks = {"fwd_pre": [], "fwd_post": [], "bwd_post": []}
for hook in mod._forward_hooks.values():
hooks["fwd_post"].append(hook)
for hook in mod._forward_pre_hooks.values():
hooks["fwd_pre"].append(hook)
for hook in mod._backward_hooks.values():
hooks["bwd_post"].append(hook)
return hooks
def has_hook(mod, hook_type):
return len(getattr(mod, HOOK_TYPE_TO_ATTR[hook_type])) > 0
def transfer_hooks(old_mod, new_mod, hook_types=None):
"""Transfer the hooks from old_mod to new_mod.
Parameters
----------
old_mod : torch.nn.Module
The old module.
new_mod : torch.nn.Module
The new module.
hook_types : Optional[List[str]]
The types of hooks to transfer. If None, transfer all hooks.
"""
if hook_types is None:
hook_types = ["fwd_pre", "fwd_post", "bwd_post"]
for hook_attr in [HOOK_TYPE_TO_ATTR[hook_type] for hook_type in hook_types]:
setattr(new_mod, hook_attr, getattr(old_mod, hook_attr))
def transfer_hooks_for_fusion(sch, subgraphs, new_mod):
"""Transfer hooks of modules in the subgraph to be fused.
For example, the fwd_pre hook of the first module in the subgraph will become
the fwd_pre hook of the fused module.
Note that if middle modules have hooks, we will throw errors because we cannot
keep these hooks in the fused module.
Parameters
----------
sch : Schedule
The parent schedule.
subgraphs : List[List[Tuple[Node, Node]]]
The fused subgraphs that need to be transferred hooks.
new_mod : torch.nn.Module
The new module that will be created after fusion.
"""
hook_types = HOOK_TYPE_TO_ATTR.keys()
if len(subgraphs) > 1:
# Since horizontal fusion needs to combine the hooks together,
# we cannot support it for now.
for i, sublst in enumerate(subgraphs):
for _, node in sublst:
if node.op != "call_module":
break
old_mod = sch.get_module(node.target)
for hook in hook_types:
if has_hook(old_mod, hook) > 0:
raise RuntimeError(
"Cannot use horizontal fusion since module "
f"{node.target} has a {hook} hook"
)
else:
ops = subgraphs[0]
for i, (_, node) in enumerate(ops):
if node.op == "call_module":
old_mod = sch.get_module(node.target)
if i == 0: # the first node
if has_hook(old_mod, "fwd_post"):
raise RuntimeError(
f"Cannot transfer hooks from {node.target} to the "
f"new module since {node.target} has a fwd_post hook"
)
transfer_hooks(old_mod, new_mod, ["fwd_pre", "bwd_post"])
elif i == len(ops) - 1: # the last node
if has_hook(old_mod, "fwd_pre") or has_hook(old_mod, "bwd_post"):
raise RuntimeError(
f"Cannot transfer hooks from {node.target} to the new "
f"module since {node.target} has a fwd_pre/bwd_post hook"
)
transfer_hooks(old_mod, new_mod, ["fwd_post"])
elif any(has_hook(old_mod, x) for x in hook_types):
raise RuntimeError(
f"Cannot transfer hooks from {node.target} to the new module "
f"since {node.target} is in the middle of the subgraph"
)
def transfor_param_tags(sch, param, new_param):
for param_tag_name in sch.get_top_schedule().metadata.param_tags:
if hasattr(param, param_tag_name):
setattr(new_param, param_tag_name, getattr(param, param_tag_name))
def is_lambda_function(obj):
return isinstance(obj, FunctionType) and obj.__name__ == "<lambda>"
[docs]def is_module_list(module, name="", parent=None):
"""A module list will become nn.Module or fx.GraphModule after tracing,
but we still want to treat it as a module list in the schedule.
"""
if isinstance(module, nn.Sequential):
return False
if isinstance(module, nn.ModuleList):
return True
if (
isinstance(module, fx.GraphModule)
and parent is not None
and isinstance(parent.mod, fx.GraphModule)
):
# If the module and its parent are both traced, we can check
# the caller in the parent. If there is a caller that directly
# calls this module, then this is not a module list.
for node in parent.mod.graph.nodes:
if node.op == "call_module" and node.target == name:
return False
# If all above cannot work, we could only chacke if its children are indexed by
# sequential integers, and treat it as a module list if so.
child_names = [name for name, _ in module.named_children()]
if not child_names:
return False
try:
child_names = [int(n) for n in child_names]
return child_names == list(range(len(child_names)))
except ValueError:
return False
@lru_cache()
def importlib_or_none(name):
"""Import the module if available, otherwise return None."""
try:
return importlib.import_module(name)
except ModuleNotFoundError:
return None