


A simple database to store the results of tuning in Dict.


The tuning space.

Symbol(name, vals)

A tunable symbol.


Dictionary that remembers insertion order

Pattern(*args, **kwargs)

ScheduleMetadata([tie_weights, param_tags, ...])

The metadata of a schedule.


partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.

Verify(sch, example_inputs[, device, ...])


init_empty_weights([enable, include_buffers])

A context manager under which models are initialized with all parameters on the meta device, therefore creating an empty model.

get_logger([name, level])

Attach to the default logger.


Register a primitive to the schedule.

set_random_seed([seed, dp_rank, pp_rank, ...])

Set random seed for reproducibility.

analyze_tie_weights(top_mod, ...)

Analyze if there is any tie weights (two weights in different module share the same memory) partitioned into different pipeline stages.

consolidate_model(sch, target[, param_init_fn])

Consolidate the model weights.

create_schedule(root[, name, path, parent, ...])

Create a schedule for the given module and preserve the module hierarchy.

dataclass([cls, init, repr, eq, order, ...])

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

field(*[, default, default_factory, init, ...])

Return an object to identify dataclass fields.


Get cuda rng tracker.

get_dialect_cls(cls_type, target[, allow_none])

Get the framework dialect class.

init_target_engine(sch, target, **kwargs)

Initialize the runtime engine for a specific target framework.

is_module_list(module[, name, parent])

A module list will become nn.Module or fx.GraphModule after tracing, but we still want to treat it as a module list in the schedule.


List all available schedule primitives.

trace(model, **kwargs)

Traces a model to a GraphModule.

trace_module(model, **kwargs)

Traces a model to a GraphModule.


Check if random seed is set.

checkpoint(function, *args[, use_reentrant])

Checkpoint a model or part of the model.

class slapo.Database(db_file_name=None)[source]

Load the database from the file.

commit(key, data)

Commit the data to the database and update the DB file.


commit(key, data)[source]

class slapo.Space[source]

create_symbol(name, vals)

log_space(training_script_args, update_space_fn)

Print the tuning space for logging.

create_symbol(name, vals)[source]

Create a symbol in the space. If the symbol already exists:

  1. If the symbol is fixed, do nothing;

  2. Otherwise re-create the symbol, because its candidate values may change due to other fixed symbols.


Clone the space.

static cfg_dict_to_str(cfg_dict)[source]

Convert a config dict to a string for logging and debugging.

log_space(training_script_args, update_space_fn)[source]

Print the tuning space for logging.

class slapo.Symbol(name, vals)[source]

slapo.init_empty_weights(enable=True, include_buffers=False)[source]

A context manager under which models are initialized with all parameters on the meta device, therefore creating an empty model. Useful when just initializing the model would blow the available RAM.

  • enable (bool) – Whether or not to enable this context manager.

  • include_buffers (bool) – Whether or not to also put all buffers on the meta device while initializing.

slapo.get_logger(name='Slapo', level=20)[source]

slapo.set_random_seed(seed=2013, dp_rank=None, pp_rank=None, tp_rank=None, always_enable_tp_seed=False)[source]

  • seed (int) – Random seed. Default is 2013.

  • dp_rank (Optional[int]) – Data parallel rank. Default is None means no data parallelism.

  • pp_rank (Optional[int]) – Pipeline parallel rank. Default is None means no pipeline parallelism.

  • tp_rank (Optional[int]) – Tensor model parallel rank. Default is None means no tensor parallelism.

  • always_enable_tp_seed (bool) – Always enable tensor model parallel seed. This is used when sequence parallelism is enabled and all dropouts should use different seeds even they are in the same TP group. Default is False, meaning that tensor model parallel seed is only enabled with get_cuda_rng_tracker().fork().


Random seed of this rank.

Return type



class slapo.ModulePattern(name)[source]



Defines the computation performed at every call.


class slapo.OrderedDict[source]

Dictionary that remembers insertion order




Remove and return a (key, value) pair from the dictionary.

move_to_end(key[, last])

Move an existing element to the end (or beginning if last is false).

update([E, ]**F)

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]






setdefault(key[, default])

Insert key with a value of default if key is not in the dictionary.



Create a new ordered dictionary with keys from iterable and values set to value.

clear() None.  Remove all items from od.

move_to_end(key, last=True)

Move an existing element to the end (or beginning if last is false).

Raise KeyError if the element does not exist.

update([E, ]**F) None.  Update D from dict/iterable E and F.

keys() a set-like object providing a view on D's keys
items() a set-like object providing a view on D's items
values() an object providing a view on D's values
pop(k[, d]) v, remove specified key and return the corresponding

value. If key is not found, d is returned if given, otherwise KeyError is raised.

setdefault(key, default=None)

Insert key with a value of default if key is not in the dictionary.

Return the value for key if key is in the dictionary, else default.

copy() a shallow copy of od

Create a new ordered dictionary with keys from iterable and values set to value.

class slapo.Pattern(*args, **kwargs)[source]



Defines the computation performed at every call.

Return type



class slapo.ScheduleMetadata(tie_weights=<factory>, param_tags=<factory>, primitives=<factory>)[source]

The metadata of a schedule. It is used to store the metadata of primitives and the top module mainly for 1) verification and 2) applying framework dialects. Note that when replacing a module, the schedule metadata of the original module is NOT transferred to the new schedule, because the new module may not have the same structure as the original module.

  • tie_weights (dict[nn.Parameter, nn.Parameter]) –

  • param_tags (set[str]) –

  • primitives (dict[str, Any]) –

slapo.analyze_tie_weights(top_mod, is_pipeline_partitioned)[source]

Analyze if there is any tie weights (two weights in different module share the same memory) partitioned into different pipeline stages.

  • top_mod (torch.nn.Module) – The top-level module. This should be a top pipeline module, so 1) it should already be traced and partitioned, and 2) it should have a number of submodules that matches pipeline stages.

  • is_pipeline_partitioned (bool) – Whether the module is partitioned for pipeline or not. If not, then all tie weights will have stage ID 0.


tie_groups – Mapping from the nn.Parameter object to the set of parameter names that are tied to it. The set of parameter names is a tuple of (parameter name, stage ID). The stage ID is 0 if the module is not partitioned for pipeline.

Dict[int, Set[Tuple[str, int]]]

slapo.consolidate_model(sch, target, param_init_fn=None, **kwargs)[source]

Consolidate the model weights. FIXME: When pipeline is enabled, this function only supports DeepSpeed runtime because it relies on DeepSpeed topology. We should use dialects in this function to make it general applicable.

  • sch (Schedule) –

  • target (str) –

  • param_init_fn (Optional[Callable[[nn.Module], None]]) –

slapo.create_schedule(root, name='', path='', parent=None, group=None, **kwargs)[source]

Create a schedule for the given module and preserve the module hierarchy.

  • root (nn.Module) – The root module to create the schedule for.

  • name (str) – The name of the module.

  • path (str) – The path from the top module.

  • parent (Optional[Schedule]) – The parent schedule. None if the module is the top module.

  • group (Optional[dist.ProcessGroup]) – The process group for the module. If None, use all available devices.

  • **kwargs – Additional arguments for the schedule.


The schedule for the module.

slapo.dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False, unsafe_hash=False, frozen=False)[source]

Returns the same class as was passed in, with dunder methods added based on the fields defined in the class.

Examines PEP 526 __annotations__ to determine fields.

If init is true, an __init__() method is added to the class. If repr is true, a __repr__() method is added. If order is true, rich comparison dunder methods are added. If unsafe_hash is true, a __hash__() method function is added. If frozen is true, fields may not be assigned to after instance creation.

slapo.field(*, default=<dataclasses._MISSING_TYPE object>, default_factory=<dataclasses._MISSING_TYPE object>, init=True, repr=True, hash=None, compare=True, metadata=None)[source]

Return an object to identify dataclass fields.

default is the default value of the field. default_factory is a 0-argument function called to initialize a field’s value. If init is True, the field will be a parameter to the class’s __init__() function. If repr is True, the field will be included in the object’s repr(). If hash is True, the field will be included in the object’s hash(). If compare is True, the field will be used in comparison functions. metadata, if specified, must be a mapping which is stored but not otherwise examined by dataclass.

It is an error to specify both default and default_factory.


Get cuda rng tracker.

slapo.get_dialect_cls(cls_type, target, allow_none=False)[source]

Get the framework dialect class.

slapo.init_target_engine(sch, target, **kwargs)[source]

Initialize the runtime engine for a specific target framework.

slapo.is_module_list(module, name='', parent=None)[source]

A module list will become nn.Module or fx.GraphModule after tracing, but we still want to treat it as a module list in the schedule.


List all available schedule primitives.


name_only (bool) – If True, only return the name of the primitives. Otherwise, return the primitive class.


If name_only, return a list of all available schedule primitives; otherwise return a dictionary mapping the name of the primitive to the primitive class.

Union[list[str], dict[str, Primitive]]

class slapo.partial[source]

partial(func, *args, **keywords) - new function with partial application of the given arguments and keywords.



tuple of arguments to future partial calls


function object to use in future partial calls


dictionary of keyword arguments to future partial calls


tuple of arguments to future partial calls


function object to use in future partial calls


dictionary of keyword arguments to future partial calls

slapo.trace(model, **kwargs)[source]

Traces a model to a GraphModule.

  • model (nn.Module) –

  • kwargs (dict[str, Any]) –

slapo.trace_module(model, **kwargs)

Check if random seed is set.

slapo.checkpoint(function, *args, use_reentrant=True, **kwargs)[source]

Checkpoint a model or part of the model. See PyTorch checkpoint for details about behaviors and arguments. The only difference is when the random seed is set by Slapo, the checkpoint function will also track the random states and restore them properly.

TODO: The implementation in Megatron-LM has a mode to distribute the saved activations across model parallel groups to further reduce the memory footprint. This is not implemented here yet.

class slapo.Verify(sch, example_inputs, device='cuda', eval_mode=True, enable=True)[source]