Projects & Blogs

  • LATEST TRANSFORMER

  • PROBABLISTIC APPR...

  • MINI-UNET

  • BASICS OF TRITON

  • TRITON FUSED ATTE...

  • MINI-ALEXNET

  • PYTHON CODE GENER...

  • SEQUENTIAL MONTE ...

  • TRUNCATED SVD

  • CUSTOM DATALOADER...

  • PROBABILITY

import torch
import os

DEVICE = "cude" if torch.cuda.is_available() else 'cpu'

## DEMO of how offsets are masked

offs_m = torch.arange(10, 20)
offs_n = 2+torch.arange(0, 5)
mask = (offs_m[:, None] >= offs_n[None, :])
print(f"Offset of n ({offs_n.shape}): {offs_n}")
print(f"Offset of m ({offs_m.shape}) : {offs_m}")
print(f"Mask ({mask.shape}) : {mask}")
Offset of n (torch.Size([5])): tensor([2, 3, 4, 5, 6])
Offset of m (torch.Size([10])) : tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
Mask (torch.Size([10, 5])) : tensor([[True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True],
        [True, True, True, True, True]])

def _attn_fwd_inner(acc, l_i, m_i, q,  #
                    desc_k, desc_v,  #
                    offset_y, dtype, start_m, qk_scale,  #
                    BLOCK_M, HEAD_DIM, BLOCK_N,  #
                    STAGE, offs_m, offs_n,  #
                    N_CTX, warp_specialize):
    # range of values handled by this stage
    print(STAGE)
    if STAGE == 1:
        lo, hi = 0, start_m * BLOCK_M
    elif STAGE == 2:
        lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
        # lo = tl.multiple_of(lo, BLOCK_M)
    # causal = False
    else:
        lo, hi = 0, N_CTX
    offsetk_y = offset_y + lo

    #dtype == tl.float8e5
    if dtype == torch.float8_e5m2:
        offsetv_y = offset_y * HEAD_DIM + lo
    else:
        offsetv_y = offset_y + lo

    # loop over k, v and update accumulator
    # TODO what is warp_specialize in tl.range ?
    # for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
    for start_n in torch.arange(lo, hi, BLOCK_N):
        # start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        print(f"start_n : {start_n} : offset of k [{offsetk_y},0]")
        # k = desc_k.load([offsetk_y, 0]).T
        # qk = tl.dot(q, k)
        if STAGE == 2:
            mask = offs_m[:, None] >= (start_n + offs_n[None, :])
            print(f"mask is used {mask}")
            # qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
            # m_ij = tl.maximum(m_i, tl.max(qk, 1))
            # qk -= m_ij[:, None]
        else:
            print("no mask used")
            # m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
            # qk = qk * qk_scale - m_ij[:, None]
        # p = tl.math.exp2(qk)
        # -- compute correction factor
        # alpha = tl.math.exp2(m_i - m_ij)
        # l_ij = tl.sum(p, 1)
        # -- update output accumulator --
        # acc = acc * alpha[:, None]
        # prepare p and v for the dot
        print(f"Offset of v [0, {offsetv_y}]")
        # if dtype == tl.float8e5:
        #     v = desc_v.load([0, offsetv_y]).T
        # else:
        #     v = desc_v.load([offsetv_y, 0])
        # p = p.to(dtype)
        # note that this non transposed v for FP8 is only supported on Blackwell
        # acc = tl.dot(p, v, acc)
        # update m_i and l_i
        # place this at the end of the loop to reduce register pressure
    #     l_i = l_i * alpha + l_ij
    #     m_i = m_ij
        offsetk_y += BLOCK_N
        offsetv_y += BLOCK_N
    # return acc, l_i, m_i
    print("----inner-----")

def _attn_fwd(sm_scale, M, Z, H, desc_q, desc_k, desc_v, desc_o,
              HEAD_DIM,  #
              BLOCK_M,  #
              BLOCK_N,  #
              FP8_OUTPUT,  #
              STAGE,  #
              warp_specialize, N_CTX):
  dtype = torch.float16
  assert BLOCK_N <= HEAD_DIM
  # start_m = tl.program_id(0)
  # off_hz = tl.program_id(1)
  for start_m in range(16):
    for off_hz in range(32):
      print(f"start_m : {start_m}, off_hz : {off_hz}")
      off_z = off_hz // H
      off_h = off_hz % H
      print(f"off_z : {off_z},off_h : {off_h}")
      y_dim = Z * H * N_CTX
      print(f"y_dim : {y_dim}")
      offset_y = off_z * (N_CTX * H) + off_h * N_CTX
      print(f"offset_y : {offset_y}")
      qo_offset_y = offset_y + start_m * BLOCK_M
      print(f"qo_offset_y : {qo_offset_y}")
      # initialize offsets
      offs_m = start_m * BLOCK_M + torch.arange(0, BLOCK_M)
      offs_n = torch.arange(0, BLOCK_N)
      print(f"offs_m : {offs_m}")
      print(f"offs_n : {offs_n}")
      # initialize pointer to m and l
      m_i = torch.zeros([BLOCK_M], dtype=torch.float32) - float("inf")
      l_i = torch.zeros([BLOCK_M], dtype=torch.float32) + 1.0
      acc = torch.zeros([BLOCK_M, HEAD_DIM], dtype=torch.float32)
      # load scales
      qk_scale = sm_scale
      qk_scale *= 1.44269504  # 1/log(2)
      # q = desc_q.load([qo_offset_y, 0])
      print(f"q load : {[qo_offset_y, 0]}")
      # q = torch.randn([qo_offset_y, 0], dtype=dtype, device=DEVICE, requires_grad=True)
      # # print(f"q : {q}")
      if STAGE & 1:
                  _attn_fwd_inner(acc, l_i, m_i, q,  #
                                  desc_k, desc_v,  #
                                  offset_y, dtype, start_m, qk_scale,  #
                                  BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                  4 - STAGE, offs_m, offs_n, N_CTX,  #
                                  warp_specialize)
      # stage 2: on-band
      if STAGE & 2:
                  _attn_fwd_inner(acc, l_i, m_i, q,  #
                                  desc_k, desc_v,  #
                                  offset_y, dtype, start_m, qk_scale,  #
                                  BLOCK_M, HEAD_DIM, BLOCK_N,  #
                                  2, offs_m, offs_n, N_CTX,  #
                                  warp_specialize)
      # epilogue
      # m_i += tl.math.log2(l_i)
      # acc = acc / l_i[:, None]
      # m_ptrs = M + off_hz * N_CTX + offs_m
      # tl.store(m_ptrs, m_i)
      # desc_o.store([qo_offset_y, 0], acc.to(dtype))
      if off_hz == 1:
          break
    if start_m == 1:
      print("------------------------------")
      return


def attention(q, k, v, causal, sm_scale, warp_specialize=True):
  # shape constraints
  HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
  # when v is in float8_e5m2 it is transposed.
  HEAD_DIM_V = v.shape[-1]
  o = torch.empty_like(q)
  stage = 3 if causal else 1
  M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
  desc_q = q
  desc_v = v
  desc_k = k
  desc_o = o
  BLOCK_M = 64 #128
  BLOCK_N = 32 # 64 #128
  grid = (q.shape[2]// BLOCK_M, q.shape[0] * q.shape[1], 1)
  print(f"grid : {grid}")
  _attn_fwd(sm_scale, M, q.shape[0], q.shape[1],  #
            desc_q, desc_k, desc_v, desc_o,  #
            N_CTX=q.shape[2],  #
            HEAD_DIM=HEAD_DIM_K,  #
            BLOCK_M=BLOCK_M,  #
            BLOCK_N=BLOCK_N,  #
            FP8_OUTPUT=q.dtype == torch.float8_e5m2,  #
            STAGE=stage,  #
            warp_specialize=warp_specialize)

sm_scale = 0.5
causal = True
warp_specialize = False #True
BATCH = 4 #4
H = 8 #2
N_CTX = 1024 #1024
HEAD_DIM = 512 #64
dtype = torch.float16
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)

q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
v = v.permute(0, 1, 3, 2).contiguous()
v = v.permute(0, 1, 3, 2)
v = v.to(torch.float8_e5m2)
print(f"V Shape : {v.shape}")

attention(q, k ,v, causal, sm_scale, warp_specialize)


V Shape : torch.Size([4, 8, 1024, 512])
grid : (16, 32, 1)
start_m : 0, off_hz : 0
off_z : 0,off_h : 0
y_dim : 32768
offset_y : 0
qo_offset_y : 0
offs_m : tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
offs_n : tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
q load : [0, 0]
1
----inner-----
2
start_n : 0 : offset of k [0,0]
mask is used tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of v [0, 0]
start_n : 32 : offset of k [32,0]
mask is used tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of v [0, 32]
----inner-----
start_m : 0, off_hz : 1
off_z : 0,off_h : 1
y_dim : 32768
offset_y : 1024
qo_offset_y : 1024
offs_m : tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
offs_n : tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
q load : [1024, 0]
1
----inner-----
2
start_n : 0 : offset of k [1024,0]
mask is used tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of v [0, 1024]
start_n : 32 : offset of k [1056,0]
mask is used tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of v [0, 1056]
----inner-----
start_m : 1, off_hz : 0
off_z : 0,off_h : 0
y_dim : 32768
offset_y : 0
qo_offset_y : 64
offs_m : tensor([ 64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
         92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105,
        106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
        120, 121, 122, 123, 124, 125, 126, 127])
offs_n : tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
q load : [64, 0]
1
start_n : 0 : offset of k [0,0]
no mask used
Offset of v [0, 0]
start_n : 32 : offset of k [32,0]
no mask used
Offset of v [0, 32]
----inner-----
2
start_n : 64 : offset of k [64,0]
mask is used tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of v [0, 64]
start_n : 96 : offset of k [96,0]
mask is used tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of v [0, 96]
----inner-----
start_m : 1, off_hz : 1
off_z : 0,off_h : 1
y_dim : 32768
offset_y : 1024
qo_offset_y : 1088
offs_m : tensor([ 64,  65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
         92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103, 104, 105,
        106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
        120, 121, 122, 123, 124, 125, 126, 127])
offs_n : tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
q load : [1088, 0]
1
start_n : 0 : offset of k [1024,0]
no mask used
Offset of v [0, 1024]
start_n : 32 : offset of k [1056,0]
no mask used
Offset of v [0, 1056]
----inner-----
2
start_n : 64 : offset of k [1088,0]
mask is used tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of v [0, 1088]
start_n : 96 : offset of k [1120,0]
mask is used tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of v [0, 1120]
----inner-----
------------------------------
def _attn_bwd_preprocess(O, DO,  #
                         Delta,  #
                         Z, H, N_CTX,  #
                         BLOCK_M, HEAD_DIM  #
                         ):
  # Since PRE_BLOCK = BLOCK_M
  for pre_bloc_by_nctx in range(N_CTX // BLOCK_M):
    # BATCH = Z & N_HEAD = H
    for off_hz in range(Z * H):
      # If Dimension is 4x8x1024x512 then this is running 1024x512 where we are chunking them in 128x152 so pre-block = 8 and for each batch and each head so 8x4 = 32
      print(f"tl-program-ids : {pre_bloc_by_nctx} of {N_CTX // BLOCK_M}, {off_hz} of {Z * H}")
      off_m = pre_bloc_by_nctx * BLOCK_M + torch.arange(0, BLOCK_M)
      off_n = torch.arange(0, HEAD_DIM)
      # load
      offset_o_do = off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
      print(f"Offsets of o and do {offset_o_do.shape} : {offset_o_do}")
      # o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
      # do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
      # delta = tl.sum(o * do, axis=1)
      # write-back
      delta_offset = off_hz * N_CTX + off_m
      # delta is the sum of all hidden dimensions so its dimension is 4x8x1024 per 128 block of ctx
      print(f"Delta {delta_offset.shape} : {delta_offset}")
      # tl.store(Delta + off_hz * N_CTX + off_m, delta)
      if off_hz == 1:
        break
    if pre_bloc_by_nctx == 1:
        break


def _attn_bwd_dq(dq, q, K, V,  #
                 do, m, D,
                 # shared by Q/K/V/DO.
                 stride_tok, stride_d,  #
                 H, N_CTX,  #
                 BLOCK_M2,  #
                 BLOCK_N2,  #
                 HEAD_DIM,
                 # Filled in by the wrapper.
                 start_m, start_n, num_steps,  #
                 MASK):
    print(f"-------------_attn_bwd_dq--start_m--{start_m}--start_n--{start_n}----MASK-{MASK}-----")
    offs_m = start_m + torch.arange(0, BLOCK_M2)
    offs_n = start_n + torch.arange(0, BLOCK_N2)
    offs_k = torch.arange(0, HEAD_DIM)
    offset_kT = offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
    print(f"Offset of kT : {offset_kT.shape}")
    kT_ptrs = torch.randn((16,512))
    # kT_ptrs = K + offset_kT
    offset_vT = offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
    print(f"Offset of vT : {offset_vT.shape}")
    # vT_ptrs = V + offset_vT
    vT_ptrs = torch.randn((512,16))
    # D (= delta) is pre-divided by ds_scale.
    print(f"Di Offset : {offs_m.shape}")
    # Di = torch.load(D + offs_m)
    Di = torch.randn((128,))
    curr_n = start_n
    step_n = BLOCK_N2
    for blk_idx in range(num_steps):
        # kT = tl.load(kT_ptrs)
        # vT = tl.load(vT_ptrs)
        # kT = kT_ptrs
        # vT = vT_ptrs
        # qk = torch.dot(q, kT)
        # p = tl.math.exp2(qk - m)
        # Autoregressive masking.
        offs_n = curr_n + torch.arange(0, BLOCK_N2)
        print(f"Offset of n ({offs_n.shape}): {offs_n}")
        print(f"Offset of m ({offs_m.shape}) : {offs_m}")
        if MASK:
            mask = (offs_m[:, None] >= offs_n[None, :])
            print(f"Mask ({mask.shape}) : {mask}")
            # p = torch.where(mask, p, 0.0)
        # Compute dP and dS.
        # dp = tl.dot(do, vT).to(tl.float32)
        # ds = p * (dp - Di[:, None])
        # ds = ds.to(tl.float16)
        # Compute dQ.
        # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
        # dq += tl.dot(ds, tl.trans(kT))
        # Increment pointers.
        curr_n += step_n
        # kT_ptrs += step_n * stride_tok
        # vT_ptrs += step_n * stride_tok
    # return dq
    print(f"-------------end of_attn_bwd_dq-------")

def _attn_bwd_dkdv(dk, dv,  #
                   Q, k, v, sm_scale,  #
                   DO,  #
                   M, D,  #
                   # shared by Q/K/V/DO.
                   stride_tok, stride_d,  #
                   H, N_CTX, BLOCK_M1,  #
                   BLOCK_N1,  #
                   HEAD_DIM,  #
                   # Filled in by the wrapper.
                   start_n, start_m, num_steps,  #
                   MASK):
    print(f"--------_attn_bwd_dkdv-(MASK : {MASK})-------------")
    offs_m = start_m + torch.arange(0, BLOCK_M1)
    offs_n = start_n + torch.arange(0, BLOCK_N1)
    offs_k = torch.arange(0, HEAD_DIM)
    qT_offset = offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
    do_offset = offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d

    # print(f"({start_m}, {start_n}) : qT_offset({qT_offset.shape}) : {qT_offset} : do_offset({do_offset.shape}) : {do_offset}")
    qT_ptrs = torch.randn((512,16)) if MASK else torch.randn((512,32))
    do_ptrs = torch.randn((512,16)) if MASK else torch.randn((512,32))
    # # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
    # # tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
    curr_m = start_m
    step_m = BLOCK_M1
    for blk_idx in range(num_steps):
    #     # qT = tl.load(qT_ptrs)
          # qT = qT_ptrs
    #     # Load m before computing qk to reduce pipeline stall.
          offs_m = curr_m + torch.arange(0, BLOCK_M1)
          print(f"Current Step : {blk_idx} : offs_m({offs_m.shape}) : {offs_m}")
    #     # m = tl.load(M + offs_m)
          # m = torch.randn(((32,)))
    #     qkT = tl.dot(k, qT)
    #     pT = tl.math.exp2(qkT - m[None, :])
    #     # Autoregressive masking.
          if MASK:
              mask = (offs_m[None, :] >= offs_n[:, None])
              print(f"Mask ({mask.shape}) : {mask}")
    #         pT = tl.where(mask, pT, 0.0)
    #     do = tl.load(do_ptrs)
    #     # Compute dV.
    #     ppT = pT
    #     ppT = ppT.to(tl.float16)
    #     dv += tl.dot(ppT, do)
    #     # D (= delta) is pre-divided by ds_scale.
    #     Di = tl.load(D + offs_m)
    #     # Compute dP and dS.
    #     dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
    #     dsT = pT * (dpT - Di[None, :])
    #     dsT = dsT.to(tl.float16)
    #     dk += tl.dot(dsT, tl.trans(qT))
    #     # Increment pointers.
          curr_m += step_m
        # qT_ptrs += step_m * stride_tok
    #     do_ptrs += step_m * stride_tok
          if blk_idx == 1:
              break
    # return dk, dv

def _attn_bwd(Q, K, V, sm_scale,  #
              DO,  #
              DQ, DK, DV,  #
              M, D,
              # shared by Q/K/V/DO.
              stride_z, stride_h, stride_tok, stride_d,  #
              H, N_CTX,  #
              causal,
              BLOCK_M1,  #
              BLOCK_N1,  #
              BLOCK_M2,  #
              BLOCK_N2,  #
              BLK_SLICE_FACTOR,  #
              HEAD_DIM):
    LN2 = 0.6931471824645996  # = ln(2)
    # bhid = tl.program_id(2)
    for bhid in range(32):
    # pid = tl.program_id(0)
      for pid in range(8):
        off_chz = (bhid * N_CTX)
        row = bhid // H
        col = bhid % H
        # adj will move along per BxH
        adj = (stride_h * col + stride_z * row)
        print(f"(bhid,pid to 32,8) : ({bhid},{pid}) : Row : {row}, Col : {col}, off_chz : {off_chz} adj : {adj}")
        offs_k = torch.arange(0, HEAD_DIM)
        start_n = pid * BLOCK_N1
        start_m = start_n
        MASK_BLOCK_M1 = BLOCK_M1 // BLK_SLICE_FACTOR
        # print(f"MASK_BLOCK_M1 : {MASK_BLOCK_M1}")
        offs_n = start_n + torch.arange(0, BLOCK_N1)

        dv = torch.zeros([BLOCK_N1, HEAD_DIM], dtype=torch.float32)
        dk = torch.zeros([BLOCK_N1, HEAD_DIM], dtype=torch.float32)

        # load K and V: they stay in SRAM throughout the inner loop.
        # k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
        # v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
        offset_k_v = offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
        print(f"Offsets({offset_k_v.shape}) of k or v {offset_k_v}")
        num_steps = BLOCK_N1 // MASK_BLOCK_M1
        print(f"MASK : True : start_m : {start_m} : Num Steps : {num_steps}")
        # Mask is true
        _attn_bwd_dkdv(dk, dv,  #
                            Q, k, v, sm_scale,  #
                            DO,  #
                            M, D,  #
                            stride_tok, stride_d,  #
                            H, N_CTX,  #
                            MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM,  #
                            start_n, start_m, num_steps,  #
                            MASK=True  #
                            )
        start_m += num_steps * MASK_BLOCK_M1
        num_steps = (N_CTX - start_m) // BLOCK_M1
        print(f"MASK : False : start_m : {start_m} : Num Steps : {num_steps}")

        # # Compute dK and dV for non-masked blocks.
        # # Mask is false
        _attn_bwd_dkdv(  #
            dk, dv,  #
            Q, k, v, sm_scale,  #
            DO,  #
            M, D,  #
            stride_tok, stride_d,  #
            H, N_CTX,  #
            BLOCK_M1, BLOCK_N1, HEAD_DIM,  #
            start_n, start_m, num_steps,  #
            MASK=False  #
        )
        print("-------End of _attn_bwd_dkdv-----------")
        dq_offset = offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
        print(f"Offset of DV : {dq_offset}")
        # dv_ptrs = DV + dq_offset
        # tl.store(dv_ptrs, dv)

        # # Write back dK.
        # dk *= sm_scale
        dk_offset = offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
        print(f"Offset of DK : {dk_offset}")
        # # dk_ptrs = DK + dk_offset
        # # tl.store(dk_ptrs, dk)

        # # THIS BLOCK DOES DQ:
        start_m = pid * BLOCK_M2
        end_n = start_m + BLOCK_M2

        MASK_BLOCK_N2 = BLOCK_N2 // BLK_SLICE_FACTOR
        offs_m = start_m + torch.arange(0, BLOCK_M2)
        print(f"Offset of m ({offs_m.shape}) : {offs_m}")
        offs_q = offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
        print(f"Offset of q ({offs_q.shape}) : {offs_q}")
        # q = tl.load(Q + offs_q)
        dq = torch.zeros([BLOCK_M2, HEAD_DIM], dtype=torch.float32)
        offs_do = offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
        print(f"Offset of do ({offs_do.shape}) : {offs_do}")
        # do = tl.load(DO + offs_do)
        # # m = tl.load(M + offs_m)
        m = torch.randn((128,))
        m = m[:, None]

        # # Compute dQ for masked (diagonal) blocks.
        # # NOTE: This code scans each row of QK^T backward (from right to left,
        # # but inside each call to _attn_bwd_dq, from left to right), but that's
        # # not due to anything important.  I just wanted to reuse the loop
        # # structure for dK & dV above as much as possible.
        num_steps = BLOCK_M2 // MASK_BLOCK_N2
        print(f"({bhid},{pid}) MASK : True : start_m : {start_m} start_n : {end_n - num_steps * MASK_BLOCK_N2} : Num Steps : {num_steps}")
        _attn_bwd_dq(dq, q, K, V,  #
                          do, m, D,  #
                          stride_tok, stride_d,  #
                          H, N_CTX,  #
                          BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM,  #
                          start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,  #
                          MASK=True  #
                          )

        if causal:
          end_n = start_m
        else:
          end_n -= num_steps * MASK_BLOCK_N2

        print(f"end_n : {end_n}")
        # # stage 2
        num_steps = end_n // BLOCK_N2
        print(f"({bhid},{pid}) MASK : False : start_m : {start_m} start_n : {end_n - num_steps * BLOCK_N2} : Num Steps : {num_steps}")
        _attn_bwd_dq(dq, q, K, V,  #
                          do, m, D,  #
                          stride_tok, stride_d,  #
                          H, N_CTX,  #
                          BLOCK_M2, BLOCK_N2, HEAD_DIM,  #
                          start_m, end_n - num_steps * BLOCK_N2, num_steps,  #
                          MASK=False  #
                          )
        # Write back dQ.
        # dq_offset = offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
        # print(f"Offset of DQ : {dq_offset}")
        # dq_ptrs = DQ + dq_offset
        # Here while using softmax we used 2^x instead of e^x so when we do a derivative of it we need to mulitply it with (ln 2)
        # as derivative of e^x is e^x but derivative of 2^x is (ln 2).2^x
        # as dQ = d(qk)K = dS K so dS propotional to (ln 2)2^(qk)
        # dq *= LN2
        # tl.store(dq_ptrs, dq)
        print("-------End of _attn_bwd-----------")
        if pid == 2:
          break
      if bhid == 2:
        return


sm_scale = 0.5
causal = True
warp_specialize = False #True
BATCH = 4 #4
H = 8 #2
N_CTX = 1024 #1024
HEAD_DIM = 512#32 #64
dtype = torch.float16
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE, requires_grad=True)

# q = q.to(torch.float8_e5m2)
# k = k.to(torch.float8_e5m2)
v = v.permute(0, 1, 3, 2).contiguous()
v = v.permute(0, 1, 3, 2)
# v = v.to(torch.float8_e5m2)
print(f"V Shape : {v.shape}")
o = torch.randn_like(q)
do = torch.randn_like(o)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634  # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (sm_scale * RCP_LN2)
PRE_BLOCK = 128
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
print(f"Pre Grid : {pre_grid}")

M = torch.randn((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
delta = torch.empty_like(M)
print("-------Start of _attn_bwd_preprocess-----------")
_attn_bwd_preprocess(
            o, do,  #
            delta,  #
            BATCH, N_HEAD, N_CTX,  #
            BLOCK_M=PRE_BLOCK, HEAD_DIM=HEAD_DIM  #
        )
print("-------End of _attn_bwd_preprocess-----------")
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
print(f"Grid (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD): {grid}")
print(f"Strides of q : {q.stride(0)}, {q.stride(1)}, {q.stride(2)}, {q.stride(3)}")
print("-------Start of _attn_bwd-----------")
_attn_bwd(
            q, arg_k, v, sm_scale, do, dq, dk, dv,  #
            M, delta,  #
            q.stride(0), q.stride(1), q.stride(2), q.stride(3),  #
            N_HEAD, N_CTX,  #
            causal,
            BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1,  #
            BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2,  #
            BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,  #
            HEAD_DIM=HEAD_DIM,  #
            # num_warps=NUM_WARPS,  #
            # num_stages=NUM_STAGES  #
        )
V Shape : torch.Size([4, 8, 1024, 512])
Pre Grid : (8, 32)
-------Start of _attn_bwd_preprocess-----------
tl-program-ids : 0 of 8, 0 of 32
Offsets of o and do torch.Size([128, 512]) : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Delta torch.Size([128]) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
tl-program-ids : 0 of 8, 1 of 32
Offsets of o and do torch.Size([128, 512]) : tensor([[524288, 524289, 524290,  ..., 524797, 524798, 524799],
        [524800, 524801, 524802,  ..., 525309, 525310, 525311],
        [525312, 525313, 525314,  ..., 525821, 525822, 525823],
        ...,
        [588288, 588289, 588290,  ..., 588797, 588798, 588799],
        [588800, 588801, 588802,  ..., 589309, 589310, 589311],
        [589312, 589313, 589314,  ..., 589821, 589822, 589823]])
Delta torch.Size([128]) : tensor([1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035,
        1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047,
        1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059,
        1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071,
        1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083,
        1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095,
        1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107,
        1108, 1109, 1110, 1111, 1112, 1113, 1114, 1115, 1116, 1117, 1118, 1119,
        1120, 1121, 1122, 1123, 1124, 1125, 1126, 1127, 1128, 1129, 1130, 1131,
        1132, 1133, 1134, 1135, 1136, 1137, 1138, 1139, 1140, 1141, 1142, 1143,
        1144, 1145, 1146, 1147, 1148, 1149, 1150, 1151])
tl-program-ids : 1 of 8, 0 of 32
Offsets of o and do torch.Size([128, 512]) : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Delta torch.Size([128]) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
tl-program-ids : 1 of 8, 1 of 32
Offsets of o and do torch.Size([128, 512]) : tensor([[589824, 589825, 589826,  ..., 590333, 590334, 590335],
        [590336, 590337, 590338,  ..., 590845, 590846, 590847],
        [590848, 590849, 590850,  ..., 591357, 591358, 591359],
        ...,
        [653824, 653825, 653826,  ..., 654333, 654334, 654335],
        [654336, 654337, 654338,  ..., 654845, 654846, 654847],
        [654848, 654849, 654850,  ..., 655357, 655358, 655359]])
Delta torch.Size([128]) : tensor([1152, 1153, 1154, 1155, 1156, 1157, 1158, 1159, 1160, 1161, 1162, 1163,
        1164, 1165, 1166, 1167, 1168, 1169, 1170, 1171, 1172, 1173, 1174, 1175,
        1176, 1177, 1178, 1179, 1180, 1181, 1182, 1183, 1184, 1185, 1186, 1187,
        1188, 1189, 1190, 1191, 1192, 1193, 1194, 1195, 1196, 1197, 1198, 1199,
        1200, 1201, 1202, 1203, 1204, 1205, 1206, 1207, 1208, 1209, 1210, 1211,
        1212, 1213, 1214, 1215, 1216, 1217, 1218, 1219, 1220, 1221, 1222, 1223,
        1224, 1225, 1226, 1227, 1228, 1229, 1230, 1231, 1232, 1233, 1234, 1235,
        1236, 1237, 1238, 1239, 1240, 1241, 1242, 1243, 1244, 1245, 1246, 1247,
        1248, 1249, 1250, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259,
        1260, 1261, 1262, 1263, 1264, 1265, 1266, 1267, 1268, 1269, 1270, 1271,
        1272, 1273, 1274, 1275, 1276, 1277, 1278, 1279])
-------End of _attn_bwd_preprocess-----------
Grid (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD): (8, 1, 32)
Strides of q : 4194304, 524288, 512, 1
-------Start of _attn_bwd-----------
(bhid,pid to 32,8) : (0,0) : Row : 0, Col : 0, off_chz : 0 adj : 0
Offsets(torch.Size([128, 512])) of k or v tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
MASK : True : start_m : 0 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
MASK : False : start_m : 128 : Num Steps : 28
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
        174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Offset of DK : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Offset of q (torch.Size([128, 512])) : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Offset of do (torch.Size([128, 512])) : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
(0,0) MASK : True : start_m : 0 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
-------------end of_attn_bwd_dq-------
end_n : 0
(0,0) MASK : False : start_m : 0 start_n : 0 : Num Steps : 0
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (0,1) : Row : 0, Col : 0, off_chz : 0 adj : 0
Offsets(torch.Size([128, 512])) of k or v tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
MASK : True : start_m : 128 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
        158, 159])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
MASK : False : start_m : 256 : Num Steps : 24
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
        302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
        316, 317, 318, 319])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Offset of DK : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of q (torch.Size([128, 512])) : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Offset of do (torch.Size([128, 512])) : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
(0,1) MASK : True : start_m : 128 start_n : 128 : Num Steps : 8
-------------_attn_bwd_dq--start_m--128--start_n--128----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
        158, 159])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
        174, 175])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
        190, 191])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
        206, 207])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,
        222, 223])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
        238, 239])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
-------------end of_attn_bwd_dq-------
end_n : 128
(0,1) MASK : False : start_m : 128 start_n : 0 : Num Steps : 4
-------------_attn_bwd_dq--start_m--128--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
        82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of n (torch.Size([32])): tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
        124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (0,2) : Row : 0, Col : 0, off_chz : 0 adj : 0
Offsets(torch.Size([128, 512])) of k or v tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
MASK : True : start_m : 256 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
        286, 287])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
MASK : False : start_m : 384 : Num Steps : 20
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397,
        398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411,
        412, 413, 414, 415])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429,
        430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443,
        444, 445, 446, 447])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
Offset of DK : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of q (torch.Size([128, 512])) : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
Offset of do (torch.Size([128, 512])) : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
(0,2) MASK : True : start_m : 256 start_n : 256 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--256----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
        286, 287])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
        302, 303])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317,
        318, 319])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333,
        334, 335])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349,
        350, 351])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
        366, 367])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
-------------end of_attn_bwd_dq-------
end_n : 256
(0,2) MASK : False : start_m : 256 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
        82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
        124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
        174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
        206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
        220, 221, 222, 223])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
        238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251,
        252, 253, 254, 255])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (1,0) : Row : 0, Col : 1, off_chz : 1024 adj : 524288
Offsets(torch.Size([128, 512])) of k or v tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
MASK : True : start_m : 0 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
MASK : False : start_m : 128 : Num Steps : 28
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
        174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Offset of DK : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Offset of q (torch.Size([128, 512])) : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Offset of do (torch.Size([128, 512])) : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
(1,0) MASK : True : start_m : 0 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
-------------end of_attn_bwd_dq-------
end_n : 0
(1,0) MASK : False : start_m : 0 start_n : 0 : Num Steps : 0
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (1,1) : Row : 0, Col : 1, off_chz : 1024 adj : 524288
Offsets(torch.Size([128, 512])) of k or v tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
MASK : True : start_m : 128 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
        158, 159])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
MASK : False : start_m : 256 : Num Steps : 24
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
        302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
        316, 317, 318, 319])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Offset of DK : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of q (torch.Size([128, 512])) : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Offset of do (torch.Size([128, 512])) : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
(1,1) MASK : True : start_m : 128 start_n : 128 : Num Steps : 8
-------------_attn_bwd_dq--start_m--128--start_n--128----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
        158, 159])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
        174, 175])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
        190, 191])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
        206, 207])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,
        222, 223])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
        238, 239])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
-------------end of_attn_bwd_dq-------
end_n : 128
(1,1) MASK : False : start_m : 128 start_n : 0 : Num Steps : 4
-------------_attn_bwd_dq--start_m--128--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
        82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of n (torch.Size([32])): tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
        124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (1,2) : Row : 0, Col : 1, off_chz : 1024 adj : 524288
Offsets(torch.Size([128, 512])) of k or v tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
MASK : True : start_m : 256 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
        286, 287])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
MASK : False : start_m : 384 : Num Steps : 20
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397,
        398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411,
        412, 413, 414, 415])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429,
        430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443,
        444, 445, 446, 447])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
Offset of DK : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of q (torch.Size([128, 512])) : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
Offset of do (torch.Size([128, 512])) : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
(1,2) MASK : True : start_m : 256 start_n : 256 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--256----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
        286, 287])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
        302, 303])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317,
        318, 319])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333,
        334, 335])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349,
        350, 351])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
        366, 367])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
-------------end of_attn_bwd_dq-------
end_n : 256
(1,2) MASK : False : start_m : 256 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
        82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
        124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
        174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
        206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
        220, 221, 222, 223])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
        238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251,
        252, 253, 254, 255])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (2,0) : Row : 0, Col : 2, off_chz : 2048 adj : 1048576
Offsets(torch.Size([128, 512])) of k or v tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
MASK : True : start_m : 0 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
MASK : False : start_m : 128 : Num Steps : 28
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
        174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Offset of DK : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Offset of q (torch.Size([128, 512])) : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
Offset of do (torch.Size([128, 512])) : tensor([[    0,     1,     2,  ...,   509,   510,   511],
        [  512,   513,   514,  ...,  1021,  1022,  1023],
        [ 1024,  1025,  1026,  ...,  1533,  1534,  1535],
        ...,
        [64000, 64001, 64002,  ..., 64509, 64510, 64511],
        [64512, 64513, 64514,  ..., 65021, 65022, 65023],
        [65024, 65025, 65026,  ..., 65533, 65534, 65535]])
(2,0) MASK : True : start_m : 0 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Offset of m (torch.Size([128])) : tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
-------------end of_attn_bwd_dq-------
end_n : 0
(2,0) MASK : False : start_m : 0 start_n : 0 : Num Steps : 0
-------------_attn_bwd_dq--start_m--0--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (2,1) : Row : 0, Col : 2, off_chz : 2048 adj : 1048576
Offsets(torch.Size([128, 512])) of k or v tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
MASK : True : start_m : 128 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
        158, 159])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
MASK : False : start_m : 256 : Num Steps : 24
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
        302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315,
        316, 317, 318, 319])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Offset of DK : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of q (torch.Size([128, 512])) : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
Offset of do (torch.Size([128, 512])) : tensor([[ 65536,  65537,  65538,  ...,  66045,  66046,  66047],
        [ 66048,  66049,  66050,  ...,  66557,  66558,  66559],
        [ 66560,  66561,  66562,  ...,  67069,  67070,  67071],
        ...,
        [129536, 129537, 129538,  ..., 130045, 130046, 130047],
        [130048, 130049, 130050,  ..., 130557, 130558, 130559],
        [130560, 130561, 130562,  ..., 131069, 131070, 131071]])
(2,1) MASK : True : start_m : 128 start_n : 128 : Num Steps : 8
-------------_attn_bwd_dq--start_m--128--start_n--128----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
        158, 159])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
        174, 175])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
        190, 191])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
        206, 207])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,
        222, 223])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
        238, 239])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
-------------end of_attn_bwd_dq-------
end_n : 128
(2,1) MASK : False : start_m : 128 start_n : 0 : Num Steps : 4
-------------_attn_bwd_dq--start_m--128--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
        82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
Offset of n (torch.Size([32])): tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
        124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211,
        212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225,
        226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239,
        240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253,
        254, 255])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
(bhid,pid to 32,8) : (2,2) : Row : 0, Col : 2, off_chz : 2048 adj : 1048576
Offsets(torch.Size([128, 512])) of k or v tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
MASK : True : start_m : 256 : Num Steps : 8
--------_attn_bwd_dkdv-(MASK : True)-------------
Current Step : 0 : offs_m(torch.Size([16])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
Current Step : 1 : offs_m(torch.Size([16])) : tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
        286, 287])
Mask (torch.Size([128, 16])) : tensor([[ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]])
MASK : False : start_m : 384 : Num Steps : 20
--------_attn_bwd_dkdv-(MASK : False)-------------
Current Step : 0 : offs_m(torch.Size([32])) : tensor([384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397,
        398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411,
        412, 413, 414, 415])
Current Step : 1 : offs_m(torch.Size([32])) : tensor([416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429,
        430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443,
        444, 445, 446, 447])
-------End of _attn_bwd_dkdv-----------
Offset of DV : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
Offset of DK : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of q (torch.Size([128, 512])) : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
Offset of do (torch.Size([128, 512])) : tensor([[131072, 131073, 131074,  ..., 131581, 131582, 131583],
        [131584, 131585, 131586,  ..., 132093, 132094, 132095],
        [132096, 132097, 132098,  ..., 132605, 132606, 132607],
        ...,
        [195072, 195073, 195074,  ..., 195581, 195582, 195583],
        [195584, 195585, 195586,  ..., 196093, 196094, 196095],
        [196096, 196097, 196098,  ..., 196605, 196606, 196607]])
(2,2) MASK : True : start_m : 256 start_n : 256 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--256----MASK-True-----
Offset of kT : torch.Size([16, 512])
Offset of vT : torch.Size([512, 16])
Di Offset : torch.Size([128])
Offset of n (torch.Size([16])): tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[ True, False, False,  ..., False, False, False],
        [ True,  True, False,  ..., False, False, False],
        [ True,  True,  True,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285,
        286, 287])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301,
        302, 303])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317,
        318, 319])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333,
        334, 335])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349,
        350, 351])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365,
        366, 367])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True],
        [ True,  True,  True,  ...,  True,  True,  True]])
Offset of n (torch.Size([16])): tensor([368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Mask (torch.Size([128, 16])) : tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [ True,  True,  True,  ...,  True, False, False],
        [ True,  True,  True,  ...,  True,  True, False],
        [ True,  True,  True,  ...,  True,  True,  True]])
-------------end of_attn_bwd_dq-------
end_n : 256
(2,2) MASK : False : start_m : 256 start_n : 0 : Num Steps : 8
-------------_attn_bwd_dq--start_m--256--start_n--0----MASK-False-----
Offset of kT : torch.Size([32, 512])
Offset of vT : torch.Size([512, 32])
Di Offset : torch.Size([128])
Offset of n (torch.Size([32])): tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
        50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81,
        82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109,
        110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123,
        124, 125, 126, 127])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173,
        174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187,
        188, 189, 190, 191])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205,
        206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219,
        220, 221, 222, 223])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
Offset of n (torch.Size([32])): tensor([224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237,
        238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251,
        252, 253, 254, 255])
Offset of m (torch.Size([128])) : tensor([256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269,
        270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283,
        284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297,
        298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311,
        312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325,
        326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339,
        340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353,
        354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367,
        368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381,
        382, 383])
-------------end of_attn_bwd_dq-------
-------End of _attn_bwd-----------
@2025 Debashis Blogs...
Contact
Privacy