Layer normalization is applied after every transformer layer, making it critical for inference performance. Unlike batch normalization, LayerNorm normalizes across features, not batches, making it easier to parallelize on GPU. This guide covers numerically stable variance computation with Welford's algorithm, warp-level optimizations, and fusion with residual connections.
Compute mean and variance in single pass with numerical stability.
Fast parallel reductions for mean and variance within warp.
Combine residual connection and LayerNorm in one kernel.
Skip mean computation for faster RMSNorm (used in LLaMA).
Two-pass reads data 3 times and may have numerical issues.
__global__ void layernorm_naive(float* x, float* y, float* gamma, float* beta,
int N, int D, float eps) {
int row = blockIdx.x;
float* x_row = x + row * D;
float* y_row = y + row * D;
// Pass 1: compute mean
float mean = 0.0f;
for (int i = 0; i < D; i++) mean += x_row[i];
mean /= D;
// Pass 2: compute variance
float var = 0.0f;
for (int i = 0; i < D; i++) {
float diff = x_row[i] - mean;
var += diff * diff;
}
var /= D;
// Pass 3: normalize
float inv_std = rsqrtf(var + eps);
for (int i = 0; i < D; i++) {
y_row[i] = (x_row[i] - mean) * inv_std * gamma[i] + beta[i];
}
}Welford algorithm with warp shuffles is numerically stable and fast.
__global__ void layernorm_welford(float* x, float* y, float* gamma, float* beta,
int N, int D, float eps) {
int row = blockIdx.x;
int tid = threadIdx.x;
float* x_row = x + row * D;
float* y_row = y + row * D;
// Welford online algorithm
float mean = 0.0f, M2 = 0.0f;
int count = 0;
for (int i = tid; i < D; i += blockDim.x) {
float val = x_row[i];
count++;
float delta = val - mean;
mean += delta / count;
M2 += delta * (val - mean);
}
// Parallel Welford reduction across warp
for (int offset = 16; offset > 0; offset /= 2) {
float other_mean = __shfl_down_sync(0xffffffff, mean, offset);
float other_M2 = __shfl_down_sync(0xffffffff, M2, offset);
int other_count = __shfl_down_sync(0xffffffff, count, offset);
int total = count + other_count;
float delta = other_mean - mean;
mean = (count * mean + other_count * other_mean) / total;
M2 = M2 + other_M2 + delta * delta * count * other_count / total;
count = total;
}
mean = __shfl_sync(0xffffffff, mean, 0);
float var = __shfl_sync(0xffffffff, M2, 0) / D;
float inv_std = rsqrtf(var + eps);
for (int i = tid; i < D; i += blockDim.x) {
y_row[i] = (x_row[i] - mean) * inv_std * gamma[i] + beta[i];
}
}| Metric | Naive | Optimized | Improvement |
|---|---|---|---|
| Throughput (GB/s) | 320 | 680 | 2.1x |
| Latency (μs) | 28 | 14 | 2x |
RMSNorm skips the mean subtraction, only normalizing by root mean square: y = x / sqrt(mean(x²) + eps). This is faster and works well for LLMs (used in LLaMA, Gemma).
Welford is numerically stable for large values. The standard formula var = E[x²] - E[x]² can produce negative variance due to floating-point cancellation.
Ready to optimize your CUDA code? Download RightNow AI and get real-time performance analysis for your kernels.