Source code for slapo.pipeline

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Pipeline stage wrappers for supported frameworks."""
import operator
from collections import OrderedDict

from torch import fx
from torch.fx.passes.split_module import split_module

from .logger import get_logger
from .framework_dialect import get_dialect_cls
from .utils.common import transfer_hooks

logger = get_logger()


def propagate_partition(sch, starting_stage_id=0, stop_at=None):
    assert isinstance(sch.mod, fx.GraphModule)

    # Assign partition (pipeline stage) ID to each node.
    curr_stage_id = starting_stage_id
    for node in sch.mod.graph.nodes:
        if "partition" not in node.meta:
            node.meta["partition"] = curr_stage_id
        else:
            node.meta["partition"] = curr_stage_id
            curr_stage_id += 1
    assert curr_stage_id > starting_stage_id

    # Partition this module.
    mod_after_split = split_module(
        sch.mod,
        None,
        lambda node: node.meta["partition"],
        keep_original_order=True,
    )
    if not sch.parent:
        sch.replace(mod_after_split)
        return mod_after_split, curr_stage_id

    # Propagate partitions to the parent module.
    placeholders = []
    for node in mod_after_split.graph.nodes:
        if node.op == "placeholder":
            placeholders.append(node)

    # Get target call node in parent graph
    target_call_node = None
    parent_mod = sch.parent.mod
    assert isinstance(parent_mod, fx.GraphModule)
    for node in parent_mod.graph.nodes:
        if node.op == "call_module" and node.target == sch.name:
            target_call_node = node
            break
    assert target_call_node is not None

    # Check if the target call node is also partitioned.
    keep_last_parition = "partition" in target_call_node.meta

    # Create mapping from placeholder in subgraph to arguments in parent graph
    ph2arg = {}
    for i, arg in enumerate(target_call_node.args):
        ph2arg[placeholders[i].name] = arg
    for i, (_, kwarg) in enumerate(
        target_call_node.kwargs.items(), len(target_call_node.args)
    ):
        ph2arg[placeholders[i].name] = kwarg

    # Register partitioned submodules to parent module
    for name, child in mod_after_split.named_children():
        parent_mod.add_module(name, child)

    # Replace call_module in parent graph with submodule calls (inline).
    new_node = target_call_node
    last_call_mod_node = None
    last_node_except_output = None
    output_node = None
    for child_node in mod_after_split.graph.nodes:
        if child_node.op == "call_module":
            new_args = fx.map_arg(child_node.args, lambda node: ph2arg[node.name])
            new_kwargs = fx.map_arg(child_node.kwargs, lambda node: ph2arg[node.name])
            with parent_mod.graph.inserting_after(new_node):
                new_node = parent_mod.graph.call_module(
                    child_node.target, new_args, new_kwargs
                )
                # Specify new partition points
                new_node.meta["partition"] = True
            # Add current node to mapping
            ph2arg[child_node.name] = new_node
            last_call_mod_node = new_node
        elif child_node.op == "call_function":
            new_args = fx.map_arg(child_node.args, lambda node: ph2arg[node.name])
            new_kwargs = fx.map_arg(child_node.kwargs, lambda node: ph2arg[node.name])
            with parent_mod.graph.inserting_after(new_node):
                new_node = parent_mod.graph.call_function(
                    child_node.target, new_args, new_kwargs
                )
            # Add current node to mapping
            ph2arg[child_node.name] = new_node
        elif child_node.op == "output":
            output_node = child_node
            continue
        elif child_node.op != "placeholder":
            raise RuntimeError(
                f"Cannot support {child_node.name} with op type {child_node.op} "
                "in the splitted module"
            )
        last_node_except_output = new_node
    if last_call_mod_node is None or last_node_except_output is None:
        raise RuntimeError("Cannot find call_module node in the splitted module")
    if last_call_mod_node != last_node_except_output:
        raise RuntimeError(
            f"The second last node is {last_node_except_output} with op type "
            f"{last_node_except_output.op}, "
            "which is not call_module node in the splitted module. "
            "A possible reason is that fx.split_module generates getitems "
            "after the partitioned submodule call, and this is not supported yet."
        )

    target_call_node.replace_all_uses_with(last_node_except_output)

    # If the parent call node is also partitioned (see example),
    # we keep the partition of the last partition; otherwise
    # we remove the partition so that the last partition will
    # be fused with following logic.
    # Example 1:
    #  sch["t5.encoder.block.11"].cut_pipeline_stage()
    #  sch["t5.encoder"].cut_pipeline_stage()
    #  sch["t5.decoder.block.11"].cut_pipeline_stage()
    # The above schdule results in 4 stages:
    #  0. encoder block 0-11
    #  1. encoder block 12-23
    #  2. decoder block 0-11
    #  3. decoder block 12-23
    # Example 2:
    #  sch["t5.encoder.block.11"].cut_pipeline_stage()
    #  sch["t5.decoder.block.11"].cut_pipeline_stage()
    # The above schdule results in 3 stages:
    #  0. encoder block 0-11
    #  1. encoder block 12-23 + decoder block 0-11
    #  2. decoder block 12-23
    if not keep_last_parition:
        last_call_mod_node.meta.pop("partition")

    # Fix output
    if len(output_node.args) > 1 or len(output_node.kwargs) > 0:
        raise RuntimeError("Multiple output arguments not supported yet!")
    if len(output_node.args) == 1 and (isinstance(output_node.args[0], (dict, tuple))):
        if isinstance(output_node.args[0], tuple):
            raise RuntimeError("Tuple return not supported yet!")
        ret_dict = output_node.args[0]
        ph2arg[None] = None
        users_to_replace = []
        for user in last_call_mod_node.users:
            if user.op == "call_method" and user.target == "get":
                value = ret_dict.get(user.args[1], user.args[2])
                users_to_replace.append(
                    (
                        user,
                        ph2arg[value.name if isinstance(value, fx.Node) else None],
                    )
                )
            elif (  # pylint: disable=comparison-with-callable
                user.op == "call_function" and user.target == operator.getitem
            ):
                users_to_replace.append((user, ph2arg[ret_dict[user.args[1]].name]))
        for user, target in users_to_replace:
            user.replace_all_uses_with(target)

    # Recompile
    parent_mod.graph.erase_node(target_call_node)
    parent_mod.delete_all_unused_submodules()
    parent_mod.graph.eliminate_dead_code()
    parent_mod.graph.lint()
    parent_mod.recompile()
    mod_after_split.delete_all_unused_submodules()
    mod_after_split.graph.eliminate_dead_code()
    mod_after_split.graph.lint()
    mod_after_split.recompile()
    sch.replace(mod_after_split)
    if stop_at is None or sch.parent.name != stop_at:
        return propagate_partition(sch.parent, starting_stage_id, stop_at)
    return mod_after_split, curr_stage_id


def add_partition_notation(sch, idx):
    sch.partition_idx = idx
    for subsch in sch.child.values():
        add_partition_notation(subsch, idx)


def analyze_pipeline_module(top_mod):
    def get_itemized_name(node, suffix=""):
        """Recursively traverse getitem nodes to get the itemized name
        (e.g., submod_1.0.1).
        """
        if node.op != "call_function":
            assert not suffix or suffix.startswith(".")
            return f"{node.target}{suffix}"

        # pylint: disable=comparison-with-callable
        assert node.target == operator.getitem, (
            "Expect only getitem "
            f"function call in top pipeline module, but got {node.target}"
        )
        return get_itemized_name(node.args[0], f"{suffix}.{node.args[1]}")

    submod_2_stage_id = {}
    curr_stage_id = 0
    tensor_2_id = {}
    liveness = {-1: []}
    stage_id_2_arg_names = {-1: []}

    # 1st pass (forward): Construct ID-tensor mapping and a part of the liveness.
    # After this pass, liveness of each submod should include all live tensors
    # used by itself.
    for idx, node in enumerate(top_mod.graph.nodes):
        itemized_name = get_itemized_name(node)
        tensor_2_id[itemized_name] = idx
        if node.op == "placeholder":
            # Liveness -1 indicates the primary inputs in order.
            liveness[-1].append(itemized_name)
            stage_id_2_arg_names[-1].append(itemized_name)
        elif node.op == "call_module" and node.target.startswith("submod_"):
            liveness[curr_stage_id] = []
            stage_id_2_arg_names[curr_stage_id] = []
            for arg_idx, arg in enumerate(node.args):
                arg_name = get_itemized_name(arg)
                assert (
                    arg_name in tensor_2_id
                ), f"arg {arg_idx} in {node} is not defined (in {tensor_2_id})"
                liveness[curr_stage_id].append(arg_name)
                stage_id_2_arg_names[curr_stage_id].append(arg_name)
            submod_2_stage_id[node.target] = curr_stage_id
            curr_stage_id += 1

    # 2nd pass (backward): Construct the rest of the liveness by adding
    # more live tensors used by rest submods and should be bypassed).
    # Note that we used an OrderedDict as a ordered set to keep the order
    # of tensors in the live set; otherwise the generated liveness for
    # each rank in a TP group may be different.
    live_set = OrderedDict()
    for node in reversed(top_mod.graph.nodes):
        if node.op == "call_module" and node.target.startswith("submod_"):
            curr_stage_id = submod_2_stage_id[node.target]
            # Add all arguments to the live set.
            live_set.update({arg: 1 for arg in liveness[curr_stage_id]})
            # Remove tensors that are defined in this stage from the live set.
            # (i.e., the output of this stage).
            live_set = OrderedDict(
                {t: 1 for t in live_set if not t.startswith(node.target)}
            )
            # Get the difference between the live tensors used by this stage
            # and the current live set.
            diff = [t for t in live_set if t not in liveness[curr_stage_id]]
            # Add all diff tensors to the liveness of this stage.
            liveness[curr_stage_id].extend(diff)

    # Override the liveness of the first stage to match the input order.
    if list(liveness[0]) != list(liveness[-1]):
        logger.warning(
            "Inputs between first submodule and top module are mismatched"
            " (first submodule: %s, top module: %s). "
            "Possibly because (1) some arguments in top modules are specified to None "
            "when tracing, and they are not removed by the PyTorch tracer. "
            "(2) FX tracer changes the order of the arguments in the function calls.",
            liveness[0],
            liveness[-1],
        )
    else:
        liveness[0] = liveness[-1]
    del liveness[-1]

    stage_id_2_name = {v: k for k, v in submod_2_stage_id.items()}
    return stage_id_2_arg_names, stage_id_2_name, liveness


[docs]def analyze_tie_weights(top_mod, is_pipeline_partitioned): """Analyze if there is any tie weights (two weights in different module share the same memory) partitioned into different pipeline stages. Parameters ---------- top_mod : torch.nn.Module The top-level module. This should be a top pipeline module, so 1) it should already be traced and partitioned, and 2) it should have a number of submodules that matches pipeline stages. is_pipeline_partitioned : bool Whether the module is partitioned for pipeline or not. If not, then all tie weights will have stage ID 0. Returns ------- tie_groups : Dict[int, Set[Tuple[str, int]]] Mapping from the nn.Parameter object to the set of parameter names that are tied to it. The set of parameter names is a tuple of (parameter name, stage ID). The stage ID is 0 if the module is not partitioned for pipeline. """ # Mapping from parameter name to (the pipeline stage ID, the parameter object). params = {} # The result tie groups. tie_groups = {} def _traverse_children(stage_id, prefix, curr_mod_name, curr_mod): full_prefix = f"{prefix}.{curr_mod_name}" if prefix else curr_mod_name # We cannot use named_parameters(recurse=True) but have to explicitly # traverse submodules recursively. This is because tie weights will # only show up once in named_parameters. In other words, we cannot # identify tie weights with named_parameters(recurse=True). for curr_param_name, curr_param in curr_mod.named_parameters(): curr_param_full_name = ( f"{full_prefix}.{curr_param_name}" if full_prefix else curr_param_name ) if curr_param_full_name in params: continue # Check if this parameter is tie to another one in a different stage. for target_param_full_name, ( target_stage, target_param, ) in params.items(): if is_pipeline_partitioned and stage_id == target_stage: continue if id(curr_param) == id(target_param): if curr_param in tie_groups: # Get the tie group of the target parameter. tie_group = tie_groups[curr_param] else: # Create a new tie group, and use the target parameter name # as the primary key. tie_group = set([(target_param_full_name, target_stage)]) tie_groups[curr_param] = tie_group # Add the current parameter to the tie group. tie_group.add((curr_param_full_name, stage_id)) # Add this parameter for the rest analysis. params[curr_param_full_name] = (stage_id, curr_param) # Traverse children. for name, mod in curr_mod.named_children(): _traverse_children(stage_id, f"{full_prefix}", name, mod) if is_pipeline_partitioned: for stage_id, (mod_name, stage_mod) in enumerate(top_mod.named_children()): _traverse_children(stage_id, "", mod_name, stage_mod) else: _traverse_children(0, "", "", top_mod) # Explicitly delete the reference to parameters for safety. del params # Remove module name. if is_pipeline_partitioned: ret = {} for param, tie_group_set in tie_groups.items(): group = [(name[name.find(".") + 1 :], sid) for name, sid in tie_group_set] ret[param] = group else: ret = tie_groups return ret
def generate_pipeline_partition(sch): # Identify the common ancestor of all pipeline cutting paths. common_ancestor_path = "" if len(sch.metadata.primitives["cut_pipeline_stage"]) > 1: path_tokens = [ path.split(".") for path in sch.metadata.primitives["cut_pipeline_stage"] ] for tokens in zip(*path_tokens): if len(list(set(tokens))) != 1: break common_ancestor_path += f"{tokens[0]}." # Remove the last dot. common_ancestor_path = common_ancestor_path[:-1] # Propogate pipeline partitioning from the cutting level to the common ancestor. starting_stage_id = 0 assert sch.metadata.primitives["cut_pipeline_stage"] for path in sch.metadata.primitives["cut_pipeline_stage"]: if path == common_ancestor_path: # Skip the common ancestor because it should be handled in the next step. continue pipe_level_sch = sch[path] _, starting_stage_id = propagate_partition( pipe_level_sch, starting_stage_id, stop_at=common_ancestor_path ) starting_stage_id += 1 propagate_partition(sch[common_ancestor_path]) # Add partition notations to the submodules for idx, subsch in enumerate(sch.child.values()): add_partition_notation(subsch, idx) return sch def build_pipeline_model(sch, target, **kwargs): # Analyze pipelien module for liveness and arguments. partitioned_mod = sch.mod ( stage_id_2_arg_names, stage_id_2_name, liveness, ) = analyze_pipeline_module(partitioned_mod) # Get the pipeline wrapper class. pipe_wrapper_cls = get_dialect_cls("pipeline_stage", target) # Generate wrappers for pipelined modules res_partition = [] named_children = dict(partitioned_mod.named_children()) for idx, partition in enumerate(named_children.values()): res_partition.append( pipe_wrapper_cls( partition, idx, stage_id_2_name[idx], len(named_children), liveness, stage_id_2_arg_names, ) ) # Transfer hooks from the original module to the corresponding pipeline stage. transfer_hooks(partitioned_mod, res_partition[0].mod, ["fwd_pre", "bwd_post"]) transfer_hooks(partitioned_mod, res_partition[-1].mod, ["fwd_post"]) pipe_engine_fn = get_dialect_cls("pipeline_engine", target) return pipe_engine_fn( sch.metadata, res_partition, **kwargs, )