Triton is a programming language and compiler for writing highly efficient GPU kernels in Python. Developed by OpenAI, it bridges the gap between high-level frameworks and low-level CUDA - you write Python-like code and Triton compiles it to optimized GPU assembly. For CUDA developers, Triton eliminates much of the complexity of GPU programming. Instead of managing thread blocks, shared memory, and memory coalescing manually, you express algorithms at a higher level and Triton handles the optimization. It's particularly powerful for custom attention mechanisms, quantization kernels, and operations not well-supported by cuDNN. This guide covers Triton's programming model, kernel development, integration with PyTorch, and optimization techniques for writing production-quality GPU kernels.
CUDA Integration: Triton compiles Python functions to GPU code that runs alongside CUDA kernels. It can directly operate on PyTorch tensors and integrates with the CUDA ecosystem. Triton-generated kernels often match or exceed hand-written CUDA performance for many operations, especially matrix operations and attention.
Triton is included with PyTorch 2.0+ or can be installed separately.
# Triton comes with PyTorch 2.0+
pip install torch # Includes triton
# Or install standalone
pip install triton
# Verify installation
python -c "import triton; print(f'Triton {triton.__version__}')"
# Test basic kernel
python -c "
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n):
pid = tl.program_id(0)
offsets = pid * 1024 + tl.arange(0, 1024)
mask = offsets < n
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)
x = torch.randn(10000, device='cuda')
y = torch.randn(10000, device='cuda')
out = torch.empty_like(x)
add_kernel[(10,)](x, y, out, x.numel())
print('Triton kernel works!')
"A simple Triton kernel demonstrating the basic programming model.
import torch
import triton
import triton.language as tl
@triton.jit
def add_kernel(
x_ptr, # Pointer to first input tensor
y_ptr, # Pointer to second input tensor
out_ptr, # Pointer to output tensor
n_elements, # Total number of elements
BLOCK_SIZE: tl.constexpr, # Compile-time constant
):
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(axis=0) # Which block am I?
# Calculate offsets for this block
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Mask for bounds checking
mask = offsets < n_elements
# Load data (masked to handle edge cases)
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# Compute
output = x + y
# Store result
tl.store(out_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
assert x.is_cuda and y.is_cuda
output = torch.empty_like(x)
n_elements = x.numel()
# Calculate grid size
BLOCK_SIZE = 1024
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
# Launch kernel
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=BLOCK_SIZE)
return output
# Test
x = torch.randn(100000, device='cuda')
y = torch.randn(100000, device='cuda')
output = add(x, y)
assert torch.allclose(output, x + y)
print("Kernel verified!")An optimized fused softmax kernel with automatic tuning for best performance.
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024}, num_warps=4),
triton.Config({'BLOCK_SIZE': 2048}, num_warps=8),
triton.Config({'BLOCK_SIZE': 4096}, num_warps=8),
],
key=['n_cols'], # Retune when n_cols changes
)
@triton.jit
def fused_softmax_kernel(
output_ptr, input_ptr,
input_row_stride, output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
# Each program handles one row
row_idx = tl.program_id(0)
# Pointer to current row
row_start_ptr = input_ptr + row_idx * input_row_stride
# Load row with multiple blocks if needed
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask = col_offsets < n_cols
# Load row
row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
# Compute softmax
row_max = tl.max(row, axis=0)
row = row - row_max # Numerical stability
numerator = tl.exp(row)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Store result
output_row_start_ptr = output_ptr + row_idx * output_row_stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=mask)
def fused_softmax(x: torch.Tensor) -> torch.Tensor:
n_rows, n_cols = x.shape
output = torch.empty_like(x)
# Launch one program per row
grid = (n_rows,)
fused_softmax_kernel[grid](
output, x,
x.stride(0), output.stride(0),
n_cols,
)
return output
# Benchmark against PyTorch
x = torch.randn(4096, 4096, device='cuda')
# Warmup
for _ in range(10):
_ = fused_softmax(x)
_ = torch.softmax(x, dim=-1)
torch.cuda.synchronize()
import time
start = time.time()
for _ in range(100):
_ = fused_softmax(x)
torch.cuda.synchronize()
triton_time = time.time() - start
start = time.time()
for _ in range(100):
_ = torch.softmax(x, dim=-1)
torch.cuda.synchronize()
torch_time = time.time() - start
print(f"Triton: {triton_time*10:.2f}ms, PyTorch: {torch_time*10:.2f}ms")
print(f"Speedup: {torch_time/triton_time:.2f}x")@triton.autotune tests multiple configurations and selects the best one. Include various BLOCK_SIZE and num_warps combinations.
Triton shines when you fuse multiple operations (like softmax) into a single kernel, eliminating intermediate memory reads/writes.
Mark block sizes and other constants as tl.constexpr to enable compiler optimizations like loop unrolling.
Choose BLOCK_SIZE to be multiples of 32 (warp size) and ideally 128 bytes for optimal memory coalescing.
tl.dot uses Tensor Cores automatically when shapes are compatible. Ensure dimensions are multiples of 16.
Set this environment variable to see which configuration the autotuner selected.
| Task | Performance | Notes |
|---|---|---|
| Fused Softmax speedup | 1.5-3x | vs torch.softmax |
| Flash Attention speedup | 2-4x | vs vanilla attention |
| Quantized MatMul | 3-5x | INT8 vs FP16 |
| Compilation time | 1-5s | First call per kernel config |
Use Triton when: 1) You need to fuse multiple operations PyTorch does separately, 2) PyTorch doesn't have an efficient implementation for your operation, 3) You need custom quantization or precision handling. Stick with PyTorch for standard ops like matmul and conv.
Triton is 3-10x faster to develop and often achieves 80-100% of hand-tuned CUDA performance. CUDA gives more control but requires managing threads, shared memory, and synchronization manually. Triton is preferred unless you need that low-level control.
Yes! Wrap your Triton kernel in a torch.autograd.Function and define forward() and backward() methods. The backward pass can also be a Triton kernel.
Triton JIT compiles kernels on first use. Subsequent calls are fast. For benchmarking, always do warmup iterations. In production, you can cache compiled kernels.
Higher-level, no kernel writing needed for standard ops
NumPy-like GPU arrays, less optimization control
JIT for Python, supports CUDA but less GPU-optimized
Optimize your Triton CUDA code with RightNow AI - get real-time performance suggestions and memory analysis.