flash_api.cpp 73 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t b_k,
  19. const size_t seqlen_q,
  20. const size_t seqlen_k,
  21. const size_t seqlen_q_rounded,
  22. const size_t seqlen_k_rounded,
  23. const size_t h,
  24. const size_t h_k,
  25. const size_t d,
  26. const size_t d_rounded,
  27. // device pointers
  28. const at::Tensor q,
  29. const at::Tensor k,
  30. const at::Tensor v,
  31. at::Tensor out,
  32. void *cu_seqlens_q_d,
  33. void *cu_seqlens_k_d,
  34. void *seqused_q,
  35. void *seqused_k,
  36. void *p_d,
  37. void *softmax_lse_d,
  38. float p_dropout,
  39. float softmax_scale,
  40. int window_size_left,
  41. int window_size_right,
  42. bool seqlenq_ngroups_swapped=false,
  43. bool unpadded_lse=false) {
  44. // Reset the parameters
  45. params = {};
  46. params.is_bf16 = q.dtype() == torch::kBFloat16;
  47. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  48. params.is_kv_cache = false;
  49. params.page_num_blocks = 0;
  50. // Set the pointers and strides.
  51. params.q_ptr = q.data_ptr();
  52. params.k_ptr = k.data_ptr();
  53. params.v_ptr = v.data_ptr();
  54. // All stride are in elements, not bytes.
  55. params.q_row_stride = q.stride(-3);
  56. params.k_row_stride = k.stride(-3);
  57. params.v_row_stride = v.stride(-3);
  58. params.q_head_stride = q.stride(-2);
  59. params.k_head_stride = k.stride(-2);
  60. params.v_head_stride = v.stride(-2);
  61. params.o_ptr = out.data_ptr();
  62. params.o_row_stride = out.stride(-3);
  63. params.o_head_stride = out.stride(-2);
  64. if (cu_seqlens_q_d == nullptr) {
  65. params.q_batch_stride = q.stride(0);
  66. params.k_batch_stride = k.stride(0);
  67. params.v_batch_stride = v.stride(0);
  68. params.o_batch_stride = out.stride(0);
  69. if (seqlenq_ngroups_swapped) {
  70. params.q_batch_stride *= seqlen_q;
  71. params.o_batch_stride *= seqlen_q;
  72. }
  73. }
  74. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  75. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  76. params.seqused_q = static_cast<int *>(seqused_q);
  77. params.seqused_k = static_cast<int *>(seqused_k);
  78. TORCH_CHECK(
  79. bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k),
  80. "cu_seqlens_q and cu_seqlens_k must be both null or non-null"
  81. );
  82. // P = softmax(QK^T)
  83. params.p_ptr = p_d;
  84. // Softmax sum
  85. params.softmax_lse_ptr = softmax_lse_d;
  86. // Set the dimensions.
  87. params.b = b;
  88. params.b_k = b_k;
  89. params.h = h;
  90. params.h_k = h_k;
  91. params.h_h_k_ratio = h / h_k;
  92. params.seqlen_q = seqlen_q;
  93. params.seqlen_k = seqlen_k;
  94. params.seqlen_q_rounded = seqlen_q_rounded;
  95. params.seqlen_k_rounded = seqlen_k_rounded;
  96. params.d = d;
  97. params.d_rounded = d_rounded;
  98. // Set the different scale values.
  99. params.scale_softmax = softmax_scale;
  100. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  101. __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
  102. __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half);
  103. params.scale_softmax_log2_half2 = reinterpret_cast<uint32_t&>(scale_softmax_log2_half2);
  104. // Set this to probability of keeping an element to simplify things.
  105. params.p_dropout = 1.f - p_dropout;
  106. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  107. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  108. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  109. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  110. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  111. params.rp_dropout = 1.f / params.p_dropout;
  112. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  113. TORCH_CHECK(p_dropout < 1.f);
  114. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  115. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  116. #endif
  117. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  118. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  119. window_size_left = std::min(int(seqlen_k), window_size_left);
  120. window_size_right = std::min(int(seqlen_k), window_size_right);
  121. if (window_size_left < 0) { window_size_left = seqlen_k; }
  122. if (window_size_right < 0) { window_size_right = seqlen_k; }
  123. params.window_size_left = window_size_left;
  124. params.window_size_right = window_size_right;
  125. params.is_causal = window_size_left == int(seqlen_k) && window_size_right == 0;
  126. if ((window_size_left < int(seqlen_k) || window_size_right < int(seqlen_k)) && !params.is_causal) {
  127. params.is_local = true;
  128. }
  129. #ifdef FLASHATTENTION_DISABLE_LOCAL
  130. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  131. "This flash attention build does not support local attention.");
  132. #endif
  133. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  134. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  135. #endif
  136. params.unpadded_lse = unpadded_lse;
  137. params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
  138. }
  139. void set_params_dgrad(Flash_bwd_params &params,
  140. // sizes
  141. const size_t b,
  142. const size_t seqlen_q,
  143. const size_t seqlen_k,
  144. const size_t seqlen_q_rounded,
  145. const size_t seqlen_k_rounded,
  146. const size_t h,
  147. const size_t h_k,
  148. const size_t d,
  149. const size_t d_rounded,
  150. // device pointers
  151. const at::Tensor q,
  152. const at::Tensor k,
  153. const at::Tensor v,
  154. const at::Tensor out,
  155. const at::Tensor dout,
  156. at::Tensor dq,
  157. at::Tensor dk,
  158. at::Tensor dv,
  159. void *cu_seqlens_q_d,
  160. void *cu_seqlens_k_d,
  161. void *seqused_q,
  162. void *seqused_k,
  163. void *dq_accum_d,
  164. void *dk_accum_d,
  165. void *dv_accum_d,
  166. void *softmax_lse_d,
  167. void *dsoftmax_sum_d,
  168. float p_dropout,
  169. float softmax_scale,
  170. int window_size_left,
  171. int window_size_right,
  172. bool deterministic) {
  173. set_params_fprop(params,
  174. b, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  175. q, k, v, out,
  176. cu_seqlens_q_d,
  177. cu_seqlens_k_d,
  178. seqused_q,
  179. seqused_k,
  180. nullptr,
  181. softmax_lse_d,
  182. p_dropout,
  183. softmax_scale,
  184. window_size_left,
  185. window_size_right);
  186. // Set the pointers and strides.
  187. params.do_ptr = dout.data_ptr();
  188. params.do_row_stride = dout.stride(-3);
  189. params.do_head_stride = dout.stride(-2);
  190. params.dq_ptr = dq.data_ptr();
  191. params.dk_ptr = dk.data_ptr();
  192. params.dv_ptr = dv.data_ptr();
  193. params.page_num_blocks = 0;
  194. params.dq_row_stride = dq.stride(-3);
  195. params.dk_row_stride = dk.stride(-3);
  196. params.dv_row_stride = dv.stride(-3);
  197. params.dq_head_stride = dq.stride(-2);
  198. params.dk_head_stride = dk.stride(-2);
  199. params.dv_head_stride = dv.stride(-2);
  200. if (cu_seqlens_q_d == nullptr) {
  201. params.do_batch_stride = dout.stride(0);
  202. params.dq_batch_stride = dq.stride(0);
  203. params.dk_batch_stride = dk.stride(0);
  204. params.dv_batch_stride = dv.stride(0);
  205. }
  206. params.dq_accum_ptr = dq_accum_d;
  207. params.dk_accum_ptr = dk_accum_d;
  208. params.dv_accum_ptr = dv_accum_d;
  209. // Softmax sum
  210. params.dsoftmax_sum = dsoftmax_sum_d;
  211. params.deterministic = deterministic;
  212. }
  213. // Find the number of splits that maximizes the occupancy. For example, if we have
  214. // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
  215. // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
  216. // splits as that would incur more HBM reads/writes.
  217. // So we find the best efficiency, then find the smallest number of splits that gets 80%
  218. // of the best efficiency.
  219. inline int num_splits_heuristic(int batch_nheads_mblocks, int batch_nheads, int num_SMs, int num_n_blocks,
  220. int max_splits, int head_size, bool use_one_mma_wg) {
  221. // Goal of the starting threshold is to determine whether to split or not.
  222. // Empirically, the efficiency threshold can be much lower than 80% depending on num_n_blocks.
  223. int num_m_blocks = batch_nheads_mblocks/batch_nheads;
  224. float start_threshold;
  225. float num_n_blocksf = float(num_n_blocks);
  226. if (head_size == 128) {
  227. if (std::log2f(num_n_blocksf) <= 4) { // 2048 -- .25
  228. start_threshold = .20f + (std::log2f(num_n_blocksf) - 3) * .05f;
  229. } else if (std::log2f(num_n_blocksf) <= 5) { // 4096 -- .25
  230. start_threshold = .25f;
  231. } else if (std::log2f(num_n_blocksf) <= 6) { // 8192 -- .36
  232. start_threshold = .28f + (std::log2f(num_n_blocksf) - 5) * .08f;
  233. } else if (std::log2f(num_n_blocksf) <= 7) { // 16K -- .42
  234. start_threshold = .36f + (std::log2f(num_n_blocksf) - 6) * .06f;
  235. } else {
  236. // Just split freely
  237. start_threshold = .8f;
  238. }
  239. if (num_m_blocks > 1 && start_threshold < .5f)
  240. start_threshold += .05f * (std::log2f(num_n_blocksf) - 2);
  241. } else if (head_size == 256) {
  242. // TODO for hdim 256
  243. if (num_n_blocks <= 40) {
  244. start_threshold = .24f;
  245. } else if (std::log2f(num_n_blocksf) <= 8) {
  246. start_threshold = .33f + std::max(0.f, (std::log2f(num_n_blocksf) - std::log2f(50)) * 0.02971f);
  247. } else {
  248. // Just split freely
  249. start_threshold = .8f;
  250. }
  251. } else if (head_size == 64) {
  252. if (use_one_mma_wg) {
  253. if (std::log2f(num_n_blocksf) <= 4) { // 2K -- .33
  254. start_threshold = .33f;
  255. } else if (std::log2f(num_n_blocksf) <= 5) { // 4K -- .37
  256. start_threshold = .33f + (std::log2f(num_n_blocksf) - 4) * .04f;
  257. } else if (std::log2f(num_n_blocksf) <= 6) { // 8K -- .40
  258. start_threshold = .37f + (std::log2f(num_n_blocksf) - 5) * .03f;
  259. } else if (std::log2f(num_n_blocksf) <= 7) { // 16K -- .43
  260. start_threshold = .4f + (std::log2f(num_n_blocksf) - 6) * .03f;
  261. } else if (std::log2f(num_n_blocksf) <= 8) { // 32K -- .46
  262. start_threshold = .43f + (std::log2f(num_n_blocksf) - 7) * .03f;
  263. } else {
  264. start_threshold = .8f;
  265. }
  266. } else {
  267. if (std::log2f(num_n_blocksf) <= 6) { // 8K -- .5
  268. start_threshold = .5f;
  269. } else {
  270. start_threshold = .8f;
  271. }
  272. }
  273. } else {
  274. // placeholder for other hdims
  275. start_threshold = .8f;
  276. }
  277. float first_wave = float(batch_nheads_mblocks) / num_SMs;
  278. // printf("Start threshold and wave = %f, %f.\n", start_threshold, first_wave);
  279. // Only use start_threshold if initial work doesn't exceed one wave
  280. if ((first_wave/ceil(first_wave) > start_threshold && first_wave <= 1.f) ||
  281. (first_wave/ceil(first_wave) > .8f)) {
  282. return 1;
  283. }
  284. // if (first_wave_batch_nheads > start_threshold) { return 1; }
  285. // if (first_wave_batch_nheads > start_threshold || first_wave > .8f) { return 1; }
  286. // if (float(batch_nheads)/num_SMs > start_threshold) { return 1; }
  287. // If num_n_blocks is too small, use 1 split
  288. // For example, we never split for hdim = 128 and seqlen_k = 512,
  289. // or for hdim = 128, seqlen_k = 1024, and one MMA warpgroup.
  290. if (num_n_blocks < 8 || (use_one_mma_wg && num_n_blocks < 10)) { return 1; }
  291. max_splits = std::min({max_splits, num_SMs, num_n_blocks});
  292. float max_efficiency = 0.f;
  293. std::vector<float> efficiency;
  294. efficiency.reserve(max_splits);
  295. // NOTE: disable split eligibility check for FA3 since we have dynamic tile scheduler
  296. // for exiting splits with no work early, and check leads to efficiency quantization issues.
  297. // Comment from FA2:
  298. // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
  299. // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
  300. // (i.e. it's 11 splits anyway).
  301. // So we check if the number of blocks per split is the same as the previous num_splits.
  302. // auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
  303. // auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
  304. // return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
  305. // };
  306. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  307. // if (!is_split_eligible(num_splits)) {
  308. // efficiency.push_back(0.f);
  309. // } else {
  310. float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
  311. float eff = n_waves / ceil(n_waves);
  312. // printf("num_splits = %d, n_waves = %f, ceil(n_waves) = %f, eff = %f\n", num_splits, n_waves, ceil(n_waves), eff);
  313. if (eff > max_efficiency) { max_efficiency = eff; }
  314. efficiency.push_back(eff);
  315. // }
  316. }
  317. // Correct for excessive splitting with e.g. 1 bsz*nheads*mblocks
  318. // Empirically, efficiency threshold in these cases is about 40% for 64K seqlen_k
  319. float threshold = num_m_blocks == 1 ? std::min(0.3f + batch_nheads * 0.1f, 0.8f) : 0.8f;
  320. threshold = threshold * max_efficiency;
  321. // printf("Max efficiency = %f. Threshold = %f.\n", max_efficiency, threshold);
  322. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  323. // if (!is_split_eligible(num_splits)) { continue; }
  324. if (efficiency[num_splits - 1] > threshold) {
  325. // printf("num_splits chosen = %d, threshold = %f, efficiency = %f.\n", num_splits, threshold, efficiency[num_splits - 1]);
  326. return num_splits;
  327. }
  328. }
  329. return 1;
  330. }
  331. std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
  332. const int num_heads, const int num_heads_k, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
  333. const int head_size_rounded, const float p_dropout,
  334. const int num_splits, cudaDeviceProp *dprops, bool use_gqa_packing, bool is_causal, struct c10::TensorOptions opts) {
  335. auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
  336. params.num_splits = num_splits;
  337. at::Tensor softmax_lse_accum;
  338. at::Tensor out_accum;
  339. if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
  340. if (num_splits < 1) {
  341. const int gqa_ratio = num_heads / num_heads_k;
  342. const int block_h = 1 << static_cast<int>(std::ceil(std::log2(std::clamp(gqa_ratio, 1, 32))));
  343. const int block_m = head_size == 64 ? 192 : 128;
  344. const bool use_one_mma_wg = max_seqlen_q <= 64/block_h;
  345. int block_n = 128;
  346. if (head_size == 128 && !is_causal) {
  347. block_n = 176;
  348. } else if (head_size == 256) {
  349. block_n = use_one_mma_wg ? 96 : 80;
  350. }
  351. const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
  352. const int batch_nheads = use_gqa_packing ? batch_size * num_heads_k : batch_size * num_heads;
  353. const int batch_nheads_mblocks = use_gqa_packing
  354. ? ceildiv(max_seqlen_q, block_m / block_h) * batch_nheads
  355. : ceildiv(max_seqlen_q, block_m) * batch_nheads;
  356. params.num_splits = num_splits_heuristic(batch_nheads_mblocks, batch_nheads,
  357. dprops->multiProcessorCount, num_n_blocks, 128, head_size, use_one_mma_wg);
  358. // printf("Num splits heuristic = %d.\n", params.num_splits);
  359. }
  360. if (params.num_splits > 1) {
  361. softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
  362. out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
  363. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  364. params.oaccum_ptr = out_accum.data_ptr();
  365. params.oaccum_row_stride = out_accum.stride(-2);
  366. params.oaccum_head_stride = out_accum.stride(-3);
  367. params.oaccum_batch_stride = out_accum.stride(-4);
  368. params.oaccum_split_stride = out_accum.stride(0);
  369. }
  370. TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
  371. }
  372. return std::make_tuple(softmax_lse_accum, out_accum);
  373. }
  374. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  375. int dtype = 1;
  376. if (params.is_bf16) { dtype = 2; }
  377. else if (params.is_e4m3) { dtype = 3; }
  378. PREC_SWITCH(dtype, Element, [&] {
  379. HEADDIM_SWITCH(params.d, kHeadSize, [&] {
  380. if(!params.use_gqa_packing) {
  381. run_mha_fwd_<Element, kHeadSize>(params, stream);
  382. } else {
  383. QUERYHEAD_SWITCH(params.h_h_k_ratio, kBlockH, [&] {
  384. run_mha_fwd_gqa_<Element, kHeadSize, kBlockH>(params, stream);
  385. });
  386. }
  387. });
  388. });
  389. #if 0
  390. if (!params.is_e4m3) {
  391. if (params.is_bf16) {
  392. if (params.d == 64) {
  393. run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
  394. } else if (params.d == 128) {
  395. run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
  396. } else {
  397. run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
  398. }
  399. } else {
  400. if (params.d == 64) {
  401. run_mha_fwd_<cutlass::half_t, 64>(params, stream);
  402. } else if (params.d == 128) {
  403. run_mha_fwd_<cutlass::half_t, 128>(params, stream);
  404. } else {
  405. run_mha_fwd_<cutlass::half_t, 256>(params, stream);
  406. }
  407. }
  408. } else {
  409. if (params.d == 64) {
  410. run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
  411. } else if (params.d == 128) {
  412. run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
  413. } else if (params.d == 256) {
  414. run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
  415. }
  416. }
  417. #endif
  418. }
  419. std::vector<at::Tensor>
  420. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  421. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  422. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  423. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  424. const float softmax_scale,
  425. c10::optional<at::Tensor> &descale_q_, // 1
  426. c10::optional<at::Tensor> &descale_k_, // 1
  427. c10::optional<at::Tensor> &descale_v_, // 1
  428. bool is_causal,
  429. int window_size_left,
  430. int window_size_right,
  431. bool use_gqa_packing = false
  432. ) {
  433. auto dprops = at::cuda::getCurrentDeviceProperties();
  434. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  435. TORCH_CHECK(is_sm90, "FlashAttention-3 only supports Hopper GPUs or newer.");
  436. auto q_dtype = q.dtype();
  437. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == at::ScalarType::Float8_e4m3fn,
  438. "FlashAttention-3 only support fp16, bf16, or fp8 e4m3 data type");
  439. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  440. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  441. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  442. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  443. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  444. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  445. const auto sizes = q.sizes();
  446. const int batch_size = sizes[0];
  447. int seqlen_q = sizes[1];
  448. int num_heads = sizes[2];
  449. const int head_size_og = sizes[3];
  450. const int seqlen_k = k.size(1);
  451. const int num_heads_k = k.size(2);
  452. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  453. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  454. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  455. // Guard against mistaken setting of gqa flag
  456. if (num_heads == num_heads_k) { use_gqa_packing = false; }
  457. TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
  458. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  459. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  460. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  461. at::Tensor q_padded, k_padded, v_padded;
  462. if (head_size_og % 8 != 0) {
  463. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  464. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  465. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  466. } else {
  467. q_padded = q;
  468. k_padded = k;
  469. v_padded = v;
  470. }
  471. at::Tensor out;
  472. if (out_.has_value()) {
  473. out = out_.value();
  474. // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  475. TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn
  476. ? (out.dtype() == at::kBFloat16)
  477. : (out.dtype() == q_dtype),
  478. "Output must have the same dtype as input dtype if dtype is "
  479. "not fp8, or fp16 for fp8 input.");
  480. CHECK_DEVICE(out);
  481. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  482. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  483. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  484. } else {
  485. if (q_dtype == at::ScalarType::Float8_e4m3fn)
  486. out = torch::empty_like(q_padded, at::kBFloat16);
  487. else
  488. out = torch::empty_like(q_padded);
  489. }
  490. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  491. const int head_size = round_multiple(head_size_og, 8);
  492. const int head_size_rounded = round_multiple(head_size, 32);
  493. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  494. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  495. if (is_causal) { window_size_right = 0; }
  496. // Otherwise the kernel will be launched from cuda:0 device
  497. at::cuda::CUDAGuard device_guard{q.device()};
  498. auto opts = q.options();
  499. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  500. at::Tensor p;
  501. Flash_fwd_params params;
  502. set_params_fprop(params,
  503. batch_size, batch_size,
  504. seqlen_q, seqlen_k,
  505. seqlen_q_rounded, seqlen_k_rounded,
  506. num_heads, num_heads_k,
  507. head_size, head_size_rounded,
  508. q_padded, k_padded, v_padded, out,
  509. /*cu_seqlens_q_d=*/nullptr,
  510. /*cu_seqlens_k_d=*/nullptr,
  511. /*seqused_q=*/nullptr,
  512. /*seqused_k=*/nullptr,
  513. nullptr,
  514. softmax_lse.data_ptr(),
  515. /*p_dropout=*/0.f,
  516. softmax_scale,
  517. /*window_size_left=*/window_size_left,
  518. /*window_size_right=*/window_size_right);
  519. auto tile_count_semaphore = is_causal || params.is_local
  520. ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  521. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  522. at::Tensor descale_q, descale_k, descale_v;
  523. if(q_dtype == at::ScalarType::Float8_e4m3fn) {
  524. if (descale_q_.has_value()) {
  525. descale_q = descale_q_.value();
  526. CHECK_DEVICE(descale_q);
  527. CHECK_SHAPE(descale_q, 1);
  528. } else { descale_q = torch::ones({1}, opts.dtype(at::kFloat)); }
  529. if (descale_k_.has_value()) {
  530. descale_k = descale_k_.value();
  531. CHECK_DEVICE(descale_k);
  532. CHECK_SHAPE(descale_k, 1);
  533. } else { descale_k = torch::ones({1}, opts.dtype(at::kFloat)); }
  534. if (descale_v_.has_value()) {
  535. descale_v = descale_v_.value();
  536. CHECK_DEVICE(descale_v);
  537. CHECK_SHAPE(descale_v, 1);
  538. } else { descale_v = torch::ones({1}, opts.dtype(at::kFloat)); }
  539. params.descale_q_ptr = descale_q.data_ptr<float>();
  540. params.descale_k_ptr = descale_k.data_ptr<float>();
  541. params.descale_v_ptr = descale_v.data_ptr<float>();
  542. } else {
  543. params.descale_q_ptr = nullptr;
  544. params.descale_k_ptr = nullptr;
  545. params.descale_v_ptr = nullptr;
  546. }
  547. params.use_gqa_packing = use_gqa_packing;
  548. if (seqlen_k > 0) {
  549. auto stream = at::cuda::getCurrentCUDAStream().stream();
  550. run_mha_fwd(params, stream);
  551. } else {
  552. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  553. out.zero_();
  554. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  555. }
  556. at::Tensor out_padded = out;
  557. if (head_size_og % 8 != 0) {
  558. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  559. if (out_.has_value()) { out_.value().copy_(out); }
  560. }
  561. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
  562. }
  563. std::vector<at::Tensor>
  564. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  565. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  566. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  567. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  568. const at::Tensor &cu_seqlens_q, // b+1
  569. const at::Tensor &cu_seqlens_k, // b+1
  570. c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
  571. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  572. std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  573. int max_seqlen_q,
  574. const int max_seqlen_k,
  575. const float softmax_scale,
  576. bool is_causal,
  577. int window_size_left,
  578. int window_size_right) {
  579. auto dprops = at::cuda::getCurrentDeviceProperties();
  580. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  581. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  582. auto q_dtype = q.dtype();
  583. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  584. "FlashAttention only support fp16 and bf16 data type");
  585. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  586. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  587. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  588. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  589. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  590. CHECK_DEVICE(cu_seqlens_q);
  591. CHECK_DEVICE(cu_seqlens_k);
  592. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  593. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  594. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  595. CHECK_CONTIGUOUS(cu_seqlens_q);
  596. CHECK_CONTIGUOUS(cu_seqlens_k);
  597. at::Tensor block_table;
  598. const bool paged_KV = block_table_.has_value();
  599. if (paged_KV) {
  600. block_table = block_table_.value();
  601. CHECK_DEVICE(block_table);
  602. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  603. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  604. }
  605. const auto sizes = q.sizes();
  606. const int batch_size = cu_seqlens_q.numel() - 1;
  607. int num_heads = sizes[1];
  608. const int head_size_og = sizes[2];
  609. const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
  610. void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  611. const int total_q = q.sizes()[0];
  612. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  613. const int num_blocks = !paged_KV ? 0 : k.size(0);
  614. const int page_block_size = !paged_KV ? -1 : k.size(1);
  615. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  616. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  617. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  618. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  619. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  620. const int total_k = k.size(0);
  621. if (!paged_KV) {
  622. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  623. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  624. } else {
  625. CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
  626. CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
  627. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  628. }
  629. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  630. if (seqused_q.has_value()){
  631. auto seqused_q_ = seqused_q.value();
  632. TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  633. TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
  634. TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
  635. CHECK_SHAPE(seqused_q_, batch_size);
  636. }
  637. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  638. if (seqused_k.has_value()){
  639. auto seqused_k_ = seqused_k.value();
  640. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  641. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  642. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  643. CHECK_SHAPE(seqused_k_, batch_size);
  644. }
  645. at::Tensor q_padded, k_padded, v_padded;
  646. if (head_size_og % 8 != 0) {
  647. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  648. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  649. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  650. } else {
  651. q_padded = q;
  652. k_padded = k;
  653. v_padded = v;
  654. }
  655. at::Tensor out;
  656. if (out_.has_value()) {
  657. out = out_.value();
  658. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  659. CHECK_DEVICE(out);
  660. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  661. CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
  662. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  663. } else {
  664. out = torch::empty_like(q_padded);
  665. }
  666. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  667. const int head_size = round_multiple(head_size_og, 8);
  668. const int head_size_rounded = round_multiple(head_size, 32);
  669. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  670. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  671. if (is_causal) { window_size_right = 0; }
  672. // Otherwise the kernel will be launched from cuda:0 device
  673. at::cuda::CUDAGuard device_guard{q.device()};
  674. auto opts = q.options();
  675. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  676. Flash_fwd_params params;
  677. set_params_fprop(params,
  678. batch_size, batch_size,
  679. max_seqlen_q, max_seqlen_k,
  680. seqlen_q_rounded, seqlen_k_rounded,
  681. num_heads, num_heads_k,
  682. head_size, head_size_rounded,
  683. q_padded, k_padded, v_padded, out,
  684. cu_seqlens_q_d,
  685. cu_seqlens_k.data_ptr(),
  686. seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
  687. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  688. /*p_d=*/nullptr,
  689. softmax_lse.data_ptr(),
  690. /*p_dropout=*/0.f,
  691. softmax_scale,
  692. window_size_left,
  693. window_size_right,
  694. /*seqlenq_ngroups_swapped=*/false,
  695. /*unpadded_lse=*/true);
  696. params.total_q = total_q;
  697. params.total_k = total_k;
  698. if (paged_KV) {
  699. params.block_table = block_table.data_ptr<int>();
  700. params.block_table_batch_stride = block_table.stride(0);
  701. params.k_batch_stride = k.stride(0);
  702. params.v_batch_stride = v.stride(0);
  703. params.page_num_blocks = k.size(0);
  704. }
  705. params.page_block_size = page_block_size;
  706. params.page_num_blocks = num_blocks;
  707. //printf("mha_varlen_fwd: params.seqlen_k=%d, max_seqlen_k=%d, params.page_num_blocks=%d\n", (int)params.seqlen_k, (int)max_seqlen_k, (int)params.page_num_blocks);
  708. if (max_seqlen_k > 0) {
  709. auto stream = at::cuda::getCurrentCUDAStream().stream();
  710. run_mha_fwd(params, stream);
  711. } else {
  712. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  713. out.zero_();
  714. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  715. }
  716. at::Tensor out_padded = out;
  717. if (head_size_og % 8 != 0) {
  718. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  719. if (out_.has_value()) { out_.value().copy_(out); }
  720. }
  721. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
  722. }
  723. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  724. // FP16_SWITCH(!params.is_bf16, [&] {
  725. // HEADDIM_SWITCH(params.d, [&] {
  726. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  727. // });
  728. // });
  729. if (!params.is_bf16) {
  730. if (params.d <= 64) {
  731. run_mha_bwd_<cutlass::half_t, 64>(params, stream);
  732. } else if (params.d <= 96) {
  733. run_mha_bwd_<cutlass::half_t, 96>(params, stream);
  734. } else {
  735. run_mha_bwd_<cutlass::half_t, 128>(params, stream);
  736. }
  737. } else {
  738. if (params.d <= 64) {
  739. run_mha_bwd_<cutlass::bfloat16_t, 64>(params, stream);
  740. } else if (params.d <= 96) {
  741. run_mha_bwd_<cutlass::bfloat16_t, 96>(params, stream);
  742. } else {
  743. run_mha_bwd_<cutlass::bfloat16_t, 128>(params, stream);
  744. }
  745. }
  746. }
  747. std::vector<at::Tensor>
  748. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  749. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  750. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  751. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  752. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  753. const at::Tensor &softmax_lse, // b x h x seqlen_q
  754. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  755. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  756. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  757. const float softmax_scale,
  758. const bool is_causal,
  759. int window_size_left,
  760. int window_size_right,
  761. const bool deterministic) {
  762. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  763. TORCH_CHECK(false, "This flash attention build does not support backward.");
  764. #endif
  765. auto dprops = at::cuda::getCurrentDeviceProperties();
  766. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  767. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  768. auto stream = at::cuda::getCurrentCUDAStream().stream();
  769. auto q_dtype = q.dtype();
  770. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  771. "FlashAttention only support fp16 and bf16 data type");
  772. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  773. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  774. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  775. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  776. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  777. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  778. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  779. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  780. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  781. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  782. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  783. const auto sizes = q.sizes();
  784. const int batch_size = sizes[0];
  785. const int seqlen_q = sizes[1];
  786. const int num_heads = sizes[2];
  787. const int head_size_og = dout.size(3);
  788. const int head_size = sizes[3];
  789. const int seqlen_k = k.size(1);
  790. const int num_heads_k = k.size(2);
  791. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  792. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  793. TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
  794. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  795. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  796. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  797. // This should match the kernel configs
  798. const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
  799. const int seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
  800. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  801. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  802. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  803. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  804. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  805. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  806. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  807. at::Tensor dq, dk, dv;
  808. if (dq_.has_value()) {
  809. dq = dq_.value();
  810. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  811. CHECK_DEVICE(dq);
  812. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  813. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  814. } else {
  815. dq = torch::empty_like(q);
  816. }
  817. if (dk_.has_value()) {
  818. dk = dk_.value();
  819. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  820. CHECK_DEVICE(dk);
  821. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  822. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  823. } else {
  824. dk = torch::empty_like(k);
  825. }
  826. if (dv_.has_value()) {
  827. dv = dv_.value();
  828. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  829. CHECK_DEVICE(dv);
  830. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  831. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  832. } else {
  833. dv = torch::empty_like(v);
  834. }
  835. at::Tensor dout_padded;
  836. if (head_size_og % 8 != 0) {
  837. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  838. } else {
  839. dout_padded = dout;
  840. }
  841. // Otherwise the kernel will be launched from cuda:0 device
  842. at::cuda::CUDAGuard device_guard{q.device()};
  843. auto opts = q.options();
  844. // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  845. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  846. auto softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  847. at::Tensor dq_accum;
  848. at::Tensor dk_accum, dv_accum;
  849. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  850. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  851. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  852. at::Tensor dk_expanded, dv_expanded;
  853. if (num_heads_k != num_heads) { // MQA / GQA
  854. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  855. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  856. } else {
  857. dk_expanded = dk;
  858. dv_expanded = dv;
  859. }
  860. if (is_causal) { window_size_right = 0; }
  861. Flash_bwd_params params;
  862. set_params_dgrad(params,
  863. batch_size,
  864. seqlen_q, seqlen_k,
  865. seqlen_q_rounded, seqlen_k_rounded,
  866. num_heads, num_heads_k,
  867. head_size, head_size_rounded,
  868. q, k, v, out,
  869. dout_padded, dq, dk_expanded, dv_expanded,
  870. /*cu_seqlens_q_d=*/nullptr,
  871. /*cu_seqlens_k_d=*/nullptr,
  872. /*seqused_q=*/nullptr,
  873. /*seqused_k=*/nullptr,
  874. dq_accum.data_ptr(),
  875. // loop ? dk_accum.data_ptr() : nullptr,
  876. // loop ? dv_accum.data_ptr() : nullptr,
  877. nullptr,
  878. nullptr,
  879. softmax_lse.data_ptr(),
  880. softmax_d.data_ptr(),
  881. /*p_dropout=*/0.f,
  882. softmax_scale,
  883. /*window_size_left=*/window_size_left,
  884. /*window_size_right=*/window_size_right,
  885. deterministic);
  886. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  887. // Will be zero'ed out in the backward preprocess kernel
  888. at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  889. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  890. // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads);
  891. if (seqlen_q > 0) {
  892. run_mha_bwd(params, stream);
  893. } else {
  894. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  895. dk_expanded.zero_();
  896. dv_expanded.zero_();
  897. softmax_d.zero_();
  898. }
  899. // For MQA/GQA we need to sum dK and dV across the groups
  900. if (num_heads_k != num_heads) {
  901. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  902. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  903. }
  904. if (head_size_og % 8 != 0) {
  905. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  906. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  907. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  908. }
  909. return { dq, dk, dv, softmax_d, dq_accum};
  910. }
  911. std::vector<at::Tensor>
  912. mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  913. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  914. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  915. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  916. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  917. const at::Tensor &softmax_lse, // b x h x seqlen_q
  918. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  919. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  920. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  921. const at::Tensor &cu_seqlens_q, // b+1
  922. const at::Tensor &cu_seqlens_k, // b+1
  923. c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
  924. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  925. const int max_seqlen_q,
  926. const int max_seqlen_k, // max sequence length to choose the kernel
  927. const float softmax_scale,
  928. const bool is_causal,
  929. int window_size_left,
  930. int window_size_right,
  931. const bool deterministic) {
  932. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  933. TORCH_CHECK(false, "This flash attention build does not support backward.");
  934. #endif
  935. auto dprops = at::cuda::getCurrentDeviceProperties();
  936. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  937. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  938. auto stream = at::cuda::getCurrentCUDAStream().stream();
  939. auto q_dtype = q.dtype();
  940. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  941. "FlashAttention only support fp16 and bf16 data type");
  942. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  943. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  944. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  945. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  946. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  947. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  948. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  949. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  950. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  951. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  952. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  953. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  954. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  955. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  956. CHECK_CONTIGUOUS(cu_seqlens_q);
  957. CHECK_CONTIGUOUS(cu_seqlens_k);
  958. const auto sizes = q.sizes();
  959. const int total_q = sizes[0];
  960. const int batch_size = cu_seqlens_q.numel() - 1;
  961. const int num_heads = sizes[1];
  962. const int head_size_og = dout.size(2);
  963. const int head_size = sizes[2];
  964. const int total_k = k.size(0);
  965. const int num_heads_k = k.size(1);
  966. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  967. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  968. TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
  969. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  970. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  971. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  972. // This should match the kernel configs
  973. const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
  974. const int seqlen_q_rounded = round_multiple(max_seqlen_q, kBlockM);
  975. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  976. int const total_q_padded_rounded = round_multiple(total_q + batch_size * 128, 128);
  977. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  978. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  979. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  980. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  981. CHECK_SHAPE(out, total_q, num_heads, head_size);
  982. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  983. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  984. if (seqused_q.has_value()){
  985. auto seqused_q_ = seqused_q.value();
  986. TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  987. TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
  988. TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
  989. CHECK_SHAPE(seqused_q_, batch_size);
  990. }
  991. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  992. if (seqused_k.has_value()){
  993. auto seqused_k_ = seqused_k.value();
  994. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  995. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  996. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  997. CHECK_SHAPE(seqused_k_, batch_size);
  998. }
  999. at::Tensor dq, dk, dv;
  1000. if (dq_.has_value()) {
  1001. dq = dq_.value();
  1002. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  1003. CHECK_DEVICE(dq);
  1004. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  1005. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  1006. } else {
  1007. dq = torch::empty_like(q);
  1008. }
  1009. if (dk_.has_value()) {
  1010. dk = dk_.value();
  1011. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  1012. CHECK_DEVICE(dk);
  1013. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  1014. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  1015. } else {
  1016. dk = torch::empty_like(k);
  1017. }
  1018. if (dv_.has_value()) {
  1019. dv = dv_.value();
  1020. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  1021. CHECK_DEVICE(dv);
  1022. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  1023. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  1024. } else {
  1025. dv = torch::empty_like(v);
  1026. }
  1027. at::Tensor dout_padded;
  1028. if (head_size_og % 8 != 0) {
  1029. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1030. } else {
  1031. dout_padded = dout;
  1032. }
  1033. if (is_causal) { window_size_right = 0; }
  1034. // Otherwise the kernel will be launched from cuda:0 device
  1035. at::cuda::CUDAGuard device_guard{q.device()};
  1036. auto opts = q.options();
  1037. // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  1038. auto softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1039. auto softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1040. at::Tensor dq_accum;
  1041. at::Tensor dk_accum, dv_accum;
  1042. dq_accum = torch::empty({num_heads, total_q_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1043. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  1044. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  1045. at::Tensor dk_expanded, dv_expanded;
  1046. if (num_heads_k != num_heads) { // MQA / GQA
  1047. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  1048. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  1049. } else {
  1050. dk_expanded = dk;
  1051. dv_expanded = dv;
  1052. }
  1053. Flash_bwd_params params;
  1054. set_params_dgrad(params,
  1055. batch_size,
  1056. max_seqlen_q, max_seqlen_k,
  1057. seqlen_q_rounded, seqlen_k_rounded,
  1058. num_heads, num_heads_k,
  1059. head_size, head_size_rounded,
  1060. q, k, v, out,
  1061. dout_padded, dq, dk_expanded, dv_expanded,
  1062. cu_seqlens_q.data_ptr(),
  1063. cu_seqlens_k.data_ptr(),
  1064. seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
  1065. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  1066. dq_accum.data_ptr(),
  1067. // loop ? dk_accum.data_ptr() : nullptr,
  1068. // loop ? dv_accum.data_ptr() : nullptr,
  1069. nullptr,
  1070. nullptr,
  1071. softmax_lse.data_ptr(),
  1072. softmax_d.data_ptr(),
  1073. /*p_dropout=*/0.f,
  1074. softmax_scale,
  1075. /*window_size_left=*/window_size_left,
  1076. /*window_size_right=*/window_size_right,
  1077. deterministic);
  1078. params.total_q = total_q;
  1079. params.total_k = total_k;
  1080. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  1081. // Will be zero'ed out in the backward preprocess kernel
  1082. at::Tensor dq_semaphore = torch::empty({(max_seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  1083. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  1084. if (max_seqlen_q > 0) {
  1085. run_mha_bwd(params, stream);
  1086. } else {
  1087. // If max_seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  1088. dk_expanded.zero_();
  1089. dv_expanded.zero_();
  1090. softmax_d.zero_();
  1091. }
  1092. // For MQA/GQA we need to sum dK and dV across the groups
  1093. if (num_heads_k != num_heads) {
  1094. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1095. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1096. }
  1097. if (head_size_og % 8 != 0) {
  1098. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1099. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1100. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1101. }
  1102. return { dq, dk, dv, softmax_d, dq_accum, softmax_lse_log2 };
  1103. }
  1104. std::vector<at::Tensor>
  1105. mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  1106. const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1107. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1108. c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
  1109. c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
  1110. c10::optional<const at::Tensor> &seqlens_k_, // batch_size
  1111. c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  1112. c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  1113. c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
  1114. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  1115. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  1116. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  1117. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  1118. const float softmax_scale,
  1119. c10::optional<at::Tensor> &descale_q_, // 1
  1120. c10::optional<at::Tensor> &descale_k_, // 1
  1121. c10::optional<at::Tensor> &descale_v_, // 1
  1122. bool is_causal,
  1123. int window_size_left,
  1124. int window_size_right,
  1125. const float softcap,
  1126. bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  1127. int num_splits,
  1128. int max_seqlen_k_hint,
  1129. bool use_gqa_packing
  1130. ) {
  1131. auto dprops = at::cuda::getCurrentDeviceProperties();
  1132. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  1133. // bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  1134. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  1135. TORCH_CHECK(is_sm90, "FlashAttention-3 only supports Hopper GPUs or newer.");
  1136. auto q_dtype = q.dtype();
  1137. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == at::ScalarType::Float8_e4m3fn,
  1138. "FlashAttention-3 only support fp16, bf16, or fp8 e4m3 data type");
  1139. TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
  1140. TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
  1141. CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
  1142. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1143. TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1144. TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1145. at::Tensor block_table;
  1146. const bool paged_KV = block_table_.has_value();
  1147. if (paged_KV) {
  1148. TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
  1149. block_table = block_table_.value();
  1150. CHECK_DEVICE(block_table);
  1151. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  1152. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  1153. }
  1154. const auto sizes = q.sizes();
  1155. const int batch_size = sizes[0];
  1156. int seqlen_q = sizes[1];
  1157. int num_heads = sizes[2];
  1158. const int head_size_og = sizes[3];
  1159. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  1160. const int num_blocks = !paged_KV ? 0 : kcache.size(0);
  1161. const int page_block_size = !paged_KV ? 1 : kcache.size(1);
  1162. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  1163. const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
  1164. const int num_heads_k = kcache.size(2);
  1165. const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
  1166. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  1167. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  1168. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  1169. // Guard against mistaken setting of gqa flag
  1170. if (num_heads == num_heads_k) { use_gqa_packing = false; }
  1171. // causal=true is the same as causal=false in this case
  1172. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  1173. if (is_causal) { window_size_right = 0; }
  1174. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  1175. // H/t Daniel Haziza
  1176. const int seqlenq_ngroups_swapped =
  1177. seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 &&
  1178. window_size_right < 0 && head_size_og % 8 == 0 &&
  1179. !alibi_slopes_.has_value() && !use_gqa_packing;
  1180. if (seqlenq_ngroups_swapped) {
  1181. const int ngroups = num_heads / num_heads_k;
  1182. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  1183. seqlen_q = ngroups;
  1184. num_heads = num_heads_k;
  1185. }
  1186. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  1187. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  1188. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  1189. if (!paged_KV) {
  1190. CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1191. CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1192. } else {
  1193. CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1194. CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1195. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  1196. }
  1197. at::Tensor q_padded, kcache_padded, vcache_padded;
  1198. if (head_size_og % 8 != 0) {
  1199. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1200. kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1201. vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1202. } else {
  1203. q_padded = q;
  1204. kcache_padded = kcache;
  1205. vcache_padded = vcache;
  1206. }
  1207. at::Tensor out;
  1208. if (out_.has_value()) {
  1209. out = out_.value();
  1210. // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  1211. TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn
  1212. ? (out.dtype() == at::kBFloat16)
  1213. : (out.dtype() == q_dtype),
  1214. "Output must have the same dtype as input dtype if dtype is "
  1215. "not fp8, or fp16 for fp8 input.");
  1216. CHECK_DEVICE(out);
  1217. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1218. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  1219. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  1220. } else {
  1221. if (q_dtype == at::ScalarType::Float8_e4m3fn) {
  1222. out = torch::empty_like(q_padded, at::kBFloat16);
  1223. }
  1224. else
  1225. out = torch::empty_like(q_padded);
  1226. }
  1227. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1228. const int head_size = round_multiple(head_size_og, 8);
  1229. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  1230. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  1231. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  1232. // Otherwise the kernel will be launched from cuda:0 device
  1233. at::cuda::CUDAGuard device_guard{q.device()};
  1234. auto opts = q.options();
  1235. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  1236. Flash_fwd_params params;
  1237. set_params_fprop(params,
  1238. batch_size, batch_size_c,
  1239. seqlen_q, seqlen_k,
  1240. seqlen_q_rounded, seqlen_k_rounded,
  1241. num_heads, num_heads_k,
  1242. head_size, head_size_rounded,
  1243. q_padded, kcache_padded, vcache_padded, out,
  1244. /*cu_seqlens_q_d=*/nullptr,
  1245. /*cu_seqlens_k_d=*/nullptr,
  1246. /*seqused_q=*/nullptr,
  1247. /*seqused_k=*/nullptr,
  1248. /*p_ptr=*/nullptr,
  1249. softmax_lse.data_ptr(),
  1250. /*p_dropout=*/0.f,
  1251. softmax_scale,
  1252. window_size_left,
  1253. window_size_right
  1254. );
  1255. at::Tensor descale_q, descale_k, descale_v;
  1256. if(q_dtype == at::ScalarType::Float8_e4m3fn) {
  1257. if (descale_q_.has_value()) {
  1258. descale_q = descale_q_.value();
  1259. CHECK_DEVICE(descale_q);
  1260. CHECK_SHAPE(descale_q, 1);
  1261. } else { descale_q = torch::ones({1}, opts.dtype(at::kFloat)); }
  1262. if (descale_k_.has_value()) {
  1263. descale_k = descale_k_.value();
  1264. CHECK_DEVICE(descale_k);
  1265. CHECK_SHAPE(descale_k, 1);
  1266. } else { descale_k = torch::ones({1}, opts.dtype(at::kFloat)); }
  1267. if (descale_v_.has_value()) {
  1268. descale_v = descale_v_.value();
  1269. CHECK_DEVICE(descale_v);
  1270. CHECK_SHAPE(descale_v, 1);
  1271. } else { descale_v = torch::ones({1}, opts.dtype(at::kFloat)); }
  1272. params.descale_q_ptr = descale_q.data_ptr<float>();
  1273. params.descale_k_ptr = descale_k.data_ptr<float>();
  1274. params.descale_v_ptr = descale_v.data_ptr<float>();
  1275. } else {
  1276. params.descale_q_ptr = nullptr;
  1277. params.descale_k_ptr = nullptr;
  1278. params.descale_v_ptr = nullptr;
  1279. }
  1280. params.is_kv_cache = true;
  1281. params.use_gqa_packing = use_gqa_packing;
  1282. at::Tensor k, v, k_padded, v_padded;
  1283. if (k_.has_value()) {
  1284. TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
  1285. TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
  1286. TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
  1287. k = k_.value();
  1288. v = v_.value();
  1289. TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
  1290. TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
  1291. CHECK_DEVICE(k); CHECK_DEVICE(v);
  1292. TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
  1293. TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
  1294. int seqlen_knew = k.size(1);
  1295. CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1296. CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1297. if (head_size_og % 8 != 0) {
  1298. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1299. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1300. } else {
  1301. k_padded = k;
  1302. v_padded = v;
  1303. }
  1304. params.seqlen_knew = seqlen_knew;
  1305. params.knew_ptr = k_padded.data_ptr();
  1306. params.vnew_ptr = v_padded.data_ptr();
  1307. // All stride are in elements, not bytes.
  1308. params.knew_batch_stride = k_padded.stride(0);
  1309. params.vnew_batch_stride = v_padded.stride(0);
  1310. params.knew_row_stride = k_padded.stride(-3);
  1311. params.vnew_row_stride = v_padded.stride(-3);
  1312. params.knew_head_stride = k_padded.stride(-2);
  1313. params.vnew_head_stride = v_padded.stride(-2);
  1314. }
  1315. if (seqlens_k_.has_value()) {
  1316. auto seqlens_k = seqlens_k_.value();
  1317. TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
  1318. CHECK_DEVICE(seqlens_k);
  1319. CHECK_CONTIGUOUS(seqlens_k);
  1320. CHECK_SHAPE(seqlens_k, batch_size);
  1321. params.seqused_k = static_cast<int *>(seqlens_k.data_ptr());
  1322. }
  1323. if (leftpad_k_.has_value()) {
  1324. TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
  1325. auto leftpad_k = leftpad_k_.value();
  1326. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  1327. CHECK_DEVICE(leftpad_k);
  1328. CHECK_CONTIGUOUS(leftpad_k);
  1329. CHECK_SHAPE(leftpad_k, batch_size);
  1330. TORCH_CHECK(false, "Left Padding K is not supported");
  1331. //params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  1332. }
  1333. if (rotary_cos_.has_value()) {
  1334. TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  1335. auto rotary_cos = rotary_cos_.value();
  1336. CHECK_DEVICE(rotary_cos);
  1337. params.rotary_dim = rotary_cos.size(1) * 2;
  1338. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  1339. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  1340. const int seqlen_ro = rotary_cos.size(0);
  1341. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  1342. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  1343. CHECK_CONTIGUOUS(rotary_cos);
  1344. TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1345. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  1346. auto rotary_sin = rotary_sin_.value();
  1347. CHECK_DEVICE(rotary_sin);
  1348. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  1349. CHECK_CONTIGUOUS(rotary_sin);
  1350. TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1351. params.rotary_cos_ptr = rotary_cos.data_ptr();
  1352. params.rotary_sin_ptr = rotary_sin.data_ptr();
  1353. params.is_rotary_interleaved = is_rotary_interleaved;
  1354. } else {
  1355. params.rotary_dim = 0;
  1356. }
  1357. if (cache_batch_idx_.has_value()) {
  1358. auto cache_batch_idx = cache_batch_idx_.value();
  1359. CHECK_DEVICE(cache_batch_idx);
  1360. CHECK_CONTIGUOUS(cache_batch_idx);
  1361. TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
  1362. params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
  1363. }
  1364. // Keep references to these tensors to extend their lifetime
  1365. at::Tensor softmax_lse_accum, out_accum;
  1366. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  1367. params, batch_size, num_heads, num_heads_k, head_size, max_seqlen_k_hint, seqlen_q,
  1368. head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, use_gqa_packing, is_causal, opts);
  1369. auto tile_count_semaphore = is_causal || params.is_local || params.num_splits != 1
  1370. ? torch::zeros({1}, opts.dtype(torch::kInt32))
  1371. : torch::empty({1}, opts.dtype(torch::kInt32));
  1372. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  1373. if (paged_KV) {
  1374. params.block_table = block_table.data_ptr<int>();
  1375. params.block_table_batch_stride = block_table.stride(0);
  1376. }
  1377. params.page_block_size = page_block_size;
  1378. TORCH_CHECK(!alibi_slopes_.has_value(), "Alibi Slopes are not supported yet");
  1379. //set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1380. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1381. // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
  1382. // or paged KV cache
  1383. //run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
  1384. run_mha_fwd(params, stream);
  1385. if (head_size_og % 8 != 0) {
  1386. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1387. if (out_.has_value()) { out_.value().copy_(out); }
  1388. if (k_.has_value()) {
  1389. // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
  1390. // but we don't expect to get this case in practice. This is just so that the code works for that case.
  1391. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1392. vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1393. }
  1394. }
  1395. if (seqlenq_ngroups_swapped) {
  1396. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  1397. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  1398. }
  1399. return {out, softmax_lse};
  1400. }
  1401. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1402. m.doc() = "FlashAttention";
  1403. m.def("fwd", &mha_fwd, "Forward pass");
  1404. m.def("bwd", &mha_bwd, "Backward pass");
  1405. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  1406. m.def("varlen_bwd", &mha_varlen_bwd, "Varlen backward pass");
  1407. m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
  1408. }