slapo.schedule

Classes:

ScheduleMetadata([tie_weights, param_tags, ...])

The metadata of a schedule.

Functions:

list_primitives([name_only])

List all available schedule primitives.

create_schedule(root[, name, path, parent, ...])

Create a schedule for the given module and preserve the module hierarchy.

class slapo.schedule.ScheduleMetadata(tie_weights=<factory>, param_tags=<factory>, primitives=<factory>)[source]

The metadata of a schedule. It is used to store the metadata of primitives and the top module mainly for 1) verification and 2) applying framework dialects. Note that when replacing a module, the schedule metadata of the original module is NOT transferred to the new schedule, because the new module may not have the same structure as the original module.

Parameters
  • tie_weights (dict[nn.Parameter, nn.Parameter]) –

  • param_tags (set[str]) –

  • primitives (dict[str, Any]) –

Return type

None

slapo.schedule.list_primitives(name_only=True)[source]

List all available schedule primitives.

Parameters

name_only (bool) – If True, only return the name of the primitives. Otherwise, return the primitive class.

Returns

If name_only, return a list of all available schedule primitives; otherwise return a dictionary mapping the name of the primitive to the primitive class.

Return type

Union[list[str], dict[str, Primitive]]

slapo.schedule.create_schedule(root, name='', path='', parent=None, group=None, **kwargs)[source]

Create a schedule for the given module and preserve the module hierarchy.

Parameters
  • root (nn.Module) – The root module to create the schedule for.

  • name (str) – The name of the module.

  • path (str) – The path from the top module.

  • parent (Optional[Schedule]) – The parent schedule. None if the module is the top module.

  • group (Optional[dist.ProcessGroup]) – The process group for the module. If None, use all available devices.

  • **kwargs – Additional arguments for the schedule.

Returns

The schedule for the module.

Return type

Schedule

class slapo.schedule.Schedule(mod, name='', path='', parent=None, group=None)[source]
Parameters
  • mod (nn.Module) –

  • name (str) –

  • path (str) –

  • parent (Optional['Schedule']) –

  • group (Optional[dist.ProcessGroup]) –

named_schedules(prefix='')[source]

Returns an iterator over all subschedules in the current schedule, yielding both the name of the subschedule as well as the subschedule itself.

Parameters

prefix (str) –

find_node(regex_or_pattern_fn)[source]

Find a node in a static dataflow graph

Parameters

regex_or_pattern_fn (Union[str, Callable]) – If this argument is a regular expression, it will only match the call_module node whose target satisfies the regex; otherwise, it will try to match all the nodes satisfies the pattern function. The pattern_fn should be in lambda node: … format.

Returns

Returns all the nodes whose names satisfying the regex, or the nodes satisfying the given pattern constraints.

Return type

Union[List[Tuple[str, fx.Node]], List[List[Tuple[str, fx.Node]]]

find_subgraph(pattern_fn)[source]

Find a subgraph in a static dataflow graph

Parameters

pattern_fn (Union[FunctionType, Pattern]) – This argument specifies the subgraph pattern. Using a lambda function is easier to specify a pattern, while the Pattern class provides the ability to create patterns include submodules.

Returns

Returns all the subgraphs containing the nodes satisfying the pattern constraints. The outer-most list contains different subgraphs, and the inner list contains the nodes inside a specific subgraph. The inner-most tuple includes the name of the parent module that the node belongs to, and the matched node object.

Return type

List[List[Tuple[str, fx.Node]]

find(regex_or_pattern_fn)[source]

Find a node or a subgraph in a static dataflow graph. This API is a dispatcher for find_node and find_subgraph

If you need to match a general node pattern, please directly use the find_node API.

Parameters

regex_or_pattern_fn (Union[str, Callable]) – A regular expression for specifying the target of call_module node, or a callable function/Pattern class specifying the subgraph pattern

Returns

For find_node, it returns all the nodes whose names satisfying the regex. For find_subgraph, it returns all the subgraphs containing the nodes satisfying the pattern constraints. The outer-most list contains different subgraphs, and the inner list contains the nodes inside a specific subgraph. The inner-most tuple includes the name of the parent module that the node belongs to, and the matched node object.

Return type

Union[List[Tuple[str, fx.Node]], List[List[Tuple[str, fx.Node]]]

trace_until(paths, **kwargs)[source]

A syntax sugar that traces from the top module until the sub-module specified in path, so that we can apply computation optimization, such as cutting pipeline stages at the level.

Parameters
  • paths (Union[str, List[str]]) – The path to the sub-module that we want to trace until.

  • **kwargs – Other arguments for trace API.