Note
Go to the end to download the full example code
Optimize Attention Module on A Single Device¶
This guide uses the Attention module, the core and most time-consuming module in Transformer-based models, as an example to show how we can leverage Slapo to optimize its performance on a single device. We will cover module tracing, pattern matching, operator fusion, and partial module replacement in this tutorial.
We first import the necessary packages. Make sure you have already installed the PyTorch framework.
import torch
import torch.nn as nn
import torch.nn.functional as F
import slapo
Model Definition¶
The Attention module consists of SelfAttention and Projection modules, where SelfAttention takes in the hidden states and passes it through three different linear layers to generate the query, key and value tensors. Then, those tensors will be performed the following scaled dot-product attention:
where \(d_k\) is the hidden dimension. Finally, the output of the attention module will be passed through a linear projection layer, added with the residual connection, and conducted a layer norm to generate the final output. The following code shows the implementation of the Attention module.
def scaled_dot_product(q, k, v):
# (bs, head, seq, hs // head)
d_k = q.shape[-1]
attn_score = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(d_k)
# (bs, head, seq, seq)
attn_probs = F.softmax(attn_score, dim=-1)
attn_probs = F.dropout(attn_probs, 0.1)
# (bs, head, seq, hs // head)
attn = torch.matmul(attn_probs, v)
return attn
class SelfAttention(nn.Module):
def __init__(self, hidden_size, n_heads):
super().__init__()
self.q_proj = nn.Linear(hidden_size, hidden_size)
self.k_proj = nn.Linear(hidden_size, hidden_size)
self.v_proj = nn.Linear(hidden_size, hidden_size)
self.n_heads = n_heads
def permute_for_scores(self, x):
# x: (batch_size, seq_len, hidden_size)
new_shape = x.shape[:-1] + (self.n_heads, -1)
x = x.view(new_shape)
# output: (bs, head, seq, hs // head)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
# hidden_states: (batch_size, seq_len, hidden_size)
# qkv layers
q = self.permute_for_scores(self.q_proj(hidden_states))
k = self.permute_for_scores(self.k_proj(hidden_states))
v = self.permute_for_scores(self.v_proj(hidden_states))
# core attention
output = scaled_dot_product(q, k, v)
# output: (bs, seq, head, hs // head)
output.permute(0, 2, 1, 3)
output.view(output.shape[0], output.shape[1], -1)
return output
class Projection(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.dropout = nn.Dropout(0.1)
self.layer_norm = nn.LayerNorm(hidden_size)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.layer_norm(hidden_states + input_tensor)
return hidden_states
class Attention(nn.Module):
def __init__(self, hidden_size, n_heads):
super().__init__()
self.self_attn = SelfAttention(hidden_size, n_heads)
self.proj = Projection(hidden_size)
def forward(self, hidden_states):
self_output = self.self_attn(hidden_states)
attention_output = self.proj(self_output, hidden_states)
return attention_output
Users can instantiate the model based on the above definition as usual.
model = Attention(hidden_size=1024, n_heads=16)
Create Model Schedule¶
Later, we pass the model to Slapo and create a default schedule for it.
The schedule always includes the original or the transformed module.
Users can check the module by calling the mod
attribute.
sch = slapo.create_schedule(model)
print(sch.mod)
Attention(
(self_attn): SelfAttention(
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
(proj): Projection(
(dense): Linear(in_features=1024, out_features=1024, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
)
As we can see, Slapo works seamlessly with the PyTorch models and preserves the hierarchical structure of the original model. As we have not added any optimizations, the module is exactly the same as the original one. We can easily obtain the submodules by passing the module name to the schedule, which will return a new schedule for the submodule.
attn_sch = sch["self_attn"]
print(attn_sch.mod)
SelfAttention(
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
This is also the idea of progressive optimization – we only apply optimizations to a small part of the model at a time and do not affect other parts. If no optimizations are applied, then no changes will be made to the model, which is different from the traditional static graph optimization employed by deep learning compilers.
In the following, we will show how to gradually apply optimizations to the model.
Optimize SelfAttention Module¶
Replace QKV Linear Layers¶
Since the three linear layers in the SelfAttention module are independent, we can merge them into a single linear layer to reduce the number of GEMM operations, and thus reduce the kernel launch overheads.
The first thing to do is to find those three linear layers and the consequential
operations in the model. Slapo provides an easy-to-use API to help users
define the pattern and find the corresponding module or subgraph in the model.
We can define a subgraph pattern function as shown below. The call_module
function will try to match a call node that satisfies the user-defined
constraint in the dataflow graph. The first argument specifies the name of the
module, where regular expression is supported, so it can support fuzzy
matching in this case. The latter arguments are the arguments of the call node.
Using this function, we can use just one line of code to match the three linear
layers. Also, we need to incorporate the view
and permute
operations,
which should also be fused together instead of doing three times separately.
from slapo.pattern import call_module
def pattern(x):
x = call_module(r"[qkv]_proj", x)
new_shape = x.shape[:-1] + (16, -1)
x = x.view(new_shape)
return x.permute(0, 2, 1, 3)
After defining the pattern, we can use the .find()
primitive to find the
corresponding subgraph in the model.
qkv_subgraphs = attn_sch.find(pattern)
The primitive basically does two things. First, it will implicitly trace the submodule
into a static subgraph. Currently, we use torch.fx
as the IR, so the traced module will become a torch.fx.GraphModule
, and we can also
see the forward function of it.
print(attn_sch.mod)
GraphModule(
(q_proj): Linear(in_features=1024, out_features=1024, bias=True)
(k_proj): Linear(in_features=1024, out_features=1024, bias=True)
(v_proj): Linear(in_features=1024, out_features=1024, bias=True)
)
def forward(self, hidden_states):
q_proj = self.q_proj(hidden_states)
getattr_1 = q_proj.shape
getitem = getattr_1[slice(None, -1, None)]; getattr_1 = None
add = getitem + (16, -1); getitem = None
view = q_proj.view(add); q_proj = add = None
permute = view.permute(0, 2, 1, 3); view = None
k_proj = self.k_proj(hidden_states)
getattr_2 = k_proj.shape
getitem_1 = getattr_2[slice(None, -1, None)]; getattr_2 = None
add_1 = getitem_1 + (16, -1); getitem_1 = None
view_1 = k_proj.view(add_1); k_proj = add_1 = None
permute_1 = view_1.permute(0, 2, 1, 3); view_1 = None
v_proj = self.v_proj(hidden_states); hidden_states = None
getattr_3 = v_proj.shape
getitem_2 = getattr_3[slice(None, -1, None)]; getattr_3 = None
add_2 = getitem_2 + (16, -1); getitem_2 = None
view_2 = v_proj.view(add_2); v_proj = add_2 = None
permute_2 = view_2.permute(0, 2, 1, 3); view_2 = None
getattr_4 = permute.shape
getitem_3 = getattr_4[-1]; getattr_4 = None
transpose = permute_1.transpose(-2, -1); permute_1 = None
matmul = torch.matmul(permute, transpose); permute = transpose = None
sqrt = torch.sqrt(getitem_3); getitem_3 = None
truediv = matmul / sqrt; matmul = sqrt = None
softmax = torch.nn.functional.softmax(truediv, dim = -1, _stacklevel = 3, dtype = None); truediv = None
dropout = torch.nn.functional.dropout(softmax, p = 0.1, training = True, inplace = False); softmax = None
matmul_1 = torch.matmul(dropout, permute_2); dropout = permute_2 = None
return matmul_1
# To see more debug info, please use `graph_module.print_readable()`
Second, the .find()
primitive will return a list of subgraphs that
match the pattern. In our case, there will be three subgraphs, one for each
linear layer and the consequential view
and permute
operations.
print(qkv_subgraphs)
[[('', q_proj), ('', getattr_1), ('', getitem), ('', add), ('', view), ('', permute)], [('', k_proj), ('', getattr_2), ('', getitem_1), ('', add_1), ('', view_1), ('', permute_1)], [('', v_proj), ('', getattr_3), ('', getitem_2), ('', add_2), ('', view_2), ('', permute_2)]]
Then, we define a fused QKV module as follows and instantiate it.
class FusedQKV(nn.Module):
def __init__(self, hidden_size, n_heads) -> None:
super().__init__()
self.n_heads = n_heads
self.fused_linear = nn.Linear(hidden_size, hidden_size * 3)
def permute_for_scores(self, x):
new_shape = x.shape[:-1] + (self.n_heads, -1)
x = x.view(new_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states):
qkv = self.fused_linear(hidden_states)
reshaped_qkv = self.permute_for_scores(qkv)
q, k, v = torch.split(reshaped_qkv, 1, dim=-1)
q = torch.squeeze(q, -1).contiguous()
k = torch.squeeze(k, -1).contiguous()
v = torch.squeeze(v, -1).contiguous()
return [q, k, v]
fused_qkv = FusedQKV(hidden_size=1024, n_heads=16)
We can replace the subgraphs with the fused QKV module by calling the
.replace()
primitive. The first argument is the new module,
and the second argument is the subgraph to be replaced.
After replacing the subgraph, we can check the model again to see the
changes.
attn_sch.replace(fused_qkv, qkv_subgraphs)
print(attn_sch.mod)
GraphModule(
(FusedQKV_0): FusedQKV(
(fused_linear): Linear(in_features=1024, out_features=3072, bias=True)
)
)
def forward(self, hidden_states):
fused_qkv_0 = self.FusedQKV_0(hidden_states); hidden_states = None
getitem_8 = fused_qkv_0[2]
getitem_7 = fused_qkv_0[1]
getitem_6 = fused_qkv_0[0]; fused_qkv_0 = None
getattr_4 = getitem_6.shape
getitem_3 = getattr_4[-1]; getattr_4 = None
transpose = getitem_7.transpose(-2, -1); getitem_7 = None
matmul = torch.matmul(getitem_6, transpose); getitem_6 = transpose = None
sqrt = torch.sqrt(getitem_3); getitem_3 = None
truediv = matmul / sqrt; matmul = sqrt = None
softmax = torch.nn.functional.softmax(truediv, dim = -1, _stacklevel = 3, dtype = None); truediv = None
dropout = torch.nn.functional.dropout(softmax, p = 0.1, training = True, inplace = False); softmax = None
matmul_1 = torch.matmul(dropout, getitem_8); dropout = getitem_8 = None
return matmul_1
# To see more debug info, please use `graph_module.print_readable()`
From the above output, we can see there is a new module called FusedQKV_0
with \(3\times\) out_features
compared to the original linear layer.
The corresponding forward function is also changed to leverage the fused
module.
Replace Scaled Dot-Product Attention¶
Next, we still use the .find()
primitive to find the core attention function
and replace it with a more efficient implementation. Different from the QKV example
that requires us to explicitly write the fuzzy pattern, we can directly write
a function with the identical computation subgraph as the pattern. Since the
scaled_dot_product
function has been defined previously, we can reuse it
and pass it into .find()
.
core_attn_subgraph = attn_sch.find(scaled_dot_product)
print(core_attn_subgraph)
[[('', getattr_4), ('', getitem_3), ('', transpose), ('', matmul), ('', sqrt), ('', truediv), ('', softmax), ('', dropout), ('', matmul_1)]]
We can use the FlashAttentionOp
provided by Slapo that makes use
of flash attention kernels from
xFormers and
flash-attention libraries
to replace the core attention. We directly import and replace the subgraph
with FlashAttentionOp
.
Notice, since the scaled_dot_product
function we defined above only accepts
the query
, key
, and value
tensors, while FlashAttentionOp
requires
five arguments, so we need to explicitly pass None
to the attention_mask
argument, and set the dropout probability p
to 0.1 by setting the concrete_args
.
Note
We use native_xformers
in this tutorial to demonstrate the functionality.
In reality, users can choose cutlass
, triton
, or cuda
kernels to achieve
better performance, while the latter two only support NVIDIA V100 GPU.
Please refer to slapo.op.attention.FlashAttentionOp for more details.
from slapo.op.attention import FlashAttentionOp
flash_attn = FlashAttentionOp(attn_op_name="native_xformers", apply_causal_mask=False)
attn_sch.replace(
flash_attn, core_attn_subgraph, concrete_args={"attention_mask": None, "p": 0.1}
)
print(attn_sch.mod)
GraphModule(
(FusedQKV_0): FusedQKV(
(fused_linear): Linear(in_features=1024, out_features=3072, bias=True)
)
(FlashAttentionOp_0): FlashAttentionOp()
)
def forward(self, hidden_states):
fused_qkv_0 = self.FusedQKV_0(hidden_states); hidden_states = None
getitem_8 = fused_qkv_0[2]
getitem_7 = fused_qkv_0[1]
getitem_6 = fused_qkv_0[0]; fused_qkv_0 = None
flash_attention_op_0 = self.FlashAttentionOp_0(getitem_6, getitem_7, getitem_8, attention_mask = None, p = 0.1); getitem_6 = getitem_7 = getitem_8 = None
return flash_attention_op_0
# To see more debug info, please use `graph_module.print_readable()`
Again, the FlashAttentionOp
is attached to the GraphModule, and the forward
function becomes much simpler to call those two submodules.
Optimize the Projection Module¶
We then optimize the Projection
module. A common practice is to fuse the
dropout and the layer norm layer with those element-wise addition operations.
We first create a subschedule for the Projection
module.
proj_sch = sch["proj"]
print(proj_sch.mod)
Projection(
(dense): Linear(in_features=1024, out_features=1024, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
As we want to fuse the linear bias with the consequential layers, we need to
decompose the linear layer into two separate matrix multiplication and bias add operations.
In Slapo, this is easy to achieve by simply calling .decompose()
on the linear module.
Note
The default nn.Linear module
in PyTorch will directly pass both weight and bias to the backend
F.linear function,
and dispatch it to the corresponding C/CUDA library, so there is no way to fuse the bias if we
do not take it apart. Another reason for decomposing is that we can still optimize the
weight
parameter later (e.g., sharding) even though the bias
may be fused.
proj_sch["dense"].decompose()
print(proj_sch.mod)
Projection(
(dense): LinearWithSeparateBias(in_features=1024, out_features=1024, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
We can see the Linear
module changed into LinearWithSeparateBias
, and other submodules
remain the same. Next, we need to explicitly call the .trace()
primitive to trace the module
into a static subgraph. It gives us more control over the traced module. For example, we can
pass in the flatten
flag to let the tracer gets into each submodule so that the bias add
can be depicted as a node in the subgraph.
proj_sch.trace(flatten=True)
print(proj_sch.mod)
GraphModule(
(dense): Module()
(dropout): Dropout(p=0.1, inplace=False)
(layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
def forward(self, hidden_states, input_tensor):
dense_weight = self.dense.weight
linear = torch._C._nn.linear(hidden_states, dense_weight, None); hidden_states = dense_weight = None
dense_bias = self.dense.bias
add = linear + dense_bias; linear = dense_bias = None
dropout = self.dropout(add); add = None
add_1 = dropout + input_tensor; dropout = input_tensor = None
layer_norm = self.layer_norm(add_1); add_1 = None
return layer_norm
# To see more debug info, please use `graph_module.print_readable()`
We can again define the fusion pattern as follows. Here the pattern includes three input arguments, Slapo can still handle it correctly and grab all the required nodes in the subgraph.
def ln_pattern(x, bias, residual):
return F.layer_norm(F.dropout(x + bias) + residual, 1024)
ln_subgraph = proj_sch.find(ln_pattern)
print(ln_subgraph)
[[('', add), ('', dropout), ('', add_1), ('', layer_norm)]]
For this case of vertical fusion, Slapo provides a .fuse()
primitive to easily fuse the subgraph.
Users can specify the backend fusion compiler and the name of the fused module. By default, Slapo
will use TorchScript with nvFuser to fuse the subgraph.
proj_sch.fuse(ln_subgraph, compiler="TorchScript", name="FusedLayerNorm")
print(proj_sch.mod)
/usr/local/lib/python3.8/site-packages/torch/jit/_check.py:172: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
warnings.warn("The TorchScript type system doesn't support "
GraphModule(
(dense): Module()
(FusedLayerNorm_0): RecursiveScriptModule(
original_name=GraphModule
(dropout): RecursiveScriptModule(original_name=Dropout)
(layer_norm): RecursiveScriptModule(original_name=LayerNorm)
)
)
def forward(self, hidden_states, input_tensor):
dense_weight = self.dense.weight
linear = torch._C._nn.linear(hidden_states, dense_weight, None); hidden_states = dense_weight = None
dense_bias = self.dense.bias
fused_layer_norm_0 = self.FusedLayerNorm_0(linear, dense_bias, input_tensor); linear = dense_bias = input_tensor = None
return fused_layer_norm_0
# To see more debug info, please use `graph_module.print_readable()`
As shown in the above output, the FusedLayerNorm
module is attached to the GraphModule, and
only torch._C._nn.Linear
and FusedLayerNorm
are called in the forward function.
Build the Optimized Model¶
Finally, we finish all the optimizations for Attention module on a single device.
We can pass the schedule into sch.build
to build the optimized model for execution.
It returns the optimized model and a default optimizer. We can print out the top-level module
to see the changes. The optimizations are clearly reflected in the new module, and we still
keep the module hierarchy, which greatly enhances the readability and debuggability of the code.
opt_model, _ = slapo.build(sch, init_weights=False)
print(opt_model)
Attention(
(self_attn): GraphModule(
(FusedQKV_0): FusedQKV(
(fused_linear): Linear(in_features=1024, out_features=3072, bias=True)
)
(FlashAttentionOp_0): FlashAttentionOp()
)
(proj): GraphModule(
(dense): Module()
(FusedLayerNorm_0): RecursiveScriptModule(
original_name=GraphModule
(dropout): RecursiveScriptModule(original_name=Dropout)
(layer_norm): RecursiveScriptModule(original_name=LayerNorm)
)
)
)
Total running time of the script: ( 0 minutes 0.118 seconds)