Note
Go to the end to download the full example code
Optimize MLP Module on Multi-Device¶
This guide uses the multi-layer perceptron (MLP) module, one of the basin components in Transformer-based models, as an example to show how we can leverage Slapo to optimize its performance on multiple devices. We will cover tensor parallelism, synchronization, and operator fusion in this tutorial.
We first import the necessary packages. Make sure you have already installed the PyTorch framework.
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import slapo
Since we will use multiple GPUs to run the model, we need to initialize the distributed
backend. We only initialize the CPU backend in this tutorial, but you can
initialize the NCCL backend on GPU by passing in backend="nccl"
, and change
the actual number of devices accordingly.
slapo.env.setup(rank=0, world_size=1, backend="gloo")
print(f"rank: {dist.get_rank()}, world_size: {dist.get_world_size()}")
rank: 0, world_size: 1
Model Definition¶
We first define a MLP module that consists of two linear layers and a GELU activation, which is a basic component in Transformer-based models like GPT. Users can instantiate the module as usual.
class MLP(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear1 = nn.Linear(hidden_size, hidden_size)
self.activation = nn.GELU()
self.linear2 = nn.Linear(hidden_size, hidden_size)
def forward(self, data):
out = self.linear1(data)
out = self.activation(out)
out = self.linear2(out)
return out
model = MLP(1024)
Create Model Schedule¶
We then create a default schedule sch
for the model. Users can always check the
corresponding PyTorch model by calling sch.mod
.
sch = slapo.create_schedule(model)
print(sch.mod)
MLP(
(linear1): Linear(in_features=1024, out_features=1024, bias=True)
(activation): GELU(approximate='none')
(linear2): Linear(in_features=1024, out_features=1024, bias=True)
)
Tensor Parallelism¶
Here comes the most important part of transforming the single-device model to
a parallelized one. Slapo provides a .shard()
primitive to realize tensor
parallelism. Users can specify the name of the tensor and the axis to shard the
tensor along. We follow the convention of Megatron-LM
to shard the weight \(A\) in the first linear layer by column, and the
weight \(B\) in the second linear layer by row. Consider a machine with two
devices, the computation becomes as follows:
where \(X\) is the input tensor. Since PyTorch’s nn.Linear
module by default
transposes the weight matrix, axis=0
means sharding the output dimension.
As each device only holds a part of the result, we need to synchronize the results
at the end of both forward and backward pass. We can also use .sync()
to specify the
synchronization point and strategy. Here we use all_reduce
to synchronize the results
after the second linear layer during forward pass, and insert another all_reduce
before the first linear layer during backward pass. Users only need to write the following
several lines of code to realize complex tensor parallelism but have no need to care about
the low-level implementation details.
sch["linear1"].shard("weight", axis=0)
sch["linear1"].shard("bias", axis=0)
sch["linear2"].shard("weight", axis=1)
sch["linear2"].sync(mode="fwd_post", sync_op_or_fn="all_reduce")
sch["linear1"].sync(mode="bwd_post", sync_op_or_fn="all_reduce")
print(sch.mod)
MLP(
(linear1): Linear(in_features=1024, out_features=1024, bias=True)
(activation): GELU(approximate='none')
(linear2): LinearWithSyncFunc(in_features=1024, out_features=1024, bias=True, sync_fn=partial(reduce_forward_output, group=None))
)
If lanuch this script with two devices, you can see that the weight and bias of the linear layers are correctly sharded, where the output dimension of the first linear layer becomes half of the original one, and each device only holds half of the weight.
To further verify the end-to-end numerical correctness, Slapo also provides
a .Verify()
context that can be used to execute the forward function and
compare the results with the original module. For example, users can leverage this context
with slapo.Verify(sch, example_inputs=[torch.randn(2, 512, 1024)])
to encapsulate those .shard()
and .sync()
primitives.
If no errors are reported, the numerical correctness is guaranteed.
Operator Fusion¶
Another optimization we can do is to fuse the GELU activation with the first
linear layer. We can use .decompose()
to decompose the linear layer into
a matrix multiplication and a bias addition. As shown in the output below,
the nn.Linear
layer is replaced with the predefined LinearWithSeparateBias
module.
sch["linear1"].decompose()
print(sch.mod)
MLP(
(linear1): LinearWithSeparateBias(in_features=1024, out_features=1024, bias=True)
(activation): GELU(approximate='none')
(linear2): LinearWithSyncFunc(in_features=1024, out_features=1024, bias=True, sync_fn=partial(reduce_forward_output, group=None))
)
To enable operator fusion, we need a static dataflow graph. Here, we explicitly
call .trace()
to trace the module and break the linear layer into two separate
multiply and add operators. Users can easily determine whether they want their
dataflow graph to be flattened or not by just passing in a flag.
sch.trace(flatten=True)
print(sch.mod)
GraphModule(
(linear1): Module()
(activation): GELU(approximate='none')
(linear2): LinearWithSyncFunc(in_features=1024, out_features=1024, bias=True, sync_fn=partial(reduce_forward_output, group=None))
)
def forward(self, data):
linear1_weight = self.linear1.weight
linear = torch._C._nn.linear(data, linear1_weight, None); data = linear1_weight = None
linear1_bias = self.linear1.bias
add = linear + linear1_bias; linear = linear1_bias = None
activation = self.activation(add); add = None
linear2 = self.linear2(activation); activation = None
return linear2
# To see more debug info, please use `graph_module.print_readable()`
Later, we define a pattern for matching the bias addition and GELU activation. Notice Slapo supports different types of patterns, including subgraphs with multiple inputs and fuzzy matching, which provides users enough flexibility to express their subgraphs.
def pattern(x, bias):
x = F.gelu(bias + x)
return x
subgraph = sch.find(pattern)
print(subgraph)
[[('', add), ('', activation)]]
As expected, the subgraph consists of two nodes, one for the bias addition and
the other for the GELU activation. We can then fuse the subgraph into a single
node by calling .fuse()
. By default, Slapo will use TorchScript with nvFuser
as the backend compiler.
sch.fuse(subgraph, compiler="TorchScript", name="BiasGeLU")
print(sch.mod)
/usr/local/lib/python3.8/site-packages/torch/jit/_check.py:172: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn("The TorchScript type system doesn't support "
GraphModule(
(linear1): Module()
(linear2): LinearWithSyncFunc(in_features=1024, out_features=1024, bias=True, sync_fn=partial(reduce_forward_output, group=None))
(BiasGeLU_0): RecursiveScriptModule(
original_name=GraphModule
(activation): RecursiveScriptModule(original_name=GELU)
)
)
def forward(self, data):
linear1_weight = self.linear1.weight
linear = torch._C._nn.linear(data, linear1_weight, None); data = linear1_weight = None
linear1_bias = self.linear1.bias
bias_ge_lu_0 = self.BiasGeLU_0(linear, linear1_bias); linear = linear1_bias = None
linear2 = self.linear2(bias_ge_lu_0); bias_ge_lu_0 = None
return linear2
# To see more debug info, please use `graph_module.print_readable()`
Build the Optimized Model¶
We can see the previous sharding optimization is still preserved, and the fused kernel is correctly inserted into the hierarchical module definition and the corresponding dataflow graph.
Finally, we can build the optimized model by calling .build()
.
opt_model, _ = slapo.build(sch, init_weights=False)
Total running time of the script: ( 0 minutes 0.077 seconds)