slapo.op.linear

Custom linear modules.

Classes:

FusedQKV(hidden_size, num_heads, world_size)

A linear module with fused QKV weights.

LinearWithSeparateBias(in_features, out_features)

Implementation modified from nn.Linear Arguments are the same as the inputs of nn.Linear

LinearWithSyncFunc(in_features, out_features)

Derived from nn.Linear but with a sync function that will be invoked before the bias addition.

LinearWithAct(in_features, out_features[, ...])

Derived from LinearWithSyncFunc but fusing the activation function.

LinearWithDropout(in_features, out_features)

Derived from LinearWithSyncFunc but fusing the dropout.

class slapo.op.linear.FusedQKV(hidden_size, num_heads, world_size)[source]

A linear module with fused QKV weights.

Parameters
  • hidden_size (int) – The hidden size of the input.

  • num_heads (int) – The number of heads.

  • world_size (int) – The size of tensor parallelism group.

Return type

None

Methods:

forward(hidden_states)

Defines the computation performed at every call.

forward(hidden_states)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class slapo.op.linear.LinearWithSeparateBias(in_features, out_features, bias=True, device=None, dtype=None)[source]

Implementation modified from nn.Linear Arguments are the same as the inputs of nn.Linear

Methods:

forward(x)

Defines the computation performed at every call.

Parameters
  • in_features (int) –

  • out_features (int) –

  • bias (bool) –

Return type

None

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters

x (torch.Tensor) –

Return type

torch.Tensor

class slapo.op.linear.LinearWithSyncFunc(in_features, out_features, bias=True, device=None, dtype=None, sync_fn=None)[source]

Derived from nn.Linear but with a sync function that will be invoked before the bias addition.

Parameters
  • in_features (int) – Size of each input sample.

  • out_features (int) – Size of each output sample.

  • bias (bool) – This is to align the interface with nn.Linear. However, this module requires bias to be True.

  • device (torch.device) – The device of the module.

  • dtype (torch.dtype) – The data type of the module.

  • sync_fn (Callable) – The sync function to be invoked before the bias addition.

Methods:

forward(x)

Defines the computation performed at every call.

extra_repr()

Set the extra representation of the module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

class slapo.op.linear.LinearWithAct(in_features, out_features, bias=True, act_fn='gelu', device=None, dtype=None, sync_fn=None)[source]

Derived from LinearWithSyncFunc but fusing the activation function.

Parameters
  • in_features (int) – Size of each input sample.

  • out_features (int) – Size of each output sample.

  • bias (bool) – This is to align the interface with nn.Linear. However, this module requires bias to be True.

  • act_fn (str) – The activation function to be fused. Currently supports “gelu” and “gelu_new”.

  • device (torch.device) – The device of the module.

  • dtype (torch.dtype) – The data type of the module.

  • sync_fn (Callable) – The sync function to be invoked before the bias addition.

Methods:

forward(x)

Defines the computation performed at every call.

extra_repr()

Set the extra representation of the module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters

x (torch.Tensor) –

Return type

torch.Tensor

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

class slapo.op.linear.LinearWithDropout(in_features, out_features, bias=True, p=0.5, inplace=False, device=None, dtype=None, sync_fn=None, use_torchscript=False)[source]

Derived from LinearWithSyncFunc but fusing the dropout.

Parameters
  • in_features (int) – Size of each input sample.

  • out_features (int) – Size of each output sample.

  • bias (bool) – This is to align the interface with nn.Linear. However, this module requires bias to be True.

  • p (float) – The probability of an element to be zeroed.

  • inplace (bool) – If set to True, will do dropout in-place.

  • device (torch.device) – The device of the module.

  • dtype (torch.dtype) – The data type of the module.

  • sync_fn (Callable) – The sync function to be invoked before the bias addition.

  • use_torchscript (bool) – Whether to use torchscript or memory_efficient_fusion to fuse dropout.

Methods:

forward(x)

Defines the computation performed at every call.

extra_repr()

Set the extra representation of the module

forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters

x (torch.Tensor) –

Return type

torch.Tensor

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.