slapo.schedule¶
Classes:
|
The metadata of a schedule. |
Functions:
|
List all available schedule primitives. |
|
Create a schedule for the given module and preserve the module hierarchy. |
- class slapo.schedule.ScheduleMetadata(tie_weights=<factory>, param_tags=<factory>, primitives=<factory>)[source]¶
The metadata of a schedule. It is used to store the metadata of primitives and the top module mainly for 1) verification and 2) applying framework dialects. Note that when replacing a module, the schedule metadata of the original module is NOT transferred to the new schedule, because the new module may not have the same structure as the original module.
- Parameters
tie_weights (dict[nn.Parameter, nn.Parameter]) –
param_tags (set[str]) –
primitives (dict[str, Any]) –
- Return type
None
- slapo.schedule.list_primitives(name_only=True)[source]¶
List all available schedule primitives.
- Parameters
name_only (bool) – If True, only return the name of the primitives. Otherwise, return the primitive class.
- Returns
If name_only, return a list of all available schedule primitives; otherwise return a dictionary mapping the name of the primitive to the primitive class.
- Return type
Union[list[str], dict[str, Primitive]]
- slapo.schedule.create_schedule(root, name='', path='', parent=None, group=None, **kwargs)[source]¶
Create a schedule for the given module and preserve the module hierarchy.
- Parameters
root (nn.Module) – The root module to create the schedule for.
name (str) – The name of the module.
path (str) – The path from the top module.
parent (Optional[Schedule]) – The parent schedule. None if the module is the top module.
group (Optional[dist.ProcessGroup]) – The process group for the module. If None, use all available devices.
**kwargs – Additional arguments for the schedule.
- Returns
The schedule for the module.
- Return type
- class slapo.schedule.Schedule(mod, name='', path='', parent=None, group=None)[source]¶
- Parameters
mod (nn.Module) –
name (str) –
path (str) –
parent (Optional['Schedule']) –
group (Optional[dist.ProcessGroup]) –
- named_schedules(prefix='')[source]¶
Returns an iterator over all subschedules in the current schedule, yielding both the name of the subschedule as well as the subschedule itself.
- Parameters
prefix (str) –
- find_node(regex_or_pattern_fn)[source]¶
Find a node in a static dataflow graph
- Parameters
regex_or_pattern_fn (Union[str, Callable]) – If this argument is a regular expression, it will only match the call_module node whose target satisfies the regex; otherwise, it will try to match all the nodes satisfies the pattern function. The pattern_fn should be in lambda node: … format.
- Returns
Returns all the nodes whose names satisfying the regex, or the nodes satisfying the given pattern constraints.
- Return type
Union[List[Tuple[str, fx.Node]], List[List[Tuple[str, fx.Node]]]
- find_subgraph(pattern_fn)[source]¶
Find a subgraph in a static dataflow graph
- Parameters
pattern_fn (Union[FunctionType, Pattern]) – This argument specifies the subgraph pattern. Using a lambda function is easier to specify a pattern, while the Pattern class provides the ability to create patterns include submodules.
- Returns
Returns all the subgraphs containing the nodes satisfying the pattern constraints. The outer-most list contains different subgraphs, and the inner list contains the nodes inside a specific subgraph. The inner-most tuple includes the name of the parent module that the node belongs to, and the matched node object.
- Return type
List[List[Tuple[str, fx.Node]]
- find(regex_or_pattern_fn)[source]¶
Find a node or a subgraph in a static dataflow graph. This API is a dispatcher for find_node and find_subgraph
If you need to match a general node pattern, please directly use the find_node API.
- Parameters
regex_or_pattern_fn (Union[str, Callable]) – A regular expression for specifying the target of call_module node, or a callable function/Pattern class specifying the subgraph pattern
- Returns
For find_node, it returns all the nodes whose names satisfying the regex. For find_subgraph, it returns all the subgraphs containing the nodes satisfying the pattern constraints. The outer-most list contains different subgraphs, and the inner list contains the nodes inside a specific subgraph. The inner-most tuple includes the name of the parent module that the node belongs to, and the matched node object.
- Return type
Union[List[Tuple[str, fx.Node]], List[List[Tuple[str, fx.Node]]]
- trace_until(paths, **kwargs)[source]¶
A syntax sugar that traces from the top module until the sub-module specified in path, so that we can apply computation optimization, such as cutting pipeline stages at the level.
- Parameters
paths (Union[str, List[str]]) – The path to the sub-module that we want to trace until.
**kwargs – Other arguments for trace API.