Flash Attention - Fast Memory-Efficient Attention
Quick start
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
PyTorch native (easiest, PyTorch 2.2+):
import torch import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim] k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
flash-attn library (more features):
pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_func
q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
Common workflows
Workflow 1: Enable in existing PyTorch model
Copy this checklist:
Flash Attention Integration:
- Step 1: Check PyTorch version (≥2.2)
- Step 2: Enable Flash Attention backend
- Step 3: Verify speedup with profiling
- Step 4: Test accuracy matches baseline
Step 1: Check PyTorch version
python -c "import torch; print(torch.version)"
Should be ≥2.2.0
If <2.2, upgrade:
pip install --upgrade torch
Step 2: Enable Flash Attention backend
Replace standard attention:
Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1) out = attn_weights @ v
After (Flash Attention)
import torch.nn.functional as F out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
Force Flash Attention backend:
with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=False ): out = F.scaled_dot_product_attention(q, k, v)
Step 3: Verify speedup with profiling
import torch.utils.benchmark as benchmark
def test_attention(use_flash): q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals()) t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s") print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
Expected: 2-4x speedup for sequences >512 tokens.
Step 4: Test accuracy matches baseline
Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1) out_standard = attn_weights @ v
Check difference
diff = (out_flash - out_standard).abs().max() print(f"Max difference: {diff:.6f}")
Should be <1e-3 for float16
Workflow 2: Use flash-attn library for advanced features
For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
flash-attn Library Setup:
- Step 1: Install flash-attn library
- Step 2: Modify attention code
- Step 3: Enable advanced features
- Step 4: Benchmark performance
Step 1: Install flash-attn library
NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation
Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"
Step 2: Modify attention code
from flash_attn import flash_attn_func
Input: [batch_size, seq_len, num_heads, head_dim]
Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2) # [batch, seq, heads, dim] k = k.transpose(1, 2) v = v.transpose(1, 2)
out = flash_attn_func( q, k, v, dropout_p=0.1, causal=True, # For autoregressive models window_size=(-1, -1), # No sliding window softmax_scale=None # Auto-scale )
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
Step 3: Enable advanced features
Multi-query attention (shared K/V across heads):
from flash_attn import flash_attn_func
q: [batch, seq, num_q_heads, dim]
k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
out = flash_attn_func(q, k, v) # Automatically handles MQA
Sliding window attention (local attention):
Only attend to window of 256 tokens before/after
out = flash_attn_func( q, k, v, window_size=(256, 256), # (left, right) window causal=True )
Step 4: Benchmark performance
import torch from flash_attn import flash_attn_func import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
Warmup
for _ in range(10): _ = flash_attn_func(q, k, v)
Benchmark
torch.cuda.synchronize() start = time.time() for _ in range(100): out = flash_attn_func(q, k, v) torch.cuda.synchronize() end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms") print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
Workflow 3: H100 FP8 optimization (FlashAttention-3)
For maximum performance on H100 GPUs.
FP8 Setup:
- Step 1: Verify H100 GPU available
- Step 2: Install flash-attn with FP8 support
- Step 3: Convert inputs to FP8
- Step 4: Run with FP8 attention
Step 1: Verify H100 GPU
nvidia-smi --query-gpu=name --format=csv
Should show "H100" or "H800"
Step 2: Install flash-attn with FP8 support
pip install flash-attn --no-build-isolation
FP8 support included for H100
Step 3: Convert inputs to FP8
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16) k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16) v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn) k_fp8 = k.to(torch.float8_e4m3fn) v_fp8 = v.to(torch.float8_e4m3fn)
Step 4: Run with FP8 attention
from flash_attn import flash_attn_func
FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
When to use vs alternatives
Use Flash Attention when:
-
Training transformers with sequences >512 tokens
-
Running inference with long context (>2K tokens)
-
GPU memory constrained (OOM with standard attention)
-
Need 2-4x speedup without accuracy loss
-
Using PyTorch 2.2+ or can install flash-attn
Use alternatives instead:
-
Standard attention: Sequences <256 tokens (overhead not worth it)
-
xFormers: Need more attention variants (not just speed)
-
Memory-efficient attention: CPU inference (Flash Attention needs GPU)
Common issues
Issue: ImportError: cannot import flash_attn
Install with no-build-isolation flag:
pip install flash-attn --no-build-isolation
Or install CUDA toolkit first:
conda install cuda -c nvidia pip install flash-attn --no-build-isolation
Issue: Slower than expected (no speedup)
Flash Attention benefits increase with sequence length:
-
<512 tokens: Minimal speedup (10-20%)
-
512-2K tokens: 2-3x speedup
2K tokens: 3-4x speedup
Check sequence length is sufficient.
Issue: RuntimeError: CUDA error
Verify GPU supports Flash Attention:
import torch print(torch.cuda.get_device_capability())
Should be ≥(7, 5) for Turing+
Flash Attention requires:
-
Ampere (A100, A10): ✅ Full support
-
Turing (T4): ✅ Supported
-
Volta (V100): ❌ Not supported
Issue: Accuracy degradation
Check dtype is float16 or bfloat16 (not float32):
q = q.to(torch.float16) # Or torch.bfloat16
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
Advanced topics
Integration with HuggingFace Transformers: See references/transformers-integration.md for enabling Flash Attention in BERT, GPT, Llama models.
Performance benchmarks: See references/benchmarks.md for detailed speed and memory comparisons across GPUs and sequence lengths.
Algorithm details: See references/algorithm.md for tiling strategy, recomputation, and IO complexity analysis.
Advanced features: See references/advanced-features.md for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
Hardware requirements
-
GPU: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
-
VRAM: Same as standard attention (Flash Attention doesn't increase memory)
-
CUDA: 12.0+ (11.8 minimum)
-
PyTorch: 2.2+ for native support
Not supported: V100 (Volta), CPU inference
Resources
-
Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
-
Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
-
PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html