bwd_prefill.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691
  1. import torch
  2. import triton
  3. import triton.language as tl
  4. from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, PERF
  5. @triton.jit
  6. def _bwd_preprocess_use_o(
  7. Out,
  8. DO,
  9. Delta,
  10. stride_oz, stride_oh, stride_om, stride_ok,
  11. stride_doz, stride_doh, stride_dom, stride_dok,
  12. stride_deltaz, stride_deltah, stride_deltam,
  13. cu_seqlens_q,
  14. cu_seqlens_k,
  15. max_seqlen_q,
  16. max_seqlen_k,
  17. BLOCK_M: tl.constexpr,
  18. BLOCK_DMODEL: tl.constexpr,
  19. ACTUAL_BLOCK_DMODEL: tl.constexpr,
  20. N_CTX_Q: tl.constexpr,
  21. Z: tl.constexpr,
  22. H: tl.constexpr,
  23. IS_VARLEN: tl.constexpr
  24. ):
  25. pid_m = tl.program_id(0)
  26. pid_bh = tl.program_id(1)
  27. # Compute batch and head indices
  28. off_z = pid_bh // H
  29. off_h = pid_bh % H
  30. if IS_VARLEN:
  31. # Compute sequence lengths for the current batch
  32. q_start = tl.load(cu_seqlens_q + off_z)
  33. q_end = tl.load(cu_seqlens_q + off_z + 1)
  34. k_start = tl.load(cu_seqlens_k + off_z)
  35. k_end = tl.load(cu_seqlens_k + off_z + 1)
  36. # Compute actual sequence lengths
  37. N_CTX_Q = q_end - q_start
  38. N_CTX_K = k_end - k_start
  39. else:
  40. q_start = 0
  41. k_start = 0
  42. N_CTX_Q = max_seqlen_q
  43. N_CTX_K = max_seqlen_k
  44. off_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
  45. off_d = tl.arange(0, BLOCK_DMODEL)
  46. # create masks
  47. mask_m = off_m < N_CTX_Q
  48. mask_d = off_d < ACTUAL_BLOCK_DMODEL
  49. # compute offsets
  50. o_offset = Out + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
  51. do_offset = DO + off_z * stride_oz + off_h * stride_oh + q_start * stride_om
  52. # compute pointers
  53. out_ptrs = o_offset + off_m[:, None] * stride_om + off_d[None, :] * stride_ok
  54. do_ptrs = do_offset + off_m[:, None] * stride_dom + off_d[None, :] * stride_dok
  55. # load
  56. o = tl.load(out_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
  57. do = tl.load(do_ptrs, mask=mask_m[:, None] & mask_d[None, :], other=0.0).to(tl.float32)
  58. # compute delta
  59. delta = tl.sum(o * do, axis=1)
  60. # write-back delta
  61. delta_offset = Delta + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
  62. delta_ptrs = delta_offset + off_m * stride_deltam
  63. tl.store(delta_ptrs, delta, mask=mask_m)
  64. @triton.jit
  65. def _bwd_kernel_one_col_block(
  66. Q,
  67. K,
  68. V,
  69. sm_scale,
  70. Out,
  71. DO,
  72. DQ,
  73. DK,
  74. DV,
  75. L,
  76. D,
  77. q_offset,
  78. k_offset,
  79. v_offset,
  80. do_offset,
  81. dq_offset,
  82. dk_offset,
  83. dv_offset,
  84. d_offset,
  85. l_offset,
  86. stride_dq_all,
  87. stride_qz,
  88. stride_qh,
  89. stride_qm,
  90. stride_qk,
  91. stride_kz,
  92. stride_kh,
  93. stride_kn,
  94. stride_kk,
  95. stride_vz,
  96. stride_vh,
  97. stride_vn,
  98. stride_vk,
  99. stride_deltaz,
  100. stride_deltah,
  101. stride_deltam,
  102. Z,
  103. H,
  104. N_CTX_Q,
  105. N_CTX_K,
  106. off_h,
  107. off_z,
  108. off_hz,
  109. start_n,
  110. num_block_m,
  111. num_block_n,
  112. BLOCK_M: tl.constexpr,
  113. BLOCK_DMODEL: tl.constexpr,
  114. ACTUAL_BLOCK_DMODEL: tl.constexpr,
  115. BLOCK_N: tl.constexpr,
  116. SEQUENCE_PARALLEL: tl.constexpr,
  117. CAUSAL: tl.constexpr,
  118. USE_EXP2: tl.constexpr,
  119. ):
  120. if CAUSAL:
  121. # TODO: Causal can skip more blocks with something like lo = start_m * BLOCK_M
  122. lo = 0
  123. else:
  124. lo = 0
  125. # initialize col and head offsets
  126. offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
  127. offs_d = tl.arange(0, BLOCK_DMODEL)
  128. # masks
  129. mask_n = offs_n < N_CTX_K
  130. mask_d = offs_d < ACTUAL_BLOCK_DMODEL
  131. kv_mask = mask_n[:, None] & mask_d[None, :]
  132. # initialize grad accumulators
  133. dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
  134. dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
  135. # load k and v once per column block
  136. k_ptrs = k_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
  137. v_ptrs = v_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
  138. k = tl.load(k_ptrs, mask=kv_mask, other=0.0)
  139. v = tl.load(v_ptrs, mask=kv_mask, other=0.0)
  140. # loop over rows
  141. for start_m in range(lo, num_block_m * BLOCK_M, BLOCK_M):
  142. offs_m = start_m + tl.arange(0, BLOCK_M)
  143. q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
  144. dq_ptrs = dq_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
  145. do_ptrs = do_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
  146. # update mask as row block changes
  147. mask_m = offs_m < N_CTX_Q
  148. q_mask = mask_m[:, None] & mask_d[None, :]
  149. # load q, k, v, do on-chip
  150. q = tl.load(q_ptrs, mask=q_mask, other=0.0)
  151. do = tl.load(do_ptrs, mask=q_mask, other=0.0)
  152. # recompute p = softmax(qk, dim=-1).T
  153. qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
  154. qk += tl.dot(q, tl.trans(k))
  155. if CAUSAL:
  156. col_offset = N_CTX_Q - N_CTX_K
  157. causal_mask = offs_m[:, None] >= (col_offset + offs_n[None, :])
  158. qk = tl.where(causal_mask, qk, float("-inf"))
  159. l_ptrs = l_offset + offs_m * stride_deltam
  160. l_i = tl.load(l_ptrs, mask=mask_m)
  161. # compute p
  162. if USE_EXP2:
  163. RCP_LN2: tl.constexpr = 1.4426950408889634
  164. qk *= sm_scale * RCP_LN2
  165. l_i *= RCP_LN2
  166. p = tl.math.exp2(qk - l_i[:, None])
  167. else:
  168. qk *= sm_scale
  169. p = tl.math.exp(qk - l_i[:, None])
  170. # mask block in the cases where the data is smaller the block size
  171. p_mask = mask_m[:, None] & mask_n[None, :]
  172. p = tl.where(p_mask, p, 0.0)
  173. # compute dv
  174. dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
  175. # compute dp
  176. dp = tl.dot(do, tl.trans(v))
  177. # compute ds , ds = p * (dp - delta[:, None])
  178. d_ptrs = d_offset + offs_m * stride_deltam
  179. Di = tl.load(d_ptrs, mask=mask_m)
  180. ds = (p * (dp - Di[:, None])) * sm_scale
  181. ds = tl.where(p_mask, ds, 0.0).to(Q.dtype.element_ty)
  182. # compute dk = dot(ds.T, q)
  183. dk += tl.dot(tl.trans(ds), q)
  184. # compute dq
  185. if SEQUENCE_PARALLEL:
  186. dq = tl.dot(ds, k)
  187. else:
  188. dq = tl.load(dq_ptrs, mask=q_mask, other=0.0)
  189. dq += tl.dot(ds, k)
  190. tl.store(dq_ptrs, dq.to(Q.dtype.element_ty), mask=q_mask)
  191. # write-back dv and dk
  192. dk_ptrs = dk_offset + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
  193. dv_ptrs = dv_offset + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vk
  194. # write-back
  195. tl.store(dk_ptrs, dk.to(K.dtype.element_ty), mask=kv_mask)
  196. tl.store(dv_ptrs, dv.to(V.dtype.element_ty), mask=kv_mask)
  197. @triton.jit
  198. def _bwd_kernel(
  199. Q,
  200. K,
  201. V,
  202. sm_scale,
  203. Out,
  204. DO,
  205. DQ,
  206. DK,
  207. DV,
  208. L,
  209. D,
  210. stride_dq_all,
  211. stride_qz,
  212. stride_qh,
  213. stride_qm,
  214. stride_qk,
  215. stride_kz,
  216. stride_kh,
  217. stride_kn,
  218. stride_kk,
  219. stride_vz,
  220. stride_vh,
  221. stride_vn,
  222. stride_vk,
  223. stride_deltaz,
  224. stride_deltah,
  225. stride_deltam,
  226. Z,
  227. H,
  228. num_block_m,
  229. num_block_n,
  230. cu_seqlens_q,
  231. cu_seqlens_k,
  232. max_seqlen_q,
  233. max_seqlen_k,
  234. BLOCK_M: tl.constexpr,
  235. BLOCK_DMODEL: tl.constexpr,
  236. ACTUAL_BLOCK_DMODEL: tl.constexpr,
  237. BLOCK_N: tl.constexpr,
  238. SEQUENCE_PARALLEL: tl.constexpr,
  239. CAUSAL: tl.constexpr,
  240. USE_EXP2: tl.constexpr,
  241. IS_VARLEN: tl.constexpr,
  242. ):
  243. # program ids
  244. off_hz = tl.program_id(0)
  245. if SEQUENCE_PARALLEL:
  246. start_n = tl.program_id(1)
  247. off_z = off_hz // H
  248. off_h = off_hz % H
  249. if IS_VARLEN:
  250. # Compute sequence lengths for the current batch
  251. q_start = tl.load(cu_seqlens_q + off_z)
  252. q_end = tl.load(cu_seqlens_q + off_z + 1)
  253. k_start = tl.load(cu_seqlens_k + off_z)
  254. k_end = tl.load(cu_seqlens_k + off_z + 1)
  255. # Compute actual sequence lengths
  256. N_CTX_Q = q_end - q_start
  257. N_CTX_K = k_end - k_start
  258. else:
  259. q_start = 0
  260. k_start = 0
  261. N_CTX_Q = max_seqlen_q
  262. N_CTX_K = max_seqlen_k
  263. # input tensor offsets
  264. q_offset = Q + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
  265. k_offset = K + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
  266. v_offset = V + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
  267. do_offset = DO + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
  268. l_offset = L + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
  269. d_offset = D + off_z * stride_deltaz + off_h * stride_deltah + q_start * stride_deltam
  270. # output tensor offsets
  271. dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
  272. dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
  273. if SEQUENCE_PARALLEL:
  274. dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
  275. else:
  276. dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
  277. # inner loop
  278. if SEQUENCE_PARALLEL:
  279. _bwd_kernel_one_col_block(
  280. Q,
  281. K,
  282. V,
  283. sm_scale,
  284. Out,
  285. DO,
  286. DQ,
  287. DK,
  288. DV,
  289. L,
  290. D,
  291. q_offset,
  292. k_offset,
  293. v_offset,
  294. do_offset,
  295. dq_offset,
  296. dk_offset,
  297. dv_offset,
  298. d_offset,
  299. l_offset,
  300. stride_dq_all,
  301. stride_qz,
  302. stride_qh,
  303. stride_qm,
  304. stride_qk,
  305. stride_kz,
  306. stride_kh,
  307. stride_kn,
  308. stride_kk,
  309. stride_vz,
  310. stride_vh,
  311. stride_vn,
  312. stride_vk,
  313. stride_deltaz,
  314. stride_deltah,
  315. stride_deltam,
  316. Z,
  317. H,
  318. N_CTX_Q,
  319. N_CTX_K,
  320. off_h,
  321. off_z,
  322. off_hz,
  323. start_n,
  324. num_block_m,
  325. num_block_n,
  326. BLOCK_M=BLOCK_M,
  327. BLOCK_DMODEL=BLOCK_DMODEL,
  328. ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
  329. BLOCK_N=BLOCK_N,
  330. SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
  331. CAUSAL=CAUSAL,
  332. USE_EXP2=USE_EXP2,
  333. )
  334. else:
  335. for start_n in range(0, num_block_n):
  336. _bwd_kernel_one_col_block(
  337. Q,
  338. K,
  339. V,
  340. sm_scale,
  341. Out,
  342. DO,
  343. DQ,
  344. DK,
  345. DV,
  346. L,
  347. D,
  348. q_offset,
  349. k_offset,
  350. v_offset,
  351. do_offset,
  352. dq_offset,
  353. dk_offset,
  354. dv_offset,
  355. d_offset,
  356. l_offset,
  357. stride_dq_all,
  358. stride_qz,
  359. stride_qh,
  360. stride_qm,
  361. stride_qk,
  362. stride_kz,
  363. stride_kh,
  364. stride_kn,
  365. stride_kk,
  366. stride_vz,
  367. stride_vh,
  368. stride_vn,
  369. stride_vk,
  370. stride_deltaz,
  371. stride_deltah,
  372. stride_deltam,
  373. Z,
  374. H,
  375. N_CTX_Q,
  376. N_CTX_K,
  377. off_h,
  378. off_z,
  379. off_hz,
  380. start_n,
  381. num_block_m,
  382. num_block_n,
  383. BLOCK_M=BLOCK_M,
  384. BLOCK_DMODEL=BLOCK_DMODEL,
  385. ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
  386. BLOCK_N=BLOCK_N,
  387. SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
  388. CAUSAL=CAUSAL,
  389. USE_EXP2=USE_EXP2,
  390. )
  391. # NOTE: smaller blocks have lower accuracy. more accumlation error probably 128 * 128 seems good but leads to oom. 64 * 64 has accumlation errors but no oom.
  392. def attention_prefill_backward_triton_impl(
  393. do,
  394. q,
  395. k,
  396. v,
  397. o,
  398. softmax_lse,
  399. dq,
  400. dk,
  401. dv,
  402. sm_scale: float,
  403. alibi_slopes,
  404. causal,
  405. layout: str,
  406. cu_seqlens_q,
  407. cu_seqlens_k,
  408. max_seqlen_q: int,
  409. max_seqlen_k: int,
  410. use_exp2: bool,
  411. sequence_parallel = True,
  412. ):
  413. if DEBUG:
  414. print()
  415. print("attention_prefill_backward_triton_new_impl")
  416. print("do:", do, do.shape)
  417. print("q:", q, q.shape)
  418. print("k:", k, k.shape)
  419. print("v:", v, v.shape)
  420. print("o:", o, o.shape)
  421. print("softmax_lse:", softmax_lse, softmax_lse.shape)
  422. print("dq:", dq, dq.shape if dq is not None else None)
  423. print("dk:", dk, dk.shape if dk is not None else None)
  424. print("dv:", dv, dv.shape if dv is not None else None)
  425. print("sm_scale:", sm_scale)
  426. print("alibi_slopes:", alibi_slopes)
  427. print("causal:", causal)
  428. print("layout:", layout)
  429. print("cu_seqlens_q:", cu_seqlens_q)
  430. print("cu_seqlens_k:", cu_seqlens_k)
  431. print("max_seqlen_q:", max_seqlen_q)
  432. print("max_seqlen_k:", max_seqlen_k)
  433. print("use_exp2:", use_exp2)
  434. print("sequence_parallel:", sequence_parallel)
  435. # make contigious
  436. q = q.contiguous()
  437. k = k.contiguous()
  438. v = v.contiguous()
  439. softmax_lse = softmax_lse.contiguous()
  440. # get strides and shape
  441. batch, nheads_q, nheads_k, head_size, max_seqlen_q, max_seqlen_k = get_shape_from_layout(q, k, layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
  442. q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, layout)
  443. stride_qz, stride_qh, stride_qm, stride_qk = q_strides
  444. stride_kz, stride_kh, stride_kn, stride_kk = k_strides
  445. stride_vz, stride_vh, stride_vn, stride_vk = v_strides
  446. stride_oz, stride_oh, stride_om, stride_ok = o_strides
  447. batch_headsize = batch * nheads_q
  448. is_varlen = layout == "thd"
  449. # FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
  450. if max_seqlen_q <= 32 or max_seqlen_k <= 32:
  451. BLOCK_M = 32
  452. BLOCK_N = 32
  453. else:
  454. BLOCK_M = 64
  455. BLOCK_N = 64
  456. num_warps = 4 # NOTE: originial is 8. changing it to 1 caused issues be careful
  457. num_stages = 1
  458. waves_per_eu = 1
  459. # divide up the problem
  460. num_blocks_m = triton.cdiv(max_seqlen_q, BLOCK_M)
  461. num_blocks_n = triton.cdiv(max_seqlen_k, BLOCK_N)
  462. # get closest power of 2 over or equal to 32.
  463. padded_d_model = 1 << (head_size - 1).bit_length()
  464. padded_d_model = max(padded_d_model, 16)
  465. BLOCK_DMODEL = padded_d_model
  466. ACTUAL_BLOCK_DMODEL = head_size
  467. do = do.contiguous()
  468. # NOTE: we might need to copy the output tensor if they are not continuous or have other issues
  469. copy_back = {"dq": False, "dk": False, "dv": False}
  470. # deal with dq
  471. if dq is None:
  472. if sequence_parallel:
  473. dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
  474. else:
  475. dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype)
  476. else:
  477. dq_og = dq
  478. if (not dq.is_contiguous()):
  479. dq = dq.contiguous()
  480. copy_back["dq"] = True
  481. if sequence_parallel:
  482. dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
  483. copy_back["dq"] = True
  484. else:
  485. # NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
  486. dq.zero_()
  487. stride_dq_all = dq.stride()[0]
  488. # deal with dk, dv
  489. if (dk is None) or (dv is None):
  490. dk = torch.empty_like(k)
  491. dv = torch.empty_like(v)
  492. else:
  493. if (not dk.is_contiguous()):
  494. dk_og = dk
  495. dk = dk.contiguous()
  496. copy_back["dk"] = True
  497. if (not dv.is_contiguous()):
  498. dv_og = dv
  499. dv = dv.contiguous()
  500. copy_back["dv"] = True
  501. if DEBUG:
  502. print("copy_back:", copy_back)
  503. # assert contigious
  504. assert do.is_contiguous()
  505. assert q.is_contiguous()
  506. assert k.is_contiguous()
  507. assert v.is_contiguous()
  508. assert o.is_contiguous()
  509. assert softmax_lse.is_contiguous()
  510. # init delta
  511. delta = torch.empty_like(softmax_lse)
  512. if is_varlen:
  513. stride_deltam, stride_deltah = delta.stride()
  514. stride_deltaz = 0
  515. else:
  516. stride_deltaz, stride_deltah, stride_deltam = delta.stride()
  517. _bwd_preprocess_use_o[(num_blocks_m, batch_headsize)](
  518. o,
  519. do,
  520. delta,
  521. stride_oz, stride_oh, stride_om, stride_ok,
  522. stride_oz, stride_oh, stride_om, stride_ok,
  523. stride_deltaz, stride_deltah, stride_deltam,
  524. cu_seqlens_q,
  525. cu_seqlens_k,
  526. max_seqlen_q,
  527. max_seqlen_k,
  528. BLOCK_M=BLOCK_M,
  529. BLOCK_DMODEL=BLOCK_DMODEL,
  530. ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
  531. N_CTX_Q=max_seqlen_q,
  532. Z=batch,
  533. H=nheads_q,
  534. IS_VARLEN=is_varlen
  535. )
  536. if DEBUG:
  537. print("_bwd_kernel inputs")
  538. print("do:", do, do.shape)
  539. print("q:", q, q.shape)
  540. print("k:", k, k.shape)
  541. print("v:", v, v.shape)
  542. print("sm_scale", sm_scale)
  543. print("o:", o, o.shape)
  544. print("dq:", dq, dq.shape)
  545. print("dk:", dk, dk.shape)
  546. print("dv:", dv, dv.shape)
  547. print("L:", softmax_lse, softmax_lse.shape)
  548. print("delta:", delta, delta.shape)
  549. print("stride_qz, stride_qh, stride_qm, stride_qk:", stride_qz, stride_qh, stride_qm, stride_qk)
  550. print("stride_kz, stride_kh, stride_kn, stride_kk:", stride_kz, stride_kh, stride_kn, stride_kk)
  551. print("stride_vz, stride_vh, stride_vn, stride_vk:", stride_vz, stride_vh, stride_vn, stride_vk)
  552. print("batch_q:", batch)
  553. print("heads_q:",nheads_q)
  554. print("max_seqlen_q:",max_seqlen_q)
  555. print("max_seqlen_k:",max_seqlen_k)
  556. print("BLOCK_M:",BLOCK_M)
  557. print("BLOCK_N:",BLOCK_M)
  558. print("BLOCK_DMODEL:",BLOCK_DMODEL)
  559. print("ACTUAL_BLOCK_DMODEL:",ACTUAL_BLOCK_DMODEL)
  560. print("SEQUENCE_PARALLEL:",sequence_parallel)
  561. print("CAUSAL:",causal)
  562. print("num_warps:",num_warps)
  563. print("num_stages:", num_stages)
  564. print("USE_EXP2:", use_exp2)
  565. print("num_blocks_m:", num_blocks_m)
  566. print("num_blocks_n:", num_blocks_n)
  567. _bwd_kernel[(batch_headsize, num_blocks_n if sequence_parallel else 1)](
  568. q,
  569. k,
  570. v,
  571. sm_scale,
  572. o,
  573. do,
  574. dq,
  575. dk,
  576. dv,
  577. softmax_lse,
  578. delta,
  579. stride_dq_all,
  580. stride_qz, stride_qh, stride_qm, stride_qk,
  581. stride_kz, stride_kh, stride_kn, stride_kk,
  582. stride_vz, stride_vh, stride_vn, stride_vk,
  583. stride_deltaz, stride_deltah, stride_deltam,
  584. batch,
  585. nheads_q,
  586. num_blocks_m,
  587. num_blocks_n,
  588. cu_seqlens_q,
  589. cu_seqlens_k,
  590. max_seqlen_q,
  591. max_seqlen_k,
  592. BLOCK_M=BLOCK_M,
  593. BLOCK_N=BLOCK_N,
  594. BLOCK_DMODEL=BLOCK_DMODEL,
  595. ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
  596. SEQUENCE_PARALLEL=sequence_parallel,
  597. CAUSAL=causal,
  598. USE_EXP2=use_exp2,
  599. num_warps=num_warps,
  600. num_stages=num_stages,
  601. waves_per_eu = waves_per_eu,
  602. IS_VARLEN=is_varlen
  603. )
  604. if DEBUG:
  605. print("_bwd_kernel outputs")
  606. print("dq:", dq, dq.shape)
  607. print("dk:", dk, dk.shape)
  608. print("dv:", dv, dv.shape)
  609. print("delta:", delta, delta.shape)
  610. if sequence_parallel:
  611. dq = dq.sum(dim=0)
  612. if DEBUG:
  613. print("attention_prefill_backward_triton_new_impl outputs")
  614. print("dq:", dq, dq.shape)
  615. print("dk:", dk, dk.shape)
  616. print("dv:", dv, dv.shape)
  617. print("delta:", delta, delta.shape)
  618. print("copy_back:", copy_back)
  619. if copy_back["dq"]:
  620. dq_og.copy_(dq)
  621. dq = dq_og
  622. if copy_back["dk"]:
  623. dk_og.copy_(dk)
  624. dk = dk_og
  625. if copy_back["dv"]:
  626. dv_og.copy_(dv)
  627. dv = dv_og
  628. return dq, dk, dv, delta, None, None