Cross-entropy loss is the standard for classification tasks. Naive implementation suffers from numerical overflow in softmax. Fusing log-softmax with negative log-likelihood provides both numerical stability and performance through reduced memory traffic.
Subtract max before exp to prevent overflow.
// log_softmax(x)_i = x_i - max(x) - log(sum(exp(x - max(x))))
__device__ float log_softmax_stable(float* logits, int idx, int C) {
// Find max
float max_val = -INFINITY;
for (int i = 0; i < C; i++) max_val = fmaxf(max_val, logits[i]);
// Compute log-sum-exp
float sum = 0;
for (int i = 0; i < C; i++) sum += expf(logits[i] - max_val);
float log_sum_exp = logf(sum) + max_val;
return logits[idx] - log_sum_exp;
}Two kernels with potential overflow and underflow.
// Separate kernels - numerically unstable
void cross_entropy_naive(float* logits, int* labels, float* loss, int N, int C) {
softmax_kernel<<<N, C>>>(logits, probs, C); // Can overflow!
cudaDeviceSynchronize();
nll_loss_kernel<<<N, 1>>>(probs, labels, loss, C); // log(0) = -inf!
}Single kernel with stable log-softmax computation.
__global__ void cross_entropy_fused(float* __restrict__ logits,
int* __restrict__ labels,
float* __restrict__ loss,
int N, int C) {
int sample = blockIdx.x;
int tid = threadIdx.x;
float* sample_logits = logits + sample * C;
int label = labels[sample];
// Step 1: Find max (parallel reduction)
__shared__ float s_max;
float local_max = -INFINITY;
for (int i = tid; i < C; i += blockDim.x) {
local_max = fmaxf(local_max, sample_logits[i]);
}
local_max = blockReduceMax(local_max);
if (tid == 0) s_max = local_max;
__syncthreads();
// Step 2: Compute sum of exp(x - max)
__shared__ float s_sum;
float local_sum = 0;
for (int i = tid; i < C; i += blockDim.x) {
local_sum += expf(sample_logits[i] - s_max);
}
local_sum = blockReduceSum(local_sum);
if (tid == 0) s_sum = local_sum;
__syncthreads();
// Step 3: Compute loss = -log_softmax[label]
if (tid == 0) {
float log_softmax = sample_logits[label] - s_max - logf(s_sum);
loss[sample] = -log_softmax;
}
}| Metric | Naive | Optimized | Improvement |
|---|---|---|---|
| Latency (batch=256, C=1000) | 0.12ms | 0.03ms | 4x faster |
| Numerical stability | Fails for logits>80 | Stable for all float32 | Robust |
No. Always fuse log-softmax with loss. Computing explicit softmax wastes memory bandwidth and risks numerical issues.
Modify loss: loss = (1-smooth)*CE + smooth*uniform_CE. Uniform CE is just mean of log-softmax across all classes.
Ready to optimize your CUDA code? Download RightNow AI and get real-time performance analysis for your kernels.