slapo.op.linear¶
Custom linear modules.
Classes:
|
A linear module with fused QKV weights. |
|
Implementation modified from nn.Linear Arguments are the same as the inputs of nn.Linear |
|
Derived from nn.Linear but with a sync function that will be invoked before the bias addition. |
|
Derived from LinearWithSyncFunc but fusing the activation function. |
|
Derived from LinearWithSyncFunc but fusing the dropout. |
- class slapo.op.linear.FusedQKV(hidden_size, num_heads, world_size)[source]¶
A linear module with fused QKV weights.
- Parameters
hidden_size (int) – The hidden size of the input.
num_heads (int) – The number of heads.
world_size (int) – The size of tensor parallelism group.
- Return type
None
Methods:
forward
(hidden_states)Defines the computation performed at every call.
- forward(hidden_states)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class slapo.op.linear.LinearWithSeparateBias(in_features, out_features, bias=True, device=None, dtype=None)[source]¶
Implementation modified from nn.Linear Arguments are the same as the inputs of nn.Linear
Methods:
forward
(x)Defines the computation performed at every call.
- Parameters
in_features (int) –
out_features (int) –
bias (bool) –
- Return type
None
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters
x (torch.Tensor) –
- Return type
- class slapo.op.linear.LinearWithSyncFunc(in_features, out_features, bias=True, device=None, dtype=None, sync_fn=None)[source]¶
Derived from nn.Linear but with a sync function that will be invoked before the bias addition.
- Parameters
in_features (int) – Size of each input sample.
out_features (int) – Size of each output sample.
bias (bool) – This is to align the interface with nn.Linear. However, this module requires bias to be True.
device (torch.device) – The device of the module.
dtype (torch.dtype) – The data type of the module.
sync_fn (Callable) – The sync function to be invoked before the bias addition.
Methods:
forward
(x)Defines the computation performed at every call.
Set the extra representation of the module
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class slapo.op.linear.LinearWithAct(in_features, out_features, bias=True, act_fn='gelu', device=None, dtype=None, sync_fn=None)[source]¶
Derived from LinearWithSyncFunc but fusing the activation function.
- Parameters
in_features (int) – Size of each input sample.
out_features (int) – Size of each output sample.
bias (bool) – This is to align the interface with nn.Linear. However, this module requires bias to be True.
act_fn (str) – The activation function to be fused. Currently supports “gelu” and “gelu_new”.
device (torch.device) – The device of the module.
dtype (torch.dtype) – The data type of the module.
sync_fn (Callable) – The sync function to be invoked before the bias addition.
Methods:
forward
(x)Defines the computation performed at every call.
Set the extra representation of the module
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters
x (torch.Tensor) –
- Return type
- class slapo.op.linear.LinearWithDropout(in_features, out_features, bias=True, p=0.5, inplace=False, device=None, dtype=None, sync_fn=None, use_torchscript=False)[source]¶
Derived from LinearWithSyncFunc but fusing the dropout.
- Parameters
in_features (int) – Size of each input sample.
out_features (int) – Size of each output sample.
bias (bool) – This is to align the interface with nn.Linear. However, this module requires bias to be True.
p (float) – The probability of an element to be zeroed.
inplace (bool) – If set to True, will do dropout in-place.
device (torch.device) – The device of the module.
dtype (torch.dtype) – The data type of the module.
sync_fn (Callable) – The sync function to be invoked before the bias addition.
use_torchscript (bool) – Whether to use torchscript or memory_efficient_fusion to fuse dropout.
Methods:
forward
(x)Defines the computation performed at every call.
Set the extra representation of the module
- forward(x)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters
x (torch.Tensor) –
- Return type