Projects & Blogs
LATEST TRANSFORMER
PROBABLISTIC APPR...
MINI-UNET
BASICS OF TRITON
MINI-ALEXNET
PYTHON CODE GENER...
SEQUENTIAL MONTE ...
TRUNCATED SVD
CUSTOM DATALOADER...
PROBABILITY
%pip install tritonRequirement already satisfied: triton in /usr/local/lib/python3.12/dist-packages (3.4.0)
Requirement already satisfied: setuptools>=40.8.0 in /usr/local/lib/python3.12/dist-packages (from triton) (75.2.0)
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()
def get_autotune_config():
  return [triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
                      num_warps=8)]
@triton.autotune(
    configs=get_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
                   BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr):
  # this is use to idenify the pid of the kernel
  pid = tl.program_id(axis=0)
  # this is use to identify the output pid size
  num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
  num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  num_pid_in_group = GROUP_SIZE_M*num_pid_n
  group_id = pid // num_pid_in_group
  first_pid_m = group_id * GROUP_SIZE_M
  group_size_m = min(num_pid_m-first_pid_m, GROUP_SIZE_M)
  pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
  pid_n = (pid % num_pid_in_group) // group_size_m
  offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
  offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
  offs_k = tl.arange(0, BLOCK_SIZE_K)
  a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
  b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
  accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
      # Load the next block of A and B, generate a mask by checking the K dimension.
      # If it is out of bounds, set it to 0.
      a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
      b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
      # We accumulate along the K dimension.
      accumulator = tl.dot(a, b, accumulator)
      # Advance the ptrs to the next K block.
      a_ptrs += BLOCK_SIZE_K * stride_ak
      b_ptrs += BLOCK_SIZE_K * stride_bk
  # You can fuse arbitrary activation functions here
  # while the accumulator is still in FP32!
  if ACTIVATION == "leaky_relu":
      accumulator = leaky_relu(accumulator)
  c = accumulator.to(tl.float16)
  offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
  c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
  tl.store(c_ptrs, c, mask=c_mask)
@triton.jit
def leaky_relu(x):
    return tl.where(x >= 0, x, 0.01 * x)
def matmul(a, b, activation=""):
    # Check constraints.
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    # Allocates output.
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    # 1D launch kernel where each block gets its own program.
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,  #
        M, N, K,  #
        a.stride(0), a.stride(1),  #
        b.stride(0), b.stride(1),  #
        c.stride(0), c.stride(1),  #
        ACTIVATION=activation  #
    )
    return c
torch.manual_seed(0)
a = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
b = torch.rand((512, 512), device=DEVICE, dtype=torch.float16) - 0.5
triton_output = matmul(a, b)
torch_output = torch.matmul(a, b)
print(f"triton_output_with_fp16_inputs={triton_output}")
print(f"torch_output_with_fp16_inputs={torch_output}")
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")triton_output_with_fp16_inputs=tensor([[-0.5210,  2.8730, -1.0684,  ..., -0.6885,  0.5825,  0.4065],
        [ 2.3477, -1.5615,  1.9453,  ..., -0.5010, -1.8594, -2.7246],
        [ 0.3354, -1.3828,  2.9043,  ..., -0.8320, -1.8623,  3.4531],
        ...,
        [ 2.0586,  1.3125, -4.1484,  ..., -2.4297,  1.2734, -1.3037],
        [-0.3135, -2.9375, -1.8770,  ...,  1.1973, -1.7500,  4.0312],
        [ 0.3804, -0.9829,  0.4966,  ...,  1.4756,  0.3811,  2.7070]],
       device='cuda:0', dtype=torch.float16)
torch_output_with_fp16_inputs=tensor([[-0.5210,  2.8730, -1.0684,  ..., -0.6885,  0.5825,  0.4065],
        [ 2.3477, -1.5615,  1.9453,  ..., -0.5010, -1.8594, -2.7246],
        [ 0.3354, -1.3828,  2.9043,  ..., -0.8320, -1.8623,  3.4531],
        ...,
        [ 2.0586,  1.3125, -4.1484,  ..., -2.4297,  1.2734, -1.3037],
        [-0.3135, -2.9375, -1.8770,  ...,  1.1973, -1.7500,  4.0312],
        [ 0.3804, -0.9829,  0.4966,  ...,  1.4756,  0.3811,  2.7070]],
       device='cuda:0', dtype=torch.float16)
✅ Triton and Torch match
from triton.runtime import driver
@triton.jit
def matrix_dropout_kernel(x_ptr, output_ptr, num_rows, num_cols, stride_rows, stride_cols, p, seed, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr):
  row_start = tl.program_id(0)
  row_step = tl.num_programs(0)
  for row_idx in tl.range(row_start, num_rows, row_step, num_stages=num_stages):
    # formation of input pointer
    row_start_ptr = x_ptr + row_idx * stride_rows
    offsets = tl.arange(0, BLOCK_SIZE)
    input_ptr = row_start_ptr + offsets
    mask = offsets < num_cols
    x = tl.load(input_ptr, mask=mask, other=float('inf'))
    # actual execution
    random_tensor = tl.rand(seed+row_idx, offsets)
    x_keep = random_tensor > p
    output = tl.where(x_keep, x/(1-p),0.0)
    # formation of pointer
    output_row_start_ptr = output_ptr + row_idx * stride_rows
    output_ptr_vector = output_row_start_ptr + offsets
    tl.store(output_ptr_vector, output, mask=mask)
properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
def matrix_dropout(x):
  n_rows , n_cols =x.shape
  BLOCK_SIZE = triton.next_power_of_2(n_cols)
  # num_wraps = 8
  num_stages = 4 if SIZE_SMEM > 200000 else 2
  y = torch.empty_like(x)
  p = 0.5
  seed = 123
  kernel = matrix_dropout_kernel.warmup(x, y, n_rows, n_cols, x.stride(0), x.stride(1), p, seed, BLOCK_SIZE, num_stages = num_stages, grid=(1,))
  kernel._init_handles()
  size_smem = kernel.metadata.shared
  num_programs = n_rows
  kernel[(num_programs, 1, 1)](x, y, n_rows, n_cols, x.stride(0), x.stride(1), p, seed, BLOCK_SIZE, num_stages)
  return y
x = torch.randn(5, 5, device=DEVICE)
output = matrix_dropout(x)
print(output)
tensor([[ 0.0000,  1.0012,  0.0000,  0.0000,  0.0000],
        [-2.5976,  0.6637, -3.1776,  2.0022,  2.9186],
        [ 1.4053,  0.0000,  2.2871,  1.0678, -0.1450],
        [ 1.0221,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  1.2551,  0.0000,  0.0000,  0.0000]], device='cuda:0')
@triton.jit
def _layer_norm_fwd_fused(
    X,  # pointer to the input
    Y,  # pointer to the output
    W,  # pointer to the weights
    B,  # pointer to the biases
    Mean,  # pointer to the mean
    Rstd,  # pointer to the 1/std
    stride,  # how much to increase the pointer when moving by 1 row
    N,  # number of columns in X
    eps,  # epsilon to avoid division by zero
    BLOCK_SIZE: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    Y += row * stride
    X += row * stride
    # Compute mean
    mean = 0
    _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        _mean += a
    mean = tl.sum(_mean, axis=0) / N
    # Compute variance
    _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
        x = tl.where(cols < N, x - mean, 0.)
        _var += x * x
    var = tl.sum(_var, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    # Write mean / rstd
    tl.store(Mean + row, mean)
    tl.store(Rstd + row, rstd)
    # Normalize and apply linear transformation
    for off in range(0, N, BLOCK_SIZE):
        cols = off + tl.arange(0, BLOCK_SIZE)
        mask = cols < N
        w = tl.load(W + cols, mask=mask)
        b = tl.load(B + cols, mask=mask)
        x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
        x_hat = (x - mean) * rstd
        y = x_hat * w + b
        # Write output
        tl.store(Y + cols, y, mask=mask)
@triton.jit
def _layer_norm_bwd_dx_fused(DX,  # pointer to the input gradient
                             DY,  # pointer to the output gradient
                             DW,  # pointer to the partial sum of weights gradient
                             DB,  # pointer to the partial sum of biases gradient
                             X,  # pointer to the input
                             W,  # pointer to the weights
                             Mean,  # pointer to the mean
                             Rstd,  # pointer to the 1/std
                             Lock,  # pointer to the lock
                             stride,  # how much to increase the pointer when moving by 1 row
                             N,  # number of columns in X
                             GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
    # Map the program id to the elements of X, DX, and DY it should compute.
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK_SIZE_N)
    mask = cols < N
    X += row * stride
    DY += row * stride
    DX += row * stride
    # Offset locks and weights/biases gradient pointer for parallel reduction
    lock_id = row % GROUP_SIZE_M
    Lock += lock_id
    Count = Lock + GROUP_SIZE_M
    DW = DW + lock_id * N + cols
    DB = DB + lock_id * N + cols
    # Load data to SRAM
    x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
    dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
    w = tl.load(W + cols, mask=mask).to(tl.float32)
    mean = tl.load(Mean + row)
    rstd = tl.load(Rstd + row)
    # Compute dx
    xhat = (x - mean) * rstd
    wdy = w * dy
    xhat = tl.where(mask, xhat, 0.)
    wdy = tl.where(mask, wdy, 0.)
    c1 = tl.sum(xhat * wdy, axis=0) / N
    c2 = tl.sum(wdy, axis=0) / N
    dx = (wdy - (xhat * c1 + c2)) * rstd
    # Write dx
    tl.store(DX + cols, dx, mask=mask)
    # Accumulate partial sums for dw/db
    partial_dw = (dy * xhat).to(w.dtype)
    partial_db = (dy).to(w.dtype)
    while tl.atomic_cas(Lock, 0, 1) == 1:
        pass
    count = tl.load(Count)
    # First store doesn't accumulate
    if count == 0:
        tl.atomic_xchg(Count, 1)
    else:
        partial_dw += tl.load(DW, mask=mask)
        partial_db += tl.load(DB, mask=mask)
    tl.store(DW, partial_dw, mask=mask)
    tl.store(DB, partial_db, mask=mask)
    # need a barrier to ensure all threads finished before
    # releasing the lock
    tl.debug_barrier()
    # Release the lock
    tl.atomic_xchg(Lock, 0)
@triton.jit
def _layer_norm_bwd_dwdb(DW,  # pointer to the partial sum of weights gradient
                         DB,  # pointer to the partial sum of biases gradient
                         FINAL_DW,  # pointer to the weights gradient
                         FINAL_DB,  # pointer to the biases gradient
                         M,  # GROUP_SIZE_M
                         N,  # number of columns
                         BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):
    # Map the program id to the elements of DW and DB it should compute.
    pid = tl.program_id(0)
    cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
    # Iterate through the rows of DW and DB to sum the partial sums.
    for i in range(0, M, BLOCK_SIZE_M):
        rows = i + tl.arange(0, BLOCK_SIZE_M)
        mask = (rows[:, None] < M) & (cols[None, :] < N)
        offs = rows[:, None] * N + cols[None, :]
        dw += tl.load(DW + offs, mask=mask, other=0.)
        db += tl.load(DB + offs, mask=mask, other=0.)
    # Write the final sum to the output.
    sum_dw = tl.sum(dw, axis=0)
    sum_db = tl.sum(db, axis=0)
    tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)
    tl.store(FINAL_DB + cols, sum_db, mask=cols < N)class LayerNorm(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, normalized_shape, weight, bias, eps):
        # allocate output
        y = torch.empty_like(x)
        # reshape input data into 2D tensor
        x_arg = x.reshape(-1, x.shape[-1])
        M, N = x_arg.shape
        mean = torch.empty((M, ), dtype=torch.float32, device=x.device)
        rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)
        # Less than 64KB per feature: enqueue fused kernel
        MAX_FUSED_SIZE = 65536 // x.element_size()
        BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
        if N > BLOCK_SIZE:
            raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
        # heuristics for number of warps
        num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
        # enqueue kernel
        _layer_norm_fwd_fused[(M, )](  #
            x_arg, y, weight, bias, mean, rstd,  #
            x_arg.stride(0), N, eps,  #
            BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)
        ctx.save_for_backward(x, weight, bias, mean, rstd)
        ctx.BLOCK_SIZE = BLOCK_SIZE
        ctx.num_warps = num_warps
        ctx.eps = eps
        return y
    @staticmethod
    def backward(ctx, dy):
        x, w, b, m, v = ctx.saved_tensors
        # heuristics for amount of parallel reduction stream for DW/DB
        N = w.shape[0]
        GROUP_SIZE_M = 64
        if N <= 8192: GROUP_SIZE_M = 96
        if N <= 4096: GROUP_SIZE_M = 128
        if N <= 1024: GROUP_SIZE_M = 256
        # allocate output
        locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device)
        _dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)
        _db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)
        dw = torch.empty((N, ), dtype=w.dtype, device=w.device)
        db = torch.empty((N, ), dtype=w.dtype, device=w.device)
        dx = torch.empty_like(dy)
        # enqueue kernel using forward pass heuristics
        # also compute partial sums for DW and DB
        x_arg = x.reshape(-1, x.shape[-1])
        M, N = x_arg.shape
        _layer_norm_bwd_dx_fused[(M, )](  #
            dx, dy, _dw, _db, x, w, m, v, locks,  #
            x_arg.stride(0), N,  #
            BLOCK_SIZE_N=ctx.BLOCK_SIZE,  #
            GROUP_SIZE_M=GROUP_SIZE_M,  #
            num_warps=ctx.num_warps)
        grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE_N']), )
        # accumulate partial sums in separate kernel
        _layer_norm_bwd_dwdb[grid](
            _dw, _db, dw, db, min(GROUP_SIZE_M, M), N,  #
            BLOCK_SIZE_M=32,  #
            BLOCK_SIZE_N=128, num_ctas=1)
        return dx, None, dw, db, None
layer_norm = LayerNorm.applydef test_layer_norm(M, N, dtype, eps=1e-5, device=DEVICE):
    # create data
    x_shape = (M, N)
    w_shape = (x_shape[-1], )
    weight = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    bias = torch.rand(w_shape, dtype=dtype, device=device, requires_grad=True)
    x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device)
    dy = .1 * torch.randn_like(x)
    x.requires_grad_(True)
    # forward pass
    y_tri = layer_norm(x, w_shape, weight, bias, eps)
    y_ref = torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps).to(dtype)
    # backward pass (triton)
    y_tri.backward(dy, retain_graph=True)
    dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
    x.grad, weight.grad, bias.grad = None, None, None
    # backward pass (torch)
    y_ref.backward(dy, retain_graph=True)
    dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
    # compare
    assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
    assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0)
    assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0)
    assert torch.allclose(dw_tri, dw_ref, atol=1e-2, rtol=0)
    print(f'Triton gradient for x : {dx_tri}')
    print(f'Pytorch gradient for x : {dx_ref}')
    print(f'Triton gradient for bias : {db_tri}')
    print(f'Pytorch gradient for bias : {db_ref}')
    print(f'Triton gradient for weights : {dw_tri}')
    print(f'Pytorch gradient for weights : {dw_ref}')test_layer_norm(1151, 8192, torch.float16)Triton gradient for x : tensor([[ 0.0339,  0.2026,  0.1421,  ..., -0.0506, -0.0554, -0.0737],
        [-0.0833, -0.0019, -0.1012,  ..., -0.0261, -0.0986,  0.0994],
        [-0.1628,  0.0720,  0.0975,  ..., -0.0108, -0.0043, -0.1440],
        ...,
        [ 0.0421, -0.1232, -0.2598,  ...,  0.0523, -0.0726,  0.0177],
        [-0.0780,  0.0089,  0.1704,  ...,  0.0391, -0.0670, -0.0617],
        [-0.0776, -0.0699,  0.0343,  ..., -0.1139,  0.0667,  0.1949]],
       device='cuda:0', dtype=torch.float16)
Pytorch gradient for x : tensor([[ 0.0339,  0.2026,  0.1421,  ..., -0.0506, -0.0554, -0.0737],
        [-0.0833, -0.0019, -0.1012,  ..., -0.0261, -0.0986,  0.0994],
        [-0.1628,  0.0720,  0.0975,  ..., -0.0108, -0.0043, -0.1440],
        ...,
        [ 0.0421, -0.1232, -0.2598,  ...,  0.0523, -0.0726,  0.0177],
        [-0.0780,  0.0089,  0.1704,  ...,  0.0391, -0.0670, -0.0617],
        [-0.0776, -0.0699,  0.0343,  ..., -0.1139,  0.0667,  0.1949]],
       device='cuda:0', dtype=torch.float16)
Triton gradient for bias : tensor([ 1.4854,  0.3247,  6.0469,  ..., -1.7324, -0.5493,  3.8535],
       device='cuda:0', dtype=torch.float16)
Pytorch gradient for bias : tensor([ 1.4834,  0.3274,  6.0469,  ..., -1.7305, -0.5498,  3.8516],
       device='cuda:0', dtype=torch.float16)
Triton gradient for weights : tensor([ 0.4792,  2.4023, -5.3047,  ...,  0.9570, -4.5078, -2.9238],
       device='cuda:0', dtype=torch.float16)
Pytorch gradient for weights : tensor([ 0.4805,  2.4023, -5.3086,  ...,  0.9585, -4.5039, -2.9258],
       device='cuda:0', dtype=torch.float16)