Optimize MLP Module on Multi-Device

This guide uses the multi-layer perceptron (MLP) module, one of the basin components in Transformer-based models, as an example to show how we can leverage Slapo to optimize its performance on multiple devices. We will cover tensor parallelism, synchronization, and operator fusion in this tutorial.

We first import the necessary packages. Make sure you have already installed the PyTorch framework.

import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import slapo

Since we will use multiple GPUs to run the model, we need to initialize the distributed backend. We only initialize the CPU backend in this tutorial, but you can initialize the NCCL backend on GPU by passing in backend="nccl", and change the actual number of devices accordingly.

slapo.env.setup(rank=0, world_size=1, backend="gloo")
print(f"rank: {dist.get_rank()}, world_size: {dist.get_world_size()}")
rank: 0, world_size: 1

Model Definition

We first define a MLP module that consists of two linear layers and a GELU activation, which is a basic component in Transformer-based models like GPT. Users can instantiate the module as usual.

class MLP(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.linear1 = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(hidden_size, hidden_size)

    def forward(self, data):
        out = self.linear1(data)
        out = self.activation(out)
        out = self.linear2(out)
        return out


model = MLP(1024)

Create Model Schedule

We then create a default schedule sch for the model. Users can always check the corresponding PyTorch model by calling sch.mod.

sch = slapo.create_schedule(model)
print(sch.mod)
MLP(
  (linear1): Linear(in_features=1024, out_features=1024, bias=True)
  (activation): GELU(approximate='none')
  (linear2): Linear(in_features=1024, out_features=1024, bias=True)
)

Tensor Parallelism

Here comes the most important part of transforming the single-device model to a parallelized one. Slapo provides a .shard() primitive to realize tensor parallelism. Users can specify the name of the tensor and the axis to shard the tensor along. We follow the convention of Megatron-LM to shard the weight \(A\) in the first linear layer by column, and the weight \(B\) in the second linear layer by row. Consider a machine with two devices, the computation becomes as follows:

\[\begin{split}f(XA)B = f\left(X\begin{bmatrix}A_1 & A_2\end{bmatrix}\right) \begin{bmatrix}B_1 \\ B_2\end{bmatrix} =f(XA_1)B_1 + f(XA_2)B_2\end{split}\]

where \(X\) is the input tensor. Since PyTorch’s nn.Linear module by default transposes the weight matrix, axis=0 means sharding the output dimension. As each device only holds a part of the result, we need to synchronize the results at the end of both forward and backward pass. We can also use .sync() to specify the synchronization point and strategy. Here we use all_reduce to synchronize the results after the second linear layer during forward pass, and insert another all_reduce before the first linear layer during backward pass. Users only need to write the following several lines of code to realize complex tensor parallelism but have no need to care about the low-level implementation details.

sch["linear1"].shard("weight", axis=0)
sch["linear1"].shard("bias", axis=0)
sch["linear2"].shard("weight", axis=1)
sch["linear2"].sync(mode="fwd_post", sync_op_or_fn="all_reduce")
sch["linear1"].sync(mode="bwd_post", sync_op_or_fn="all_reduce")
print(sch.mod)
MLP(
  (linear1): Linear(in_features=1024, out_features=1024, bias=True)
  (activation): GELU(approximate='none')
  (linear2): LinearWithSyncFunc(in_features=1024, out_features=1024, bias=True, sync_fn=partial(reduce_forward_output, group=None))
)

If lanuch this script with two devices, you can see that the weight and bias of the linear layers are correctly sharded, where the output dimension of the first linear layer becomes half of the original one, and each device only holds half of the weight.

To further verify the end-to-end numerical correctness, Slapo also provides a .Verify() context that can be used to execute the forward function and compare the results with the original module. For example, users can leverage this context with slapo.Verify(sch, example_inputs=[torch.randn(2, 512, 1024)]) to encapsulate those .shard() and .sync() primitives. If no errors are reported, the numerical correctness is guaranteed.

Operator Fusion

Another optimization we can do is to fuse the GELU activation with the first linear layer. We can use .decompose() to decompose the linear layer into a matrix multiplication and a bias addition. As shown in the output below, the nn.Linear layer is replaced with the predefined LinearWithSeparateBias module.

sch["linear1"].decompose()
print(sch.mod)
MLP(
  (linear1): LinearWithSeparateBias(in_features=1024, out_features=1024, bias=True)
  (activation): GELU(approximate='none')
  (linear2): LinearWithSyncFunc(in_features=1024, out_features=1024, bias=True, sync_fn=partial(reduce_forward_output, group=None))
)

To enable operator fusion, we need a static dataflow graph. Here, we explicitly call .trace() to trace the module and break the linear layer into two separate multiply and add operators. Users can easily determine whether they want their dataflow graph to be flattened or not by just passing in a flag.

sch.trace(flatten=True)
print(sch.mod)
GraphModule(
  (linear1): Module()
  (activation): GELU(approximate='none')
  (linear2): LinearWithSyncFunc(in_features=1024, out_features=1024, bias=True, sync_fn=partial(reduce_forward_output, group=None))
)



def forward(self, data):
    linear1_weight = self.linear1.weight
    linear = torch._C._nn.linear(data, linear1_weight, None);  data = linear1_weight = None
    linear1_bias = self.linear1.bias
    add = linear + linear1_bias;  linear = linear1_bias = None
    activation = self.activation(add);  add = None
    linear2 = self.linear2(activation);  activation = None
    return linear2

# To see more debug info, please use `graph_module.print_readable()`

Later, we define a pattern for matching the bias addition and GELU activation. Notice Slapo supports different types of patterns, including subgraphs with multiple inputs and fuzzy matching, which provides users enough flexibility to express their subgraphs.

def pattern(x, bias):
    x = F.gelu(bias + x)
    return x


subgraph = sch.find(pattern)
print(subgraph)
[[('', add), ('', activation)]]

As expected, the subgraph consists of two nodes, one for the bias addition and the other for the GELU activation. We can then fuse the subgraph into a single node by calling .fuse(). By default, Slapo will use TorchScript with nvFuser as the backend compiler.

sch.fuse(subgraph, compiler="TorchScript", name="BiasGeLU")
print(sch.mod)
/usr/local/lib/python3.8/site-packages/torch/jit/_check.py:172: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn("The TorchScript type system doesn't support "
GraphModule(
  (linear1): Module()
  (linear2): LinearWithSyncFunc(in_features=1024, out_features=1024, bias=True, sync_fn=partial(reduce_forward_output, group=None))
  (BiasGeLU_0): RecursiveScriptModule(
    original_name=GraphModule
    (activation): RecursiveScriptModule(original_name=GELU)
  )
)



def forward(self, data):
    linear1_weight = self.linear1.weight
    linear = torch._C._nn.linear(data, linear1_weight, None);  data = linear1_weight = None
    linear1_bias = self.linear1.bias
    bias_ge_lu_0 = self.BiasGeLU_0(linear, linear1_bias);  linear = linear1_bias = None
    linear2 = self.linear2(bias_ge_lu_0);  bias_ge_lu_0 = None
    return linear2

# To see more debug info, please use `graph_module.print_readable()`

Build the Optimized Model

We can see the previous sharding optimization is still preserved, and the fused kernel is correctly inserted into the hierarchical module definition and the corresponding dataflow graph.

Finally, we can build the optimized model by calling .build().

opt_model, _ = slapo.build(sch, init_weights=False)

Total running time of the script: ( 0 minutes 0.077 seconds)

Gallery generated by Sphinx-Gallery