flash_api.cpp 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <torch/version.h> // For TORCH_VERSION* macros
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <c10/cuda/CUDAGuard.h>
  10. #include <cutlass/numeric_types.h>
  11. #include "flash.h"
  12. #include "static_switch.h"
  13. #include "tile_size.h"
  14. #include "heuristics.h"
  15. // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
  16. // This is so that we can pass in torch.dtype as a parameter to the function.
  17. #if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4)
  18. #include <pybind11/pybind11.h>
  19. #include <pybind11/stl.h>
  20. namespace pybind11::detail {
  21. template <>
  22. struct type_caster<at::ScalarType> {
  23. public:
  24. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  25. PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
  26. // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
  27. // cannot be default-initialized, we provide this constructor to explicitly
  28. // initialize that field. The value doesn't matter as it will be overwritten
  29. // after a successful call to load.
  30. type_caster() : value(at::kFloat) {}
  31. bool load(handle src, bool) {
  32. PyObject* obj = src.ptr();
  33. if (THPDtype_Check(obj)) {
  34. value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
  35. return true;
  36. }
  37. return false;
  38. }
  39. static handle cast(
  40. const at::ScalarType& src,
  41. return_value_policy /* policy */,
  42. handle /* parent */) {
  43. return Py_NewRef(torch::getTHPDtype(src));
  44. }
  45. };
  46. } // namespace pybind11::detail
  47. #endif
  48. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  49. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  50. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  51. void set_params_fprop(Flash_fwd_params &params,
  52. // sizes
  53. const size_t b,
  54. const size_t seqlen_q,
  55. const size_t seqlen_k,
  56. const size_t seqlen_q_rounded,
  57. const size_t seqlen_k_rounded,
  58. const size_t h,
  59. const size_t h_k,
  60. const size_t d,
  61. const size_t d_rounded,
  62. // device pointers
  63. const at::Tensor q,
  64. const at::Tensor k,
  65. const at::Tensor v,
  66. at::Tensor out,
  67. void *cu_seqlens_q_d,
  68. void *cu_seqlens_k_d,
  69. void *seqused_q,
  70. void *seqused_k,
  71. void *softmax_lse_d,
  72. float p_dropout,
  73. float softmax_scale,
  74. int window_size_left,
  75. int window_size_right,
  76. const float softcap=0.f,
  77. const int sm_margin=0) {
  78. // Reset the parameters
  79. params = {};
  80. params.is_bf16 = q.dtype() == torch::kBFloat16;
  81. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  82. // Set the pointers and strides.
  83. params.q_ptr = q.data_ptr();
  84. params.k_ptr = k.data_ptr();
  85. params.v_ptr = v.data_ptr();
  86. // All stride are in elements, not bytes.
  87. params.q_row_stride = q.stride(-3);
  88. params.k_row_stride = k.stride(-3);
  89. params.v_row_stride = v.stride(-3);
  90. params.q_head_stride = q.stride(-2);
  91. params.k_head_stride = k.stride(-2);
  92. params.v_head_stride = v.stride(-2);
  93. params.v_dim_stride = v.stride(-1);
  94. params.o_ptr = out.data_ptr();
  95. params.o_row_stride = out.stride(-3);
  96. params.o_head_stride = out.stride(-2);
  97. if (cu_seqlens_q_d == nullptr) {
  98. params.q_batch_stride = q.stride(0);
  99. params.o_batch_stride = out.stride(0);
  100. }
  101. if (cu_seqlens_k_d == nullptr) {
  102. params.k_batch_stride = k.stride(0);
  103. params.v_batch_stride = v.stride(0);
  104. }
  105. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  106. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  107. params.seqused_q = static_cast<int *>(seqused_q);
  108. params.seqused_k = static_cast<int *>(seqused_k);
  109. // Softmax sum
  110. params.softmax_lse_ptr = softmax_lse_d;
  111. // Set the dimensions.
  112. params.b = b;
  113. params.h = h;
  114. params.h_k = h_k;
  115. params.seqlen_q = seqlen_q;
  116. params.seqlen_k = seqlen_k;
  117. params.seqlen_q_rounded = seqlen_q_rounded;
  118. params.seqlen_k_rounded = seqlen_k_rounded;
  119. params.d = d;
  120. params.d_rounded = d_rounded;
  121. // Set the different scale values.
  122. params.scale_softmax = softmax_scale;
  123. params.softcap = softcap;
  124. // Set this to probability of keeping an element to simplify things.
  125. params.p_dropout = 1.f - p_dropout;
  126. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  127. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  128. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  129. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  130. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  131. params.rp_dropout = 1.f / params.p_dropout;
  132. TORCH_CHECK(p_dropout < 1.f);
  133. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  134. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  135. #endif
  136. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  137. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  138. params.is_causal = window_size_left < 0 && window_size_right == 0;
  139. params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
  140. // TODO: check this
  141. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; }
  142. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; }
  143. params.window_size_left = window_size_left;
  144. params.window_size_right = window_size_right;
  145. params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
  146. params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
  147. #ifdef FLASHATTENTION_DISABLE_LOCAL
  148. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  149. #endif
  150. }
  151. void set_params_dgrad(Flash_bwd_params &params,
  152. // sizes
  153. const size_t b,
  154. const size_t seqlen_q,
  155. const size_t seqlen_k,
  156. const size_t seqlen_q_rounded,
  157. const size_t seqlen_k_rounded,
  158. const size_t h,
  159. const size_t h_k,
  160. const size_t d,
  161. const size_t d_rounded,
  162. // device pointers
  163. const at::Tensor q,
  164. const at::Tensor k,
  165. const at::Tensor v,
  166. const at::Tensor out,
  167. const at::Tensor dout,
  168. at::Tensor dq,
  169. at::Tensor dk,
  170. at::Tensor dv,
  171. void *cu_seqlens_q_d,
  172. void *cu_seqlens_k_d,
  173. void *seqused_q,
  174. void *seqused_k,
  175. void *dq_accum_d,
  176. void *dk_accum_d,
  177. void *dv_accum_d,
  178. void *softmax_lse_d,
  179. void *dsoftmax_sum_d,
  180. float p_dropout,
  181. float softmax_scale,
  182. int window_size_left,
  183. int window_size_right,
  184. const float softcap=0.f,
  185. bool deterministic=false,
  186. int const sm_margin=0) {
  187. set_params_fprop(params,
  188. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  189. q, k, v, out,
  190. cu_seqlens_q_d,
  191. cu_seqlens_k_d,
  192. seqused_q,
  193. seqused_k,
  194. softmax_lse_d,
  195. p_dropout,
  196. softmax_scale,
  197. window_size_left,
  198. window_size_right,
  199. softcap,
  200. sm_margin);
  201. // Set the pointers and strides.
  202. params.do_ptr = dout.data_ptr();
  203. params.do_row_stride = dout.stride(-3);
  204. params.do_head_stride = dout.stride(-2);
  205. params.dq_ptr = dq.data_ptr();
  206. params.dk_ptr = dk.data_ptr();
  207. params.dv_ptr = dv.data_ptr();
  208. params.dq_row_stride = dq.stride(-3);
  209. params.dk_row_stride = dk.stride(-3);
  210. params.dv_row_stride = dv.stride(-3);
  211. params.dq_head_stride = dq.stride(-2);
  212. params.dk_head_stride = dk.stride(-2);
  213. params.dv_head_stride = dv.stride(-2);
  214. if (cu_seqlens_q_d == nullptr) {
  215. params.do_batch_stride = dout.stride(0);
  216. params.dq_batch_stride = dq.stride(0);
  217. params.dk_batch_stride = dk.stride(0);
  218. params.dv_batch_stride = dv.stride(0);
  219. }
  220. params.dq_accum_ptr = dq_accum_d;
  221. params.dk_accum_ptr = dk_accum_d;
  222. params.dv_accum_ptr = dv_accum_d;
  223. // Softmax sum
  224. params.dsoftmax_sum = dsoftmax_sum_d;
  225. params.deterministic = deterministic;
  226. }
  227. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  228. // HEADDIM_SWITCH(params.d, [&] {
  229. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  230. // });
  231. TORCH_CHECK(params.num_splits >= 1);
  232. ARCH_SWITCH(params.arch, Arch, [&] {
  233. SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
  234. PAGEDKV_SWITCH(params.page_table, PagedKV, [&] {
  235. PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
  236. // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation
  237. static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKV || Split;
  238. SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
  239. if (!params.is_e4m3) {
  240. if (params.is_bf16) {
  241. #ifndef FLASHATTENTION_DISABLE_HDIM64
  242. if (params.d <= 64) {
  243. if (params.dv > 64 && Arch == 90) {
  244. return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  245. }
  246. else {
  247. return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  248. }
  249. }
  250. #endif
  251. #ifndef FLASHATTENTION_DISABLE_HDIM96
  252. if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  253. #endif
  254. #ifndef FLASHATTENTION_DISABLE_HDIM128
  255. if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  256. #endif
  257. #ifndef FLASHATTENTION_DISABLE_HDIM192
  258. if (params.d <= 192) {
  259. if (params.dv <= 128 && Arch == 90) {
  260. return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  261. } else {
  262. return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  263. }
  264. }
  265. #endif
  266. #ifndef FLASHATTENTION_DISABLE_HDIM256
  267. if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  268. #endif
  269. } else {
  270. #ifndef FLASHATTENTION_DISABLE_FP16
  271. #ifndef FLASHATTENTION_DISABLE_HDIM64
  272. if (params.d <= 64) {
  273. if (params.dv > 64 && Arch == 90) {
  274. return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  275. }
  276. else {
  277. return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  278. }
  279. }
  280. #endif
  281. #ifndef FLASHATTENTION_DISABLE_HDIM96
  282. if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  283. #endif
  284. #ifndef FLASHATTENTION_DISABLE_HDIM128
  285. if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  286. #endif
  287. #ifndef FLASHATTENTION_DISABLE_HDIM192
  288. if (params.d <= 192) {
  289. if (params.dv <= 128 && Arch == 90) {
  290. return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  291. } else {
  292. return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  293. }
  294. }
  295. #endif
  296. #ifndef FLASHATTENTION_DISABLE_HDIM256
  297. if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  298. #endif
  299. #else
  300. TORCH_CHECK(false, "This flash attention build does not support FP16.");
  301. #endif
  302. }
  303. } else {
  304. #ifndef FLASHATTENTION_DISABLE_FP8
  305. #ifndef FLASHATTENTION_DISABLE_HDIM64
  306. if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  307. #endif
  308. #ifndef FLASHATTENTION_DISABLE_HDIM96
  309. if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  310. #endif
  311. #ifndef FLASHATTENTION_DISABLE_HDIM128
  312. if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  313. #endif
  314. #ifndef FLASHATTENTION_DISABLE_HDIM192
  315. if (params.d <= 192) {
  316. if (params.dv <= 128 && Arch == 90) {
  317. return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  318. } else {
  319. return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream);
  320. }
  321. }
  322. #endif
  323. #ifndef FLASHATTENTION_DISABLE_HDIM256
  324. if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  325. #endif
  326. #else
  327. TORCH_CHECK(false, "This flash attention build does not support FP8.");
  328. #endif
  329. }
  330. });
  331. });
  332. });
  333. });
  334. });
  335. }
  336. void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream) {
  337. #ifndef FLASHATTENTION_DISABLE_SPLIT
  338. // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
  339. // so that kBlockM is smaller and we have more parallelism.
  340. if (params.is_fp32) {
  341. if (params.dv <= 64) {
  342. run_mha_fwd_combine_<float, float, 64>(params, stream);
  343. } else {
  344. run_mha_fwd_combine_<float, float, 128>(params, stream);
  345. }
  346. } else if (params.is_bf16) {
  347. if (params.dv <= 64) {
  348. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream);
  349. } else {
  350. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream);
  351. }
  352. } else {
  353. if (params.dv <= 64) {
  354. run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream);
  355. } else {
  356. run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream);
  357. }
  358. }
  359. #else
  360. TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
  361. #endif
  362. }
  363. inline bool get_pack_gqa(Flash_fwd_params const& params) {
  364. // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation and binary size.
  365. // Has little effect on speed.
  366. if (params.arch < 90 || params.page_table || params.num_splits > 1) { return true; }
  367. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  368. return false;
  369. #else
  370. // params.page_table must already be set
  371. if (params.h == params.h_k) { return false; }
  372. // This needs to match the kernel configs
  373. auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f);
  374. int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
  375. return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
  376. #endif
  377. }
  378. inline int get_num_splits(Flash_fwd_params const& params) {
  379. #ifdef FLASHATTENTION_DISABLE_SPLIT
  380. return 1;
  381. #else
  382. // Always enable PackGQA for Split
  383. // params.page_table must already be set
  384. // This needs to match the kernel configs
  385. bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
  386. auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f);
  387. // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
  388. // has not been set here. It's OK though because we might just underestimate kBlockN a bit
  389. auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
  390. int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
  391. int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
  392. int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
  393. // If is_local, we're not going to load all of seqlen_k
  394. int const seqlen_k_loaded = !params.is_local
  395. ? params.seqlen_k
  396. : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
  397. int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
  398. int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
  399. int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);
  400. // Always enable PackGQA for Split
  401. // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.
  402. // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending
  403. // that batch = 1.
  404. int total_mblocks = (!varlen ? params.b : 1) * params.h_k * num_m_blocks;
  405. return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);
  406. #endif
  407. }
  408. inline int get_max_headdim() {
  409. #ifndef FLASHATTENTION_DISABLE_HDIM256
  410. return 256;
  411. #endif
  412. #ifndef FLASHATTENTION_DISABLE_HDIM192
  413. return 192;
  414. #endif
  415. #ifndef FLASHATTENTION_DISABLE_HDIM128
  416. return 128;
  417. #endif
  418. #ifndef FLASHATTENTION_DISABLE_HDIM96
  419. return 96;
  420. #endif
  421. #ifndef FLASHATTENTION_DISABLE_HDIM64
  422. return 64;
  423. #endif
  424. return 0;
  425. }
  426. inline int round_up_headdim(int head_size) {
  427. #ifndef FLASHATTENTION_DISABLE_HDIM64
  428. if (head_size <= 64) { return 64; }
  429. #endif
  430. #ifndef FLASHATTENTION_DISABLE_HDIM96
  431. if (head_size <= 96) { return 96; }
  432. #endif
  433. #ifndef FLASHATTENTION_DISABLE_HDIM128
  434. if (head_size <= 128) { return 128; }
  435. #endif
  436. #ifndef FLASHATTENTION_DISABLE_HDIM192
  437. if (head_size <= 192) { return 192; }
  438. #endif
  439. #ifndef FLASHATTENTION_DISABLE_HDIM256
  440. if (head_size <= 256) { return 256; }
  441. #endif
  442. return 256;
  443. }
  444. // b: batch_size
  445. // b_k: batch_size_k
  446. // s_q: seqlen_q
  447. // s_k: seqlen_k
  448. // s_k_new: seqlen_k_new
  449. // h: num_heads
  450. // h_k: num_heads_k
  451. // d: head_size
  452. std::vector<at::Tensor>
  453. mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  454. const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
  455. const at::Tensor &v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
  456. std::optional<const at::Tensor> &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
  457. std::optional<const at::Tensor> &v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
  458. std::optional<const at::Tensor> &q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
  459. std::optional<at::Tensor> &out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
  460. std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
  461. std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
  462. std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
  463. std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  464. std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  465. std::optional<int> max_seqlen_q_,
  466. // TODO: check if we need max_seqlen_k
  467. std::optional<int> max_seqlen_k_,
  468. std::optional<const at::Tensor> &page_table_, // (b_k, max_num_pages_per_seq)
  469. std::optional<const at::Tensor> &kv_batch_idx_, // b. indices to index into the KV cache
  470. std::optional<const at::Tensor> &leftpad_k_, // b
  471. std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  472. std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  473. std::optional<at::Tensor> &q_descale_, // (b, h_k), not (b, h)
  474. std::optional<at::Tensor> &k_descale_, // (b, h_k)
  475. std::optional<at::Tensor> &v_descale_, // (b, h_k)
  476. float const softmax_scale,
  477. bool is_causal,
  478. int window_size_left,
  479. int window_size_right,
  480. float const softcap,
  481. bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  482. int num_splits,
  483. std::optional<bool> pack_gqa_,
  484. int const sm_margin
  485. ) {
  486. auto dprops = at::cuda::getCurrentDeviceProperties();
  487. bool is_sm8x = dprops->major >= 8;
  488. TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  489. auto q_type = q.scalar_type();
  490. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
  491. "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
  492. if (dprops->major < 9) {
  493. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
  494. "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
  495. }
  496. TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
  497. TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
  498. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  499. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  500. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  501. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  502. at::Tensor page_table;
  503. const bool paged_KV = page_table_.has_value();
  504. if (paged_KV) {
  505. page_table = page_table_.value();
  506. CHECK_DEVICE(page_table);
  507. TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
  508. TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
  509. }
  510. at::Tensor cu_seqlens_q;
  511. bool const is_varlen_q = cu_seqlens_q_.has_value();
  512. if (is_varlen_q) {
  513. cu_seqlens_q = cu_seqlens_q_.value();
  514. CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
  515. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
  516. TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
  517. }
  518. at::Tensor cu_seqlens_k;
  519. bool const is_varlen_k = cu_seqlens_k_.has_value();
  520. if (is_varlen_k) {
  521. cu_seqlens_k = cu_seqlens_k_.value();
  522. CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
  523. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
  524. TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
  525. TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
  526. TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
  527. }
  528. // This is what we will template on
  529. bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
  530. #ifdef FLASHATTENTION_DISABLE_VARLEN
  531. TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
  532. #endif
  533. auto const sizes = q.sizes();
  534. const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
  535. int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
  536. int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
  537. int num_heads = q.size(-2);
  538. int const head_size = q.size(-1);
  539. int const head_size_v = v.size(-1);
  540. int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
  541. int const num_pages = !paged_KV ? 0 : k.size(0);
  542. int const page_size = !paged_KV ? 1 : k.size(1);
  543. int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
  544. int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
  545. int const num_heads_k = k.size(-2);
  546. int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
  547. if (!kv_batch_idx_.has_value()) {
  548. TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
  549. }
  550. int const max_headdim = get_max_headdim();
  551. TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
  552. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  553. if (head_size_v != head_size) {
  554. TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||
  555. (head_size <= 64 && head_size_v <= 512),
  556. "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], "
  557. "or (Q/K <= 64 and V <= 512).");
  558. TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
  559. if (head_size_v > 256) {
  560. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
  561. "HeaddimV > 256 requires fp16 and bf16 data type");
  562. }
  563. }
  564. // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
  565. // TODO: check this
  566. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  567. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  568. if (is_causal) { window_size_right = 0; }
  569. // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true.
  570. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM.
  571. is_causal = window_size_left < 0 && window_size_right == 0;
  572. if (!is_varlen_q) {
  573. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  574. } else {
  575. CHECK_SHAPE(q, total_q, num_heads, head_size);
  576. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  577. }
  578. if (!paged_KV) {
  579. if (!is_varlen_k) {
  580. CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
  581. CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);
  582. } else {
  583. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  584. CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
  585. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  586. }
  587. } else {
  588. CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
  589. CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);
  590. CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
  591. }
  592. if (seqused_q_.has_value()){
  593. auto seqused_q = seqused_q_.value();
  594. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  595. CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
  596. CHECK_SHAPE(seqused_q, batch_size);
  597. }
  598. if (seqused_k_.has_value()) {
  599. auto seqused_k = seqused_k_.value();
  600. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  601. CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
  602. CHECK_SHAPE(seqused_k, batch_size);
  603. }
  604. int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
  605. TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
  606. TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));
  607. auto opts = q.options();
  608. auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
  609. at::Tensor out;
  610. if (out_.has_value()) {
  611. out = out_.value();
  612. TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
  613. CHECK_DEVICE(out);
  614. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  615. if (!is_varlen_q) {
  616. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
  617. } else {
  618. CHECK_SHAPE(out, total_q, num_heads, head_size_v);
  619. }
  620. } else {
  621. out = !is_varlen_q
  622. ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type))
  623. : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type));
  624. }
  625. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  626. int const head_size_rounded = round_up_headdim(head_size);
  627. int const head_size_v_rounded = round_up_headdim(head_size_v);
  628. int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
  629. int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
  630. // Otherwise the kernel will be launched from cuda:0 device
  631. // Cast to char to avoid compiler warning about narrowing
  632. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  633. at::Tensor softmax_lse;
  634. if (!is_varlen_q) {
  635. softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  636. } else {
  637. softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  638. }
  639. Flash_fwd_params params;
  640. set_params_fprop(params,
  641. batch_size,
  642. seqlen_q, seqlen_k,
  643. seqlen_q_rounded, seqlen_k_rounded,
  644. num_heads, num_heads_k,
  645. head_size, head_size_rounded,
  646. q, k, v, out,
  647. !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
  648. !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
  649. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  650. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  651. softmax_lse.data_ptr(),
  652. /*p_dropout=*/0.f,
  653. softmax_scale,
  654. window_size_left,
  655. window_size_right,
  656. softcap,
  657. sm_margin);
  658. params.total_q = total_q;
  659. params.total_k = total_k;
  660. params.b_k = batch_size_k;
  661. params.dv = head_size_v;
  662. params.dv_rounded = head_size_v_rounded;
  663. if (paged_KV) {
  664. params.page_table = page_table.data_ptr<int>();
  665. params.page_table_batch_stride = page_table.stride(0);
  666. }
  667. params.page_size = page_size;
  668. params.num_pages = num_pages;
  669. params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
  670. // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
  671. params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
  672. if (k_new_.has_value()) {
  673. at::Tensor k_new, v_new;
  674. TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
  675. TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
  676. TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
  677. at::Tensor cu_seqlens_k_new;
  678. bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
  679. if (is_varlen_k_new) {
  680. cu_seqlens_k_new = cu_seqlens_k_new_.value();
  681. CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
  682. TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
  683. }
  684. k_new = k_new_.value();
  685. v_new = v_new_.value();
  686. TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
  687. TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
  688. CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
  689. TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
  690. TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
  691. // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
  692. int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
  693. int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
  694. if (!is_varlen_k_new) {
  695. CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
  696. CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);
  697. } else {
  698. CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
  699. CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);
  700. CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
  701. }
  702. params.seqlen_knew = seqlen_k_new;
  703. params.total_knew = total_k_new;
  704. params.knew_ptr = k_new.data_ptr();
  705. params.vnew_ptr = v_new.data_ptr();
  706. // All stride are in elements, not bytes.
  707. params.knew_row_stride = k_new.stride(-3);
  708. params.vnew_row_stride = v_new.stride(-3);
  709. params.knew_head_stride = k_new.stride(-2);
  710. params.vnew_head_stride = v_new.stride(-2);
  711. if (!is_varlen_k_new) {
  712. params.knew_batch_stride = k_new.stride(0);
  713. params.vnew_batch_stride = v_new.stride(0);
  714. }
  715. if (is_varlen_k_new) {
  716. params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
  717. }
  718. }
  719. if (q_v_.has_value()) {
  720. TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
  721. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
  722. "q_v is only supported for fp16 and bf16 data type");
  723. TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
  724. at::Tensor q_v = q_v_.value();
  725. TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
  726. CHECK_DEVICE(q_v);
  727. TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension");
  728. if (!is_varlen_q) {
  729. CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);
  730. } else {
  731. CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
  732. }
  733. params.qv_ptr = q_v.data_ptr();
  734. // All stride are in elements, not bytes.
  735. params.qv_row_stride = q_v.stride(-3);
  736. params.qv_head_stride = q_v.stride(-2);
  737. if (!is_varlen_q) {
  738. params.qv_batch_stride = q_v.stride(0);
  739. }
  740. }
  741. if (leftpad_k_.has_value()) {
  742. auto leftpad_k = leftpad_k_.value();
  743. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  744. CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
  745. CHECK_SHAPE(leftpad_k, batch_size);
  746. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  747. }
  748. if (rotary_cos_.has_value()) {
  749. TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  750. auto rotary_cos = rotary_cos_.value();
  751. CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
  752. params.rotary_dim = rotary_cos.size(1) * 2;
  753. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  754. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  755. const int seqlen_ro = rotary_cos.size(0);
  756. if (paged_KV) {
  757. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  758. }
  759. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  760. TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
  761. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  762. auto rotary_sin = rotary_sin_.value();
  763. CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
  764. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  765. TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
  766. params.rotary_cos_ptr = rotary_cos.data_ptr();
  767. params.rotary_sin_ptr = rotary_sin.data_ptr();
  768. params.is_rotary_interleaved = is_rotary_interleaved;
  769. } else {
  770. params.rotary_dim = 0;
  771. }
  772. if (kv_batch_idx_.has_value()) {
  773. auto kv_batch_idx = kv_batch_idx_.value();
  774. CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
  775. TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
  776. params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
  777. }
  778. at::Tensor out_accum, softmax_lse_accum;
  779. auto outaccum_type = at::ScalarType::Float;
  780. if (params.num_splits > 1) {
  781. TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
  782. if (!is_varlen_q) {
  783. out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type));
  784. softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  785. params.oaccum_batch_stride = out_accum.stride(1);
  786. params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
  787. } else {
  788. out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));
  789. softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
  790. }
  791. params.is_fp32 = false;
  792. params.oaccum_ptr = out_accum.data_ptr();
  793. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  794. params.oaccum_split_stride = out_accum.stride(0);
  795. params.oaccum_row_stride = out_accum.stride(-2);
  796. params.oaccum_head_stride = out_accum.stride(-3);
  797. params.lseaccum_split_stride = softmax_lse_accum.stride(0);
  798. params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
  799. }
  800. at::Tensor tile_count_semaphore, num_m_n_blocks_splits;
  801. // We don't use the persistent scheduler if Split and not Varlen
  802. bool const persistent_scheduler = params.arch >= 90
  803. ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
  804. : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
  805. if (persistent_scheduler) {
  806. tile_count_semaphore = torch::empty({1}, opts.dtype(torch::kInt32));
  807. if (!is_varlen) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
  808. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  809. if (is_varlen) {
  810. num_m_n_blocks_splits = torch::empty({batch_size * 3}, opts.dtype(torch::kInt32));
  811. params.num_m_blocks_ptr = num_m_n_blocks_splits.data_ptr<int>();
  812. params.num_n_blocks_ptr = num_m_n_blocks_splits.data_ptr<int>() + batch_size;
  813. params.num_splits_dynamic_ptr = num_m_n_blocks_splits.data_ptr<int>() + batch_size * 2;
  814. }
  815. } else {
  816. params.tile_count_semaphore = nullptr;
  817. }
  818. if (q_type == at::ScalarType::Float8_e4m3fn) {
  819. if (q_descale_.has_value()) {
  820. auto q_descale = q_descale_.value();
  821. CHECK_DEVICE(q_descale);
  822. CHECK_SHAPE(q_descale, batch_size, num_heads_k);
  823. params.q_descale_ptr = q_descale.data_ptr<float>();
  824. params.q_descale_batch_stride = q_descale.stride(0);
  825. params.q_descale_head_stride = q_descale.stride(1);
  826. } else {
  827. params.q_descale_ptr = nullptr;
  828. }
  829. if (k_descale_.has_value()) {
  830. auto k_descale = k_descale_.value();
  831. CHECK_DEVICE(k_descale);
  832. CHECK_SHAPE(k_descale, batch_size, num_heads_k);
  833. params.k_descale_ptr = k_descale.data_ptr<float>();
  834. params.k_descale_batch_stride = k_descale.stride(0);
  835. params.k_descale_head_stride = k_descale.stride(1);
  836. } else {
  837. params.k_descale_ptr = nullptr;
  838. }
  839. if (v_descale_.has_value()) {
  840. auto v_descale = v_descale_.value();
  841. CHECK_DEVICE(v_descale);
  842. CHECK_SHAPE(v_descale, batch_size, num_heads_k);
  843. params.v_descale_ptr = v_descale.data_ptr<float>();
  844. params.v_descale_batch_stride = v_descale.stride(0);
  845. params.v_descale_head_stride = v_descale.stride(1);
  846. } else {
  847. params.v_descale_ptr = nullptr;
  848. }
  849. }
  850. #ifdef FLASHATTENTION_DISABLE_LOCAL
  851. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  852. #endif
  853. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  854. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  855. #endif
  856. #ifdef FLASHATTENTION_DISABLE_SPLIT
  857. TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
  858. #endif
  859. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  860. TORCH_CHECK(!params.pack_gqa || params.arch < 90 || params.page_table || params.num_splits > 1, "This flash attention build does not support pack_gqa.");
  861. #endif
  862. #ifdef FLASHATTENTION_DISABLE_PAGEDKV
  863. TORCH_CHECK(!paged_KV, "This flash attention build does not support paged KV.");
  864. #endif
  865. #ifdef FLASHATTENTION_DISABLE_APPENDKV
  866. TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
  867. #endif
  868. if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
  869. auto stream = at::cuda::getCurrentCUDAStream().stream();
  870. run_mha_fwd(params, stream);
  871. if (params.num_splits > 1) {
  872. if (out_type == at::ScalarType::BFloat16) {
  873. // Since we want output in BF16. Otherwise fwd_combine will output to FP16
  874. params.is_bf16 = true;
  875. }
  876. // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
  877. // and seqlen = total_q, and don't need to dispatch to Varlen there.
  878. // However, with dynamic split, each row needs to know which batch it belongs to
  879. // to read the number of splits, so we just use the varlen version of combine kernel.
  880. // if (is_varlen_q && !seqused_q_.has_value()) {
  881. // if (is_varlen_q) {
  882. // params.b = 1;
  883. // params.seqlen_q = total_q;
  884. // }
  885. run_mha_fwd_combine(params, stream);
  886. }
  887. } else if (total_q > 0 && num_heads_k > 0) {
  888. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  889. out.zero_();
  890. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  891. }
  892. // return {out, softmax_lse};
  893. return {out, softmax_lse, out_accum, softmax_lse_accum};
  894. }
  895. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  896. #ifndef FLASHATTENTION_DISABLE_BACKWARD
  897. // FP16_SWITCH(!params.is_bf16, [&] {
  898. // HEADDIM_SWITCH(params.d, [&] {
  899. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  900. // });
  901. // });
  902. ARCH_SWITCH(params.arch, Arch, [&] {
  903. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  904. if (!params.is_bf16) {
  905. #ifndef FLASHATTENTION_DISABLE_FP16
  906. #ifndef FLASHATTENTION_DISABLE_HDIM64
  907. if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }
  908. #endif
  909. #ifndef FLASHATTENTION_DISABLE_HDIM96
  910. if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }
  911. #endif
  912. #ifndef FLASHATTENTION_DISABLE_HDIM128
  913. if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }
  914. #endif
  915. #ifndef FLASHATTENTION_DISABLE_HDIM192
  916. if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }
  917. #endif
  918. #ifndef FLASHATTENTION_DISABLE_HDIM256
  919. if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }
  920. #endif
  921. #else
  922. TORCH_CHECK(false, "This flash attention build does not support FP16.");
  923. #endif
  924. } else {
  925. #ifndef FLASHATTENTION_DISABLE_HDIM64
  926. if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }
  927. #endif
  928. #ifndef FLASHATTENTION_DISABLE_HDIM96
  929. if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }
  930. #endif
  931. #ifndef FLASHATTENTION_DISABLE_HDIM128
  932. if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }
  933. #endif
  934. #ifndef FLASHATTENTION_DISABLE_HDIM192
  935. if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }
  936. #endif
  937. #ifndef FLASHATTENTION_DISABLE_HDIM256
  938. if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }
  939. #endif
  940. }
  941. });
  942. });
  943. #endif
  944. }
  945. // b: batch_size
  946. // s_q: seqlen_q
  947. // s_k: seqlen_k
  948. // h: num_heads
  949. // h_k: num_heads_k
  950. // d: head_size
  951. std::vector<at::Tensor> mha_bwd(
  952. const at::Tensor &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  953. const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  954. const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  955. const at::Tensor &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  956. const at::Tensor &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  957. const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
  958. std::optional<at::Tensor> &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  959. std::optional<at::Tensor> &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  960. std::optional<at::Tensor> &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  961. std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
  962. std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
  963. std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  964. std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  965. std::optional<int> max_seqlen_q_,
  966. std::optional<int> max_seqlen_k_,
  967. float const softmax_scale,
  968. bool is_causal,
  969. int window_size_left,
  970. int window_size_right,
  971. float const softcap,
  972. bool const deterministic,
  973. int const sm_margin) {
  974. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  975. TORCH_CHECK(false, "This flash attention build does not support backward.");
  976. #endif
  977. auto dprops = at::cuda::getCurrentDeviceProperties();
  978. bool is_sm8x = dprops->major >= 8;
  979. TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  980. auto q_type = q.dtype();
  981. TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
  982. "FlashAttention only support fp16 and bf16 data type");
  983. TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
  984. TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
  985. TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
  986. TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
  987. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  988. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  989. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  990. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  991. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  992. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  993. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  994. at::Tensor cu_seqlens_q;
  995. bool const is_varlen_q = cu_seqlens_q_.has_value();
  996. if (is_varlen_q) {
  997. cu_seqlens_q = cu_seqlens_q_.value();
  998. CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
  999. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
  1000. TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
  1001. }
  1002. at::Tensor cu_seqlens_k;
  1003. bool const is_varlen_k = cu_seqlens_k_.has_value();
  1004. if (is_varlen_k) {
  1005. cu_seqlens_k = cu_seqlens_k_.value();
  1006. CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
  1007. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
  1008. TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
  1009. }
  1010. // This is what we will template on
  1011. bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
  1012. #ifdef FLASHATTENTION_DISABLE_VARLEN
  1013. TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
  1014. #endif
  1015. auto const sizes = q.sizes();
  1016. int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
  1017. int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
  1018. int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
  1019. int const num_heads = q.size(-2);
  1020. int const head_size = q.size(-1);
  1021. int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
  1022. int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
  1023. int const num_heads_k = k.size(-2);
  1024. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  1025. int const max_headdim = get_max_headdim();
  1026. TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
  1027. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  1028. // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
  1029. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  1030. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  1031. if (is_causal) { window_size_right = 0; }
  1032. // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
  1033. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
  1034. is_causal = window_size_left < 0 && window_size_right == 0;
  1035. int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
  1036. int const head_size_rounded = round_up_headdim(head_size);
  1037. // Very important that these match the kernel configs
  1038. bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
  1039. int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
  1040. : (head_size_rounded <= 96 ? 64
  1041. : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
  1042. : 64));
  1043. int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
  1044. int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
  1045. int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
  1046. int const kBlockN_sm90 = head_size_rounded <= 128
  1047. ? 128
  1048. : (head_size_rounded <= 192 ? 96 : 80);
  1049. int const kBlockN_sm80 = head_size_rounded <= 128
  1050. ? 128
  1051. : (head_size_rounded <= 192 ? 80 : 64);
  1052. int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
  1053. : (head_size_rounded <= 96 ? 128
  1054. : (head_size_rounded <= 128 ? 96
  1055. : (head_size_rounded <= 192 ? 64 : 64)));
  1056. int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
  1057. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1058. int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
  1059. int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
  1060. int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
  1061. int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
  1062. if (!is_varlen_q) {
  1063. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  1064. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  1065. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
  1066. } else {
  1067. CHECK_SHAPE(q, total_q, num_heads, head_size);
  1068. CHECK_SHAPE(out, total_q, num_heads, head_size);
  1069. CHECK_SHAPE(dout, total_q, num_heads, head_size);
  1070. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  1071. }
  1072. if (!is_varlen_k) {
  1073. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  1074. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  1075. } else {
  1076. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  1077. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  1078. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  1079. }
  1080. if (seqused_q_.has_value()){
  1081. auto seqused_q = seqused_q_.value();
  1082. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  1083. CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
  1084. CHECK_SHAPE(seqused_q, batch_size);
  1085. }
  1086. if (seqused_k_.has_value()){
  1087. auto seqused_k = seqused_k_.value();
  1088. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  1089. CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
  1090. CHECK_SHAPE(seqused_k, batch_size);
  1091. }
  1092. at::Tensor dq, dk, dv;
  1093. if (dq_.has_value()) {
  1094. dq = dq_.value();
  1095. TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
  1096. CHECK_DEVICE(dq);
  1097. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  1098. if (!is_varlen_q) {
  1099. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  1100. } else {
  1101. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  1102. }
  1103. } else {
  1104. dq = torch::empty_like(q);
  1105. }
  1106. if (dk_.has_value()) {
  1107. dk = dk_.value();
  1108. TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
  1109. CHECK_DEVICE(dk);
  1110. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  1111. if (!is_varlen_k) {
  1112. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  1113. } else {
  1114. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  1115. }
  1116. } else {
  1117. dk = torch::empty_like(k);
  1118. }
  1119. if (dv_.has_value()) {
  1120. dv = dv_.value();
  1121. TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
  1122. CHECK_DEVICE(dv);
  1123. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  1124. if (!is_varlen_k) {
  1125. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  1126. } else {
  1127. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  1128. }
  1129. } else {
  1130. dv = torch::empty_like(v);
  1131. }
  1132. // Otherwise the kernel will be launched from cuda:0 device
  1133. // Cast to char to avoid compiler warning about narrowing
  1134. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1135. auto opts = q.options();
  1136. // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  1137. at::Tensor softmax_d, softmax_lse_log2;
  1138. if (!is_varlen) {
  1139. // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  1140. softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  1141. softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  1142. } else {
  1143. softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1144. softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1145. }
  1146. at::Tensor dq_accum, dk_accum, dv_accum;
  1147. if (!is_varlen) {
  1148. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1149. } else {
  1150. dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1151. }
  1152. if (num_heads_k != num_heads) { // MQA / GQA
  1153. if (!is_varlen) {
  1154. dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1155. dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1156. } else {
  1157. dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1158. dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1159. }
  1160. }
  1161. Flash_bwd_params params;
  1162. set_params_dgrad(params,
  1163. batch_size,
  1164. seqlen_q, seqlen_k,
  1165. seqlen_q_rounded, seqlen_k_rounded,
  1166. num_heads, num_heads_k,
  1167. head_size, head_size_rounded,
  1168. q, k, v, out,
  1169. dout, dq, dk, dv,
  1170. !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
  1171. !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
  1172. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  1173. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  1174. dq_accum.data_ptr(),
  1175. num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
  1176. num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
  1177. softmax_lse.data_ptr(),
  1178. softmax_d.data_ptr(),
  1179. /*p_dropout=*/0.f,
  1180. softmax_scale,
  1181. window_size_left,
  1182. window_size_right,
  1183. softcap,
  1184. deterministic,
  1185. sm_margin);
  1186. params.total_q = total_q;
  1187. params.total_k = total_k;
  1188. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  1189. // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  1190. // params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  1191. // Will be zero'ed out in the backward preprocess kernel
  1192. at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  1193. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  1194. if (num_heads_k != num_heads && params.deterministic) {
  1195. // TODO: do we need to zero them out?
  1196. at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  1197. at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  1198. params.dk_semaphore = dk_semaphore.data_ptr<int>();
  1199. params.dv_semaphore = dv_semaphore.data_ptr<int>();
  1200. }
  1201. #ifdef FLASHATTENTION_DISABLE_LOCAL
  1202. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  1203. #endif
  1204. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  1205. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  1206. #endif
  1207. if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
  1208. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1209. run_mha_bwd(params, stream);
  1210. } else if (total_k > 0 && num_heads_k > 0) {
  1211. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  1212. dk.zero_();
  1213. dv.zero_();
  1214. softmax_d.zero_();
  1215. } else if (total_q > 0 && num_heads_k > 0) {
  1216. dq.zero_();
  1217. softmax_d.zero_();
  1218. }
  1219. return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
  1220. }
  1221. std::vector<at::Tensor>
  1222. mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
  1223. const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads
  1224. std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
  1225. std::optional<at::ScalarType> out_dtype_
  1226. ) {
  1227. auto dprops = at::cuda::getCurrentDeviceProperties();
  1228. bool is_sm8x = dprops->major >= 8;
  1229. TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer.");
  1230. auto out_partial_type = out_partial.scalar_type();
  1231. TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
  1232. TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
  1233. CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
  1234. TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1235. TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
  1236. const auto sizes = out_partial.sizes();
  1237. const int num_splits = sizes[0];
  1238. const int batch_size = sizes[1];
  1239. const int seqlen = sizes[2];
  1240. const int num_heads = sizes[3];
  1241. const int head_size_og = sizes[4];
  1242. TORCH_CHECK(head_size_og <= 512, "FlashAttention combine only supports head dimension at most 512");
  1243. TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
  1244. CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
  1245. CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
  1246. int const alignment = 4;
  1247. at::Tensor out_partial_padded;
  1248. auto pad = [](at::Tensor x, int alignment) {
  1249. return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
  1250. };
  1251. out_partial_padded = pad(out_partial, alignment);
  1252. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1253. const int head_size = round_multiple(head_size_og, alignment);
  1254. auto opts = out_partial.options();
  1255. at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
  1256. TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
  1257. at::Tensor out;
  1258. if (out_.has_value()) {
  1259. out = out_.value();
  1260. TORCH_CHECK(out.scalar_type() == out_type);
  1261. CHECK_DEVICE(out);
  1262. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1263. CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
  1264. if (head_size_og % alignment != 0) {
  1265. out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
  1266. }
  1267. } else {
  1268. out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
  1269. }
  1270. // Otherwise the kernel will be launched from cuda:0 device
  1271. // Cast to char to avoid compiler warning about narrowing
  1272. at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
  1273. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
  1274. Flash_fwd_params params {}; // Need to reset the params to set everything to zero
  1275. params.is_fp32 = out_type == at::ScalarType::Float;
  1276. params.is_bf16 = out_type == at::ScalarType::BFloat16;
  1277. params.oaccum_ptr = out_partial_padded.data_ptr();
  1278. params.softmax_lseaccum_ptr = lse_partial.data_ptr();
  1279. params.o_ptr = out.data_ptr();
  1280. params.softmax_lse_ptr = softmax_lse.data_ptr();
  1281. params.b = batch_size;
  1282. params.h = num_heads;
  1283. params.seqlen_q = seqlen;
  1284. params.dv = head_size;
  1285. params.num_splits = num_splits;
  1286. params.oaccum_split_stride = out_partial_padded.stride(0);
  1287. params.oaccum_row_stride = out_partial_padded.stride(2);
  1288. params.oaccum_head_stride = out_partial_padded.stride(3);
  1289. params.oaccum_batch_stride = out_partial_padded.stride(1);
  1290. params.lseaccum_split_stride = lse_partial.stride(0);
  1291. params.lseaccum_head_stride = lse_partial.stride(3);
  1292. params.lseaccum_batch_stride = lse_partial.stride(1);
  1293. params.o_row_stride = out.stride(1);
  1294. params.o_head_stride = out.stride(2);
  1295. params.o_batch_stride = out.stride(0);
  1296. if (seqlen > 0 && batch_size > 0) {
  1297. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1298. run_mha_fwd_combine(params, stream);
  1299. }
  1300. at::Tensor out_padded = out;
  1301. if (head_size_og % alignment != 0) {
  1302. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1303. // if (out_.has_value()) { out_.value().copy_(out); }
  1304. }
  1305. return {out, softmax_lse};
  1306. }
  1307. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1308. m.doc() = "FlashAttention";
  1309. m.def("fwd", &mha_fwd, "Forward pass");
  1310. m.def("bwd", &mha_bwd, "Backward pass");
  1311. m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
  1312. }