slapo.op.attention¶
Attention module using high efficient CUDA kernels.
The flash-attention kernel is tested with: jfc4050/flash-attention
The xFormers kernel is tested with: facebookresearch/xformers
If you encounter an error when using above kernels, please check if the commit hash is the same as the one we tested with.
Functions:
|
Log the warning message only once. |
|
The functional equivalent of FlashAttentionTriton for correctness checking. |
|
The native PyTorch implementation of attention with the same signature as the attention implemented in xformers. |
|
Validate the sm version. |
|
Get the xformers attention operator by name. |
Classes:
|
A wrapper module that processes HF attention mask to flash attention mask. |
|
A HuggingFace self attention module with flash attention kernels. |
- slapo.op.attention.flash_attn_ref(q, k, v, bias=None, causal=False, dropout_p=0.0, softmax_scale=None, query_padding_mask=None, key_padding_mask=None, dropout_mask=None, upcast=True, reorder_ops=False)[source]¶
The functional equivalent of FlashAttentionTriton for correctness checking. Source: jfc4050/flash-attention
- Parameters
q (torch.Tensor) – Shape: (batch_size, seqlen_q, nheads, head_dim)
k (torch.Tensor) – Shape: (batch_size, seqlen_k, nheads, head_dim)
v (torch.Tensor) – Shape: (batch_size, seqlen_k, nheads, head_dim)
bias (Optional[torch.Tensor]) – Shape: (batch_size, nheads, seqlen_q, seqlen_k)
causal (bool) – Whether to apply lower triangular causal mask.
dropout_p (float) – The dropout probability.
softmax_scale (Optional[float]) – The softmax scale. If None, use 1 / sqrt(d).
query_padding_mask (Optional[torch.Tensor]) – Shape: (batch_size, seqlen_q)
key_padding_mask (Optional[torch.Tensor]) – (batch_size, seqlen_k)
dropout_mask (Optional[torch.Tensor]) – The dropout mask. Shape: (batch_size, nheads, seqlen_q, seqlen_k)
upcast (bool) – Whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16.
reorder_ops (bool) – whether to change the order of operations (scaling k instead of scaling k, etc.) without changing the math. This is to estimate the numerical error from operation reordering.
- Returns
Shape: (batch_size, seqlen_q, nheads, head_dim)
- Return type
- slapo.op.attention.xformers_ref(q, k, v, attn_bias, p=0.0, scale=None)[source]¶
The native PyTorch implementation of attention with the same signature as the attention implemented in xformers. This is used mainly to check the correctness of the xformers implementation.
- Parameters
q (torch.Tensor) – Shape: (batch_size, seqlen_q, nheads, head_dim)
k (torch.Tensor) – Shape: (batch_size, seqlen_k, nheads, head_dim)
v (torch.Tensor) – Shape: (batch_size, seqlen_k, nheads, head_dim)
attn_bias (Optional[torch.Tensor]) – Shape: (batch_size, nheads, seqlen_q, seqlen_k)
p (float) – The dropout probability.
scale (Optional[float]) – The softmax scale. If None, use 1 / sqrt(d).
- Returns
Shape: (batch_size, seqlen_q, nheads, head_dim)
- Return type
- slapo.op.attention.validate_sm_version(name, min_sm, max_sm=None)[source]¶
Validate the sm version.
- Parameters
name (str) – The name of the kernel.
min_sm (tuple[int, int]) – The minimum sm version.
max_sm (Optional[tuple[int, int]]) – The maximum sm version. If None, the maximum sm version is not checked.
- slapo.op.attention.get_xfoemers_attn_op_by_name(attn_name)[source]¶
Get the xformers attention operator by name.
- class slapo.op.attention.FlashAttentionOp(attn_op_name, apply_causal_mask, scale=None)[source]¶
A wrapper module that processes HF attention mask to flash attention mask.
- Parameters
attn_op_name (str) – The name of the attention operator. Can be “native_xformers”, “native_flash_attn”, “triton”, “cuda”, “cutlass”, or “auto”. “triton” and “cuda” uses the kernel from flash-attention; while “cutlass” and “auto” use the kernel from xFormers.
apply_causal_mask (bool) – Whether to apply causal mask.
scale (Optional[float]) – The softmax scale. If None, use 1 / sqrt(d).
Methods:
forward
(query_layer, key_layer, value_layer, ...)Defines the computation performed at every call.
- forward(query_layer, key_layer, value_layer, attention_mask, p)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class slapo.op.attention.FlashAttention(hidden_size, num_attention_heads, attn_pdrop=0.0, resid_pdrop=0.0, attn_op_name='auto', bias=True, output_proj=True, fused_qkv=False)[source]¶
A HuggingFace self attention module with flash attention kernels. Note that this module has limited supports to specialized processing, documetned as follows:
Only support absolute positional embeddings.
Do not support cross attention.
Do not support head mask, encoder_attention_mask, and output attention.
We organize the Attention module as follows:
- Attention
- SelfAttention
Q, K, V
CoreAttention
- Projection
OutDense
Methods:
Copy from transpose_for_scores but without the transpose
forward
(hidden_states[, attention_mask, ...])Defines the computation performed at every call.
- reshape_for_scores(x)[source]¶
Copy from transpose_for_scores but without the transpose
- Parameters
x (torch.Tensor) –
- forward(hidden_states, attention_mask=None, layer_past=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=False, output_attentions=False)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters
hidden_states (Optional[tuple[torch.FloatTensor]]) –
attention_mask (Optional[torch.FloatTensor]) –
layer_past (Optional[tuple[torch.Tensor]]) –
head_mask (Optional[torch.FloatTensor]) –
encoder_hidden_states (Optional[torch.Tensor]) –
encoder_attention_mask (Optional[torch.FloatTensor]) –
use_cache (Optional[bool]) –
output_attentions (Optional[bool]) –
- Return type
tuple[torch.Tensor]