Loading...
Log-softmax computes log(softmax(x)) = x - log(sum(exp(x))). Computing softmax first then log is numerically unstable. Direct log-softmax with log-sum-exp trick is both stable and efficient.
Subtract max before exp to prevent overflow.
__device__ void log_softmax(float* logits, float* output, int C) {
// 1. Find max
float max_val = logits[0];
for (int i = 1; i < C; i++) max_val = fmaxf(max_val, logits[i]);
// 2. Compute log-sum-exp
float sum = 0;
for (int i = 0; i < C; i++) sum += expf(logits[i] - max_val);
float log_sum = logf(sum) + max_val;
// 3. Output log-softmax
for (int i = 0; i < C; i++) output[i] = logits[i] - log_sum;
}Overflows for typical logit values.
// DON'T DO THIS - overflows!
__global__ void log_softmax_naive(float* x, float* y, int n, int C) {
for (int i = 0; i < C; i++) {
float sum = 0;
for (int j = 0; j < C; j++) sum += expf(x[j]); // Overflow!
y[i] = logf(expf(x[i]) / sum); // Double overflow!
}
}Parallel stable computation with block reductions.
__global__ void log_softmax_opt(float* x, float* y, int N, int C) {
int sample = blockIdx.x;
float* in = x + sample * C;
float* out = y + sample * C;
__shared__ float s_max, s_log_sum;
// Parallel max reduction
float local_max = -INFINITY;
for (int i = threadIdx.x; i < C; i += blockDim.x)
local_max = fmaxf(local_max, in[i]);
local_max = blockReduceMax(local_max);
if (threadIdx.x == 0) s_max = local_max;
__syncthreads();
// Parallel sum reduction
float local_sum = 0;
for (int i = threadIdx.x; i < C; i += blockDim.x)
local_sum += expf(in[i] - s_max);
local_sum = blockReduceSum(local_sum);
if (threadIdx.x == 0) s_log_sum = logf(local_sum) + s_max;
__syncthreads();
// Parallel output
for (int i = threadIdx.x; i < C; i += blockDim.x)
out[i] = in[i] - s_log_sum;
}| Metric | Naive | Optimized | Improvement |
|---|---|---|---|
| Latency (batch=256, C=30000) | Fails | 0.8ms | Works vs fails |
Log-softmax is more stable and directly usable for NLL loss. Never compute softmax then log.
Ready to optimize your CUDA code? Download RightNow AI and get real-time performance analysis for your kernels.