Source code for slapo.tracer

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import copy
import inspect
import operator
import random
import traceback
from typing import Any

import torch
import torch.utils._pytree as pytree
from torch import fx, nn
from torch.fx._symbolic_trace import HAS_VARSTUFF, PH, _assert_is_none, _patch_function
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from torch.fx.node import base_types

from .logger import get_logger

logger = get_logger()


def is_fx_tracable(mod):
    return not hasattr(mod, "traceable") or mod.traceable


def fix_hf_module(
    root: nn.Module, root_graph: fx.Graph, submods: dict[str, fx.GraphModule]
):
    # Fix tensor constants
    for target in dir(root):
        if "_tensor_constant" in target or target in {"position_ids", "token_type_ids"}:
            submods[target] = getattr(root, target)
    for node in root_graph.nodes:
        # Add submodule's attributes to parent module if it is used
        if node.op in {"call_module", "get_attr"} and node.target not in submods:
            attr_itr = root
            atoms = node.target.split(".")
            for atom in atoms:
                attr_itr = getattr(attr_itr, atom)
            submods[node.target] = attr_itr
        # Fix SelfAttention naming
        if node.op == "call_module" and "self" in node.target:
            logger.warning(
                "`self` in %s is a Python keyword, "
                "please rename it to avoid possible conflicts.",
                root.__class__.__name__,
            )
    # Fix arguments
    for node in root_graph.nodes:
        if node.op == "call_module":
            args = node.args
            kwargs = node.kwargs

            # Checkpoint wrapper has a unified interface (*args, **kwargs)
            # which ruins the argument matching.
            orig_forward = submods[node.target].forward
            if submods[node.target].__class__.__name__ == "CheckPointWrapper":
                orig_forward = submods[node.target].mod.forward

            sig = inspect.signature(orig_forward)
            target_args = list(sig.parameters.keys())
            res_kwargs = {}
            for key in kwargs:
                if key in target_args:
                    res_kwargs[key] = kwargs[key]
                    target_args.remove(key)
            node.args = tuple(args[: len(target_args)])
            node.kwargs = res_kwargs
    # FIXME: Dirty hack for getitem
    # https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1021-L1033
    # https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L1042-L1045
    for node in root_graph.nodes:
        if (
            # pylint: disable=comparison-with-callable
            node.op == "call_function"
            and node.target == operator.getitem
            and len(node.args) == 2
            and node.args[0].target in {"encoder", "decoder", "bert", "transformer"}
            and node.args[1] == 0
        ):
            node.args = (node.args[0], "last_hidden_state")
        if (
            # pylint: disable=comparison-with-callable
            node.op == "call_function"
            and node.target == getattr
            and len(node.args) == 2
            and node.args[0].target in {"encoder", "decoder", "bert", "transformer"}
        ):
            node.op = "call_method"
            node.target = "get"
            node.args = (node.args[0], node.args[1], None)
    return root_graph


def generate_hf_tracer_inputs(
    root: nn.Module,
    tracer: fx.Tracer,
    is_top: bool,
    call_node: fx.Node,
    kwargs: dict[str, Any],
):
    # generate random shape
    batch_size = random.randint(10, 20)
    sequence_length = random.randint(10, 20)
    shape = [batch_size, sequence_length]
    # generate concrete_args and dummy_inputs
    if is_top:
        sig = inspect.signature(
            root.forward if isinstance(root, torch.nn.Module) else root
        )
        assert "concrete_args" in kwargs
        if not hasattr(root, "device"):
            root.device = next(root.named_parameters())[1].device
        if not hasattr(root, "config"):
            assert "config" in kwargs, "Please provide `config` for HF tracer"
            root.config = kwargs["config"]
        concrete_args = kwargs["concrete_args"]  # those are args having None value
        input_names = sig.parameters.keys() - concrete_args.keys()
        inputs = {}
        for input_name in input_names:
            inputs.update(tracer._generate_dummy_input(root, input_name, shape))
        kwargs["dummy_inputs"] = inputs
        dummy_inputs = copy.copy(inputs)
    else:
        assert call_node is not None
        sig = inspect.signature(root.forward)
        arg_names = list(sig.parameters.keys())
        dummy_inputs = {}
        for i, arg in enumerate(call_node.args):
            if isinstance(arg, fx.Node):
                # FIXME: shape and dtype do affect the control flow branches
                dummy_inputs[arg_names[i]] = torch.zeros(shape, dtype=torch.float32)
            else:
                # ignore value=None
                pass
        for _, (key, arg) in enumerate(call_node.kwargs.items(), len(call_node.args)):
            assert key in arg_names
            if isinstance(arg, fx.Node):
                dummy_inputs[key] = torch.zeros(shape, dtype=torch.float32)
        concrete_args = {
            p.name: p.default
            for p in sig.parameters.values()
            if p.name not in dummy_inputs
        }
    return concrete_args, dummy_inputs


def trace_submodule(
    root: nn.Module,
    tracer_class,
    is_top: bool = False,
    call_node: fx.Node = None,
    **kwargs,
):
    # generate top graph module
    named_children = dict(root.named_children())
    leaf_modules = kwargs.get("leaf_modules", [])

    recursive = kwargs.get("recursive", True)

    # Create a tracer with the original leaf modules. This is only used
    # to judge whether a submodule is really a leaf or not.
    tracer_with_orig_leaf = tracer_class(leaf_modules=leaf_modules)

    leaf_modules = copy.deepcopy(leaf_modules)
    if not kwargs.get("flatten", False):
        # Add all children module (submodule) to be leaf module to prevent
        # the tracer from tracing into them, because we will trace submodules
        # separately to maintain the module hierarchy.
        for key, leaf_mod in named_children.items():
            if isinstance(leaf_mod, nn.ModuleList):
                leaf_modules += [
                    f"{key}.{s}" for s in list(dict(leaf_mod.named_children()).keys())
                ]
            else:
                leaf_modules.append(key)
    tracer = tracer_class(leaf_modules=leaf_modules)

    if tracer.name == "huggingface":
        concrete_args, dummy_inputs = generate_hf_tracer_inputs(
            root, tracer, is_top, call_node, kwargs
        )
        try:
            root_graph = tracer.trace(
                root, concrete_args=concrete_args, dummy_inputs=dummy_inputs
            )
        except Exception as err:
            logger.warning(traceback.format_exc())
            logger.warning("Cannot trace module %s: %s", root.__class__.__name__, err)
            return root
    else:
        concrete_args = kwargs.get("concrete_args", {})
        try:
            root_graph = tracer.trace(root, concrete_args=concrete_args)
        except Exception as err:
            logger.warning(traceback.format_exc())
            logger.warning("Cannot trace module %s: %s", root.__class__.__name__, err)
            return root
    call_arg_map = {}
    for node in root_graph.nodes:
        if node.op == "call_module":
            call_arg_map[node.target] = node

    # Trace submodules
    submods = {}
    if not kwargs.get("flatten", False):
        for name, submod in named_children.items():
            if isinstance(submod, nn.ModuleList):
                # We assume ModuleList will be iteratively traversed in forward function.
                # For example:
                # In __init__:
                #     self.layers = nn.ModuleList([nn.Linear(10, 10) for _ in range(3)])
                # In forwrad :
                #     for layer in self.layers:
                #         x = layer(x)
                # In this case, fx IR will create a unique name for each layer,
                # such as layer.0, layer.1, etc. We follow this convention to
                # trace each layer in ModuleList to register the submodule name.
                for i, layer in enumerate(submod):
                    module_qualified_name = tracer.path_of_module(layer)
                    if not recursive or tracer_with_orig_leaf.is_leaf_module(
                        layer, module_qualified_name
                    ):
                        gm_submod = layer
                    else:
                        gm_submod = trace_submodule(
                            layer,
                            tracer_class,
                            is_top=False,
                            call_node=call_arg_map[f"{name}.{i}"],
                            **kwargs,
                        )
                    submods[f"{name}.{i}"] = gm_submod
            else:
                # For other submodules including nn.Sequential, we assume they are directly
                # called in forward function. For example:
                # In __init__: self.block = nn.Sequential(...)
                # In forward : out = self.block(x)
                # In this case, fx IR will create directly call the submodule such as block.
                module_qualified_name = tracer.path_of_module(submod)
                if not recursive or tracer_with_orig_leaf.is_leaf_module(
                    submod, module_qualified_name
                ):
                    # If it is a real leaf module, stop tracing.
                    gm_submod = submod
                else:
                    gm_submod = trace_submodule(
                        submod,
                        tracer_class,
                        is_top=False,
                        call_node=call_arg_map[name],
                        **kwargs,
                    )
                submods[name] = gm_submod
    if tracer.name == "huggingface":
        root_graph = fix_hf_module(root, root_graph, submods)
    if not kwargs.get("flatten", False):
        final_gm = fx.GraphModule(submods, root_graph)
    else:
        final_gm = fx.GraphModule(root, root_graph)
    # remove redundant code
    final_gm.graph.eliminate_dead_code()
    final_gm.delete_all_unused_submodules()
    final_gm.graph.lint()
    final_gm.recompile()
    # remove meta tensors generated by HF tracer
    for name in dict(final_gm.named_buffers()):
        if "tensor_constant" in name and hasattr(final_gm, name):
            final_gm.__delattr__(name)  # pylint: disable=unnecessary-dunder-call
    return final_gm


[docs]def trace(model: nn.Module, **kwargs: dict[str, Any]): """Traces a model to a GraphModule.""" tracer_cls_name = kwargs.get("tracer", "pytorch") logger.debug("Tracer: %s Model: %s", tracer_cls_name, model.__class__.__name__) def create_args_for_root(cls, root_fn, is_module, concrete_args=None): """Override this method to make sure the argument names are the same as the original module, so that the traced module can be injected. FIXME: Implement a fx pass that fixes the argument names, so that we don't need to override this method. """ # In some cases, a function or method has been decorated with # a wrapper defined via ``functools.wraps``. In this case, # the outer code object will likely not contain the actual # parameters we care about, so unwrap the function to get to # the innermost callable. fn_for_analysis = inspect.unwrap(root_fn) co = fn_for_analysis.__code__ total_args = co.co_argcount + co.co_kwonlyargcount orig_args = list(co.co_varnames) names_iter = iter(co.co_varnames) args: list[Any] = [] skip_arg_idx = 0 if is_module: if total_args == 0: raise RuntimeError( "``cls`` argument cannot be part of *args expansion!" ) skip_arg_idx = 1 next(names_iter) # skip cls args.append(cls.root) sig = inspect.signature(fn_for_analysis) def proxy_placeholder(name: str): if concrete_args is not None and name in concrete_args: cnt = 0 def replace_ph(x): nonlocal cnt cnt += 1 param = sig.parameters[name] default = ( () if param.default is inspect.Parameter.empty else (param.default,) ) proxy_name = f"{name}_{str(cnt)}" if cnt > 1 else name out = cls.create_proxy("placeholder", proxy_name, default, {}) if x == PH: return out # Union[int, bool] == bool in Python <= 3.6 if isinstance(x, (bool, base_types)) and not isinstance( x, torch.Tensor ): torch._assert( out == x, f"{name} has been specialized to have value " f"{x} but got another value", ) elif x is None: args = ( out, f"{name} has been specialized to have value " "None but got another value", ) cls.create_proxy("call_function", _assert_is_none, args, {}) else: logger.warning( "Was not able to add assertion to guarantee " "correct input %s to " "specialized function. It is up to the user " "to make sure that your inputs match the " "inputs you specialized the function with.", name, ) return x return pytree.tree_map(replace_ph, concrete_args[name]) if name[0] == "*": default = () else: param = sig.parameters[name] default = ( () if param.default is inspect.Parameter.empty else (param.default,) ) return cls.create_proxy( "placeholder", name, default, {}, type_expr=fn_for_analysis.__annotations__.get(name, None), ) arg_names = [next(names_iter) for _ in range(skip_arg_idx, total_args)] if isinstance(concrete_args, tuple): if len(arg_names) != len(concrete_args): raise RuntimeError( f"Tracing expected {len(arg_names)} arguments but " f"got {len(concrete_args)} concrete arguments" ) concrete_args = dict(zip(arg_names, concrete_args)) args.extend(proxy_placeholder(names) for names in arg_names) if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: # TODO: type annotations for *args and **kwargs if co.co_flags & inspect.CO_VARARGS: args.append(proxy_placeholder("*" + next(names_iter))) if co.co_flags & inspect.CO_VARKEYWORDS: args.append(proxy_placeholder("**" + next(names_iter))) root_fn = _patch_function(root_fn, len(args)) flat_args, in_spec = pytree.tree_flatten(tuple(args)) if any(not isinstance(i, pytree.LeafSpec) for i in in_spec.children_specs): # In the case that we have pytree-flattened inputs in # `concrete_args`, generate a flattening wrapper around the # original root function and return that. cls.graph._codegen = _PyTreeCodeGen( _PyTreeInfo(orig_args[:total_args], in_spec, None) ) def flatten_fn(*args): tree_args = pytree.tree_unflatten(list(args), in_spec) tree_out = root_fn(*tree_args) out_args, out_spec = pytree.tree_flatten(tree_out) assert isinstance(cls.graph._codegen, _PyTreeCodeGen) cls.graph._codegen.pytree_info = ( cls.graph._codegen.pytree_info._replace(out_spec=out_spec) ) return out_args return flatten_fn, flat_args return root_fn, args if isinstance(tracer_cls_name, str): if tracer_cls_name == "huggingface": from transformers.utils.fx import ( HFTracer, _IS_IN_DEBUG_MODE, _MANUAL_META_OVERRIDES, Proxy, _proxies_to_metas, _generate_random_int, get_values, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES, _gen_constructor_wrapper, ) assert ( "concrete_args" in kwargs ), "Please provide concrete_args for HF tracer" concrete_args = kwargs.pop("concrete_args") class TracerWrapper(HFTracer): def __init__(self, **config: dict[str, Any]) -> None: super().__init__() self.name = "huggingface" self.leaf_modules = config.get("leaf_modules", []) def create_proxy( self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None, ): rv = super().create_proxy( kind, target, args, kwargs, name, type_expr, proxy_factory_fn ) # grandparent method if kind == "placeholder" and target in self.meta_args: rv.install_metadata(self.meta_args[target]) return rv if target in self.orig_fns: if "device" in kwargs: kwargs["device"] = "meta" try: args_metas = torch.fx.node.map_aggregate( args, _proxies_to_metas ) kwargs_metas = torch.fx.node.map_aggregate( kwargs, _proxies_to_metas ) if kind == "call_function": meta_target = _MANUAL_META_OVERRIDES.get(target, target) meta_out = meta_target(*args_metas, **kwargs_metas) if isinstance(meta_out, torch.Tensor): meta_out = meta_out.to(device="meta") elif kind == "call_method": method = getattr(args_metas[0].__class__, target) meta_target = _MANUAL_META_OVERRIDES.get(method, method) meta_out = meta_target(*args_metas, **kwargs_metas) elif kind == "call_module": if not hasattr(self, "orig_forward"): raise AttributeError( f"{self} does not have an attribute " "called orig_forward" ) return rv # delete original code here elif kind == "get_attr": self._disable_module_getattr = True try: attr_itr = self.root atoms = target.split(".") for atom in atoms: attr_itr = getattr(attr_itr, atom) if isinstance(attr_itr, torch.Tensor): meta_out = attr_itr.to(device="meta") else: meta_out = attr_itr finally: self._disable_module_getattr = False else: return rv if not isinstance(rv, Proxy): raise ValueError("Don't support composite output yet") rv.install_metadata(meta_out) except Exception as e: if _IS_IN_DEBUG_MODE: logger.warning( "Could not compute metadata for %s target %s: %s", kind, target, e, ) return rv def is_leaf_module( self, m: nn.Module, module_qualified_name: str ) -> bool: if ( not is_fx_tracable(m) or any(t in type(m).__name__ for t in self.leaf_modules) or any(t == module_qualified_name for t in self.leaf_modules) ): return True return super().is_leaf_module(m, module_qualified_name) def trace( self, root, concrete_args=None, dummy_inputs=None, complete_concrete_args_with_inputs_not_in_dummy_inputs=True, ): sig = inspect.signature( root.forward if isinstance(root, torch.nn.Module) else root ) if concrete_args is None: concrete_args = {} if ( dummy_inputs is not None and complete_concrete_args_with_inputs_not_in_dummy_inputs ): for param in sig.parameters.values(): if param.name in dummy_inputs: continue if param.default is inspect.Parameter.empty: raise ValueError( f"You need to specify a default value for the parameter {param.name}." ) concrete_args.update( { p.name: p.default for p in sig.parameters.values() if ( p.name not in dummy_inputs and p.name not in concrete_args ) } ) input_names = sig.parameters.keys() - concrete_args.keys() # Creating a random input shape to generate dummy inputs. batch_size = _generate_random_int() sequence_length = _generate_random_int() shape = [batch_size, sequence_length] if root.__class__.__name__ in get_values( MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES ): num_choices = _generate_random_int(low=2, high=5) shape.insert(1, num_choices) inputs = dict(dummy_inputs) if dummy_inputs is not None else {} concrete_metas = { input_name: input_.to("meta") if isinstance(input_, torch.Tensor) else input_ for input_name, input_ in inputs.items() } for param in sig.parameters.values(): if ( param.kind == inspect.Parameter.VAR_KEYWORD and param.name not in input_names ): concrete_metas[f"**{param.name}"] = {} self.meta_args = concrete_metas self.patched_torch_methods = { target: _gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH } self.orig_fns = set() for name, (wrapper, orig) in self.patched_torch_methods.items(): setattr(torch, name, wrapper) self.orig_fns.add(orig) try: # pylint: disable=bad-super-call self.graph = super(HFTracer, self).trace( root, concrete_args=concrete_args ) finally: for name, (_, orig) in self.patched_torch_methods.items(): setattr(torch, name, orig) return self.graph def create_args_for_root(self, root_fn, is_module, concrete_args=None): return create_args_for_root(self, root_fn, is_module, concrete_args) top_gm = trace_submodule( model, TracerWrapper, is_top=True, concrete_args=concrete_args, **kwargs, ) elif tracer_cls_name == "pytorch": class TracerWrapper(fx.Tracer): def __init__(self, **config: dict[str, Any]) -> None: super().__init__(param_shapes_constant=True) self.leaf_modules = config.get("leaf_modules", []) self.name = "pytorch" def is_leaf_module( self, m: nn.Module, module_qualified_name: str ) -> bool: if ( not is_fx_tracable(m) or any(t in type(m).__name__ for t in self.leaf_modules) or any(t == module_qualified_name for t in self.leaf_modules) ): return True return super().is_leaf_module(m, module_qualified_name) def create_args_for_root(self, root_fn, is_module, concrete_args=None): return create_args_for_root(self, root_fn, is_module, concrete_args) top_gm = trace_submodule(model, TracerWrapper, **kwargs) elif tracer_cls_name == "dynamo": from torch import _dynamo as dynamo assert ( "concrete_args" in kwargs ), "Please provide concrete_args for dynamo tracer" device = next(model.named_parameters())[1].device concrete_args = kwargs.pop("concrete_args") args = [] for arg in concrete_args.values(): if isinstance(arg, torch.Tensor): args.append(arg.to(device)) else: args.append(arg) kwargs.pop("tracer") kwargs.pop("recursive") kwargs.pop("flatten") top_gm, _ = dynamo.export(model, *args, **kwargs) else: raise ValueError(f"Unknown tracer: {tracer_cls_name}") else: # A custom tracer class. raise NotImplementedError("Not supported yet") return top_gm