YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
Flash Attention
Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention.
Original code here https://github.com/Dao-AILab/flash-attention.
# /// script
# dependencies = ["numpy", "torch", "kernels"]
# ///
import torch
from kernels import get_kernel
# Setup
torch.manual_seed(42)
flash_attn = get_kernel("kernels-community/flash-attn")
device = torch.device("cuda")
# Show available functions
print("Flash Attention functions:", [i for i in dir(flash_attn) if i.startswith("mha")])
# 1. Standard attention
print("\n1. Standard attention:")
B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
out = flash_attn.mha_fwd(q=q, k=k, v=v, is_causal=False)[0]
print(f"Output: {out.shape}")
# 2. Variable length sequences
print("\n2. Variable length sequences:")
q_var = torch.randn(10, H, D, device=device, dtype=torch.float16) # total_q=10
k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16) # total_k=12
# For 3 sequences with lengths [3,4,3] for q and [4,5,3] for k
cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32)
cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
out_var = flash_attn.mha_varlen_fwd(
q=q_var,
k=k_var,
v=v_var,
cu_seqlens_q=cu_q,
cu_seqlens_k=cu_k,
max_seqlen_q=4,
max_seqlen_k=5,
)[0]
print(f"Output: {out_var.shape}")
# 3. KV-cache for autoregressive generation
print("\n3. KV-cache:")
cache_len, new_len = 10, 2
kcache = vcache = torch.randn(B, cache_len, H, D, device=device, dtype=torch.float16)
q_new = k_new = v_new = torch.randn(
B, new_len, H, D, device=device, dtype=torch.float16
)
seqlens = torch.full((B,), cache_len + new_len, device=device, dtype=torch.int32)
out_kv = flash_attn.mha_fwd_kvcache(
q=q_new,
kcache=kcache,
vcache=vcache,
k=k_new,
v=v_new,
seqlens_k=seqlens,
is_causal=True,
)[0]
print(f"Output: {out_kv.shape}")
expected output
Fetching 3 files: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββ| 3/3 [00:00<00:00, 16384.00it/s]
Flash Attention functions: ['mha_bwd', 'mha_fwd', 'mha_fwd_kvcache', 'mha_varlen_bwd', 'mha_varlen_fwd']
1. Standard attention:
Output: torch.Size([2, 5, 4, 8])
2. Variable length sequences:
Output: torch.Size([10, 4, 8])
3. KV-cache:
Output: torch.Size([2, 2, 4, 8])
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support
HF Inference deployability: The model has no library tag.