Metal Flash SDPA

Optimized SDPA kernels inspired by Flash Attention for Metal.

Some components of these kernels are from mlx.

Supported Features

  • Variable-length sequences without padding
  • Causal masking
  • Grouped Query Attention (GQA) and Multi-Query Attention (MQA)
  • Softcapping support for attention score regularization
  • Data types: float32, float16, bfloat16
  • Head dimensions: 32, 64, 72, 80, 96, 128, 256

API Reference

flash_attention_varlen

metal_flash_sdpa.flash_attention_varlen(
    out: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    do_causal: bool,
    scale: float,
    softcapping: float
) -> None
  • out: Output tensor [total_q_tokens, num_heads, head_dim], modified in-place.
  • query/key/value: Input tensors [total_tokens, num_heads(_kv), head_dim].
  • cu_seqlens_q/cu_seqlens_k: Cumulative sequence lengths (torch.int32), [batch_size + 1].
  • max_seqlen_q/max_seqlen_k: Maximum sequence lengths.
  • do_causal: Enable causal masking.
  • scale: Attention score scaling factor (e.g., 1/sqrt(head_dim)).
  • softcapping: Softcapping value for score regularization (use 1.0 for no softcapping).

flash_attn_varlen_func

Compatibility wrapper matching the original Flash Attention API:

out = metal_flash_sdpa.flash_attn_varlen_func(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens_q: torch.Tensor,
    cu_seqlens_k: torch.Tensor,
    max_seqlen_q: int,
    max_seqlen_k: int,
    dropout_p: float = 0.0,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
    window_size: Tuple[int, int] = (-1, -1),
    alibi_slopes: Optional[torch.Tensor] = None,
    deterministic: bool = False,
    return_attn_probs: bool = False
)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support