Batched GEMM executes multiple small matrix multiplications in parallel, critical for multi-head attention, batch processing, and grouped convolutions. Understanding when to use batched GEMM vs. reshaping to single large GEMM is key to transformer optimization. This guide covers cuBLAS batched operations, memory layouts, and custom kernels for small matrices where cuBLAS overhead dominates.
Use cublasGemmStridedBatched for contiguous batch layout without pointer arrays.
For many small GEMMs, custom persistent kernels avoid launch overhead.
Store Q,K,V as [batch, head, seq, dim] for strided access.
Launching separate GEMMs has high overhead for many small matrices.
// Naive: launch separate GEMM for each batch
void batched_gemm_naive(cublasHandle_t handle,
float* A, float* B, float* C,
int M, int N, int K, int batch) {
float alpha = 1.0f, beta = 0.0f;
for (int i = 0; i < batch; i++) {
cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N,
N, M, K, &alpha,
B + i * K * N, N,
A + i * M * K, K,
&beta, C + i * M * N, N);
}
}Strided batched GEMM eliminates loop overhead and enables parallelism.
// Optimized: single batched GEMM launch
void batched_gemm_strided(cublasHandle_t handle,
float* A, float* B, float* C,
int M, int N, int K, int batch) {
float alpha = 1.0f, beta = 0.0f;
// Strided batched - no pointer array needed
cublasSgemmStridedBatched(handle, CUBLAS_OP_N, CUBLAS_OP_N,
N, M, K,
&alpha,
B, N, K * N, // B stride
A, K, M * K, // A stride
&beta,
C, N, M * N, // C stride
batch);
}
// For multi-head attention: reshape Q,K,V for efficient batching
// Input: [batch, seq, heads * dim]
// Reshape to: [batch * heads, seq, dim] for strided batched GEMM| Metric | Naive | Optimized | Improvement |
|---|---|---|---|
| Throughput (small matrices) | 1x | 8-15x | Reduced launch overhead |
| Multi-head Attention | 1x | 2x | Proper batching |
Use batched GEMM when batch elements have different sizes or when reshape/concat would require memory copies. For same-size batches with contiguous memory, reshaping to large GEMM may be faster.
Ready to optimize your CUDA code? Download RightNow AI and get real-time performance analysis for your kernels.