slapo.op.attention

Attention module using high efficient CUDA kernels.

The flash-attention kernel is tested with: jfc4050/flash-attention

The xFormers kernel is tested with: facebookresearch/xformers

If you encounter an error when using above kernels, please check if the commit hash is the same as the one we tested with.

Functions:

warning_once(msg)

Log the warning message only once.

flash_attn_ref(q, k, v[, bias, causal, ...])

The functional equivalent of FlashAttentionTriton for correctness checking.

xformers_ref(q, k, v, attn_bias[, p, scale])

The native PyTorch implementation of attention with the same signature as the attention implemented in xformers.

validate_sm_version(name, min_sm[, max_sm])

Validate the sm version.

get_xfoemers_attn_op_by_name(attn_name)

Get the xformers attention operator by name.

Classes:

FlashAttentionOp(attn_op_name, apply_causal_mask)

A wrapper module that processes HF attention mask to flash attention mask.

FlashAttention(hidden_size, num_attention_heads)

A HuggingFace self attention module with flash attention kernels.

slapo.op.attention.warning_once(msg)[source]

Log the warning message only once.

slapo.op.attention.flash_attn_ref(q, k, v, bias=None, causal=False, dropout_p=0.0, softmax_scale=None, query_padding_mask=None, key_padding_mask=None, dropout_mask=None, upcast=True, reorder_ops=False)[source]

The functional equivalent of FlashAttentionTriton for correctness checking. Source: jfc4050/flash-attention

Parameters
  • q (torch.Tensor) – Shape: (batch_size, seqlen_q, nheads, head_dim)

  • k (torch.Tensor) – Shape: (batch_size, seqlen_k, nheads, head_dim)

  • v (torch.Tensor) – Shape: (batch_size, seqlen_k, nheads, head_dim)

  • bias (Optional[torch.Tensor]) – Shape: (batch_size, nheads, seqlen_q, seqlen_k)

  • causal (bool) – Whether to apply lower triangular causal mask.

  • dropout_p (float) – The dropout probability.

  • softmax_scale (Optional[float]) – The softmax scale. If None, use 1 / sqrt(d).

  • query_padding_mask (Optional[torch.Tensor]) – Shape: (batch_size, seqlen_q)

  • key_padding_mask (Optional[torch.Tensor]) – (batch_size, seqlen_k)

  • dropout_mask (Optional[torch.Tensor]) – The dropout mask. Shape: (batch_size, nheads, seqlen_q, seqlen_k)

  • upcast (bool) – Whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16.

  • reorder_ops (bool) – whether to change the order of operations (scaling k instead of scaling k, etc.) without changing the math. This is to estimate the numerical error from operation reordering.

Returns

Shape: (batch_size, seqlen_q, nheads, head_dim)

Return type

torch.Tensor

slapo.op.attention.xformers_ref(q, k, v, attn_bias, p=0.0, scale=None)[source]

The native PyTorch implementation of attention with the same signature as the attention implemented in xformers. This is used mainly to check the correctness of the xformers implementation.

Parameters
  • q (torch.Tensor) – Shape: (batch_size, seqlen_q, nheads, head_dim)

  • k (torch.Tensor) – Shape: (batch_size, seqlen_k, nheads, head_dim)

  • v (torch.Tensor) – Shape: (batch_size, seqlen_k, nheads, head_dim)

  • attn_bias (Optional[torch.Tensor]) – Shape: (batch_size, nheads, seqlen_q, seqlen_k)

  • p (float) – The dropout probability.

  • scale (Optional[float]) – The softmax scale. If None, use 1 / sqrt(d).

Returns

Shape: (batch_size, seqlen_q, nheads, head_dim)

Return type

torch.Tensor

slapo.op.attention.validate_sm_version(name, min_sm, max_sm=None)[source]

Validate the sm version.

Parameters
  • name (str) – The name of the kernel.

  • min_sm (tuple[int, int]) – The minimum sm version.

  • max_sm (Optional[tuple[int, int]]) – The maximum sm version. If None, the maximum sm version is not checked.

slapo.op.attention.get_xfoemers_attn_op_by_name(attn_name)[source]

Get the xformers attention operator by name.

class slapo.op.attention.FlashAttentionOp(attn_op_name, apply_causal_mask, scale=None)[source]

A wrapper module that processes HF attention mask to flash attention mask.

Parameters
  • attn_op_name (str) – The name of the attention operator. Can be “native_xformers”, “native_flash_attn”, “triton”, “cuda”, “cutlass”, or “auto”. “triton” and “cuda” uses the kernel from flash-attention; while “cutlass” and “auto” use the kernel from xFormers.

  • apply_causal_mask (bool) – Whether to apply causal mask.

  • scale (Optional[float]) – The softmax scale. If None, use 1 / sqrt(d).

Methods:

forward(query_layer, key_layer, value_layer, ...)

Defines the computation performed at every call.

forward(query_layer, key_layer, value_layer, attention_mask, p)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class slapo.op.attention.FlashAttention(hidden_size, num_attention_heads, attn_pdrop=0.0, resid_pdrop=0.0, attn_op_name='auto', bias=True, output_proj=True, fused_qkv=False)[source]

A HuggingFace self attention module with flash attention kernels. Note that this module has limited supports to specialized processing, documetned as follows:

  • Only support absolute positional embeddings.

  • Do not support cross attention.

  • Do not support head mask, encoder_attention_mask, and output attention.

We organize the Attention module as follows:

  • Attention
    • SelfAttention
      • Q, K, V

      • CoreAttention

    • Projection
      • OutDense

Methods:

reshape_for_scores(x)

Copy from transpose_for_scores but without the transpose

forward(hidden_states[, attention_mask, ...])

Defines the computation performed at every call.

reshape_for_scores(x)[source]

Copy from transpose_for_scores but without the transpose

Parameters

x (torch.Tensor) –

forward(hidden_states, attention_mask=None, layer_past=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, use_cache=False, output_attentions=False)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters
  • hidden_states (Optional[tuple[torch.FloatTensor]]) –

  • attention_mask (Optional[torch.FloatTensor]) –

  • layer_past (Optional[tuple[torch.Tensor]]) –

  • head_mask (Optional[torch.FloatTensor]) –

  • encoder_hidden_states (Optional[torch.Tensor]) –

  • encoder_attention_mask (Optional[torch.FloatTensor]) –

  • use_cache (Optional[bool]) –

  • output_attentions (Optional[bool]) –

Return type

tuple[torch.Tensor]