Source code for slapo.verify

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import sys
import copy
from contextlib import ContextDecorator

import torch
from torch import nn
import torch.distributed as dist

from .schedule import create_schedule
from .build import build
from .random import set_random_seed
from .logger import get_logger
from .primitives.base import Primitive

logger = get_logger()


[docs]class Verify(ContextDecorator): def __init__(self, sch, example_inputs, device="cuda", eval_mode=True, enable=True): if not isinstance(example_inputs, list): example_inputs = [example_inputs] self.example_inputs = example_inputs self.original_trace = None self.sch = sch self.original_sch = create_schedule(copy.deepcopy(self.sch.mod)) self.device = device self.enable = enable self.eval_mode = eval_mode def __enter__(self): self.original_trace = sys.gettrace() # pylint: disable=unused-argument def trace_calls(frame, event, arg): if event == "call": code = frame.f_code function_name = code.co_name if function_name == "apply": # This part is useful only when we need to get the model from the schedule # (the schedule is not passed in as an argument) for _, value in frame.f_globals.items(): if isinstance(value, Primitive) and value.is_verifiable(): cls_name = getattr(value, "__name__", None) logger.info("Verifying %s...", cls_name, ranks=0) break return trace_calls sys.settrace(trace_calls) return self def __exit__(self, *exc): """Verify the correctness of the schedule. TODO: Support backward verification """ if not self.enable: return # 1. Build the original model with random weights named_params = self.original_sch.mod.named_parameters() is_initialized = named_params.__next__()[1].device != torch.device("meta") original_mod, _ = build(self.original_sch, init_weights=not is_initialized) # make sure all the buffers are on the right device original_mod = original_mod.to(self.device) # 2. Get the example inputs self.example_inputs = [x.to(self.device) for x in self.example_inputs] # Broadcast the example inputs from rank 0 to other ranks if self.sch.world_size > 1: for inp in self.example_inputs: dist.broadcast(inp, src=0, group=self.sch.group) # 3. Run the original model # make sure the random seeds are the same, which may affect the output of dropout if self.eval_mode: original_mod.eval() set_random_seed(2023) original_output = original_mod(*self.example_inputs) # 4. Broadcast the original model from rank 0 to other ranks original_state_dict = original_mod.state_dict() if self.sch.world_size > 1: for param_name in original_state_dict: dist.broadcast( original_state_dict[param_name], src=0, group=self.sch.group ) # 5. Delete the original model to avoid excessive memory usage del original_mod # 6. Get the transformed model from the schedule # Copy it and build a new schedule to prevent the original schedule from being modified copied_mod = copy.deepcopy(self.sch.mod) # copy original attributes # TODO: find a better way to copy attributes for param_name, param in self.sch.mod.named_parameters(): if hasattr(param, "orig_shape"): copied_mod.get_parameter(param_name).orig_shape = param.orig_shape new_sch = create_schedule(copied_mod) # 7. Use original weights to initialize the new model # Notice init_weights is called before actual sharding, so we only need to # assign the original weights to the corresponding modules def init_weights(mod, path): for name, _ in mod.named_parameters(recurse=False): setattr( mod, name, nn.Parameter( original_state_dict[f"{path}.{name}"].detach().to(self.device) ), ) new_mod, _ = build(new_sch, init_weights=init_weights) # 8. Run the new model # make sure all the buffers are on the right device new_mod.to(self.device) if self.eval_mode: new_mod.eval() # make sure the random seeds are the same, which may affect the output of dropout set_random_seed(2023) new_output = new_mod(*self.example_inputs) # 9. Compare the outputs torch.testing.assert_close(original_output, new_output) logger.info("Passed verification!") del new_mod sys.settrace(self.original_trace)