123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691 |
- import torch
- import triton
- import triton.language as tl
- from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, PERF
- @triton.jit
- def _bwd_preprocess_use_o(
- Out,
- DO,
- Delta,
- stride_oz, stride_oh, stride_om, stride_ok,
- stride_doz, stride_doh, stride_dom, stride_dok,
- stride_deltaz, stride_deltah, stride_deltam,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- ACTUAL_BLOCK_DMODEL: tl.constexpr,
- N_CTX_Q: tl.constexpr,
- Z: tl.constexpr,
- H: tl.constexpr,
- IS_VARLEN: tl.constexpr
- ):
- pid_m = tl.program_id(0)
- pid_bh = tl.program_id(1)
- # Compute batch and head indices
- off_z = pid_bh // H
- off_h = pid_bh % H
- if IS_VARLEN:
- # Compute sequence lengths for the current batch
- q_start = tl.load(cu_seqlens_q + off_z)
- q_end = tl.load(cu_seqlens_q + off_z + 1)
- k_start = tl.load(cu_seqlens_k + off_z)
- k_end = tl.load(cu_seqlens_k + off_z + 1)
- # Compute actual sequence lengths
- N_CTX_Q = q_end - q_start
- N_CTX_K = k_end - k_start
- else:
- q_start = 0
- k_start = 0
- N_CTX_Q = max_seqlen_q
- N_CTX_K = max_seqlen_k
- off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
- off_d = tl.arange(0, BLOCK_DMODEL)
- # create masks
- mask_m = off_m < N_CTX_Q
- mask_d = off_d < ACTUAL_BLOCK_DMODEL
- # compute offsets
- o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
- do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
- # compute pointers
- out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok
- do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok
- # load
- o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
- do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
- # compute delta
- delta = tl.sum(o * do, axis=1)
- # write-back delta
- delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
- delta_ptrs = delta_offset + off_m * stride_deltam
- tl.store(delta_ptrs, delta, mask=mask_m)
- @triton.jit
- def _bwd_kernel_one_col_block(
- Q,
- K,
- V,
- sm_scale,
- Out,
- DO,
- DQ,
- DK,
- DV,
- L,
- D,
- q_offset,
- k_offset,
- v_offset,
- do_offset,
- dq_offset,
- dk_offset,
- dv_offset,
- d_offset,
- l_offset,
- stride_dq_all,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vn,
- stride_vk,
- stride_deltaz,
- stride_deltah,
- stride_deltam,
- Z,
- H,
- N_CTX_Q,
- N_CTX_K,
- off_h,
- off_z,
- off_hz,
- start_n,
- num_block_m,
- num_block_n,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- ACTUAL_BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- SEQUENCE_PARALLEL: tl.constexpr,
- CAUSAL: tl.constexpr,
- USE_EXP2: tl.constexpr,
- ):
- if CAUSAL:
- # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
- lo = 0
- else:
- lo = 0
- # initialize col and head offsets
- offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
- offs_d = tl.arange(0, BLOCK_DMODEL)
- # masks
- mask_n = offs_n < N_CTX_K
- mask_d = offs_d < ACTUAL_BLOCK_DMODEL
- kv_mask = mask_n[:, None] & mask_d[None, :]
-
- # initialize grad accumulators
- dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
- dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
- # load k and v once per column block
- k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
- v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
- k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
- v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
- # loop over rows
- for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
- offs_m = start_m + tl.arange(0, BLOCK_M)
- q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
- dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
- do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
-
- # update mask as row block changes
- mask_m = offs_m < N_CTX_Q
- q_mask = mask_m[:, None] & mask_d[None, :]
- # load q, k, v, do on-chip
- q = tl.load(q_ptrs, mask=q_mask, other=0.0)
- do = tl.load(do_ptrs, mask=q_mask, other=0.0)
- # recompute p = softmax(qk, dim=-1).T
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
- qk += tl.dot(q, tl.trans(k))
- if CAUSAL:
- col_offset = N_CTX_Q - N_CTX_K
- causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :])
- qk = tl.where(causal_mask, qk, float("-inf"))
- l_ptrs = l_offset + offs_m * stride_deltam
- l_i = tl.load(l_ptrs, mask=mask_m)
- # compute p
- if USE_EXP2:
- RCP_LN2: tl.constexpr = 1.4426950408889634
- qk *= sm_scale * RCP_LN2
- l_i *= RCP_LN2
- p = tl.math.exp2(qk - l_i[:, None])
- else:
- qk *= sm_scale
- p = tl.math.exp(qk - l_i[:, None])
- # mask block in the cases where the data is smaller the block size
- p_mask = mask_m[:, None] & mask_n[None, :]
- p = tl.where(p_mask, p, 0.0)
-
- # compute dv
- dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
- # compute dp
- dp = tl.dot(do, tl.trans(v))
- # compute ds , ds = p * (dp - delta[:, None])
- d_ptrs = d_offset + offs_m * stride_deltam
- Di = tl.load(d_ptrs, mask=mask_m)
- ds = (p * (dp - Di[:, None])) * sm_scale
- ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty)
-
- # compute dk = dot(ds.T, q)
- dk += tl.dot(tl.trans(ds), q)
- # compute dq
- if SEQUENCE_PARALLEL:
- dq = tl.dot(ds, k)
- else:
- dq = tl.load(dq_ptrs, mask=q_mask, other=0.0)
- dq += tl.dot(ds, k)
- tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask)
- # write-back dv and dk
- dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
- dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
-
- # write-back
- tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
- tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
- @triton.jit
- def _bwd_kernel(
- Q,
- K,
- V,
- sm_scale,
- Out,
- DO,
- DQ,
- DK,
- DV,
- L,
- D,
- stride_dq_all,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vn,
- stride_vk,
- stride_deltaz,
- stride_deltah,
- stride_deltam,
- Z,
- H,
- num_block_m,
- num_block_n,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- BLOCK_M: tl.constexpr,
- BLOCK_DMODEL: tl.constexpr,
- ACTUAL_BLOCK_DMODEL: tl.constexpr,
- BLOCK_N: tl.constexpr,
- SEQUENCE_PARALLEL: tl.constexpr,
- CAUSAL: tl.constexpr,
- USE_EXP2: tl.constexpr,
- IS_VARLEN: tl.constexpr,
- ):
- # program ids
- off_hz = tl.program_id(0)
- if SEQUENCE_PARALLEL:
- start_n = tl.program_id(1)
- off_z = off_hz // H
- off_h = off_hz % H
- if IS_VARLEN:
- # Compute sequence lengths for the current batch
- q_start = tl.load(cu_seqlens_q + off_z)
- q_end = tl.load(cu_seqlens_q + off_z + 1)
- k_start = tl.load(cu_seqlens_k + off_z)
- k_end = tl.load(cu_seqlens_k + off_z + 1)
- # Compute actual sequence lengths
- N_CTX_Q = q_end - q_start
- N_CTX_K = k_end - k_start
- else:
- q_start = 0
- k_start = 0
- N_CTX_Q = max_seqlen_q
- N_CTX_K = max_seqlen_k
-
- # input tensor offsets
- q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
- k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
- v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
- do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
- l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
- d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
- # output tensor offsets
- dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
- dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
- if SEQUENCE_PARALLEL:
- dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
- else:
- dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
- # inner loop
- if SEQUENCE_PARALLEL:
- _bwd_kernel_one_col_block(
- Q,
- K,
- V,
- sm_scale,
- Out,
- DO,
- DQ,
- DK,
- DV,
- L,
- D,
- q_offset,
- k_offset,
- v_offset,
- do_offset,
- dq_offset,
- dk_offset,
- dv_offset,
- d_offset,
- l_offset,
- stride_dq_all,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vn,
- stride_vk,
- stride_deltaz,
- stride_deltah,
- stride_deltam,
- Z,
- H,
- N_CTX_Q,
- N_CTX_K,
- off_h,
- off_z,
- off_hz,
- start_n,
- num_block_m,
- num_block_n,
- BLOCK_M=BLOCK_M,
- BLOCK_DMODEL=BLOCK_DMODEL,
- ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
- BLOCK_N=BLOCK_N,
- SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
- CAUSAL=CAUSAL,
- USE_EXP2=USE_EXP2,
- )
- else:
- for start_n in range(0, num_block_n):
- _bwd_kernel_one_col_block(
- Q,
- K,
- V,
- sm_scale,
- Out,
- DO,
- DQ,
- DK,
- DV,
- L,
- D,
- q_offset,
- k_offset,
- v_offset,
- do_offset,
- dq_offset,
- dk_offset,
- dv_offset,
- d_offset,
- l_offset,
- stride_dq_all,
- stride_qz,
- stride_qh,
- stride_qm,
- stride_qk,
- stride_kz,
- stride_kh,
- stride_kn,
- stride_kk,
- stride_vz,
- stride_vh,
- stride_vn,
- stride_vk,
- stride_deltaz,
- stride_deltah,
- stride_deltam,
- Z,
- H,
- N_CTX_Q,
- N_CTX_K,
- off_h,
- off_z,
- off_hz,
- start_n,
- num_block_m,
- num_block_n,
- BLOCK_M=BLOCK_M,
- BLOCK_DMODEL=BLOCK_DMODEL,
- ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
- BLOCK_N=BLOCK_N,
- SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
- CAUSAL=CAUSAL,
- USE_EXP2=USE_EXP2,
- )
- # NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom.
- def attention_prefill_backward_triton_impl(
- do,
- q,
- k,
- v,
- o,
- softmax_lse,
- dq,
- dk,
- dv,
- sm_scale: float,
- alibi_slopes,
- causal,
- layout: str,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q: int,
- max_seqlen_k: int,
- use_exp2: bool,
- sequence_parallel = True,
- ):
- if DEBUG:
- print()
- print("attention_prefill_backward_triton_new_impl")
- print("do:", do, do.shape)
- print("q:", q, q.shape)
- print("k:", k, k.shape)
- print("v:", v, v.shape)
- print("o:", o, o.shape)
- print("softmax_lse:", softmax_lse, softmax_lse.shape)
- print("dq:", dq, dq.shape if dq is not None else None)
- print("dk:", dk, dk.shape if dk is not None else None)
- print("dv:", dv, dv.shape if dv is not None else None)
- print("sm_scale:", sm_scale)
- print("alibi_slopes:", alibi_slopes)
- print("causal:", causal)
- print("layout:", layout)
- print("cu_seqlens_q:", cu_seqlens_q)
- print("cu_seqlens_k:", cu_seqlens_k)
- print("max_seqlen_q:", max_seqlen_q)
- print("max_seqlen_k:", max_seqlen_k)
- print("use_exp2:", use_exp2)
- print("sequence_parallel:", sequence_parallel)
- # make contigious
- q = q.contiguous()
- k = k.contiguous()
- v = v.contiguous()
- softmax_lse = softmax_lse.contiguous()
- # get strides and shape
- batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
- q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
- stride_qz, stride_qh, stride_qm, stride_qk = q_strides
- stride_kz, stride_kh, stride_kn, stride_kk = k_strides
- stride_vz, stride_vh, stride_vn, stride_vk = v_strides
- stride_oz, stride_oh, stride_om, stride_ok = o_strides
- batch_headsize = batch * nheads_q
- is_varlen = layout == "thd"
- # FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
- if max_seqlen_q <= 32 or max_seqlen_k <= 32:
- BLOCK_M = 32
- BLOCK_N = 32
- else:
- BLOCK_M = 64
- BLOCK_N = 64
- num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful
- num_stages = 1
- waves_per_eu = 1
- # divide up the problem
- num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M)
- num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N)
- # get closest power of 2 over or equal to 32.
- padded_d_model = 1 << (head_size - 1).bit_length()
- padded_d_model = max(padded_d_model, 16)
- BLOCK_DMODEL = padded_d_model
- ACTUAL_BLOCK_DMODEL = head_size
- do = do.contiguous()
- # NOTE: we might need to copy the output tensor if they are not continuous or have other issues
- copy_back = {"dq": False, "dk": False, "dv": False}
- # deal with dq
- if dq is None:
- if sequence_parallel:
- dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
- else:
- dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype)
- else:
- dq_og = dq
- if (not dq.is_contiguous()):
- dq = dq.contiguous()
- copy_back["dq"] = True
- if sequence_parallel:
- dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
- copy_back["dq"] = True
- else:
- # NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
- dq.zero_()
- stride_dq_all = dq.stride()[0]
- # deal with dk, dv
- if (dk is None) or (dv is None):
- dk = torch.empty_like(k)
- dv = torch.empty_like(v)
- else:
- if (not dk.is_contiguous()):
- dk_og = dk
- dk = dk.contiguous()
- copy_back["dk"] = True
- if (not dv.is_contiguous()):
- dv_og = dv
- dv = dv.contiguous()
- copy_back["dv"] = True
- if DEBUG:
- print("copy_back:", copy_back)
- # assert contigious
- assert do.is_contiguous()
- assert q.is_contiguous()
- assert k.is_contiguous()
- assert v.is_contiguous()
- assert o.is_contiguous()
- assert softmax_lse.is_contiguous()
- # init delta
- delta = torch.empty_like(softmax_lse)
- if is_varlen:
- stride_deltam, stride_deltah = delta.stride()
- stride_deltaz = 0
- else:
- stride_deltaz, stride_deltah, stride_deltam = delta.stride()
- _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
- o,
- do,
- delta,
- stride_oz, stride_oh, stride_om, stride_ok,
- stride_oz, stride_oh, stride_om, stride_ok,
- stride_deltaz, stride_deltah, stride_deltam,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- BLOCK_M=BLOCK_M,
- BLOCK_DMODEL=BLOCK_DMODEL,
- ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
- N_CTX_Q=max_seqlen_q,
- Z=batch,
- H=nheads_q,
- IS_VARLEN=is_varlen
- )
- if DEBUG:
- print("_bwd_kernel inputs")
- print("do:", do, do.shape)
- print("q:", q, q.shape)
- print("k:", k, k.shape)
- print("v:", v, v.shape)
- print("sm_scale", sm_scale)
- print("o:", o, o.shape)
- print("dq:", dq, dq.shape)
- print("dk:", dk, dk.shape)
- print("dv:", dv, dv.shape)
- print("L:", softmax_lse, softmax_lse.shape)
- print("delta:", delta, delta.shape)
- print("stride_qz, stride_qh, stride_qm, stride_qk:", stride_qz, stride_qh, stride_qm, stride_qk)
- print("stride_kz, stride_kh, stride_kn, stride_kk:", stride_kz, stride_kh, stride_kn, stride_kk)
- print("stride_vz, stride_vh, stride_vn, stride_vk:", stride_vz, stride_vh, stride_vn, stride_vk)
- print("batch_q:", batch)
- print("heads_q:",nheads_q)
- print("max_seqlen_q:",max_seqlen_q)
- print("max_seqlen_k:",max_seqlen_k)
- print("BLOCK_M:",BLOCK_M)
- print("BLOCK_N:",BLOCK_M)
- print("BLOCK_DMODEL:",BLOCK_DMODEL)
- print("ACTUAL_BLOCK_DMODEL:",ACTUAL_BLOCK_DMODEL)
- print("SEQUENCE_PARALLEL:",sequence_parallel)
- print("CAUSAL:",causal)
- print("num_warps:",num_warps)
- print("num_stages:", num_stages)
- print("USE_EXP2:", use_exp2)
- print("num_blocks_m:", num_blocks_m)
- print("num_blocks_n:", num_blocks_n)
- _bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
- q,
- k,
- v,
- sm_scale,
- o,
- do,
- dq,
- dk,
- dv,
- softmax_lse,
- delta,
- stride_dq_all,
- stride_qz, stride_qh, stride_qm, stride_qk,
- stride_kz, stride_kh, stride_kn, stride_kk,
- stride_vz, stride_vh, stride_vn, stride_vk,
- stride_deltaz, stride_deltah, stride_deltam,
- batch,
- nheads_q,
- num_blocks_m,
- num_blocks_n,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- BLOCK_M=BLOCK_M,
- BLOCK_N=BLOCK_N,
- BLOCK_DMODEL=BLOCK_DMODEL,
- ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
- SEQUENCE_PARALLEL=sequence_parallel,
- CAUSAL=causal,
- USE_EXP2=use_exp2,
- num_warps=num_warps,
- num_stages=num_stages,
- waves_per_eu = waves_per_eu,
- IS_VARLEN=is_varlen
- )
- if DEBUG:
- print("_bwd_kernel outputs")
- print("dq:", dq, dq.shape)
- print("dk:", dk, dk.shape)
- print("dv:", dv, dv.shape)
- print("delta:", delta, delta.shape)
- if sequence_parallel:
- dq = dq.sum(dim=0)
- if DEBUG:
- print("attention_prefill_backward_triton_new_impl outputs")
- print("dq:", dq, dq.shape)
- print("dk:", dk, dk.shape)
- print("dv:", dv, dv.shape)
- print("delta:", delta, delta.shape)
- print("copy_back:", copy_back)
- if copy_back["dq"]:
- dq_og.copy_(dq)
- dq = dq_og
- if copy_back["dk"]:
- dk_og.copy_(dk)
- dk = dk_og
- if copy_back["dv"]:
- dv_og.copy_(dv)
- dv = dv_og
- return dq, dk, dv, delta, None, None
|