Group Normalization divides channels into groups and normalizes within each group. Unlike batch norm, it's independent of batch size, making it ideal for small batches or variable batch training. GN is used in detection models (Mask R-CNN) and increasingly in transformers.
Compute mean, variance, and normalize in single kernel.
__global__ void group_norm_fused(float* x, float* y, float* gamma, float* beta,
int N, int C, int HW, int G) {
int n = blockIdx.x;
int g = blockIdx.y;
int channels_per_group = C / G;
int group_size = channels_per_group * HW;
// Welford online algorithm for mean/var
float mean = 0, M2 = 0;
int count = 0;
for (int i = threadIdx.x; i < group_size; i += blockDim.x) {
int c = g * channels_per_group + i / HW;
int hw = i % HW;
float val = x[n * C * HW + c * HW + hw];
count++;
float delta = val - mean;
mean += delta / count;
M2 += delta * (val - mean);
}
// Warp reduction for mean and M2...
// Then normalize with gamma/beta
}Three-kernel approach with multiple memory passes.
// Three separate kernels
// 1. Compute group means
// 2. Compute group variances
// 3. Normalize with gamma/beta
void group_norm_naive(float* x, float* y, int N, int C, int HW, int G) {
compute_group_means<<<N*G, 256>>>(x, means, N, C, HW, G);
compute_group_vars<<<N*G, 256>>>(x, means, vars, N, C, HW, G);
normalize_groups<<<N*C, 256>>>(x, y, means, vars, gamma, beta, N, C, HW, G);
}Single kernel with online statistics and fused normalization.
__global__ void group_norm_opt(float* __restrict__ x, float* __restrict__ y,
float* __restrict__ gamma, float* __restrict__ beta,
int C, int HW, int G, float eps) {
extern __shared__ float smem[];
int n = blockIdx.x;
int g = blockIdx.y;
int cpg = C / G;
int group_size = cpg * HW;
float sum = 0, sum_sq = 0;
for (int i = threadIdx.x; i < group_size; i += blockDim.x) {
int c = g * cpg + i / HW;
float val = x[n * C * HW + c * HW + i % HW];
sum += val;
sum_sq += val * val;
}
// Block reduction
sum = blockReduceSum(sum);
sum_sq = blockReduceSum(sum_sq);
__shared__ float s_mean, s_var;
if (threadIdx.x == 0) {
s_mean = sum / group_size;
s_var = sum_sq / group_size - s_mean * s_mean;
}
__syncthreads();
float inv_std = rsqrtf(s_var + eps);
for (int i = threadIdx.x; i < group_size; i += blockDim.x) {
int c = g * cpg + i / HW;
int idx = n * C * HW + c * HW + i % HW;
y[idx] = (x[idx] - s_mean) * inv_std * gamma[c] + beta[c];
}
}| Metric | Naive | Optimized | Improvement |
|---|---|---|---|
| Latency (ResNet block) | 0.8ms | 0.15ms | 5.3x faster |
| Memory reads | 3x | 1x | 3x reduction |
Common choices are 32 groups (original paper) or groups of 16 channels. More groups = closer to layer norm, fewer groups = closer to instance norm.
GN with G=1 equals LayerNorm
GN is batch-independent alternative
Ready to optimize your CUDA code? Download RightNow AI and get real-time performance analysis for your kernels.