slapo

Classes:

Database([db_file_name])

A simple database to store the results of tuning in Dict.

Space()

The tuning space.

Symbol(name, vals)

A tunable symbol.

FunctionType

alias of function

ModulePattern(name)

OrderedDict

Dictionary that remembers insertion order

Pattern(*args, **kwargs)

ScheduleMetadata([tie_weights, param_tags, ...])

The metadata of a schedule.

partial

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

Verify(sch, example_inputs[, device, ...])

Functions:

init_empty_weights([enable, include_buffers])

A context manager under which models are initialized with all parameters on the meta device, therefore creating an empty model.

get_logger([name, level])

Attach to the default logger.

register_primitive()

Register a primitive to the schedule.

set_random_seed([seed, dp_rank, pp_rank, ...])

Set random seed for reproducibility.

analyze_tie_weights(top_mod, ...)

Analyze if there is any tie weights (two weights in different module share the same memory) partitioned into different pipeline stages.

consolidate_model(sch, target[, param_init_fn])

Consolidate the model weights.

create_schedule(root[, name, path, parent, ...])

Create a schedule for the given module and preserve the module hierarchy.

dataclass([cls, init, repr, eq, order, ...])

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

field(*[, default, default_factory, init, ...])

Return an object to identify dataclass fields.

get_cuda_rng_tracker()

Get cuda rng tracker.

get_dialect_cls(cls_type, target[, allow_none])

Get the framework dialect class.

init_target_engine(sch, target, **kwargs)

Initialize the runtime engine for a specific target framework.

is_module_list(module[, name, parent])

A module list will become nn.Module or fx.GraphModule after tracing, but we still want to treat it as a module list in the schedule.

list_primitives([name_only])

List all available schedule primitives.

trace(model, **kwargs)

Traces a model to a GraphModule.

trace_module(model, **kwargs)

Traces a model to a GraphModule.

is_random_seed_set()

Check if random seed is set.

checkpoint(function, *args[, use_reentrant])

Checkpoint a model or part of the model.

class slapo.Database(db_file_name=None)[source]

A simple database to store the results of tuning in Dict.

Methods:

load()

Load the database from the file.

commit(key, data)

Commit the data to the database and update the DB file.

load()[source]

Load the database from the file.

commit(key, data)[source]

Commit the data to the database and update the DB file.

class slapo.Space[source]

The tuning space.

Methods:

create_symbol(name, vals)

Create a symbol in the space.

next()

Get the next symbol to fix.

reset()

Reset the space to the initial state.

to_dict()

Convert the space to a dict.

clone()

Clone the space.

cfg_dict_to_str(cfg_dict)

Convert a config dict to a string for logging and debugging.

log_space(training_script_args, update_space_fn)

Print the tuning space for logging.

create_symbol(name, vals)[source]

Create a symbol in the space. If the symbol already exists:

  1. If the symbol is fixed, do nothing;

  2. Otherwise re-create the symbol, because its candidate values may change due to other fixed symbols.

next()[source]

Get the next symbol to fix.

reset()[source]

Reset the space to the initial state.

to_dict()[source]

Convert the space to a dict. Note that all symbols must be fixed before calling this function.

clone()[source]

Clone the space.

static cfg_dict_to_str(cfg_dict)[source]

Convert a config dict to a string for logging and debugging.

log_space(training_script_args, update_space_fn)[source]

Print the tuning space for logging.

class slapo.Symbol(name, vals)[source]

A tunable symbol.

Methods:

add(val)

Add a value to the symbol.

fix_at(idx)

Fix the value of this symbol at the given index.

is_fixed()

Check if the value of this symbol has been fixed.

add(val)[source]

Add a value to the symbol. If the value is already in the symbol, do nothing.

fix_at(idx)[source]

Fix the value of this symbol at the given index.

is_fixed()[source]

Check if the value of this symbol has been fixed.

slapo.init_empty_weights(enable=True, include_buffers=False)[source]

A context manager under which models are initialized with all parameters on the meta device, therefore creating an empty model. Useful when just initializing the model would blow the available RAM.

Parameters
  • enable (bool) – Whether or not to enable this context manager.

  • include_buffers (bool) – Whether or not to also put all buffers on the meta device while initializing.

slapo.get_logger(name='Slapo', level=20)[source]

Attach to the default logger.

slapo.register_primitive()[source]

Register a primitive to the schedule.

slapo.set_random_seed(seed=2013, dp_rank=None, pp_rank=None, tp_rank=None, always_enable_tp_seed=False)[source]

Set random seed for reproducibility.

Parameters
  • seed (int) – Random seed. Default is 2013.

  • dp_rank (Optional[int]) – Data parallel rank. Default is None means no data parallelism.

  • pp_rank (Optional[int]) – Pipeline parallel rank. Default is None means no pipeline parallelism.

  • tp_rank (Optional[int]) – Tensor model parallel rank. Default is None means no tensor parallelism.

  • always_enable_tp_seed (bool) – Always enable tensor model parallel seed. This is used when sequence parallelism is enabled and all dropouts should use different seeds even they are in the same TP group. Default is False, meaning that tensor model parallel seed is only enabled with get_cuda_rng_tracker().fork().

Returns

Random seed of this rank.

Return type

int

slapo.FunctionType

alias of function

class slapo.ModulePattern(name)[source]

Methods:

forward(*args)

Defines the computation performed at every call.

forward(*args)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class slapo.OrderedDict[source]

Dictionary that remembers insertion order

Methods:

clear()

popitem([last])

Remove and return a (key, value) pair from the dictionary.

move_to_end(key[, last])

Move an existing element to the end (or beginning if last is false).

update([E, ]**F)

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

keys()

items()

values()

pop(k[,d])

value.

setdefault(key[, default])

Insert key with a value of default if key is not in the dictionary.

copy()

fromkeys([value])

Create a new ordered dictionary with keys from iterable and values set to value.

clear() None.  Remove all items from od.
popitem(last=True)

Remove and return a (key, value) pair from the dictionary.

Pairs are returned in LIFO order if last is true or FIFO order if false.

move_to_end(key, last=True)

Move an existing element to the end (or beginning if last is false).

Raise KeyError if the element does not exist.

update([E, ]**F) None.  Update D from dict/iterable E and F.

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

keys() a set-like object providing a view on D's keys
items() a set-like object providing a view on D's items
values() an object providing a view on D's values
pop(k[, d]) v, remove specified key and return the corresponding

value. If key is not found, d is returned if given, otherwise KeyError is raised.

setdefault(key, default=None)

Insert key with a value of default if key is not in the dictionary.

Return the value for key if key is in the dictionary, else default.

copy() a shallow copy of od
fromkeys(value=None)

Create a new ordered dictionary with keys from iterable and values set to value.

class slapo.Pattern(*args, **kwargs)[source]

Methods:

forward(*args)

Defines the computation performed at every call.

Return type

None

forward(*args)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class slapo.ScheduleMetadata(tie_weights=<factory>, param_tags=<factory>, primitives=<factory>)[source]

The metadata of a schedule. It is used to store the metadata of primitives and the top module mainly for 1) verification and 2) applying framework dialects. Note that when replacing a module, the schedule metadata of the original module is NOT transferred to the new schedule, because the new module may not have the same structure as the original module.

Parameters
  • tie_weights (dict[nn.Parameter, nn.Parameter]) –

  • param_tags (set[str]) –

  • primitives (dict[str, Any]) –

Return type

None

slapo.analyze_tie_weights(top_mod, is_pipeline_partitioned)[source]

Analyze if there is any tie weights (two weights in different module share the same memory) partitioned into different pipeline stages.

Parameters
  • top_mod (torch.nn.Module) – The top-level module. This should be a top pipeline module, so 1) it should already be traced and partitioned, and 2) it should have a number of submodules that matches pipeline stages.

  • is_pipeline_partitioned (bool) – Whether the module is partitioned for pipeline or not. If not, then all tie weights will have stage ID 0.

Returns

tie_groups – Mapping from the nn.Parameter object to the set of parameter names that are tied to it. The set of parameter names is a tuple of (parameter name, stage ID). The stage ID is 0 if the module is not partitioned for pipeline.

Return type

Dict[int, Set[Tuple[str, int]]]

slapo.consolidate_model(sch, target, param_init_fn=None, **kwargs)[source]

Consolidate the model weights. FIXME: When pipeline is enabled, this function only supports DeepSpeed runtime because it relies on DeepSpeed topology. We should use dialects in this function to make it general applicable.

Parameters
  • sch (Schedule) –

  • target (str) –

  • param_init_fn (Optional[Callable[[nn.Module], None]]) –

slapo.create_schedule(root, name='', path='', parent=None, group=None, **kwargs)[source]

Create a schedule for the given module and preserve the module hierarchy.

Parameters
  • root (nn.Module) – The root module to create the schedule for.

  • name (str) – The name of the module.

  • path (str) – The path from the top module.

  • parent (Optional[Schedule]) – The parent schedule. None if the module is the top module.

  • group (Optional[dist.ProcessGroup]) – The process group for the module. If None, use all available devices.

  • **kwargs – Additional arguments for the schedule.

Returns

The schedule for the module.

Return type

Schedule

slapo.dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False)[source]

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

slapo.field(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None)[source]

Return an object to identify dataclass fields.

default is the default value of the field. default_factory is a 0-argument function called to initialize a field’s value. If init is True, the field will be a parameter to the class’s __init__() function. If repr is True, the field will be included in the object’s repr(). If hash is True, the field will be included in the object’s hash(). If compare is True, the field will be used in comparison functions. metadata, if specified, must be a mapping which is stored but not otherwise examined by dataclass.

It is an error to specify both default and default_factory.

slapo.get_cuda_rng_tracker()[source]

Get cuda rng tracker.

slapo.get_dialect_cls(cls_type, target, allow_none=False)[source]

Get the framework dialect class.

slapo.init_target_engine(sch, target, **kwargs)[source]

Initialize the runtime engine for a specific target framework.

slapo.is_module_list(module, name='', parent=None)[source]

A module list will become nn.Module or fx.GraphModule after tracing, but we still want to treat it as a module list in the schedule.

slapo.list_primitives(name_only=True)[source]

List all available schedule primitives.

Parameters

name_only (bool) – If True, only return the name of the primitives. Otherwise, return the primitive class.

Returns

If name_only, return a list of all available schedule primitives; otherwise return a dictionary mapping the name of the primitive to the primitive class.

Return type

Union[list[str], dict[str, Primitive]]

class slapo.partial[source]

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

Attributes:

args

tuple of arguments to future partial calls

func

function object to use in future partial calls

keywords

dictionary of keyword arguments to future partial calls

args

tuple of arguments to future partial calls

func

function object to use in future partial calls

keywords

dictionary of keyword arguments to future partial calls

slapo.trace(model, **kwargs)[source]

Traces a model to a GraphModule.

Parameters
  • model (nn.Module) –

  • kwargs (dict[str, Any]) –

slapo.trace_module(model, **kwargs)

Traces a model to a GraphModule.

Parameters
  • model (nn.Module) –

  • kwargs (dict[str, Any]) –

slapo.is_random_seed_set()[source]

Check if random seed is set.

slapo.checkpoint(function, *args, use_reentrant=True, **kwargs)[source]

Checkpoint a model or part of the model. See PyTorch checkpoint for details about behaviors and arguments. The only difference is when the random seed is set by Slapo, the checkpoint function will also track the random states and restore them properly.

TODO: The implementation in Megatron-LM has a mode to distribute the saved activations across model parallel groups to further reduce the memory footprint. This is not implemented here yet.

class slapo.Verify(sch, example_inputs, device='cuda', eval_mode=True, enable=True)[source]