Projects & Blogs
LATEST TRANSFORMER
PROBABLISTIC APPR...
MINI-UNET
BASICS OF TRITON
TRITON FUSED ATTE...
MINI-ALEXNET
PYTHON CODE GENER...
SEQUENTIAL MONTE ...
TRUNCATED SVD
CUSTOM DATALOADER...
PROBABILITY
import torch
import os
DEVICE = "cude" if torch.cuda.is_available() else 'cpu'
## DEMO of how offsets are masked
offs_m = torch.arange(10, 20)
offs_n = 2+torch.arange(0, 5)
mask = (offs_m[:, None] >= offs_n[None, :])
print(f"Offset of n ({offs_n.shape}): {offs_n}")
print(f"Offset of m ({offs_m.shape}) : {offs_m}")
print(f"Mask ({mask.shape}) : {mask}")
Offset of n (torch.Size([5])): tensor([2, 3, 4, 5, 6])
Offset of m (torch.Size([10])) : tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
Mask (torch.Size([10, 5])) : tensor([[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True],
[True, True, True, True, True]])
def _attn_fwd_inner(acc, l_i, m_i, q, #
desc_k, desc_v, #
offset_y, dtype, start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
STAGE, offs_m, offs_n, #
N_CTX, warp_specialize):
# range of values handled by this stage
print(STAGE)
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
# lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
offsetk_y = offset_y + lo
#dtype == tl.float8e5
if dtype == torch.float8_e5m2:
offsetv_y = offset_y * HEAD_DIM + lo
else:
offsetv_y = offset_y + lo
# loop over k, v and update accumulator
# TODO what is warp_specialize in tl.range ?
# for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
for start_n in torch.arange(lo, hi, BLOCK_N):
# start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
print(f"start_n : {start_n} : offset of k [{offsetk_y},0]")
# k = desc_k.load([offsetk_y, 0]).T
# qk = tl.dot(q, k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
print(f"mask is used {mask}")
# qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
# m_ij = tl.maximum(m_i, tl.max(qk, 1))
# qk -= m_ij[:, None]
else:
print("no mask used")
# m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
# qk = qk * qk_scale - m_ij[:, None]
# p = tl.math.exp2(qk)
# -- compute correction factor
# alpha = tl.math.exp2(m_i - m_ij)
# l_ij = tl.sum(p, 1)
# -- update output accumulator --
# acc = acc * alpha[:, None]
# prepare p and v for the dot
print(f"Offset of v [0, {offsetv_y}]")
# if dtype == tl.float8e5:
# v = desc_v.load([0, offsetv_y]).T
# else:
# v = desc_v.load([offsetv_y, 0])
# p = p.to(dtype)
# note that this non transposed v for FP8 is only supported on Blackwell
# acc = tl.dot(p, v, acc)
# update m_i and l_i
# place this at the end of the loop to reduce register pressure
# l_i = l_i * alpha + l_ij
# m_i = m_ij
offsetk_y += BLOCK_N
offsetv_y += BLOCK_N
# return acc, l_i, m_i
print("----inner-----")
def _attn_fwd(sm_scale, M, Z, H, desc_q, desc_k, desc_v, desc_o,
HEAD_DIM, #
BLOCK_M, #
BLOCK_N, #
FP8_OUTPUT, #
STAGE, #
warp_specialize, N_CTX):
dtype = torch.float16
assert BLOCK_N <= HEAD_DIM
# start_m = tl.program_id(0)
# off_hz = tl.program_id(1)
for start_m in range(16):
for off_hz in range(32):
print(f"start_m : {start_m}, off_hz : {off_hz}")
off_z = off_hz // H
off_h = off_hz % H
print(f"off_z : {off_z},off_h : {off_h}")
y_dim = Z * H * N_CTX
print(f"y_dim : {y_dim}")
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
print(f"offset_y : {offset_y}")
qo_offset_y = offset_y + start_m * BLOCK_M
print(f"qo_offset_y : {qo_offset_y}")
# initialize offsets
offs_m = start_m * BLOCK_M + torch.arange(0, BLOCK_M)
offs_n = torch.arange(0, BLOCK_N)
print(f"offs_m : {offs_m}")
print(f"offs_n : {offs_n}")
# initialize pointer to m and l
m_i = torch.zeros([BLOCK_M], dtype=torch.float32) - float("inf")
l_i = torch.zeros([BLOCK_M], dtype=torch.float32) + 1.0
acc = torch.zeros([BLOCK_M, HEAD_DIM], dtype=torch.float32)
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)
# q = desc_q.load([qo_offset_y, 0])
print(f"q load : {[qo_offset_y, 0]}")
# q = torch.randn([qo_offset_y, 0], dtype=dtype, device=DEVICE, requires_grad=True)
# # print(f"q : {q}")
if STAGE & 1:
_attn_fwd_inner(acc, l_i, m_i, q, #
desc_k, desc_v, #
offset_y, dtype, start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX, #
warp_specialize)
# stage 2: on-band
if STAGE & 2:
_attn_fwd_inner(acc, l_i, m_i, q, #
desc_k, desc_v, #
offset_y, dtype, start_m, qk_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
2, offs_m, offs_n, N_CTX, #
warp_specialize)
# epilogue
# m_i += tl.math.log2(l_i)
# acc = acc / l_i[:, None]
# m_ptrs = M + off_hz * N_CTX + offs_m
# tl.store(m_ptrs, m_i)
# desc_o.store([qo_offset_y, 0], acc.to(dtype))
if off_hz == 1:
break
if start_m == 1:
print("------------------------------")
return
def attention(q, k, v, causal, sm_scale, warp_specialize=True):
# shape constraints
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-1]
o = torch.empty_like(q)
stage = 3 if causal else 1
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
desc_q = q
desc_v = v
desc_k = k
desc_o = o
BLOCK_M = 64 #128
BLOCK_N = 32 # 64 #128
grid = (q.shape[2]// BLOCK_M, q.shape[0] * q.shape[1], 1)
print(f"grid : {grid}")
_attn_fwd(sm_scale, M, q.shape[0], q.shape[1], #
desc_q, desc_k, desc_v, desc_o, #
N_CTX=q.shape[2], #
HEAD_DIM=HEAD_DIM_K, #
BLOCK_M=BLOCK_M, #
BLOCK_N=BLOCK_N, #
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
STAGE=stage, #
warp_specialize=warp_specialize)
sm_scale = 0.5
causal = True
warp_specialize = False #True
BATCH = 4 #4
H = 8 #2
N_CTX = 1024 #1024
HEAD_DIM = 512 #64
dtype = torch.float16
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
v = v.permute(0, 1, 3, 2).contiguous()
v = v.permute(0, 1, 3, 2)
v = v.to(torch.float8_e5m2)
print(f"V Shape : {v.shape}")
attention(q, k ,v, causal, sm_scale, warp_specialize)
V Shape : torch.Size([4, 8, 1024, 512])
grid : (16, 32, 1)
start_m : 0, off_hz : 0
off_z : 0,off_h : 0
y_dim : 32768
offset_y : 0
qo_offset_y : 0
offs_m : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
offs_n : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
q load : [0, 0]
1
----inner-----
2
start_n : 0 : offset of k [0,0]
mask is used tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of v [0, 0]
start_n : 32 : offset of k [32,0]
mask is used tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
Offset of v [0, 32]
----inner-----
start_m : 0, off_hz : 1
off_z : 0,off_h : 1
y_dim : 32768
offset_y : 1024
qo_offset_y : 1024
offs_m : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
offs_n : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
q load : [1024, 0]
1
----inner-----
2
start_n : 0 : offset of k [1024,0]
mask is used tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of v [0, 1024]
start_n : 32 : offset of k [1056,0]
mask is used tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
Offset of v [0, 1056]
----inner-----
start_m : 1, off_hz : 0
off_z : 0,off_h : 0
y_dim : 32768
offset_y : 0
qo_offset_y : 64
offs_m : tensor([ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,
92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
120, 121, 122, 123, 124, 125, 126, 127])
offs_n : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
q load : [64, 0]
1
start_n : 0 : offset of k [0,0]
no mask used
Offset of v [0, 0]
start_n : 32 : offset of k [32,0]
no mask used
Offset of v [0, 32]
----inner-----
2
start_n : 64 : offset of k [64,0]
mask is used tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of v [0, 64]
start_n : 96 : offset of k [96,0]
mask is used tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
Offset of v [0, 96]
----inner-----
start_m : 1, off_hz : 1
off_z : 0,off_h : 1
y_dim : 32768
offset_y : 1024
qo_offset_y : 1088
offs_m : tensor([ 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91,
92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105,
106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
120, 121, 122, 123, 124, 125, 126, 127])
offs_n : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
q load : [1088, 0]
1
start_n : 0 : offset of k [1024,0]
no mask used
Offset of v [0, 1024]
start_n : 32 : offset of k [1056,0]
no mask used
Offset of v [0, 1056]
----inner-----
2
start_n : 64 : offset of k [1088,0]
mask is used tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of v [0, 1088]
start_n : 96 : offset of k [1120,0]
mask is used tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
Offset of v [0, 1120]
----inner-----
------------------------------
def _attn_bwd_preprocess(O, DO, #
Delta, #
Z, H, N_CTX, #
BLOCK_M, HEAD_DIM #
):
# Since PRE_BLOCK = BLOCK_M
for pre_bloc_by_nctx in range(N_CTX // BLOCK_M):
# BATCH = Z & N_HEAD = H
for off_hz in range(Z * H):
# If Dimension is 4x8x1024x512 then this is running 1024x512 where we are chunking them in 128x152 so pre-block = 8 and for each batch and each head so 8x4 = 32
print(f"tl-program-ids : {pre_bloc_by_nctx} of {N_CTX // BLOCK_M}, {off_hz} of {Z * H}")
off_m = pre_bloc_by_nctx * BLOCK_M + torch.arange(0, BLOCK_M)
off_n = torch.arange(0, HEAD_DIM)
# load
offset_o_do = off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
print(f"Offsets of o and do {offset_o_do.shape} : {offset_o_do}")
# o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
# do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
# delta = tl.sum(o * do, axis=1)
# write-back
delta_offset = off_hz * N_CTX + off_m
# delta is the sum of all hidden dimensions so its dimension is 4x8x1024 per 128 block of ctx
print(f"Delta {delta_offset.shape} : {delta_offset}")
# tl.store(Delta + off_hz * N_CTX + off_m, delta)
if off_hz == 1:
break
if pre_bloc_by_nctx == 1:
break
def _attn_bwd_dq(dq, q, K, V, #
do, m, D,
# shared by Q/K/V/DO.
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, #
BLOCK_N2, #
HEAD_DIM,
# Filled in by the wrapper.
start_m, start_n, num_steps, #
MASK):
print(f"-------------_attn_bwd_dq--start_m--{start_m}--start_n--{start_n}----MASK-{MASK}-----")
offs_m = start_m + torch.arange(0, BLOCK_M2)
offs_n = start_n + torch.arange(0, BLOCK_N2)
offs_k = torch.arange(0, HEAD_DIM)
offset_kT = offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
print(f"Offset of kT : {offset_kT.shape}")
kT_ptrs = torch.randn((16,512))
# kT_ptrs = K + offset_kT
offset_vT = offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
print(f"Offset of vT : {offset_vT.shape}")
# vT_ptrs = V + offset_vT
vT_ptrs = torch.randn((512,16))
# D (= delta) is pre-divided by ds_scale.
print(f"Di Offset : {offs_m.shape}")
# Di = torch.load(D + offs_m)
Di = torch.randn((128,))
curr_n = start_n
step_n = BLOCK_N2
for blk_idx in range(num_steps):
# kT = tl.load(kT_ptrs)
# vT = tl.load(vT_ptrs)
# kT = kT_ptrs
# vT = vT_ptrs
# qk = torch.dot(q, kT)
# p = tl.math.exp2(qk - m)
# Autoregressive masking.
offs_n = curr_n + torch.arange(0, BLOCK_N2)
print(f"Offset of n ({offs_n.shape}): {offs_n}")
print(f"Offset of m ({offs_m.shape}) : {offs_m}")
if MASK:
mask = (offs_m[:, None] >= offs_n[None, :])
print(f"Mask ({mask.shape}) : {mask}")
# p = torch.where(mask, p, 0.0)
# Compute dP and dS.
# dp = tl.dot(do, vT).to(tl.float32)
# ds = p * (dp - Di[:, None])
# ds = ds.to(tl.float16)
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
# dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
curr_n += step_n
# kT_ptrs += step_n * stride_tok
# vT_ptrs += step_n * stride_tok
# return dq
print(f"-------------end of_attn_bwd_dq-------")
def _attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
# shared by Q/K/V/DO.
stride_tok, stride_d, #
H, N_CTX, BLOCK_M1, #
BLOCK_N1, #
HEAD_DIM, #
# Filled in by the wrapper.
start_n, start_m, num_steps, #
MASK):
print(f"--------_attn_bwd_dkdv-(MASK : {MASK})-------------")
offs_m = start_m + torch.arange(0, BLOCK_M1)
offs_n = start_n + torch.arange(0, BLOCK_N1)
offs_k = torch.arange(0, HEAD_DIM)
qT_offset = offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
do_offset = offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
# print(f"({start_m}, {start_n}) : qT_offset({qT_offset.shape}) : {qT_offset} : do_offset({do_offset.shape}) : {do_offset}")
qT_ptrs = torch.randn((512,16)) if MASK else torch.randn((512,32))
do_ptrs = torch.randn((512,16)) if MASK else torch.randn((512,32))
# # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
# # tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
curr_m = start_m
step_m = BLOCK_M1
for blk_idx in range(num_steps):
# # qT = tl.load(qT_ptrs)
# qT = qT_ptrs
# # Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + torch.arange(0, BLOCK_M1)
print(f"Current Step : {blk_idx} : offs_m({offs_m.shape}) : {offs_m}")
# # m = tl.load(M + offs_m)
# m = torch.randn(((32,)))
# qkT = tl.dot(k, qT)
# pT = tl.math.exp2(qkT - m[None, :])
# # Autoregressive masking.
if MASK:
mask = (offs_m[None, :] >= offs_n[:, None])
print(f"Mask ({mask.shape}) : {mask}")
# pT = tl.where(mask, pT, 0.0)
# do = tl.load(do_ptrs)
# # Compute dV.
# ppT = pT
# ppT = ppT.to(tl.float16)
# dv += tl.dot(ppT, do)
# # D (= delta) is pre-divided by ds_scale.
# Di = tl.load(D + offs_m)
# # Compute dP and dS.
# dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
# dsT = pT * (dpT - Di[None, :])
# dsT = dsT.to(tl.float16)
# dk += tl.dot(dsT, tl.trans(qT))
# # Increment pointers.
curr_m += step_m
# qT_ptrs += step_m * stride_tok
# do_ptrs += step_m * stride_tok
if blk_idx == 1:
break
# return dk, dv
def _attn_bwd(Q, K, V, sm_scale, #
DO, #
DQ, DK, DV, #
M, D,
# shared by Q/K/V/DO.
stride_z, stride_h, stride_tok, stride_d, #
H, N_CTX, #
causal,
BLOCK_M1, #
BLOCK_N1, #
BLOCK_M2, #
BLOCK_N2, #
BLK_SLICE_FACTOR, #
HEAD_DIM):
LN2 = 0.6931471824645996 # = ln(2)
# bhid = tl.program_id(2)
for bhid in range(32):
# pid = tl.program_id(0)
for pid in range(8):
off_chz = (bhid * N_CTX)
row = bhid // H
col = bhid % H
# adj will move along per BxH
adj = (stride_h * col + stride_z * row)
print(f"(bhid,pid to 32,8) : ({bhid},{pid}) : Row : {row}, Col : {col}, off_chz : {off_chz} adj : {adj}")
offs_k = torch.arange(0, HEAD_DIM)
start_n = pid * BLOCK_N1
start_m = start_n
MASK_BLOCK_M1 = BLOCK_M1 // BLK_SLICE_FACTOR
# print(f"MASK_BLOCK_M1 : {MASK_BLOCK_M1}")
offs_n = start_n + torch.arange(0, BLOCK_N1)
dv = torch.zeros([BLOCK_N1, HEAD_DIM], dtype=torch.float32)
dk = torch.zeros([BLOCK_N1, HEAD_DIM], dtype=torch.float32)
# load K and V: they stay in SRAM throughout the inner loop.
# k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
# v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
offset_k_v = offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
print(f"Offsets({offset_k_v.shape}) of k or v {offset_k_v}")
num_steps = BLOCK_N1 // MASK_BLOCK_M1
print(f"MASK : True : start_m : {start_m} : Num Steps : {num_steps}")
# Mask is true
_attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
start_n, start_m, num_steps, #
MASK=True #
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
print(f"MASK : False : start_m : {start_m} : Num Steps : {num_steps}")
# # Compute dK and dV for non-masked blocks.
# # Mask is false
_attn_bwd_dkdv( #
dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M1, BLOCK_N1, HEAD_DIM, #
start_n, start_m, num_steps, #
MASK=False #
)
print("-------End of _attn_bwd_dkdv-----------")
dq_offset = offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
print(f"Offset of DV : {dq_offset}")
# dv_ptrs = DV + dq_offset
# tl.store(dv_ptrs, dv)
# # Write back dK.
# dk *= sm_scale
dk_offset = offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
print(f"Offset of DK : {dk_offset}")
# # dk_ptrs = DK + dk_offset
# # tl.store(dk_ptrs, dk)
# # THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
MASK_BLOCK_N2 = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + torch.arange(0, BLOCK_M2)
print(f"Offset of m ({offs_m.shape}) : {offs_m}")
offs_q = offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
print(f"Offset of q ({offs_q.shape}) : {offs_q}")
# q = tl.load(Q + offs_q)
dq = torch.zeros([BLOCK_M2, HEAD_DIM], dtype=torch.float32)
offs_do = offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
print(f"Offset of do ({offs_do.shape}) : {offs_do}")
# do = tl.load(DO + offs_do)
# # m = tl.load(M + offs_m)
m = torch.randn((128,))
m = m[:, None]
# # Compute dQ for masked (diagonal) blocks.
# # NOTE: This code scans each row of QK^T backward (from right to left,
# # but inside each call to _attn_bwd_dq, from left to right), but that's
# # not due to anything important. I just wanted to reuse the loop
# # structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
print(f"({bhid},{pid}) MASK : True : start_m : {start_m} start_n : {end_n - num_steps * MASK_BLOCK_N2} : Num Steps : {num_steps}")
_attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
MASK=True #
)
if causal:
end_n = start_m
else:
end_n -= num_steps * MASK_BLOCK_N2
print(f"end_n : {end_n}")
# # stage 2
num_steps = end_n // BLOCK_N2
print(f"({bhid},{pid}) MASK : False : start_m : {start_m} start_n : {end_n - num_steps * BLOCK_N2} : Num Steps : {num_steps}")
_attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, BLOCK_N2, HEAD_DIM, #
start_m, end_n - num_steps * BLOCK_N2, num_steps, #
MASK=False #
)
# Write back dQ.
# dq_offset = offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
# print(f"Offset of DQ : {dq_offset}")
# dq_ptrs = DQ + dq_offset
# Here while using softmax we used 2^x instead of e^x so when we do a derivative of it we need to mulitply it with (ln 2)
# as derivative of e^x is e^x but derivative of 2^x is (ln 2).2^x
# as dQ = d(qk)K = dS K so dS propotional to (ln 2)2^(qk)
# dq *= LN2
# tl.store(dq_ptrs, dq)
print("-------End of _attn_bwd-----------")
if pid == 2:
break
if bhid == 2:
return
sm_scale = 0.5
causal = True
warp_specialize = False #True
BATCH = 4 #4
H = 8 #2
N_CTX = 1024 #1024
HEAD_DIM = 512#32 #64
dtype = torch.float16
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
# q = q.to(torch.float8_e5m2)
# k = k.to(torch.float8_e5m2)
v = v.permute(0, 1, 3, 2).contiguous()
v = v.permute(0, 1, 3, 2)
# v = v.to(torch.float8_e5m2)
print(f"V Shape : {v.shape}")
o = torch.randn_like(q)
do = torch.randn_like(o)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (sm_scale * RCP_LN2)
PRE_BLOCK = 128
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
print(f"Pre Grid : {pre_grid}")
M = torch.randn((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
delta = torch.empty_like(M)
print("-------Start of _attn_bwd_preprocess-----------")
_attn_bwd_preprocess(
o, do, #
delta, #
BATCH, N_HEAD, N_CTX, #
BLOCK_M=PRE_BLOCK, HEAD_DIM=HEAD_DIM #
)
print("-------End of _attn_bwd_preprocess-----------")
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
print(f"Grid (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD): {grid}")
print(f"Strides of q : {q.stride(0)}, {q.stride(1)}, {q.stride(2)}, {q.stride(3)}")
print("-------Start of _attn_bwd-----------")
_attn_bwd(
q, arg_k, v, sm_scale, do, dq, dk, dv, #
M, delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
N_HEAD, N_CTX, #
causal,
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=HEAD_DIM, #
# num_warps=NUM_WARPS, #
# num_stages=NUM_STAGES #
)V Shape : torch.Size([4, 8, 1024, 512])
Pre Grid : (8, 32)
-------Start of _attn_bwd_preprocess-----------
tl-program-ids : 0 of 8, 0 of 32
Offsets of o and do torch.Size([128, 512]) : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Delta torch.Size([128]) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
tl-program-ids : 0 of 8, 1 of 32
Offsets of o and do torch.Size([128, 512]) : tensor([[524288, 524289, 524290, ..., 524797, 524798, 524799],
[524800, 524801, 524802, ..., 525309, 525310, 525311],
[525312, 525313, 525314, ..., 525821, 525822, 525823],
...,
[588288, 588289, 588290, ..., 588797, 588798, 588799],
[588800, 588801, 588802, ..., 589309, 589310, 589311],
[589312, 589313, 589314, ..., 589821, 589822, 589823]])
Delta torch.Size([128]) : tensor([1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035,
1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047,
1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059,
1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071,
1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083,
1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095,
1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107,
1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1116, 1117, 1118, 1119,
1120, 1121, 1122, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131,
1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143,
1144, 1145, 1146, 1147, 1148, 1149, 1150, 1151])
tl-program-ids : 1 of 8, 0 of 32
Offsets of o and do torch.Size([128, 512]) : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Delta torch.Size([128]) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
tl-program-ids : 1 of 8, 1 of 32
Offsets of o and do torch.Size([128, 512]) : tensor([[589824, 589825, 589826, ..., 590333, 590334, 590335],
[590336, 590337, 590338, ..., 590845, 590846, 590847],
[590848, 590849, 590850, ..., 591357, 591358, 591359],
...,
[653824, 653825, 653826, ..., 654333, 654334, 654335],
[654336, 654337, 654338, ..., 654845, 654846, 654847],
[654848, 654849, 654850, ..., 655357, 655358, 655359]])
Delta torch.Size([128]) : tensor([1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163,
1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175,
1176, 1177, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187,
1188, 1189, 1190, 1191, 1192, 1193, 1194, 1195, 1196, 1197, 1198, 1199,
1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211,
1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1220, 1221, 1222, 1223,
1224, 1225, 1226, 1227, 1228, 1229, 1230, 1231, 1232, 1233, 1234, 1235,
1236, 1237, 1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247,
1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259,
1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269, 1270, 1271,
1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279])
-------End of _attn_bwd_preprocess-----------
Grid (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD): (8, 1, 32)
Strides of q : 4194304, 524288, 512, 1
-------Start of _attn_bwd-----------
(bhid,pid to 32,8) : (0,0) : Row : 0, Col : 0, off_chz : 0 adj : 0
Offsets(torch.Size([128, 512])) of k or v tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
MASK : True : start_m : 0 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
MASK : False : start_m : 128 : Num Steps : 28
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
188, 189, 190, 191])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Offset of DK : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Offset of q (torch.Size([128, 512])) : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Offset of do (torch.Size([128, 512])) : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
(0,0) MASK : True : start_m : 0 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
110, 111])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
-------------end of_attn_bwd_dq-------
end_n : 0
(0,0) MASK : False : start_m : 0 start_n : 0 : Num Steps : 0
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (0,1) : Row : 0, Col : 0, off_chz : 0 adj : 0
Offsets(torch.Size([128, 512])) of k or v tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
MASK : True : start_m : 128 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
158, 159])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
MASK : False : start_m : 256 : Num Steps : 24
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
316, 317, 318, 319])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Offset of DK : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of q (torch.Size([128, 512])) : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Offset of do (torch.Size([128, 512])) : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
(0,1) MASK : True : start_m : 128 start_n : 128 : Num Steps : 8
-------------_attn_bwd_dq--start_m--128--start_n--128----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
158, 159])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
174, 175])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
206, 207])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,
222, 223])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
238, 239])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
-------------end of_attn_bwd_dq-------
end_n : 128
(0,1) MASK : False : start_m : 128 start_n : 0 : Num Steps : 4
-------------_attn_bwd_dq--start_m--128--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of n (torch.Size([32])): tensor([ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (0,2) : Row : 0, Col : 0, off_chz : 0 adj : 0
Offsets(torch.Size([128, 512])) of k or v tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
MASK : True : start_m : 256 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
286, 287])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
MASK : False : start_m : 384 : Num Steps : 20
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397,
398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411,
412, 413, 414, 415])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429,
430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443,
444, 445, 446, 447])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
Offset of DK : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of q (torch.Size([128, 512])) : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
Offset of do (torch.Size([128, 512])) : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
(0,2) MASK : True : start_m : 256 start_n : 256 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--256----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
286, 287])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
302, 303])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317,
318, 319])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333,
334, 335])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349,
350, 351])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
366, 367])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
-------------end of_attn_bwd_dq-------
end_n : 256
(0,2) MASK : False : start_m : 256 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
188, 189, 190, 191])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
220, 221, 222, 223])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251,
252, 253, 254, 255])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (1,0) : Row : 0, Col : 1, off_chz : 1024 adj : 524288
Offsets(torch.Size([128, 512])) of k or v tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
MASK : True : start_m : 0 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
MASK : False : start_m : 128 : Num Steps : 28
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
188, 189, 190, 191])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Offset of DK : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Offset of q (torch.Size([128, 512])) : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Offset of do (torch.Size([128, 512])) : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
(1,0) MASK : True : start_m : 0 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
110, 111])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
-------------end of_attn_bwd_dq-------
end_n : 0
(1,0) MASK : False : start_m : 0 start_n : 0 : Num Steps : 0
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (1,1) : Row : 0, Col : 1, off_chz : 1024 adj : 524288
Offsets(torch.Size([128, 512])) of k or v tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
MASK : True : start_m : 128 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
158, 159])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
MASK : False : start_m : 256 : Num Steps : 24
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
316, 317, 318, 319])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Offset of DK : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of q (torch.Size([128, 512])) : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Offset of do (torch.Size([128, 512])) : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
(1,1) MASK : True : start_m : 128 start_n : 128 : Num Steps : 8
-------------_attn_bwd_dq--start_m--128--start_n--128----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
158, 159])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
174, 175])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
206, 207])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,
222, 223])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
238, 239])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
-------------end of_attn_bwd_dq-------
end_n : 128
(1,1) MASK : False : start_m : 128 start_n : 0 : Num Steps : 4
-------------_attn_bwd_dq--start_m--128--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of n (torch.Size([32])): tensor([ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (1,2) : Row : 0, Col : 1, off_chz : 1024 adj : 524288
Offsets(torch.Size([128, 512])) of k or v tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
MASK : True : start_m : 256 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
286, 287])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
MASK : False : start_m : 384 : Num Steps : 20
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397,
398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411,
412, 413, 414, 415])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429,
430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443,
444, 445, 446, 447])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
Offset of DK : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of q (torch.Size([128, 512])) : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
Offset of do (torch.Size([128, 512])) : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
(1,2) MASK : True : start_m : 256 start_n : 256 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--256----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
286, 287])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
302, 303])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317,
318, 319])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333,
334, 335])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349,
350, 351])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
366, 367])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
-------------end of_attn_bwd_dq-------
end_n : 256
(1,2) MASK : False : start_m : 256 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
188, 189, 190, 191])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
220, 221, 222, 223])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251,
252, 253, 254, 255])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (2,0) : Row : 0, Col : 2, off_chz : 2048 adj : 1048576
Offsets(torch.Size([128, 512])) of k or v tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
MASK : True : start_m : 0 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
MASK : False : start_m : 128 : Num Steps : 28
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
188, 189, 190, 191])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Offset of DK : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Offset of q (torch.Size([128, 512])) : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
Offset of do (torch.Size([128, 512])) : tensor([[ 0, 1, 2, ..., 509, 510, 511],
[ 512, 513, 514, ..., 1021, 1022, 1023],
[ 1024, 1025, 1026, ..., 1533, 1534, 1535],
...,
[64000, 64001, 64002, ..., 64509, 64510, 64511],
[64512, 64513, 64514, ..., 65021, 65022, 65023],
[65024, 65025, 65026, ..., 65533, 65534, 65535]])
(2,0) MASK : True : start_m : 0 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
110, 111])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Offset of m (torch.Size([128])) : tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
-------------end of_attn_bwd_dq-------
end_n : 0
(2,0) MASK : False : start_m : 0 start_n : 0 : Num Steps : 0
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (2,1) : Row : 0, Col : 2, off_chz : 2048 adj : 1048576
Offsets(torch.Size([128, 512])) of k or v tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
MASK : True : start_m : 128 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
158, 159])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
MASK : False : start_m : 256 : Num Steps : 24
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
316, 317, 318, 319])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Offset of DK : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of q (torch.Size([128, 512])) : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
Offset of do (torch.Size([128, 512])) : tensor([[ 65536, 65537, 65538, ..., 66045, 66046, 66047],
[ 66048, 66049, 66050, ..., 66557, 66558, 66559],
[ 66560, 66561, 66562, ..., 67069, 67070, 67071],
...,
[129536, 129537, 129538, ..., 130045, 130046, 130047],
[130048, 130049, 130050, ..., 130557, 130558, 130559],
[130560, 130561, 130562, ..., 131069, 131070, 131071]])
(2,1) MASK : True : start_m : 128 start_n : 128 : Num Steps : 8
-------------_attn_bwd_dq--start_m--128--start_n--128----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
158, 159])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
174, 175])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
190, 191])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
206, 207])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,
222, 223])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
238, 239])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
-------------end of_attn_bwd_dq-------
end_n : 128
(2,1) MASK : False : start_m : 128 start_n : 0 : Num Steps : 4
-------------_attn_bwd_dq--start_m--128--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
Offset of n (torch.Size([32])): tensor([ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
254, 255])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (2,2) : Row : 0, Col : 2, off_chz : 2048 adj : 1048576
Offsets(torch.Size([128, 512])) of k or v tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
MASK : True : start_m : 256 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
286, 287])
Mask (torch.Size([128, 16])) : tensor([[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]])
MASK : False : start_m : 384 : Num Steps : 20
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397,
398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411,
412, 413, 414, 415])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429,
430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443,
444, 445, 446, 447])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
Offset of DK : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of q (torch.Size([128, 512])) : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
Offset of do (torch.Size([128, 512])) : tensor([[131072, 131073, 131074, ..., 131581, 131582, 131583],
[131584, 131585, 131586, ..., 132093, 132094, 132095],
[132096, 132097, 132098, ..., 132605, 132606, 132607],
...,
[195072, 195073, 195074, ..., 195581, 195582, 195583],
[195584, 195585, 195586, ..., 196093, 196094, 196095],
[196096, 196097, 196098, ..., 196605, 196606, 196607]])
(2,2) MASK : True : start_m : 256 start_n : 256 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--256----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False, ..., False, False, False],
[ True, True, False, ..., False, False, False],
[ True, True, True, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
286, 287])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
302, 303])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317,
318, 319])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333,
334, 335])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349,
350, 351])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
366, 367])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., True, True, True]])
Offset of n (torch.Size([16])): tensor([368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[ True, True, True, ..., True, False, False],
[ True, True, True, ..., True, True, False],
[ True, True, True, ..., True, True, True]])
-------------end of_attn_bwd_dq-------
end_n : 256
(2,2) MASK : False : start_m : 256 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
156, 157, 158, 159])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
188, 189, 190, 191])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
220, 221, 222, 223])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
Offset of n (torch.Size([32])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251,
252, 253, 254, 255])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
382, 383])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------