interface_fa.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. import torch
  2. import os
  3. from .fwd_prefill import attention_prefill_forward_triton_impl
  4. from .bwd_prefill import attention_prefill_backward_triton_impl
  5. from .fwd_decode import attention_decode_forward_triton_impl
  6. from .fwd_ref import attention_forward_pytorch_ref_impl
  7. from .bwd_ref import attention_backward_pytorch_ref_impl
  8. from .utils import MetaData, get_shape_from_layout, DEBUG
  9. USE_REF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_REF', '0').lower() in ('1', 'true', 'yes')
  10. def fwd(q,
  11. k,
  12. v,
  13. o,
  14. alibi_slopes,
  15. dropout_p,
  16. softmax_scale,
  17. causal,
  18. window_size_left,
  19. window_size_right,
  20. softcap,
  21. return_softmax,
  22. gen_):
  23. if DEBUG:
  24. print()
  25. print("flash_attn_triton_amd.py::fwd")
  26. print("q:", q, q.shape)
  27. print("k:", k, k.shape)
  28. print("v:", v, v.shape)
  29. print("o:", o)
  30. print("alibi_slopes:", alibi_slopes)
  31. print("dropout_p:", dropout_p)
  32. print("softmax_scale:", softmax_scale)
  33. print("causal:", causal)
  34. print("window_size_left:", window_size_left)
  35. print("window_size_right:", window_size_right)
  36. print("softcap:", softcap)
  37. print("softcap:", softcap)
  38. print("return_softmax:", return_softmax)
  39. if dropout_p != 0.0:
  40. raise ValueError("dropout is not supported on AMD's Triton Backend yet")
  41. if o is None:
  42. o = torch.empty_like(q)
  43. # Setup metadata
  44. metadata = MetaData(sm_scale=softmax_scale)
  45. metadata.max_seqlens_q = q.shape[1]
  46. metadata.max_seqlens_k = k.shape[1]
  47. metadata.layout = "bshd"
  48. if return_softmax:
  49. metadata.return_scores = True
  50. batch, nheads_q, nheads_k, head_size, _, _ = get_shape_from_layout(q, k, metadata.layout)
  51. if causal:
  52. metadata.need_causal()
  53. if alibi_slopes is not None:
  54. metadata.need_alibi(alibi_slopes, batch, nheads_q)
  55. if dropout_p > 0.0:
  56. metadata.need_dropout(dropout_p, return_softmax)
  57. # Check arguments
  58. metadata.check_args(q, k, v, o)
  59. if USE_REF:
  60. if DEBUG:
  61. print("Using reference implementation")
  62. (output,
  63. softmax_lse,
  64. exp_scores,
  65. _,
  66. _,
  67. _,
  68. _) = attention_forward_pytorch_ref_impl(
  69. q,
  70. k,
  71. v,
  72. metadata.sm_scale,
  73. metadata.causal,
  74. metadata.layout,
  75. metadata.cu_seqlens_q,
  76. metadata.cu_seqlens_k,
  77. metadata.max_seqlens_q,
  78. metadata.max_seqlens_k,
  79. metadata.use_exp2)
  80. o.copy_(output)
  81. else:
  82. if DEBUG:
  83. print("Using Triton implementation")
  84. (_,
  85. softmax_lse,
  86. exp_scores,
  87. _,
  88. _,
  89. _,
  90. _,
  91. _,
  92. _) = attention_prefill_forward_triton_impl(
  93. q,
  94. k,
  95. v,
  96. o,
  97. metadata.sm_scale,
  98. metadata.alibi_slopes,
  99. metadata.causal,
  100. metadata.bias,
  101. metadata.dropout_p,
  102. metadata.layout,
  103. metadata.cu_seqlens_q,
  104. metadata.cu_seqlens_k,
  105. metadata.max_seqlens_q,
  106. metadata.max_seqlens_k,
  107. metadata.return_scores,
  108. metadata.use_exp2)
  109. if DEBUG:
  110. print("fwd outputs")
  111. print("o:", o, o.shape)
  112. print("softmax_lse:", softmax_lse, softmax_lse.shape)
  113. print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None )
  114. return o, softmax_lse, exp_scores, None
  115. def bwd(
  116. dout,
  117. q,
  118. k,
  119. v,
  120. out,
  121. softmax_lse,
  122. dq,
  123. dk,
  124. dv,
  125. alibi_slopes,
  126. dropout_p,
  127. softmax_scale,
  128. causal,
  129. window_size_left,
  130. window_size_right,
  131. softcap,
  132. deterministic,
  133. gen_,
  134. rng_state,
  135. ):
  136. if DEBUG:
  137. print()
  138. print("flash_attn_triton_amd.py::bwd")
  139. print("dout:", dout, dout.shape)
  140. print("q:", q, q.shape)
  141. print("k:", k, k.shape)
  142. print("v:", v, v.shape)
  143. print("out:", out, out.shape)
  144. print("softmax_lse:", softmax_lse, softmax_lse.shape)
  145. print("dq:", dq, dq.shape)
  146. print("dk:", dk, dk.shape)
  147. print("dv:", dv, dv.shape)
  148. print("alibi_slopes:", alibi_slopes)
  149. print("dropout_p:", dropout_p)
  150. print("out:", out)
  151. print("softmax_scale:", softmax_scale)
  152. print("causal:", causal)
  153. print("window_size_left:", window_size_left)
  154. print("window_size_right:", window_size_right)
  155. print("deterministic:", deterministic)
  156. print("gen_:", gen_)
  157. print("rng_state:", rng_state)
  158. if dropout_p != 0.0:
  159. raise ValueError("dropout is not supported on AMD yet")
  160. if USE_REF:
  161. if DEBUG:
  162. print("Using reference implementation")
  163. dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
  164. dout,
  165. q,
  166. k,
  167. v,
  168. out,
  169. softmax_lse,
  170. softmax_scale,
  171. causal,
  172. "bshd",
  173. None,
  174. None,
  175. None,
  176. None,
  177. False,
  178. )
  179. dq.copy_(dq_ref)
  180. dk.copy_(dk_ref)
  181. dv.copy_(dv_ref)
  182. delta = delta_ref
  183. else:
  184. if DEBUG:
  185. print("Using Triton implementation")
  186. dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl(
  187. dout,
  188. q,
  189. k,
  190. v,
  191. out,
  192. softmax_lse,
  193. dq,
  194. dk,
  195. dv,
  196. softmax_scale,
  197. alibi_slopes,
  198. causal,
  199. "bshd",
  200. None,
  201. None,
  202. None,
  203. None,
  204. False,
  205. )
  206. delta = delta_triton
  207. if DEBUG:
  208. print("bwd outputs")
  209. print("dv:", dv, dv.shape)
  210. print("dk:", dk, dk.shape)
  211. print("dq:", dq, dq.shape)
  212. return dq, dk, dv, delta
  213. def varlen_fwd(
  214. q,
  215. k,
  216. v,
  217. o,
  218. cu_seqlens_q,
  219. cu_seqlens_k,
  220. seqused_k,
  221. leftpad_k,
  222. block_table_,
  223. alibi_slopes,\
  224. max_seqlen_q,
  225. max_seqlen_k,
  226. dropout_p,
  227. softmax_scale,
  228. zero_tensors,
  229. causal,
  230. window_size_left,
  231. window_size_right,
  232. softcap,
  233. return_softmax,
  234. gen_):
  235. if DEBUG:
  236. print()
  237. print("flash_attn_triton_amd.py::varlen_fwd")
  238. print("q:", q, q.shape)
  239. print("k:", k, k.shape)
  240. print("v:", v, v.shape)
  241. print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape)
  242. print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape)
  243. print("alibi_slopes:", alibi_slopes)
  244. print("max_seqlen_q:", max_seqlen_q)
  245. print("max_seqlen_k:", max_seqlen_k)
  246. print("dropout_p:", dropout_p)
  247. print("softmax_scale:", softmax_scale)
  248. print("causal:", causal)
  249. print("window_size_left:", window_size_left)
  250. print("window_size_right:", window_size_right)
  251. print("gen_:", gen_)
  252. if dropout_p != 0.0:
  253. raise ValueError("dropout is not supported on AMD's Triton Backend yet")
  254. if o is None:
  255. o = torch.empty_like(q)
  256. # Setup metadata
  257. metadata = MetaData(sm_scale=softmax_scale)
  258. if return_softmax:
  259. metadata.return_scores = True
  260. metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) # set layout to "thd" and other metdata
  261. # get shapes
  262. batch, nheads_q, nheads_k, head_size , seqlen_q, seqlen_k = get_shape_from_layout(q, k, metadata.layout, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)
  263. if causal:
  264. metadata.need_causal()
  265. if alibi_slopes is not None:
  266. metadata.need_alibi(alibi_slopes, batch, nheads_q)
  267. if dropout_p > 0.0:
  268. metadata.need_dropout(dropout_p, return_softmax)
  269. # Check arguments
  270. metadata.check_args(q, k, v, o)
  271. if o is None:
  272. o = torch.empty_like(q, dtype=v.dtype)
  273. if USE_REF:
  274. if DEBUG:
  275. print("Using reference implementation")
  276. (output,
  277. softmax_lse,
  278. exp_scores,
  279. _,
  280. _,
  281. _,
  282. _) = attention_forward_pytorch_ref_impl(
  283. q,
  284. k,
  285. v,
  286. metadata.sm_scale,
  287. metadata.causal,
  288. metadata.layout,
  289. metadata.cu_seqlens_q,
  290. metadata.cu_seqlens_k,
  291. metadata.max_seqlens_q,
  292. metadata.max_seqlens_k,
  293. metadata.use_exp2)
  294. o.copy_(output)
  295. else:
  296. if DEBUG:
  297. print("Using Triton implementation")
  298. (_,
  299. softmax_lse,
  300. exp_scores,
  301. _,
  302. _,
  303. _,
  304. _,
  305. _,
  306. _) = attention_prefill_forward_triton_impl(
  307. q,
  308. k,
  309. v,
  310. o,
  311. metadata.sm_scale,
  312. metadata.alibi_slopes,
  313. metadata.causal,
  314. metadata.bias,
  315. metadata.dropout_p,
  316. metadata.layout,
  317. metadata.cu_seqlens_q,
  318. metadata.cu_seqlens_k,
  319. metadata.max_seqlens_q,
  320. metadata.max_seqlens_k,
  321. metadata.return_scores,
  322. metadata.use_exp2)
  323. if DEBUG:
  324. print("varlen_fwd outputs")
  325. print("o:", o, o.shape)
  326. print("softmax_lse:", softmax_lse, softmax_lse.shape)
  327. print("exp_scores:", exp_scores, exp_scores.shape if exp_scores is not None else None )
  328. return o, softmax_lse, exp_scores, None
  329. def varlen_bwd(
  330. dout,
  331. q,
  332. k,
  333. v,
  334. out,
  335. softmax_lse,
  336. dq,
  337. dk,
  338. dv,
  339. cu_seqlens_q,
  340. cu_seqlens_k,
  341. alibi_slopes,
  342. max_seqlen_q,
  343. max_seqlen_k,
  344. dropout_p,
  345. softmax_scale,
  346. zero_tensors,
  347. causal,
  348. window_size_left,
  349. window_size_right,
  350. softcap,
  351. deterministic,
  352. gen_,
  353. rng_state,
  354. ):
  355. if DEBUG:
  356. print()
  357. print("varlen_bwd")
  358. print("dout:", dout, dout.shape)
  359. print("q:", q, q.shape)
  360. print("k:", k, k.shape)
  361. print("v:", v, v.shape)
  362. print("softmax_lse:", softmax_lse, softmax_lse.shape)
  363. print("dq:", dq, dq.shape)
  364. print("dk:", dk, dk.shape)
  365. print("dv:", dv, dv.shape)
  366. print("cu_seqlens_q:", cu_seqlens_q, cu_seqlens_q.shape)
  367. print("cu_seqlens_k:", cu_seqlens_k, cu_seqlens_k.shape)
  368. print("alibi_slopes:", alibi_slopes)
  369. print("max_seqlen_q:", max_seqlen_q)
  370. print("max_seqlen_k:", max_seqlen_k)
  371. print("dropout_p:", dropout_p)
  372. print("out:", out)
  373. print("softmax_scale:", softmax_scale)
  374. print("causal:", causal)
  375. print("window_size_left:", window_size_left)
  376. print("window_size_right:", window_size_right)
  377. print("deterministic:", deterministic)
  378. print("gen_:", gen_)
  379. print("rng_state:", rng_state)
  380. if dropout_p != 0.0:
  381. raise ValueError("dropout is not supported on AMD yet")
  382. if USE_REF:
  383. if DEBUG:
  384. print("Using reference implementation")
  385. dq_ref, dk_ref, dv_ref, delta_ref = attention_backward_pytorch_ref_impl(
  386. dout,
  387. q,
  388. k,
  389. v,
  390. out,
  391. softmax_lse,
  392. softmax_scale,
  393. causal,
  394. "thd",
  395. cu_seqlens_q,
  396. cu_seqlens_k,
  397. max_seqlen_q,
  398. max_seqlen_k,
  399. False,
  400. )
  401. dq.copy_(dq_ref)
  402. dk.copy_(dk_ref)
  403. dv.copy_(dv_ref)
  404. delta = delta_ref
  405. else:
  406. if DEBUG:
  407. print("Using Triton implementation")
  408. dq_triton, dk_triton, dv_triton, delta_triton, _, _ = attention_prefill_backward_triton_impl(
  409. dout,
  410. q,
  411. k,
  412. v,
  413. out,
  414. softmax_lse,
  415. dq,
  416. dk,
  417. dv,
  418. softmax_scale,
  419. alibi_slopes,
  420. causal,
  421. "thd",
  422. cu_seqlens_q,
  423. cu_seqlens_k,
  424. max_seqlen_q,
  425. max_seqlen_k,
  426. False,
  427. )
  428. delta = delta_triton
  429. if DEBUG:
  430. print("varlen_bwd outputs")
  431. print("delta:", delta, delta.shape)
  432. print("dv:", dv, dv.shape)
  433. print("dk:", dk, dk.shape)
  434. print("dq:", dq, dq.shape)
  435. return dq, dk, dv, delta
  436. def fwd_kvcache(
  437. q,
  438. k_cache,
  439. v_cache,
  440. k,
  441. v,
  442. cache_seqlens,
  443. rotary_cos,
  444. rotary_sin,
  445. cache_batch_idx,
  446. cache_leftpad,
  447. block_table,
  448. alibi_slopes,
  449. out,
  450. softmax_scale,
  451. causal,
  452. window_size_left,
  453. window_size_right,
  454. softcap,
  455. rotary_interleaved,
  456. num_splits):
  457. if out is None:
  458. out = torch.empty_like(q)
  459. # fill metadata
  460. metadata = MetaData(sm_scale=softmax_scale)
  461. metadata.layout = "bshd"
  462. metadata.max_seqlens_q = q.shape[1]
  463. metadata.max_seqlens_k = k_cache.shape[1]
  464. metadata.cache_seqlens = cache_seqlens
  465. metadata.cache_batch_idx = cache_batch_idx
  466. if k is not None and v is not None:
  467. metadata.new_kv = True
  468. metadata.seqlen_new = k.shape[1]
  469. metadata.k_new = k
  470. metadata.v_new = v
  471. if causal:
  472. metadata.need_causal()
  473. if alibi_slopes is not None:
  474. batch, _ , nheads_q, _= q.shape
  475. metadata.need_alibi(alibi_slopes, batch, nheads_q)
  476. # launch kernel
  477. # TODO: pass output as an arg. Maybe we are copying output which is causing slow down
  478. output, softmax_lse = attention_decode_forward_triton_impl(
  479. q,
  480. k_cache,
  481. v_cache,
  482. metadata.sm_scale,
  483. metadata.causal,
  484. metadata.alibi_slopes,
  485. metadata.layout,
  486. metadata.cache_seqlens,
  487. metadata.cache_batch_idx,
  488. metadata.new_kv,
  489. metadata.k_new,
  490. metadata.v_new,
  491. )
  492. return output, softmax_lse