flash_api.cpp 86 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <torch/version.h> // For TORCH_VERSION* macros
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <c10/cuda/CUDAGuard.h>
  10. #include <cutlass/numeric_types.h>
  11. #include "flash.h"
  12. #include "static_switch.h"
  13. #include "tile_size.h"
  14. #include "heuristics.h"
  15. // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
  16. // This is so that we can pass in torch.dtype as a parameter to the function.
  17. #if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4)
  18. #include <pybind11/pybind11.h>
  19. #include <pybind11/stl.h>
  20. namespace pybind11::detail {
  21. template <>
  22. struct type_caster<at::ScalarType> {
  23. public:
  24. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  25. PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
  26. // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
  27. // cannot be default-initialized, we provide this constructor to explicitly
  28. // initialize that field. The value doesn't matter as it will be overwritten
  29. // after a successful call to load.
  30. type_caster() : value(at::kFloat) {}
  31. bool load(handle src, bool) {
  32. PyObject* obj = src.ptr();
  33. if (THPDtype_Check(obj)) {
  34. value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
  35. return true;
  36. }
  37. return false;
  38. }
  39. static handle cast(
  40. const at::ScalarType& src,
  41. return_value_policy /* policy */,
  42. handle /* parent */) {
  43. return Py_NewRef(torch::getTHPDtype(src));
  44. }
  45. };
  46. } // namespace pybind11::detail
  47. #endif
  48. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  49. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  50. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  51. void set_params_fprop(Flash_fwd_params &params,
  52. // sizes
  53. const size_t b,
  54. const size_t seqlen_q,
  55. const size_t seqlen_k,
  56. const size_t seqlen_q_rounded,
  57. const size_t seqlen_k_rounded,
  58. const size_t h,
  59. const size_t h_k,
  60. const size_t d,
  61. const size_t d_rounded,
  62. // device pointers
  63. const at::Tensor q,
  64. const at::Tensor k,
  65. const at::Tensor v,
  66. at::Tensor out,
  67. void *cu_seqlens_q_d,
  68. void *cu_seqlens_k_d,
  69. void *seqused_q,
  70. void *seqused_k,
  71. void *softmax_lse_d,
  72. float p_dropout,
  73. float softmax_scale,
  74. int window_size_left,
  75. int window_size_right,
  76. const float softcap=0.f) {
  77. // Reset the parameters
  78. params = {};
  79. params.is_bf16 = q.dtype() == torch::kBFloat16;
  80. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  81. // Set the pointers and strides.
  82. params.q_ptr = q.data_ptr();
  83. params.k_ptr = k.data_ptr();
  84. params.v_ptr = v.data_ptr();
  85. // All stride are in elements, not bytes.
  86. params.q_row_stride = q.stride(-3);
  87. params.k_row_stride = k.stride(-3);
  88. params.v_row_stride = v.stride(-3);
  89. params.q_head_stride = q.stride(-2);
  90. params.k_head_stride = k.stride(-2);
  91. params.v_head_stride = v.stride(-2);
  92. params.v_dim_stride = v.stride(-1);
  93. params.o_ptr = out.data_ptr();
  94. params.o_row_stride = out.stride(-3);
  95. params.o_head_stride = out.stride(-2);
  96. if (cu_seqlens_q_d == nullptr) {
  97. params.q_batch_stride = q.stride(0);
  98. params.o_batch_stride = out.stride(0);
  99. }
  100. if (cu_seqlens_k_d == nullptr) {
  101. params.k_batch_stride = k.stride(0);
  102. params.v_batch_stride = v.stride(0);
  103. }
  104. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  105. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  106. params.seqused_q = static_cast<int *>(seqused_q);
  107. params.seqused_k = static_cast<int *>(seqused_k);
  108. // Softmax sum
  109. params.softmax_lse_ptr = softmax_lse_d;
  110. // Set the dimensions.
  111. params.b = b;
  112. params.h = h;
  113. params.h_k = h_k;
  114. params.seqlen_q = seqlen_q;
  115. params.seqlen_k = seqlen_k;
  116. params.seqlen_q_rounded = seqlen_q_rounded;
  117. params.seqlen_k_rounded = seqlen_k_rounded;
  118. params.d = d;
  119. params.d_rounded = d_rounded;
  120. // Set the different scale values.
  121. params.scale_softmax = softmax_scale;
  122. params.softcap = softcap;
  123. // Set this to probability of keeping an element to simplify things.
  124. params.p_dropout = 1.f - p_dropout;
  125. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  126. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  127. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  128. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  129. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  130. params.rp_dropout = 1.f / params.p_dropout;
  131. TORCH_CHECK(p_dropout < 1.f);
  132. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  133. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  134. #endif
  135. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  136. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  137. params.is_causal = window_size_left < 0 && window_size_right == 0;
  138. params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
  139. // TODO: check this
  140. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; }
  141. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; }
  142. params.window_size_left = window_size_left;
  143. params.window_size_right = window_size_right;
  144. #ifdef FLASHATTENTION_DISABLE_LOCAL
  145. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  146. "This flash attention build does not support local attention.");
  147. #endif
  148. }
  149. void set_params_dgrad(Flash_bwd_params &params,
  150. // sizes
  151. const size_t b,
  152. const size_t seqlen_q,
  153. const size_t seqlen_k,
  154. const size_t seqlen_q_rounded,
  155. const size_t seqlen_k_rounded,
  156. const size_t h,
  157. const size_t h_k,
  158. const size_t d,
  159. const size_t d_rounded,
  160. // device pointers
  161. const at::Tensor q,
  162. const at::Tensor k,
  163. const at::Tensor v,
  164. const at::Tensor out,
  165. const at::Tensor dout,
  166. at::Tensor dq,
  167. at::Tensor dk,
  168. at::Tensor dv,
  169. void *cu_seqlens_q_d,
  170. void *cu_seqlens_k_d,
  171. void *seqused_q,
  172. void *seqused_k,
  173. void *dq_accum_d,
  174. void *dk_accum_d,
  175. void *dv_accum_d,
  176. void *softmax_lse_d,
  177. void *dsoftmax_sum_d,
  178. float p_dropout,
  179. float softmax_scale,
  180. int window_size_left,
  181. int window_size_right,
  182. const float softcap=0.f,
  183. bool deterministic=false) {
  184. set_params_fprop(params,
  185. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  186. q, k, v, out,
  187. cu_seqlens_q_d,
  188. cu_seqlens_k_d,
  189. seqused_q,
  190. seqused_k,
  191. softmax_lse_d,
  192. p_dropout,
  193. softmax_scale,
  194. window_size_left,
  195. window_size_right,
  196. softcap);
  197. // Set the pointers and strides.
  198. params.do_ptr = dout.data_ptr();
  199. params.do_row_stride = dout.stride(-3);
  200. params.do_head_stride = dout.stride(-2);
  201. params.dq_ptr = dq.data_ptr();
  202. params.dk_ptr = dk.data_ptr();
  203. params.dv_ptr = dv.data_ptr();
  204. params.dq_row_stride = dq.stride(-3);
  205. params.dk_row_stride = dk.stride(-3);
  206. params.dv_row_stride = dv.stride(-3);
  207. params.dq_head_stride = dq.stride(-2);
  208. params.dk_head_stride = dk.stride(-2);
  209. params.dv_head_stride = dv.stride(-2);
  210. if (cu_seqlens_q_d == nullptr) {
  211. params.do_batch_stride = dout.stride(0);
  212. params.dq_batch_stride = dq.stride(0);
  213. params.dk_batch_stride = dk.stride(0);
  214. params.dv_batch_stride = dv.stride(0);
  215. }
  216. params.dq_accum_ptr = dq_accum_d;
  217. params.dk_accum_ptr = dk_accum_d;
  218. params.dv_accum_ptr = dv_accum_d;
  219. // Softmax sum
  220. params.dsoftmax_sum = dsoftmax_sum_d;
  221. params.deterministic = deterministic;
  222. }
  223. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  224. // HEADDIM_SWITCH(params.d, [&] {
  225. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  226. // });
  227. TORCH_CHECK(params.num_splits >= 1);
  228. SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
  229. PAGEDKV_SWITCH(params.page_table, PagedKV, [&] {
  230. if (!params.is_e4m3) {
  231. if (params.is_bf16) {
  232. if (params.d <= 64) {
  233. run_mha_fwd_<cutlass::bfloat16_t, 64, Split, PagedKV>(params, stream);
  234. } else if (params.d <= 96) {
  235. run_mha_fwd_<cutlass::bfloat16_t, 96, Split, PagedKV>(params, stream);
  236. } else if (params.d <= 128) {
  237. run_mha_fwd_<cutlass::bfloat16_t, 128, Split, PagedKV>(params, stream);
  238. } else if (params.d <= 192) {
  239. run_mha_fwd_<cutlass::bfloat16_t, 192, Split, PagedKV>(params, stream);
  240. } else {
  241. run_mha_fwd_<cutlass::bfloat16_t, 256, Split, PagedKV>(params, stream);
  242. }
  243. } else {
  244. #ifndef FLASHATTENTION_DISABLE_FP16
  245. if (params.d <= 64) {
  246. run_mha_fwd_<cutlass::half_t, 64, Split, PagedKV>(params, stream);
  247. } else if (params.d <= 96) {
  248. run_mha_fwd_<cutlass::half_t, 96, Split, PagedKV>(params, stream);
  249. } else if (params.d <= 128) {
  250. run_mha_fwd_<cutlass::half_t, 128, Split, PagedKV>(params, stream);
  251. } else if (params.d <= 192) {
  252. run_mha_fwd_<cutlass::half_t, 192, Split, PagedKV>(params, stream);
  253. } else {
  254. run_mha_fwd_<cutlass::half_t, 256, Split, PagedKV>(params, stream);
  255. }
  256. #else
  257. TORCH_CHECK(false, "This flash attention build does not support FP16.");
  258. #endif
  259. }
  260. } else {
  261. #ifndef FLASHATTENTION_DISABLE_FP8
  262. if (params.d <= 64) {
  263. run_mha_fwd_<cutlass::float_e4m3_t, 64, Split, PagedKV>(params, stream);
  264. } else if (params.d <= 96) {
  265. run_mha_fwd_<cutlass::float_e4m3_t, 96, Split, PagedKV>(params, stream);
  266. } else if (params.d <= 128) {
  267. run_mha_fwd_<cutlass::float_e4m3_t, 128, Split, PagedKV>(params, stream);
  268. } else if (params.d <= 192) {
  269. run_mha_fwd_<cutlass::float_e4m3_t, 192, Split, PagedKV>(params, stream);
  270. } else {
  271. run_mha_fwd_<cutlass::float_e4m3_t, 256, Split, PagedKV>(params, stream);
  272. }
  273. #else
  274. TORCH_CHECK(false, "This flash attention build does not support FP8.");
  275. #endif
  276. }
  277. });
  278. });
  279. }
  280. void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream) {
  281. #ifndef FLASHATTENTION_DISABLE_SPLIT
  282. // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
  283. // so that kBlockM is smaller and we have more parallelism.
  284. if (params.is_fp32) {
  285. if (params.d <= 64) {
  286. run_mha_fwd_combine_<float, float, 64>(params, stream);
  287. } else if (params.d <= 128) {
  288. run_mha_fwd_combine_<float, float, 128>(params, stream);
  289. } else {
  290. run_mha_fwd_combine_<float, float, 256>(params, stream);
  291. }
  292. } else if (params.is_bf16) {
  293. if (params.d <= 64) {
  294. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream);
  295. } else if (params.d <= 128) {
  296. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream);
  297. } else {
  298. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 256>(params, stream);
  299. }
  300. } else {
  301. if (params.d <= 64) {
  302. run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream);
  303. } else if (params.d <= 128) {
  304. run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream);
  305. } else {
  306. run_mha_fwd_combine_<cutlass::half_t, float, 256>(params, stream);
  307. }
  308. }
  309. #else
  310. TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
  311. #endif
  312. }
  313. inline bool get_pack_gqa(Flash_fwd_params const& params) {
  314. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  315. return false;
  316. #else
  317. // params.page_table must already be set
  318. if (params.h == params.h_k) { return false; }
  319. // This needs to match the kernel configs
  320. auto [kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap] = tile_size_fwd(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f);
  321. return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
  322. #endif
  323. }
  324. inline int get_num_splits(Flash_fwd_params const& params) {
  325. #ifdef FLASHATTENTION_DISABLE_SPLIT
  326. return 1;
  327. #else
  328. // params.pack_gqa must already be set
  329. // params.page_table must already be set
  330. auto dprops = at::cuda::getCurrentDeviceProperties();
  331. // This needs to match the kernel configs
  332. auto [kBlockM, kBlockN, Mma1_is_RS, IntraWGOverlap] = tile_size_fwd(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f);
  333. int seqlen_q_packgqa = params.seqlen_q * (params.pack_gqa ? params.h / params.h_k : 1);
  334. const int num_n_blocks = (params.seqlen_k + kBlockN - 1) / kBlockN;
  335. const int num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
  336. return num_splits_heuristic(params.b * params.h_k * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
  337. #endif
  338. }
  339. std::vector<at::Tensor>
  340. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  341. // batch_size x seqlen_k x num_heads_k x head_size or num_pages x page_size x num_heads_k x head_size if there's a page_table.
  342. const at::Tensor &k,
  343. const at::Tensor &v,
  344. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  345. const float softmax_scale,
  346. bool is_causal,
  347. c10::optional<at::Tensor> &q_descale_, // batch_size x *num_heads_k* (not num_heads_q)
  348. c10::optional<at::Tensor> &k_descale_, // batch_size x num_heads_k
  349. c10::optional<at::Tensor> &v_descale_, // batch_size x num_heads_k
  350. int window_size_left,
  351. int window_size_right,
  352. int sink_token_length,
  353. const float softcap,
  354. int num_splits,
  355. c10::optional<bool> pack_gqa_
  356. ) {
  357. auto dprops = at::cuda::getCurrentDeviceProperties();
  358. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  359. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  360. auto q_type = q.scalar_type();
  361. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
  362. "FlashAttention only support fp16, bf16, and fp8_e4m3 data type");
  363. TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
  364. TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
  365. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  366. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  367. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  368. TORCH_CHECK(v.stride(-1) == 1 || v.stride(-3) == 1, "Input tensor V must have contiguous last dimension or contiguous seqlen dimension");
  369. if (v.stride(-1) != 1) {
  370. TORCH_CHECK(q_type == at::ScalarType::Float8_e4m3fn, "Only fp8_e4m3 data type supports input tensor V having contiguous seqlen dimension")
  371. #ifndef FLASHATTENTION_ENABLE_VCOLMAJOR
  372. TORCH_CHECK(false, "This flash attention build does not support V having contiguous seqlen dimension.");
  373. #endif
  374. }
  375. const auto sizes = q.sizes();
  376. const int batch_size = sizes[0];
  377. int seqlen_q = sizes[1];
  378. int num_heads = sizes[2];
  379. const int head_size_og = sizes[3];
  380. const int seqlen_k = k.size(1);
  381. const int num_heads_k = k.size(2);
  382. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  383. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  384. // TODO: check this
  385. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  386. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  387. if (is_causal) {
  388. window_size_left = -1;
  389. window_size_right = 0;
  390. }
  391. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  392. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  393. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  394. int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
  395. at::Tensor q_padded, k_padded, v_padded;
  396. auto pad = [](at::Tensor x, int alignment) {
  397. return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
  398. };
  399. q_padded = pad(q, alignment);
  400. k_padded = pad(k, alignment);
  401. v_padded = pad(v, alignment);
  402. if (v_padded.stride(-1) != 1) {
  403. TORCH_CHECK(v_padded.stride(-1) % 16 == 0 && v_padded.stride(-2) % 16 == 0 && v_padded.stride(-4) % 16 == 0,
  404. "If input tensor V has contiguous seqlen dimension, the others dimension must have stride divisible by 16");
  405. }
  406. auto opts = q.options();
  407. auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
  408. at::Tensor out;
  409. if (out_.has_value()) {
  410. out = out_.value();
  411. TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
  412. CHECK_DEVICE(out);
  413. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  414. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  415. if (head_size_og % alignment != 0) { out = torch::empty_like(q_padded, opts.dtype(out_type)); }
  416. } else {
  417. out = torch::empty_like(q_padded, opts.dtype(out_type));
  418. }
  419. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  420. const int head_size = round_multiple(head_size_og, alignment);
  421. const int head_size_rounded = head_size <= 64 ? 64 : (head_size <= 128 ? round_multiple(head_size, 32) : round_multiple(head_size, 64));
  422. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  423. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  424. // Otherwise the kernel will be launched from cuda:0 device
  425. // Cast to char to avoid compiler warning about narrowing
  426. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  427. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  428. Flash_fwd_params params;
  429. set_params_fprop(params,
  430. batch_size,
  431. seqlen_q, seqlen_k,
  432. seqlen_q_rounded, seqlen_k_rounded,
  433. num_heads, num_heads_k,
  434. head_size, head_size_rounded,
  435. q_padded, k_padded, v_padded, out,
  436. /*cu_seqlens_q_d=*/nullptr,
  437. /*cu_seqlens_k_d=*/nullptr,
  438. /*seqused_q_=*/nullptr,
  439. /*seqused_k=*/nullptr,
  440. softmax_lse.data_ptr(),
  441. /*p_dropout=*/0.f,
  442. softmax_scale,
  443. window_size_left,
  444. window_size_right,
  445. softcap);
  446. params.sink_token_length = sink_token_length;
  447. params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
  448. params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
  449. at::Tensor out_accum, softmax_lse_accum;
  450. auto outaccum_type = at::ScalarType::Float;
  451. if (params.num_splits > 1) {
  452. TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
  453. out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type));
  454. softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  455. params.is_fp32 = false;
  456. params.oaccum_ptr = out_accum.data_ptr();
  457. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  458. params.oaccum_split_stride = out_accum.stride(0);
  459. params.oaccum_row_stride = out_accum.stride(3);
  460. params.oaccum_head_stride = out_accum.stride(2);
  461. params.oaccum_batch_stride = out_accum.stride(1);
  462. params.lseaccum_split_stride = softmax_lse_accum.stride(0);
  463. params.lseaccum_head_stride = softmax_lse_accum.stride(2);
  464. params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
  465. }
  466. auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  467. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  468. if (q_type == at::ScalarType::Float8_e4m3fn) {
  469. if (q_descale_.has_value()) {
  470. auto q_descale = q_descale_.value();
  471. CHECK_DEVICE(q_descale);
  472. CHECK_SHAPE(q_descale, batch_size, num_heads_k);
  473. params.q_descale_ptr = q_descale.data_ptr<float>();
  474. params.q_descale_batch_stride = q_descale.stride(0);
  475. params.q_descale_head_stride = q_descale.stride(1);
  476. } else {
  477. params.q_descale_ptr = nullptr;
  478. }
  479. if (k_descale_.has_value()) {
  480. auto k_descale = k_descale_.value();
  481. CHECK_DEVICE(k_descale);
  482. CHECK_SHAPE(k_descale, batch_size, num_heads_k);
  483. params.k_descale_ptr = k_descale.data_ptr<float>();
  484. params.k_descale_batch_stride = k_descale.stride(0);
  485. params.k_descale_head_stride = k_descale.stride(1);
  486. } else {
  487. params.k_descale_ptr = nullptr;
  488. }
  489. if (v_descale_.has_value()) {
  490. auto v_descale = v_descale_.value();
  491. CHECK_DEVICE(v_descale);
  492. CHECK_SHAPE(v_descale, batch_size, num_heads_k);
  493. params.v_descale_ptr = v_descale.data_ptr<float>();
  494. params.v_descale_batch_stride = v_descale.stride(0);
  495. params.v_descale_head_stride = v_descale.stride(1);
  496. } else {
  497. params.v_descale_ptr = nullptr;
  498. }
  499. }
  500. #ifdef FLASHATTENTION_DISABLE_LOCAL
  501. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  502. #endif
  503. #ifdef FLASHATTENTION_DISABLE_SPLIT
  504. TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
  505. #endif
  506. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  507. TORCH_CHECK(!params.pack_gqa, "This flash attention build does not support pack_gqa.");
  508. #endif
  509. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  510. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  511. #endif
  512. if (seqlen_k > 0 && batch_size > 0) {
  513. auto stream = at::cuda::getCurrentCUDAStream().stream();
  514. run_mha_fwd(params, stream);
  515. if (params.num_splits > 1) {
  516. if (out_type == at::ScalarType::BFloat16) {
  517. // Since we want output in BF16. Otherwise fwd_combine will output to FP16
  518. params.is_bf16 = true;
  519. }
  520. run_mha_fwd_combine(params, stream);
  521. }
  522. } else if (batch_size > 0) {
  523. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  524. out.zero_();
  525. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  526. }
  527. at::Tensor out_padded = out;
  528. if (head_size_og % alignment != 0) {
  529. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  530. if (out_.has_value()) { out_.value().copy_(out); }
  531. }
  532. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
  533. // return {out, q_padded, k_padded, v_padded, out_accum, softmax_lse_accum};
  534. }
  535. std::vector<at::Tensor>
  536. mha_varlen_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  537. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  538. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  539. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  540. const at::Tensor &cu_seqlens_q, // b+1
  541. const at::Tensor &cu_seqlens_k, // b+1
  542. c10::optional<at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  543. c10::optional<at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  544. int const max_seqlen_q,
  545. int const max_seqlen_k,
  546. const float softmax_scale,
  547. bool is_causal,
  548. c10::optional<at::Tensor> &q_descale_, // batch_size x *num_heads_k* (not num_heads_q)
  549. c10::optional<at::Tensor> &k_descale_, // batch_size x num_heads_k
  550. c10::optional<at::Tensor> &v_descale_, // batch_size x num_heads_k
  551. int window_size_left,
  552. int window_size_right,
  553. const float softcap,
  554. int num_splits,
  555. c10::optional<bool> pack_gqa_
  556. ) {
  557. #ifdef FLASHATTENTION_DISABLE_VARLEN
  558. TORCH_CHECK(false, "This flash attention build does not support varlen.");
  559. #endif
  560. auto dprops = at::cuda::getCurrentDeviceProperties();
  561. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  562. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  563. auto q_type = q.scalar_type();
  564. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
  565. "FlashAttention only support fp16, bf16, and fp8_e4m3 data type");
  566. TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
  567. TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
  568. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  569. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  570. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  571. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  572. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  573. TORCH_CHECK(cu_seqlens_q.stride(-1) == 1, "cu_seqlens_q must have contiguous last dimension");
  574. TORCH_CHECK(cu_seqlens_k.stride(-1) == 1, "cu_seqlens_q must have contiguous last dimension");
  575. const auto sizes = q.sizes();
  576. const int batch_size = cu_seqlens_q.numel() - 1;
  577. int num_heads = sizes[1];
  578. const int head_size_og = sizes[2];
  579. const int num_heads_k = k.size(1);
  580. const int total_q = q.sizes()[0];
  581. const int total_k = k.sizes()[0];
  582. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  583. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  584. if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
  585. if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
  586. if (is_causal) {
  587. window_size_left = -1;
  588. window_size_right = 0;
  589. }
  590. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  591. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  592. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  593. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  594. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  595. if (seqused_q_.has_value()){
  596. auto seqused_q = seqused_q_.value();
  597. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  598. TORCH_CHECK(seqused_q.is_cuda(), "seqused_q must be on CUDA device");
  599. TORCH_CHECK(seqused_q.is_contiguous(), "seqused_q must be contiguous");
  600. CHECK_SHAPE(seqused_q, batch_size);
  601. }
  602. if (seqused_k_.has_value()){
  603. auto seqused_k = seqused_k_.value();
  604. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  605. TORCH_CHECK(seqused_k.is_cuda(), "seqused_k must be on CUDA device");
  606. TORCH_CHECK(seqused_k.is_contiguous(), "seqused_k must be contiguous");
  607. CHECK_SHAPE(seqused_k, batch_size);
  608. }
  609. int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
  610. at::Tensor q_padded, k_padded, v_padded;
  611. auto pad = [](at::Tensor x, int alignment) {
  612. return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
  613. };
  614. q_padded = pad(q, alignment);
  615. k_padded = pad(k, alignment);
  616. v_padded = pad(v, alignment);
  617. auto opts = q.options();
  618. auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
  619. at::Tensor out;
  620. if (out_.has_value()) {
  621. out = out_.value();
  622. TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
  623. CHECK_DEVICE(out);
  624. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  625. CHECK_SHAPE(out, total_q, num_heads, head_size_og);
  626. if (head_size_og % alignment != 0) { out = torch::empty_like(q_padded, opts.dtype(out_type)); }
  627. } else {
  628. out = torch::empty_like(q_padded, opts.dtype(out_type));
  629. }
  630. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  631. const int head_size = round_multiple(head_size_og, alignment);
  632. const int head_size_rounded = head_size <= 64 ? 64 : (head_size <= 128 ? round_multiple(head_size, 32) : round_multiple(head_size, 64));
  633. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  634. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  635. // Otherwise the kernel will be launched from cuda:0 device
  636. // Cast to char to avoid compiler warning about narrowing
  637. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  638. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  639. Flash_fwd_params params;
  640. set_params_fprop(params,
  641. batch_size,
  642. max_seqlen_q, max_seqlen_k,
  643. seqlen_q_rounded, seqlen_k_rounded,
  644. num_heads, num_heads_k,
  645. head_size, head_size_rounded,
  646. q_padded, k_padded, v_padded, out,
  647. cu_seqlens_q.data_ptr(),
  648. cu_seqlens_k.data_ptr(),
  649. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  650. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  651. softmax_lse.data_ptr(),
  652. /*p_dropout=*/0.f,
  653. softmax_scale,
  654. window_size_left,
  655. window_size_right,
  656. softcap);
  657. params.total_q = total_q;
  658. params.total_k = total_k;
  659. params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
  660. params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
  661. at::Tensor out_accum, softmax_lse_accum;
  662. auto outaccum_type = at::ScalarType::Float;
  663. if (params.num_splits > 1) {
  664. TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
  665. out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type));
  666. softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
  667. params.is_fp32 = false;
  668. params.oaccum_ptr = out_accum.data_ptr();
  669. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  670. params.oaccum_split_stride = out_accum.stride(0);
  671. params.oaccum_row_stride = out_accum.stride(2);
  672. params.oaccum_head_stride = out_accum.stride(1);
  673. params.oaccum_batch_stride = 0;
  674. params.lseaccum_split_stride = softmax_lse_accum.stride(0);
  675. params.lseaccum_head_stride = softmax_lse_accum.stride(1);
  676. params.lseaccum_batch_stride = 0;
  677. }
  678. auto tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32));
  679. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  680. if (q_type == at::ScalarType::Float8_e4m3fn) {
  681. if (q_descale_.has_value()) {
  682. auto q_descale = q_descale_.value();
  683. CHECK_DEVICE(q_descale);
  684. CHECK_SHAPE(q_descale, batch_size, num_heads_k);
  685. params.q_descale_ptr = q_descale.data_ptr<float>();
  686. params.q_descale_batch_stride = q_descale.stride(0);
  687. params.q_descale_head_stride = q_descale.stride(1);
  688. } else {
  689. params.q_descale_ptr = nullptr;
  690. }
  691. if (k_descale_.has_value()) {
  692. auto k_descale = k_descale_.value();
  693. CHECK_DEVICE(k_descale);
  694. CHECK_SHAPE(k_descale, batch_size, num_heads_k);
  695. params.k_descale_ptr = k_descale.data_ptr<float>();
  696. params.k_descale_batch_stride = k_descale.stride(0);
  697. params.k_descale_head_stride = k_descale.stride(1);
  698. } else {
  699. params.k_descale_ptr = nullptr;
  700. }
  701. if (v_descale_.has_value()) {
  702. auto v_descale = v_descale_.value();
  703. CHECK_DEVICE(v_descale);
  704. CHECK_SHAPE(v_descale, batch_size, num_heads_k);
  705. params.v_descale_ptr = v_descale.data_ptr<float>();
  706. params.v_descale_batch_stride = v_descale.stride(0);
  707. params.v_descale_head_stride = v_descale.stride(1);
  708. } else {
  709. params.v_descale_ptr = nullptr;
  710. }
  711. }
  712. #ifdef FLASHATTENTION_DISABLE_LOCAL
  713. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  714. #endif
  715. #ifdef FLASHATTENTION_DISABLE_SPLIT
  716. TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
  717. #endif
  718. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  719. TORCH_CHECK(!params.pack_gqa, "This flash attention build does not support pack_gqa.");
  720. #endif
  721. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  722. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  723. #endif
  724. if (max_seqlen_k > 0 && batch_size > 0) {
  725. auto stream = at::cuda::getCurrentCUDAStream().stream();
  726. run_mha_fwd(params, stream);
  727. if (params.num_splits > 1) {
  728. if (out_type == at::ScalarType::BFloat16) {
  729. // Since we want output in BF16. Otherwise fwd_combine will output to FP16
  730. params.is_bf16 = true;
  731. }
  732. // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
  733. // and seqlen = total_q, and don't need to dispatch to Varlen there.
  734. if (!seqused_q_.has_value()) {
  735. params.b = 1;
  736. params.seqlen_q = total_q;
  737. }
  738. run_mha_fwd_combine(params, stream);
  739. }
  740. } else if (batch_size > 0) {
  741. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  742. out.zero_();
  743. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  744. }
  745. at::Tensor out_padded = out;
  746. if (head_size_og % 8 != 0) {
  747. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  748. if (out_.has_value()) { out_.value().copy_(out); }
  749. }
  750. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
  751. // return {out, q_padded, k_padded, v_padded, out_accum, softmax_lse_accum};
  752. }
  753. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  754. #ifndef FLASHATTENTION_DISABLE_BACKWARD
  755. // FP16_SWITCH(!params.is_bf16, [&] {
  756. // HEADDIM_SWITCH(params.d, [&] {
  757. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  758. // });
  759. // });
  760. if (!params.is_bf16) {
  761. #ifndef FLASHATTENTION_DISABLE_FP16
  762. if (params.d <= 64) {
  763. run_mha_bwd_<cutlass::half_t, 64>(params, stream);
  764. } else if (params.d <= 96) {
  765. run_mha_bwd_<cutlass::half_t, 96>(params, stream);
  766. } else if (params.d <= 128) {
  767. run_mha_bwd_<cutlass::half_t, 128>(params, stream);
  768. } else if (params.d <= 192) {
  769. run_mha_bwd_<cutlass::half_t, 192>(params, stream);
  770. } else {
  771. run_mha_bwd_<cutlass::half_t, 256>(params, stream);
  772. }
  773. #else
  774. TORCH_CHECK(false, "This flash attention build does not support FP16.");
  775. #endif
  776. } else {
  777. if (params.d <= 64) {
  778. run_mha_bwd_<cutlass::bfloat16_t, 64>(params, stream);
  779. } else if (params.d <= 96) {
  780. run_mha_bwd_<cutlass::bfloat16_t, 96>(params, stream);
  781. } else if (params.d <= 128) {
  782. run_mha_bwd_<cutlass::bfloat16_t, 128>(params, stream);
  783. } else if (params.d <= 192) {
  784. run_mha_bwd_<cutlass::bfloat16_t, 192>(params, stream);
  785. } else {
  786. run_mha_bwd_<cutlass::bfloat16_t, 256>(params, stream);
  787. }
  788. }
  789. #endif
  790. }
  791. std::vector<at::Tensor>
  792. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  793. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  794. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  795. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  796. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  797. const at::Tensor &softmax_lse, // b x h x seqlen_q
  798. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  799. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  800. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  801. const float softmax_scale,
  802. bool is_causal,
  803. int window_size_left,
  804. int window_size_right,
  805. int sink_token_length,
  806. const float softcap,
  807. const bool deterministic) {
  808. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  809. TORCH_CHECK(false, "This flash attention build does not support backward.");
  810. #endif
  811. auto dprops = at::cuda::getCurrentDeviceProperties();
  812. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  813. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  814. auto stream = at::cuda::getCurrentCUDAStream().stream();
  815. auto q_type = q.dtype();
  816. TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
  817. "FlashAttention only support fp16 and bf16 data type");
  818. TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
  819. TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
  820. TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
  821. TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
  822. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  823. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  824. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  825. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  826. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  827. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  828. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  829. const auto sizes = q.sizes();
  830. const int batch_size = sizes[0];
  831. const int seqlen_q = sizes[1];
  832. const int num_heads = sizes[2];
  833. const int head_size_og = dout.size(3);
  834. const int head_size = sizes[3];
  835. const int seqlen_k = k.size(1);
  836. const int num_heads_k = k.size(2);
  837. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  838. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  839. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  840. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  841. // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
  842. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  843. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  844. if (is_causal) {
  845. window_size_left = -1;
  846. window_size_right = 0;
  847. }
  848. // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
  849. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
  850. is_causal = window_size_left < 0 && window_size_right == 0;
  851. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  852. const int head_size_rounded = head_size <= 64 ? 64 : (head_size <= 128 ? round_multiple(head_size, 32) : round_multiple(head_size, 64));
  853. // Very important that these match the kernel configs
  854. bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
  855. const int kBlockM = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
  856. : (head_size_rounded <= 96 ? 64
  857. : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
  858. : 64));
  859. const int kBlockN = head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 96 : 80);
  860. const int seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
  861. const int seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
  862. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  863. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  864. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  865. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  866. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  867. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  868. at::Tensor dq, dk, dv;
  869. if (dq_.has_value()) {
  870. dq = dq_.value();
  871. TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
  872. CHECK_DEVICE(dq);
  873. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  874. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  875. } else {
  876. dq = torch::empty_like(q);
  877. }
  878. if (dk_.has_value()) {
  879. dk = dk_.value();
  880. TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
  881. CHECK_DEVICE(dk);
  882. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  883. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  884. } else {
  885. dk = torch::empty_like(k);
  886. }
  887. if (dv_.has_value()) {
  888. dv = dv_.value();
  889. TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
  890. CHECK_DEVICE(dv);
  891. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  892. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  893. } else {
  894. dv = torch::empty_like(v);
  895. }
  896. at::Tensor dout_padded;
  897. if (head_size_og % 8 != 0) {
  898. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  899. } else {
  900. dout_padded = dout;
  901. }
  902. // Otherwise the kernel will be launched from cuda:0 device
  903. // Cast to char to avoid compiler warning about narrowing
  904. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  905. auto opts = q.options();
  906. // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  907. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  908. auto softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  909. at::Tensor dq_accum;
  910. at::Tensor dk_accum, dv_accum;
  911. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  912. if (num_heads_k != num_heads) { // MQA / GQA
  913. dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  914. dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  915. }
  916. Flash_bwd_params params;
  917. set_params_dgrad(params,
  918. batch_size,
  919. seqlen_q, seqlen_k,
  920. seqlen_q_rounded, seqlen_k_rounded,
  921. num_heads, num_heads_k,
  922. head_size, head_size_rounded,
  923. q, k, v, out,
  924. dout_padded, dq, dk, dv,
  925. nullptr /*cu_seqlens_q*/,
  926. nullptr /*cu_seqlens_k*/,
  927. nullptr /*seqused_q_*/,
  928. nullptr /*seqused_k*/,
  929. dq_accum.data_ptr(),
  930. num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
  931. num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
  932. // nullptr,
  933. // nullptr,
  934. softmax_lse.data_ptr(),
  935. softmax_d.data_ptr(),
  936. /*p_dropout=*/0.f,
  937. softmax_scale,
  938. window_size_left,
  939. window_size_right,
  940. softcap,
  941. deterministic);
  942. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  943. params.sink_token_length = sink_token_length;
  944. // Will be zero'ed out in the backward preprocess kernel
  945. at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  946. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  947. if (num_heads_k != num_heads && params.deterministic) {
  948. at::Tensor dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  949. at::Tensor dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  950. params.dk_semaphore = dk_semaphore.data_ptr<int>();
  951. params.dv_semaphore = dv_semaphore.data_ptr<int>();
  952. }
  953. #ifdef FLASHATTENTION_DISABLE_LOCAL
  954. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  955. #endif
  956. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  957. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  958. #endif
  959. if (seqlen_q > 0) {
  960. run_mha_bwd(params, stream);
  961. } else {
  962. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  963. dk.zero_();
  964. dv.zero_();
  965. softmax_d.zero_();
  966. }
  967. if (head_size_og % 8 != 0) {
  968. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  969. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  970. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  971. }
  972. return { dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum};
  973. }
  974. std::vector<at::Tensor>
  975. mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  976. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  977. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  978. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  979. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  980. const at::Tensor &softmax_lse, // b x h x seqlen_q
  981. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  982. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  983. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  984. const at::Tensor &cu_seqlens_q, // b+1
  985. const at::Tensor &cu_seqlens_k, // b+1
  986. c10::optional<at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  987. c10::optional<at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  988. const int max_seqlen_q,
  989. const int max_seqlen_k, // max sequence length to choose the kernel
  990. const float softmax_scale,
  991. bool is_causal,
  992. int window_size_left,
  993. int window_size_right,
  994. const float softcap,
  995. const bool deterministic) {
  996. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  997. TORCH_CHECK(false, "This flash attention build does not support backward.");
  998. #endif
  999. #ifdef FLASHATTENTION_DISABLE_VARLEN
  1000. TORCH_CHECK(false, "This flash attention build does not support varlen.");
  1001. #endif
  1002. auto dprops = at::cuda::getCurrentDeviceProperties();
  1003. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  1004. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  1005. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1006. auto q_type = q.dtype();
  1007. TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
  1008. "FlashAttention only support fp16 and bf16 data type");
  1009. TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
  1010. TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
  1011. TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
  1012. TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
  1013. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  1014. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  1015. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  1016. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  1017. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  1018. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1019. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1020. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1021. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  1022. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  1023. CHECK_CONTIGUOUS(cu_seqlens_q);
  1024. CHECK_CONTIGUOUS(cu_seqlens_k);
  1025. const auto sizes = q.sizes();
  1026. const int total_q = sizes[0];
  1027. const int batch_size = cu_seqlens_q.numel() - 1;
  1028. const int num_heads = sizes[1];
  1029. const int head_size_og = dout.size(2);
  1030. const int head_size = sizes[2];
  1031. const int total_k = k.size(0);
  1032. const int num_heads_k = k.size(1);
  1033. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  1034. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  1035. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  1036. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  1037. // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
  1038. if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
  1039. if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
  1040. if (is_causal) {
  1041. window_size_left = -1;
  1042. window_size_right = 0;
  1043. }
  1044. // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
  1045. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
  1046. is_causal = window_size_left < 0 && window_size_right == 0;
  1047. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1048. const int head_size_rounded = head_size <= 64 ? 64 : (head_size <= 128 ? round_multiple(head_size, 32) : round_multiple(head_size, 64));
  1049. // Very important that these match the kernel configs
  1050. // const int kBlockM = head_size_rounded <= 64 ? 128 : 64;
  1051. bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
  1052. const int kBlockM = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
  1053. : (head_size_rounded <= 96 ? 64
  1054. : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
  1055. : 64));
  1056. const int kBlockN = head_size_rounded <= 128 ? 128 : (head_size_rounded <= 192 ? 96 : 80);
  1057. const int seqlen_q_rounded = round_multiple(max_seqlen_q, kBlockM);
  1058. const int seqlen_k_rounded = round_multiple(max_seqlen_k, kBlockN);
  1059. int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
  1060. int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
  1061. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  1062. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  1063. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  1064. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  1065. CHECK_SHAPE(out, total_q, num_heads, head_size);
  1066. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  1067. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  1068. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  1069. if (seqused_q_.has_value()){
  1070. auto seqused_q = seqused_q_.value();
  1071. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  1072. TORCH_CHECK(seqused_q.is_cuda(), "seqused_q must be on CUDA device");
  1073. TORCH_CHECK(seqused_q.is_contiguous(), "seqused_q must be contiguous");
  1074. CHECK_SHAPE(seqused_q, batch_size);
  1075. }
  1076. if (seqused_k_.has_value()){
  1077. auto seqused_k = seqused_k_.value();
  1078. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  1079. TORCH_CHECK(seqused_k.is_cuda(), "seqused_k must be on CUDA device");
  1080. TORCH_CHECK(seqused_k.is_contiguous(), "seqused_k must be contiguous");
  1081. CHECK_SHAPE(seqused_k, batch_size);
  1082. }
  1083. at::Tensor dq, dk, dv;
  1084. if (dq_.has_value()) {
  1085. dq = dq_.value();
  1086. TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
  1087. CHECK_DEVICE(dq);
  1088. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  1089. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  1090. } else {
  1091. dq = torch::empty_like(q);
  1092. }
  1093. if (dk_.has_value()) {
  1094. dk = dk_.value();
  1095. TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
  1096. CHECK_DEVICE(dk);
  1097. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  1098. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  1099. } else {
  1100. dk = torch::empty_like(k);
  1101. }
  1102. if (dv_.has_value()) {
  1103. dv = dv_.value();
  1104. TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
  1105. CHECK_DEVICE(dv);
  1106. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  1107. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  1108. } else {
  1109. dv = torch::empty_like(v);
  1110. }
  1111. at::Tensor dout_padded;
  1112. if (head_size_og % 8 != 0) {
  1113. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1114. } else {
  1115. dout_padded = dout;
  1116. }
  1117. // Otherwise the kernel will be launched from cuda:0 device
  1118. // Cast to char to avoid compiler warning about narrowing
  1119. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1120. auto opts = q.options();
  1121. // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  1122. auto softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1123. auto softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1124. at::Tensor dq_accum;
  1125. at::Tensor dk_accum, dv_accum;
  1126. dq_accum = torch::empty({num_heads, total_q_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1127. if (num_heads_k != num_heads) { // MQA / GQA
  1128. dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1129. dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1130. }
  1131. Flash_bwd_params params;
  1132. set_params_dgrad(params,
  1133. batch_size,
  1134. max_seqlen_q, max_seqlen_k,
  1135. seqlen_q_rounded, seqlen_k_rounded,
  1136. num_heads, num_heads_k,
  1137. head_size, head_size_rounded,
  1138. q, k, v, out,
  1139. dout_padded, dq, dk, dv,
  1140. cu_seqlens_q.data_ptr(),
  1141. cu_seqlens_k.data_ptr(),
  1142. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  1143. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  1144. dq_accum.data_ptr(),
  1145. num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
  1146. num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
  1147. // nullptr,
  1148. // nullptr,
  1149. softmax_lse.data_ptr(),
  1150. softmax_d.data_ptr(),
  1151. /*p_dropout=*/0.f,
  1152. softmax_scale,
  1153. window_size_left,
  1154. window_size_right,
  1155. softcap,
  1156. deterministic);
  1157. params.total_q = total_q;
  1158. params.total_k = total_k;
  1159. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  1160. // Will be zero'ed out in the backward preprocess kernel
  1161. at::Tensor dq_semaphore = torch::empty({(max_seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  1162. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  1163. if (num_heads_k != num_heads && params.deterministic) {
  1164. // TODO: do we need to zero them out?
  1165. at::Tensor dk_semaphore = torch::empty({(max_seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  1166. at::Tensor dv_semaphore = torch::empty({(max_seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  1167. params.dk_semaphore = dk_semaphore.data_ptr<int>();
  1168. params.dv_semaphore = dv_semaphore.data_ptr<int>();
  1169. }
  1170. #ifdef FLASHATTENTION_DISABLE_LOCAL
  1171. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  1172. #endif
  1173. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  1174. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  1175. #endif
  1176. if (max_seqlen_q > 0) {
  1177. run_mha_bwd(params, stream);
  1178. } else {
  1179. // If max_seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  1180. dk.zero_();
  1181. dv.zero_();
  1182. softmax_d.zero_();
  1183. }
  1184. if (head_size_og % 8 != 0) {
  1185. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1186. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1187. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1188. }
  1189. return { dq, dk, dv, softmax_d, dq_accum, softmax_lse_log2 };
  1190. }
  1191. std::vector<at::Tensor>
  1192. mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size or total_q x num_heads x head_size
  1193. // batch_size_k x seqlen_k x num_heads_k x head_size or num_pages x page_size x num_heads_k x head_size if there's a page_table.
  1194. const at::Tensor &kcache,
  1195. const at::Tensor &vcache,
  1196. c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
  1197. c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
  1198. // batch_size x seqlen_q x num_heads x head_size or total_q x num_heads x head_size
  1199. c10::optional<at::Tensor> &out_,
  1200. c10::optional<const at::Tensor> &seqused_k_, // batch_size
  1201. c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  1202. c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  1203. c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
  1204. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  1205. c10::optional<const at::Tensor> &page_table_, // batch_size_k x max_num_pages_per_seq
  1206. c10::optional<const at::Tensor> &cu_seqlens_q_, // b+1
  1207. c10::optional<int> max_seqlen_q_,
  1208. float const softmax_scale,
  1209. bool is_causal,
  1210. c10::optional<at::Tensor> &q_descale_, // batch_size x *num_heads_k* (not num_heads_q)
  1211. c10::optional<at::Tensor> &k_descale_, // batch_size x num_heads_k
  1212. c10::optional<at::Tensor> &v_descale_, // batch_size x num_heads_k
  1213. int window_size_left,
  1214. int window_size_right,
  1215. int sink_token_length,
  1216. float const softcap,
  1217. bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  1218. int num_splits,
  1219. c10::optional<bool> pack_gqa_
  1220. ) {
  1221. auto dprops = at::cuda::getCurrentDeviceProperties();
  1222. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  1223. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  1224. auto q_type = q.scalar_type();
  1225. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
  1226. "FlashAttention only support fp16, bf16, and fp8_e4m3 data type");
  1227. TORCH_CHECK(kcache.scalar_type() == q_type, "query and key must have the same dtype");
  1228. TORCH_CHECK(vcache.scalar_type() == q_type, "query and value must have the same dtype");
  1229. CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
  1230. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1231. TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1232. TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1233. at::Tensor page_table;
  1234. const bool paged_KV = page_table_.has_value();
  1235. if (paged_KV) {
  1236. page_table = page_table_.value();
  1237. CHECK_DEVICE(page_table);
  1238. TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
  1239. TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
  1240. }
  1241. at::Tensor cu_seqlens_q;
  1242. bool const is_varlen_q = cu_seqlens_q_.has_value();
  1243. if (is_varlen_q) {
  1244. cu_seqlens_q = cu_seqlens_q_.value();
  1245. CHECK_DEVICE(cu_seqlens_q);
  1246. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
  1247. TORCH_CHECK(cu_seqlens_q.stride(-1) == 1, "cu_seqlens_q must have contiguous last dimension");
  1248. TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
  1249. }
  1250. const auto sizes = q.sizes();
  1251. const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
  1252. int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
  1253. int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
  1254. int num_heads = q.size(-2);
  1255. const int head_size_og = q.size(-1);
  1256. const int num_heads_k = kcache.size(2);
  1257. const int batch_size_k = !paged_KV ? kcache.size(0) : page_table.size(0);
  1258. if (!cache_batch_idx_.has_value()) {
  1259. TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
  1260. }
  1261. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  1262. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  1263. const int max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
  1264. const int num_pages = !paged_KV ? 0 : kcache.size(0);
  1265. const int page_size = !paged_KV ? 1 : kcache.size(1);
  1266. const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_pages_per_seq * page_size;
  1267. // TODO: check this
  1268. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  1269. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  1270. if (is_causal) {
  1271. window_size_left = -1;
  1272. window_size_right = 0;
  1273. }
  1274. if (!is_varlen_q) {
  1275. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  1276. } else {
  1277. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  1278. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  1279. }
  1280. if (!paged_KV) {
  1281. CHECK_SHAPE(kcache, batch_size_k, seqlen_k, num_heads_k, head_size_og);
  1282. CHECK_SHAPE(vcache, batch_size_k, seqlen_k, num_heads_k, head_size_og);
  1283. } else {
  1284. CHECK_SHAPE(kcache, num_pages, page_size, num_heads_k, head_size_og);
  1285. CHECK_SHAPE(vcache, num_pages, page_size, num_heads_k, head_size_og);
  1286. CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
  1287. }
  1288. if (seqused_k_.has_value()) {
  1289. auto seqused_k = seqused_k_.value();
  1290. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  1291. CHECK_DEVICE(seqused_k);
  1292. CHECK_CONTIGUOUS(seqused_k);
  1293. CHECK_SHAPE(seqused_k, batch_size);
  1294. }
  1295. int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
  1296. at::Tensor q_padded, k_padded, v_padded;
  1297. auto pad = [](at::Tensor x, int alignment) {
  1298. return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
  1299. };
  1300. q_padded = pad(q, alignment);
  1301. k_padded = pad(kcache, alignment);
  1302. v_padded = pad(vcache, alignment);
  1303. auto opts = q.options();
  1304. auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
  1305. at::Tensor out;
  1306. if (out_.has_value()) {
  1307. out = out_.value();
  1308. TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
  1309. CHECK_DEVICE(out);
  1310. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1311. if (!is_varlen_q) {
  1312. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  1313. } else {
  1314. CHECK_SHAPE(out, total_q, num_heads, head_size_og);
  1315. }
  1316. if (head_size_og % alignment != 0) { out = torch::empty_like(q_padded, opts.dtype(out_type)); }
  1317. } else {
  1318. out = torch::empty_like(q_padded, opts.dtype(out_type));
  1319. }
  1320. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1321. const int head_size = round_multiple(head_size_og, alignment);
  1322. const int head_size_rounded = head_size <= 64 ? 64 : (head_size <= 128 ? round_multiple(head_size, 32) : round_multiple(head_size, 64));
  1323. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  1324. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  1325. // Otherwise the kernel will be launched from cuda:0 device
  1326. // Cast to char to avoid compiler warning about narrowing
  1327. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1328. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  1329. Flash_fwd_params params;
  1330. set_params_fprop(params,
  1331. batch_size,
  1332. seqlen_q, seqlen_k,
  1333. seqlen_q_rounded, seqlen_k_rounded,
  1334. num_heads, num_heads_k,
  1335. head_size, head_size_rounded,
  1336. q_padded, k_padded, v_padded, out,
  1337. /*cu_seqlens_q_d=*/!is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
  1338. /*cu_seqlens_k_d=*/nullptr,
  1339. /*seqused_q_=*/nullptr,
  1340. /*seqused_k=*/seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  1341. softmax_lse.data_ptr(),
  1342. /*p_dropout=*/0.f,
  1343. softmax_scale,
  1344. window_size_left,
  1345. window_size_right,
  1346. softcap);
  1347. params.total_q = total_q;
  1348. params.sink_token_length = sink_token_length;
  1349. params.b_k = batch_size_k;
  1350. if (paged_KV) {
  1351. params.page_table = page_table.data_ptr<int>();
  1352. params.page_table_batch_stride = page_table.stride(0);
  1353. }
  1354. params.page_size = page_size;
  1355. params.num_pages = num_pages;
  1356. params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
  1357. params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
  1358. if (k_.has_value()) {
  1359. at::Tensor k, v, k_padded, v_padded;
  1360. TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
  1361. TORCH_CHECK(seqused_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
  1362. TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
  1363. k = k_.value();
  1364. v = v_.value();
  1365. TORCH_CHECK(k.dtype() == q_type, "Key must have the same dtype as query");
  1366. TORCH_CHECK(v.dtype() == q_type, "Value must have the same dtype as query");
  1367. CHECK_DEVICE(k); CHECK_DEVICE(v);
  1368. TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
  1369. TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
  1370. int seqlen_knew = k.size(1);
  1371. CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1372. CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1373. if (head_size_og % 8 != 0) {
  1374. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1375. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1376. } else {
  1377. k_padded = k;
  1378. v_padded = v;
  1379. }
  1380. params.seqlen_knew = seqlen_knew;
  1381. params.knew_ptr = k_padded.data_ptr();
  1382. params.vnew_ptr = v_padded.data_ptr();
  1383. // All stride are in elements, not bytes.
  1384. params.knew_batch_stride = k_padded.stride(0);
  1385. params.vnew_batch_stride = v_padded.stride(0);
  1386. params.knew_row_stride = k_padded.stride(-3);
  1387. params.vnew_row_stride = v_padded.stride(-3);
  1388. params.knew_head_stride = k_padded.stride(-2);
  1389. params.vnew_head_stride = v_padded.stride(-2);
  1390. }
  1391. if (leftpad_k_.has_value()) {
  1392. auto leftpad_k = leftpad_k_.value();
  1393. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  1394. CHECK_DEVICE(leftpad_k);
  1395. CHECK_CONTIGUOUS(leftpad_k);
  1396. CHECK_SHAPE(leftpad_k, batch_size);
  1397. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  1398. }
  1399. if (rotary_cos_.has_value()) {
  1400. TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  1401. auto rotary_cos = rotary_cos_.value();
  1402. CHECK_DEVICE(rotary_cos);
  1403. params.rotary_dim = rotary_cos.size(1) * 2;
  1404. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  1405. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  1406. const int seqlen_ro = rotary_cos.size(0);
  1407. if (paged_KV) {
  1408. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  1409. }
  1410. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  1411. CHECK_CONTIGUOUS(rotary_cos);
  1412. TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
  1413. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  1414. auto rotary_sin = rotary_sin_.value();
  1415. CHECK_DEVICE(rotary_sin);
  1416. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  1417. CHECK_CONTIGUOUS(rotary_sin);
  1418. TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
  1419. params.rotary_cos_ptr = rotary_cos.data_ptr();
  1420. params.rotary_sin_ptr = rotary_sin.data_ptr();
  1421. params.is_rotary_interleaved = is_rotary_interleaved;
  1422. } else {
  1423. params.rotary_dim = 0;
  1424. }
  1425. if (cache_batch_idx_.has_value()) {
  1426. auto cache_batch_idx = cache_batch_idx_.value();
  1427. CHECK_DEVICE(cache_batch_idx);
  1428. CHECK_CONTIGUOUS(cache_batch_idx);
  1429. TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
  1430. params.kv_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
  1431. }
  1432. at::Tensor out_accum, softmax_lse_accum;
  1433. auto outaccum_type = at::ScalarType::Float;
  1434. if (params.num_splits > 1) {
  1435. TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
  1436. if (!is_varlen_q) {
  1437. out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type));
  1438. softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  1439. params.oaccum_batch_stride = out_accum.stride(1);
  1440. params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
  1441. } else {
  1442. out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type));
  1443. softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
  1444. }
  1445. params.is_fp32 = false;
  1446. params.oaccum_ptr = out_accum.data_ptr();
  1447. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  1448. params.oaccum_split_stride = out_accum.stride(0);
  1449. params.oaccum_row_stride = out_accum.stride(-2);
  1450. params.oaccum_head_stride = out_accum.stride(-3);
  1451. params.lseaccum_split_stride = softmax_lse_accum.stride(0);
  1452. params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
  1453. }
  1454. at::Tensor tile_count_semaphore;
  1455. // We don't use the persistent scheduler if Split or PagedKV or AppendKV
  1456. if ((params.is_causal || params.is_local || seqused_k_.has_value() || leftpad_k_.has_value()) && params.num_splits == 1 && !paged_KV && !k_.has_value()) {
  1457. tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32));
  1458. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  1459. } else {
  1460. params.tile_count_semaphore = nullptr;
  1461. }
  1462. if (q_type == at::ScalarType::Float8_e4m3fn) {
  1463. if (q_descale_.has_value()) {
  1464. auto q_descale = q_descale_.value();
  1465. CHECK_DEVICE(q_descale);
  1466. CHECK_SHAPE(q_descale, batch_size, num_heads_k);
  1467. params.q_descale_ptr = q_descale.data_ptr<float>();
  1468. params.q_descale_batch_stride = q_descale.stride(0);
  1469. params.q_descale_head_stride = q_descale.stride(1);
  1470. } else {
  1471. params.q_descale_ptr = nullptr;
  1472. }
  1473. if (k_descale_.has_value()) {
  1474. auto k_descale = k_descale_.value();
  1475. CHECK_DEVICE(k_descale);
  1476. CHECK_SHAPE(k_descale, batch_size, num_heads_k);
  1477. params.k_descale_ptr = k_descale.data_ptr<float>();
  1478. params.k_descale_batch_stride = k_descale.stride(0);
  1479. params.k_descale_head_stride = k_descale.stride(1);
  1480. } else {
  1481. params.k_descale_ptr = nullptr;
  1482. }
  1483. if (v_descale_.has_value()) {
  1484. auto v_descale = v_descale_.value();
  1485. CHECK_DEVICE(v_descale);
  1486. CHECK_SHAPE(v_descale, batch_size, num_heads_k);
  1487. params.v_descale_ptr = v_descale.data_ptr<float>();
  1488. params.v_descale_batch_stride = v_descale.stride(0);
  1489. params.v_descale_head_stride = v_descale.stride(1);
  1490. } else {
  1491. params.v_descale_ptr = nullptr;
  1492. }
  1493. }
  1494. #ifdef FLASHATTENTION_DISABLE_LOCAL
  1495. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  1496. #endif
  1497. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  1498. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  1499. #endif
  1500. #ifdef FLASHATTENTION_DISABLE_SPLIT
  1501. TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
  1502. #endif
  1503. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  1504. TORCH_CHECK(!params.pack_gqa, "This flash attention build does not support pack_gqa.");
  1505. #endif
  1506. #ifdef FLASHATTENTION_DISABLE_PAGEDKV
  1507. TORCH_CHECK(!paged_KV, "This flash attention build does not support paged KV.");
  1508. #endif
  1509. #ifdef FLASHATTENTION_DISABLE_APPENDKV
  1510. TORCH_CHECK(!k_.has_value(), "This flash attention build does not support appending KV.");
  1511. #endif
  1512. #ifdef FLASHATTENTION_DISABLE_VARLEN
  1513. TORCH_CHECK(!seqused_k_.has_value() && !leftpad_k_.has_value() && !k_.has_value() && !cu_seqlens_q_.has_value(), "This flash attention build does not support varlen.");
  1514. #endif
  1515. if (seqlen_q > 0 && total_q > 0 && seqlen_k > 0 && batch_size > 0) {
  1516. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1517. run_mha_fwd(params, stream);
  1518. if (params.num_splits > 1) {
  1519. if (out_type == at::ScalarType::BFloat16) {
  1520. // Since we want output in BF16. Otherwise fwd_combine will output to FP16
  1521. params.is_bf16 = true;
  1522. }
  1523. // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
  1524. // and seqlen = total_q, and don't need to dispatch to Varlen there.
  1525. // if (is_varlen_q && !seqused_q_.has_value()) {
  1526. if (is_varlen_q) {
  1527. params.b = 1;
  1528. params.seqlen_q = total_q;
  1529. }
  1530. run_mha_fwd_combine(params, stream);
  1531. }
  1532. } else if (seqlen_q > 0 && total_q > 0 && batch_size > 0) {
  1533. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  1534. out.zero_();
  1535. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  1536. }
  1537. at::Tensor out_padded = out;
  1538. if (head_size_og % alignment != 0) {
  1539. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1540. if (out_.has_value()) { out_.value().copy_(out); }
  1541. }
  1542. // return {out, softmax_lse};
  1543. return {out, softmax_lse, out_accum, softmax_lse_accum};
  1544. }
  1545. std::vector<at::Tensor>
  1546. mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
  1547. const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads
  1548. std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
  1549. std::optional<at::ScalarType> out_dtype_
  1550. ) {
  1551. auto dprops = at::cuda::getCurrentDeviceProperties();
  1552. bool is_sm80 = dprops->major >= 8;
  1553. TORCH_CHECK(is_sm80, "Attention combine function only supports Ampere GPUs or newer.");
  1554. auto out_partial_type = out_partial.scalar_type();
  1555. TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
  1556. TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
  1557. CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
  1558. TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1559. TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
  1560. const auto sizes = out_partial.sizes();
  1561. const int num_splits = sizes[0];
  1562. const int batch_size = sizes[1];
  1563. const int seqlen = sizes[2];
  1564. const int num_heads = sizes[3];
  1565. const int head_size_og = sizes[4];
  1566. TORCH_CHECK(head_size_og <= 256, "FlashAttention combine only supports head dimension at most 256");
  1567. TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
  1568. CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
  1569. CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
  1570. int const alignment = 4;
  1571. at::Tensor out_partial_padded;
  1572. auto pad = [](at::Tensor x, int alignment) {
  1573. return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
  1574. };
  1575. out_partial_padded = pad(out_partial, alignment);
  1576. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1577. const int head_size = round_multiple(head_size_og, alignment);
  1578. auto opts = out_partial.options();
  1579. at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
  1580. TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
  1581. at::Tensor out;
  1582. if (out_.has_value()) {
  1583. out = out_.value();
  1584. TORCH_CHECK(out.scalar_type() == out_type);
  1585. CHECK_DEVICE(out);
  1586. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1587. CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
  1588. if (head_size_og % alignment != 0) {
  1589. out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
  1590. }
  1591. } else {
  1592. out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
  1593. }
  1594. // Otherwise the kernel will be launched from cuda:0 device
  1595. // Cast to char to avoid compiler warning about narrowing
  1596. at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
  1597. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
  1598. Flash_fwd_params params {}; // Need to reset the params to set everything to zero
  1599. params.is_fp32 = out_type == at::ScalarType::Float;
  1600. params.is_bf16 = out_type == at::ScalarType::BFloat16;
  1601. params.oaccum_ptr = out_partial_padded.data_ptr();
  1602. params.softmax_lseaccum_ptr = lse_partial.data_ptr();
  1603. params.o_ptr = out.data_ptr();
  1604. params.softmax_lse_ptr = softmax_lse.data_ptr();
  1605. params.b = batch_size;
  1606. params.h = num_heads;
  1607. params.seqlen_q = seqlen;
  1608. params.d = head_size;
  1609. params.num_splits = num_splits;
  1610. params.oaccum_split_stride = out_partial_padded.stride(0);
  1611. params.oaccum_row_stride = out_partial_padded.stride(2);
  1612. params.oaccum_head_stride = out_partial_padded.stride(3);
  1613. params.oaccum_batch_stride = out_partial_padded.stride(1);
  1614. params.lseaccum_split_stride = lse_partial.stride(0);
  1615. params.lseaccum_head_stride = lse_partial.stride(3);
  1616. params.lseaccum_batch_stride = lse_partial.stride(1);
  1617. params.o_row_stride = out.stride(1);
  1618. params.o_head_stride = out.stride(2);
  1619. params.o_batch_stride = out.stride(0);
  1620. if (seqlen > 0 && batch_size > 0) {
  1621. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1622. run_mha_fwd_combine(params, stream);
  1623. }
  1624. at::Tensor out_padded = out;
  1625. if (head_size_og % alignment != 0) {
  1626. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1627. // if (out_.has_value()) { out_.value().copy_(out); }
  1628. }
  1629. return {out, softmax_lse};
  1630. }
  1631. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1632. m.doc() = "FlashAttention";
  1633. m.def("fwd", &mha_fwd, "Forward pass");
  1634. m.def("fwd_varlen", &mha_varlen_fwd, "Varlen forward pass");
  1635. m.def("bwd", &mha_bwd, "Backward pass");
  1636. m.def("bwd_varlen", &mha_varlen_bwd, "Varlen backward pass");
  1637. m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
  1638. m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
  1639. }