Projects & Blogs
LATEST TRANSFORMER
PROBABLISTIC APPR...
MINI-UNET
BASICS OF TRITON
TRITON FUSED ATTE...
MINI-ALEXNET
PYTHON CODE GENER...
SEQUENTIAL MONTE ...
TRUNCATED SVD
CUSTOM DATALOADER...
PROBABILITY
%pip install tritonRequirement already satisfied: triton in /usr/local/lib/python3.12/dist-packages (3.4.0)
Requirement already satisfied: setuptools>=40.8.0 in /usr/local/lib/python3.12/dist-packages (from triton) (75.2.0)
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def get_autotune_config():
return [triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8)]
@triton.autotune(
configs=get_autotune_config(),
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr):
# this is use to idenify the pid of the kernel
pid = tl.program_id(axis=0)
# this is use to identify the output pid size
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M*num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m-first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator = tl.dot(a, b, accumulator)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# You can fuse arbitrary activation functions here
# while the accumulator is still in FP32!
if ACTIVATION == "leaky_relu":
accumulator = leaky_relu(accumulator)
c = accumulator.to(tl.float16)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@triton.jit
def leaky_relu(x):
return tl.where(x >= 0, x, 0.01 * x)
def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
# Allocates output.
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel[grid](
a, b, c, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
ACTIVATION=activation #
)
return c
torch.manual_seed(0)
a = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
b = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")triton_output_with_fp16_inputs=tensor([[-0.5210, 2.8730, -1.0684, ..., -0.6885, 0.5825, 0.4065],
[ 2.3477, -1.5615, 1.9453, ..., -0.5010, -1.8594, -2.7246],
[ 0.3354, -1.3828, 2.9043, ..., -0.8320, -1.8623, 3.4531],
...,
[ 2.0586, 1.3125, -4.1484, ..., -2.4297, 1.2734, -1.3037],
[-0.3135, -2.9375, -1.8770, ..., 1.1973, -1.7500, 4.0312],
[ 0.3804, -0.9829, 0.4966, ..., 1.4756, 0.3811, 2.7070]],
device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[-0.5210, 2.8730, -1.0684, ..., -0.6885, 0.5825, 0.4065],
[ 2.3477, -1.5615, 1.9453, ..., -0.5010, -1.8594, -2.7246],
[ 0.3354, -1.3828, 2.9043, ..., -0.8320, -1.8623, 3.4531],
...,
[ 2.0586, 1.3125, -4.1484, ..., -2.4297, 1.2734, -1.3037],
[-0.3135, -2.9375, -1.8770, ..., 1.1973, -1.7500, 4.0312],
[ 0.3804, -0.9829, 0.4966, ..., 1.4756, 0.3811, 2.7070]],
device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
from triton.runtime import driver
@triton.jit
def matrix_dropout_kernel(x_ptr, output_ptr, num_rows, num_cols, stride_rows, stride_cols, p, seed, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
row_start = tl.program_id(0)
row_step = tl.num_programs(0)
for row_idx in tl.range(row_start, num_rows, row_step, num_stages=num_stages):
# formation of input pointer
row_start_ptr = x_ptr + row_idx * stride_rows
offsets = tl.arange(0, BLOCK_SIZE)
input_ptr = row_start_ptr + offsets
mask = offsets < num_cols
x = tl.load(input_ptr, mask=mask, other=float('inf'))
# actual execution
random_tensor = tl.rand(seed+row_idx, offsets)
x_keep = random_tensor > p
output = tl.where(x_keep, x/(1-p),0.0)
# formation of pointer
output_row_start_ptr = output_ptr + row_idx * stride_rows
output_ptr_vector = output_row_start_ptr + offsets
tl.store(output_ptr_vector, output, mask=mask)
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
def matrix_dropout(x):
n_rows , n_cols =x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
# num_wraps = 8
num_stages = 4 if SIZE_SMEM > 200000 else 2
y = torch.empty_like(x)
p = 0.5
seed = 123
kernel = matrix_dropout_kernel.warmup(x, y, n_rows, n_cols, x.stride(0), x.stride(1), p, seed, BLOCK_SIZE, num_stages = num_stages, grid=(1,))
kernel._init_handles()
size_smem = kernel.metadata.shared
num_programs = n_rows
kernel[(num_programs, 1, 1)](x, y, n_rows, n_cols, x.stride(0), x.stride(1), p, seed, BLOCK_SIZE, num_stages)
return y
x = torch.randn(5, 5, device=DEVICE)
output = matrix_dropout(x)
print(output)
tensor([[ 0.0000, 1.0012, 0.0000, 0.0000, 0.0000],
[-2.5976, 0.6637, -3.1776, 2.0022, 2.9186],
[ 1.4053, 0.0000, 2.2871, 1.0678, -0.1450],
[ 1.0221, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 1.2551, 0.0000, 0.0000, 0.0000]], device='cuda:0')
@triton.jit
def _layer_norm_fwd_fused(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
Y += row * stride
X += row * stride
# Compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Write mean / rstd
tl.store(Mean + row, mean)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
b = tl.load(B + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w + b
# Write output
tl.store(Y + cols, y, mask=mask)
@triton.jit
def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient
DY, # pointer to the output gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
X, # pointer to the input
W, # pointer to the weights
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
Lock, # pointer to the lock
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# Map the program id to the elements of X, DX, and DY it should compute.
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE_N)
mask = cols < N
X += row * stride
DY += row * stride
DX += row * stride
# Offset locks and weights/biases gradient pointer for parallel reduction
lock_id = row % GROUP_SIZE_M
Lock += lock_id
Count = Lock + GROUP_SIZE_M
DW = DW + lock_id * N + cols
DB = DB + lock_id * N + cols
# Load data to SRAM
x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
w = tl.load(W + cols, mask=mask).to(tl.float32)
mean = tl.load(Mean + row)
rstd = tl.load(Rstd + row)
# Compute dx
xhat = (x - mean) * rstd
wdy = w * dy
xhat = tl.where(mask, xhat, 0.)
wdy = tl.where(mask, wdy, 0.)
c1 = tl.sum(xhat * wdy, axis=0) / N
c2 = tl.sum(wdy, axis=0) / N
dx = (wdy - (xhat * c1 + c2)) * rstd
# Write dx
tl.store(DX + cols, dx, mask=mask)
# Accumulate partial sums for dw/db
partial_dw = (dy * xhat).to(w.dtype)
partial_db = (dy).to(w.dtype)
while tl.atomic_cas(Lock, 0, 1) == 1:
pass
count = tl.load(Count)
# First store doesn't accumulate
if count == 0:
tl.atomic_xchg(Count, 1)
else:
partial_dw += tl.load(DW, mask=mask)
partial_db += tl.load(DB, mask=mask)
tl.store(DW, partial_dw, mask=mask)
tl.store(DB, partial_db, mask=mask)
# need a barrier to ensure all threads finished before
# releasing the lock
tl.debug_barrier()
# Release the lock
tl.atomic_xchg(Lock, 0)
@triton.jit
def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
FINAL_DW, # pointer to the weights gradient
FINAL_DB, # pointer to the biases gradient
M, # GROUP_SIZE_M
N, # number of columns
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
# Map the program id to the elements of DW and DB it should compute.
pid = tl.program_id(0)
cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Iterate through the rows of DW and DB to sum the partial sums.
for i in range(0, M, BLOCK_SIZE_M):
rows = i + tl.arange(0, BLOCK_SIZE_M)
mask = (rows[:, None] < M) & (cols[None, :] < N)
offs = rows[:, None] * N + cols[None, :]
dw += tl.load(DW + offs, mask=mask, other=0.)
db += tl.load(DB + offs, mask=mask, other=0.)
# Write the final sum to the output.
sum_dw = tl.sum(dw, axis=0)
sum_db = tl.sum(db, axis=0)
tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
tl.store(FINAL_DB + cols, sum_db, mask=cols < N)class LayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx, x, normalized_shape, weight, bias, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
mean = torch.empty((M, ), dtype=torch.float32, device=x.device)
rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_SIZE:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
# enqueue kernel
_layer_norm_fwd_fused[(M, )]( #
x_arg, y, weight, bias, mean, rstd, #
x_arg.stride(0), N, eps, #
BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
ctx.save_for_backward(x, weight, bias, mean, rstd)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_warps = num_warps
ctx.eps = eps
return y
@staticmethod
def backward(ctx, dy):
x, w, b, m, v = ctx.saved_tensors
# heuristics for amount of parallel reduction stream for DW/DB
N = w.shape[0]
GROUP_SIZE_M = 64
if N <= 8192: GROUP_SIZE_M = 96
if N <= 4096: GROUP_SIZE_M = 128
if N <= 1024: GROUP_SIZE_M = 256
# allocate output
locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device)
_dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)
_db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)
dw = torch.empty((N, ), dtype=w.dtype, device=w.device)
db = torch.empty((N, ), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
x_arg = x.reshape(-1, x.shape[-1])
M, N = x_arg.shape
_layer_norm_bwd_dx_fused[(M, )]( #
dx, dy, _dw, _db, x, w, m, v, locks, #
x_arg.stride(0), N, #
BLOCK_SIZE_N=ctx.BLOCK_SIZE, #
GROUP_SIZE_M=GROUP_SIZE_M, #
num_warps=ctx.num_warps)
grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE_N']), )
# accumulate partial sums in separate kernel
_layer_norm_bwd_dwdb[grid](
_dw, _db, dw, db, min(GROUP_SIZE_M, M), N, #
BLOCK_SIZE_M=32, #
BLOCK_SIZE_N=128, num_ctas=1)
return dx, None, dw, db, None
layer_norm = LayerNorm.applydef test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE):
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device)
dy = .1 * torch.randn_like(x)
x.requires_grad_(True)
# forward pass
y_tri = layer_norm(x, w_shape, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
# backward pass (triton)
y_tri.backward(dy, retain_graph=True)
dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
x.grad, weight.grad, bias.grad = None, None, None
# backward pass (torch)
y_ref.backward(dy, retain_graph=True)
dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
# compare
assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0)
assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0)
assert torch.allclose(dw_tri, dw_ref, atol=1e-2, rtol=0)
print(f'Triton gradient for x : {dx_tri}')
print(f'Pytorch gradient for x : {dx_ref}')
print(f'Triton gradient for bias : {db_tri}')
print(f'Pytorch gradient for bias : {db_ref}')
print(f'Triton gradient for weights : {dw_tri}')
print(f'Pytorch gradient for weights : {dw_ref}')test_layer_norm(1151, 8192, torch.float16)Triton gradient for x : tensor([[ 0.0339, 0.2026, 0.1421, ..., -0.0506, -0.0554, -0.0737],
[-0.0833, -0.0019, -0.1012, ..., -0.0261, -0.0986, 0.0994],
[-0.1628, 0.0720, 0.0975, ..., -0.0108, -0.0043, -0.1440],
...,
[ 0.0421, -0.1232, -0.2598, ..., 0.0523, -0.0726, 0.0177],
[-0.0780, 0.0089, 0.1704, ..., 0.0391, -0.0670, -0.0617],
[-0.0776, -0.0699, 0.0343, ..., -0.1139, 0.0667, 0.1949]],
device='cuda:0', dtype=torch.float16)
Pytorch gradient for x : tensor([[ 0.0339, 0.2026, 0.1421, ..., -0.0506, -0.0554, -0.0737],
[-0.0833, -0.0019, -0.1012, ..., -0.0261, -0.0986, 0.0994],
[-0.1628, 0.0720, 0.0975, ..., -0.0108, -0.0043, -0.1440],
...,
[ 0.0421, -0.1232, -0.2598, ..., 0.0523, -0.0726, 0.0177],
[-0.0780, 0.0089, 0.1704, ..., 0.0391, -0.0670, -0.0617],
[-0.0776, -0.0699, 0.0343, ..., -0.1139, 0.0667, 0.1949]],
device='cuda:0', dtype=torch.float16)
Triton gradient for bias : tensor([ 1.4854, 0.3247, 6.0469, ..., -1.7324, -0.5493, 3.8535],
device='cuda:0', dtype=torch.float16)
Pytorch gradient for bias : tensor([ 1.4834, 0.3274, 6.0469, ..., -1.7305, -0.5498, 3.8516],
device='cuda:0', dtype=torch.float16)
Triton gradient for weights : tensor([ 0.4792, 2.4023, -5.3047, ..., 0.9570, -4.5078, -2.9238],
device='cuda:0', dtype=torch.float16)
Pytorch gradient for weights : tensor([ 0.4805, 2.4023, -5.3086, ..., 0.9585, -4.5039, -2.9258],
device='cuda:0', dtype=torch.float16)