Matrix multiplication (GEMM - General Matrix Multiply) is the foundation of modern deep learning and scientific computing. A naive CUDA implementation can be 100x slower than optimized code. This guide walks you through the critical optimizations that transform a basic kernel into one that rivals cuBLAS performance. Understanding matrix multiplication optimization teaches you the core concepts that apply to almost every CUDA kernel: memory hierarchy exploitation, thread cooperation, and instruction-level parallelism.
Load tiles of input matrices into shared memory, enabling data reuse across threads. Each element is loaded once from global memory but used multiple times for computation.
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
// Load tile from global to shared memory
As[ty][tx] = A[row * K + (tile * TILE_SIZE + tx)];
Bs[ty][tx] = B[(tile * TILE_SIZE + ty) * N + col];
__syncthreads();
// Compute partial dot product
for (int k = 0; k < TILE_SIZE; k++) {
sum += As[ty][k] * Bs[k][tx];
}
__syncthreads();Ensure consecutive threads access consecutive memory addresses. This allows the GPU to combine multiple memory requests into fewer, larger transactions.
// GOOD: Coalesced - threads access consecutive columns
// Thread 0: B[row][0], Thread 1: B[row][1], ...
float val = B[row * N + threadIdx.x];
// BAD: Non-coalesced - threads access consecutive rows
// Thread 0: B[0][col], Thread 1: B[1][col], ...
float val = B[threadIdx.x * N + col]; // Strided accessChoose block dimensions that maximize occupancy while minimizing register pressure. For matrix multiplication, 16x16 or 32x32 tiles are common choices.
// 16x16 block = 256 threads
// Good balance of occupancy and shared memory usage
dim3 block(16, 16);
dim3 grid((N + 15) / 16, (M + 15) / 16);
// For larger matrices, 32x32 can be better
// But uses 4x more shared memory per block
dim3 block(32, 32); // 1024 threadsHave each thread compute multiple output elements, increasing arithmetic intensity and hiding memory latency through instruction-level parallelism.
// Each thread computes a 4x4 tile of outputs
float regA[4], regB[4];
float regC[4][4] = {0};
// Load A elements into registers
for (int i = 0; i < 4; i++)
regA[i] = As[ty * 4 + i][k];
// Load B elements into registers
for (int i = 0; i < 4; i++)
regB[i] = Bs[k][tx * 4 + i];
// Compute 16 multiply-adds per iteration
for (int i = 0; i < 4; i++)
for (int j = 0; j < 4; j++)
regC[i][j] += regA[i] * regB[j];Pad shared memory arrays or access patterns to prevent multiple threads from accessing the same memory bank simultaneously.
// Without padding: potential 32-way bank conflict
__shared__ float Bs[32][32]; // Column access = bank conflict
// With padding: conflict-free access
__shared__ float Bs[32][33]; // Extra column breaks the pattern
// Access pattern
float val = Bs[k][tx]; // Now conflict-freeThis naive implementation suffers from poor memory access patterns and no data reuse. Each element of A and B is loaded K times from global memory.
__global__ void matmul_naive(float *A, float *B, float *C,
int M, int N, int K) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0.0f;
for (int k = 0; k < K; k++) {
sum += A[row * K + k] * B[k * N + col];
}
C[row * N + col] = sum;
}
}
// Launch configuration
dim3 block(16, 16);
dim3 grid((N + 15) / 16, (M + 15) / 16);
matmul_naive<<<grid, block>>>(d_A, d_B, d_C, M, N, K);This optimized version uses shared memory tiling, coalesced memory access, and bank conflict avoidance. It achieves 10-20x speedup over the naive version.
#define TILE_SIZE 32
__global__ void matmul_optimized(float *A, float *B, float *C,
int M, int N, int K) {
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE + 1]; // +1 for bank conflict
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
int row = by * TILE_SIZE + ty;
int col = bx * TILE_SIZE + tx;
float sum = 0.0f;
// Loop over tiles
for (int tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; tile++) {
// Collaborative loading into shared memory
int aCol = tile * TILE_SIZE + tx;
int bRow = tile * TILE_SIZE + ty;
As[ty][tx] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0f;
Bs[ty][tx] = (bRow < K && col < N) ? B[bRow * N + col] : 0.0f;
__syncthreads();
// Compute partial dot product
#pragma unroll
for (int k = 0; k < TILE_SIZE; k++) {
sum += As[ty][k] * Bs[k][tx];
}
__syncthreads();
}
// Write result
if (row < M && col < N) {
C[row * N + col] = sum;
}
}| Metric | Naive | Optimized | Improvement |
|---|---|---|---|
| Execution Time (1024x1024) | 45.2 ms | 3.8 ms | 11.9x faster |
| Memory Bandwidth Utilization | 12% | 78% | 6.5x better |
| GFLOPS (2048x2048) | 180 GFLOPS | 2,100 GFLOPS | 11.7x higher |
| Compared to cuBLAS | 6% of cuBLAS | 85% of cuBLAS | 14x closer |
The optimal tile size depends on your GPU architecture. For most modern GPUs, 32x32 tiles provide a good balance between shared memory usage and occupancy. Start with 32x32 and profile to find the best size for your specific hardware and matrix dimensions.
Use cuBLAS for production workloads where maximum performance is critical. Write custom kernels when you need to fuse matrix multiplication with other operations, have unusual matrix shapes, or need custom precision formats. cuBLAS is highly optimized for standard use cases.
The tiling approach works for any matrix dimensions. Use boundary checks when loading tiles (as shown in the optimized code) to handle cases where the matrix dimensions are not multiples of the tile size. Zero-padding the out-of-bounds accesses maintains correctness.
In naive matrix multiplication, each element is loaded K times from global memory (where K is the inner dimension). With N elements per matrix, this means O(N³) memory accesses for O(N³) computations - a 1:1 ratio. Tiling reduces this to O(N³/T) accesses where T is the tile size, achieving a T:1 compute-to-memory ratio.
Fundamental technique used in GEMM
Often needed before/after GEMM
Similar tiling patterns apply
Ready to optimize your CUDA code? Download RightNow AI and get real-time performance analysis for your kernels.