Attention is the computational bottleneck in transformers. Standard attention has O(N²) memory complexity, making long sequences impossible. FlashAttention revolutionized this by computing attention in tiles, achieving O(N) memory with faster execution through better GPU utilization. This guide covers the FlashAttention algorithm, multi-head parallelization, KV-cache optimization for inference, and emerging techniques like multi-query attention.
Compute attention in SRAM tiles, never materializing the full N×N matrix in HBM.
Use online softmax to compute correct results across tiles without full matrix.
Share K/V heads across multiple Q heads for 8x KV-cache reduction.
Non-contiguous memory allocation for variable-length sequences.
Standard attention requires O(N²) memory, limiting sequence length.
// Standard attention: O(N²) memory
// S = Q @ K^T / sqrt(d)
// P = softmax(S)
// O = P @ V
__global__ void attention_naive(float* Q, float* K, float* V, float* O,
int N, int d) {
// Allocate full N×N attention matrix
extern __shared__ float S[]; // This won't fit for large N!
int row = blockIdx.x;
float scale = 1.0f / sqrtf((float)d);
// Compute Q[row] @ K^T -> S[row, :]
for (int j = 0; j < N; j++) {
float sum = 0.0f;
for (int k = 0; k < d; k++) {
sum += Q[row * d + k] * K[j * d + k];
}
S[j] = sum * scale;
}
// Softmax over S
// ... (problematic for large N)
}FlashAttention tiles computation to fit in SRAM, achieving O(N) memory.
// FlashAttention: O(N) memory, tiled computation
__global__ void flash_attention(float* Q, float* K, float* V, float* O,
int N, int d, int Bc, int Br) {
// Bc, Br = block sizes for columns/rows (fit in SRAM)
extern __shared__ float sram[];
float* Qi = sram; // Br × d
float* Kj = sram + Br * d; // Bc × d
float* Vj = Kj + Bc * d; // Bc × d
float* Sij = Vj + Bc * d; // Br × Bc
int block_row = blockIdx.x;
float scale = 1.0f / sqrtf((float)d);
// Load Q block to shared memory
// Initialize running max and sum for online softmax
float row_max = -INFINITY;
float row_sum = 0.0f;
float* Oi = O + block_row * Br * d; // Output accumulator
// Process K,V in tiles
for (int j = 0; j < N; j += Bc) {
// Load Kj, Vj tiles to SRAM
// Compute Sij = Qi @ Kj^T
// Update row_max, row_sum with online softmax
// Rescale previous output: Oi *= exp(old_max - new_max)
// Accumulate: Oi += softmax(Sij) @ Vj
}
// Final scaling by 1/row_sum
// Write Oi to global memory
}| Metric | Naive | Optimized | Improvement |
|---|---|---|---|
| Memory Usage | O(N²) | O(N) | Enables 64K+ sequences |
| Speed (A100) | 1x | 2-4x | Better HBM utilization |
| Training Throughput | 1x | 3x | Fused forward+backward |
FlashAttention never materializes the N×N attention matrix. It processes Q, K, V in tiles that fit in SRAM, using online softmax to compute correct results across tiles.
FlashAttention is typically 2-4x faster than standard attention due to reduced HBM accesses. The speedup is larger for longer sequences where memory bandwidth dominates.
Attention uses online softmax
Q×K and P×V are matmuls
Multi-head attention is batched matmul
Ready to optimize your CUDA code? Download RightNow AI and get real-time performance analysis for your kernels.