flash_api.cpp 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <torch/version.h> // For TORCH_VERSION* macros
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <c10/cuda/CUDAGuard.h>
  10. #include <cutlass/numeric_types.h>
  11. #include "flash.h"
  12. #include "static_switch.h"
  13. #include "tile_size.h"
  14. #include "heuristics.h"
  15. // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
  16. // This is so that we can pass in torch.dtype as a parameter to the function.
  17. #if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4)
  18. #include <pybind11/pybind11.h>
  19. #include <pybind11/stl.h>
  20. namespace pybind11::detail {
  21. template <>
  22. struct type_caster<at::ScalarType> {
  23. public:
  24. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  25. PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
  26. // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
  27. // cannot be default-initialized, we provide this constructor to explicitly
  28. // initialize that field. The value doesn't matter as it will be overwritten
  29. // after a successful call to load.
  30. type_caster() : value(at::kFloat) {}
  31. bool load(handle src, bool) {
  32. PyObject* obj = src.ptr();
  33. if (THPDtype_Check(obj)) {
  34. value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
  35. return true;
  36. }
  37. return false;
  38. }
  39. static handle cast(
  40. const at::ScalarType& src,
  41. return_value_policy /* policy */,
  42. handle /* parent */) {
  43. return Py_NewRef(torch::getTHPDtype(src));
  44. }
  45. };
  46. } // namespace pybind11::detail
  47. #endif
  48. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  49. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  50. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  51. void set_params_fprop(Flash_fwd_params &params,
  52. // sizes
  53. const size_t b,
  54. const size_t seqlen_q,
  55. const size_t seqlen_k,
  56. const size_t seqlen_q_rounded,
  57. const size_t seqlen_k_rounded,
  58. const size_t h,
  59. const size_t h_k,
  60. const size_t d,
  61. const size_t d_rounded,
  62. // device pointers
  63. const at::Tensor q,
  64. const at::Tensor k,
  65. const at::Tensor v,
  66. at::Tensor out,
  67. void *cu_seqlens_q_d,
  68. void *cu_seqlens_k_d,
  69. void *seqused_q,
  70. void *seqused_k,
  71. void *softmax_lse_d,
  72. float p_dropout,
  73. float softmax_scale,
  74. int window_size_left,
  75. int window_size_right,
  76. const float softcap=0.f,
  77. const int sm_margin=0) {
  78. // Reset the parameters
  79. params = {};
  80. params.is_bf16 = q.dtype() == torch::kBFloat16;
  81. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  82. // Set the pointers and strides.
  83. params.q_ptr = q.data_ptr();
  84. params.k_ptr = k.data_ptr();
  85. params.v_ptr = v.data_ptr();
  86. // All stride are in elements, not bytes.
  87. params.q_row_stride = q.stride(-3);
  88. params.k_row_stride = k.stride(-3);
  89. params.v_row_stride = v.stride(-3);
  90. params.q_head_stride = q.stride(-2);
  91. params.k_head_stride = k.stride(-2);
  92. params.v_head_stride = v.stride(-2);
  93. params.v_dim_stride = v.stride(-1);
  94. params.o_ptr = out.data_ptr();
  95. params.o_row_stride = out.stride(-3);
  96. params.o_head_stride = out.stride(-2);
  97. if (cu_seqlens_q_d == nullptr) {
  98. params.q_batch_stride = q.stride(0);
  99. params.o_batch_stride = out.stride(0);
  100. }
  101. if (cu_seqlens_k_d == nullptr) {
  102. params.k_batch_stride = k.stride(0);
  103. params.v_batch_stride = v.stride(0);
  104. }
  105. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  106. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  107. params.seqused_q = static_cast<int *>(seqused_q);
  108. params.seqused_k = static_cast<int *>(seqused_k);
  109. // Softmax sum
  110. params.softmax_lse_ptr = softmax_lse_d;
  111. // Set the dimensions.
  112. params.b = b;
  113. params.h = h;
  114. params.h_k = h_k;
  115. params.seqlen_q = seqlen_q;
  116. params.seqlen_k = seqlen_k;
  117. params.seqlen_q_rounded = seqlen_q_rounded;
  118. params.seqlen_k_rounded = seqlen_k_rounded;
  119. params.d = d;
  120. params.d_rounded = d_rounded;
  121. // Set the different scale values.
  122. params.scale_softmax = softmax_scale;
  123. params.softcap = softcap;
  124. // Set this to probability of keeping an element to simplify things.
  125. params.p_dropout = 1.f - p_dropout;
  126. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  127. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  128. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  129. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  130. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  131. params.rp_dropout = 1.f / params.p_dropout;
  132. TORCH_CHECK(p_dropout < 1.f);
  133. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  134. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  135. #endif
  136. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  137. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  138. params.is_causal = window_size_left < 0 && window_size_right == 0;
  139. params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
  140. // TODO: check this
  141. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; }
  142. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; }
  143. params.window_size_left = window_size_left;
  144. params.window_size_right = window_size_right;
  145. params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
  146. params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
  147. #ifdef FLASHATTENTION_DISABLE_LOCAL
  148. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  149. #endif
  150. }
  151. void set_params_dgrad(Flash_bwd_params &params,
  152. // sizes
  153. const size_t b,
  154. const size_t seqlen_q,
  155. const size_t seqlen_k,
  156. const size_t seqlen_q_rounded,
  157. const size_t seqlen_k_rounded,
  158. const size_t h,
  159. const size_t h_k,
  160. const size_t d,
  161. const size_t d_rounded,
  162. // device pointers
  163. const at::Tensor q,
  164. const at::Tensor k,
  165. const at::Tensor v,
  166. const at::Tensor out,
  167. const at::Tensor dout,
  168. at::Tensor dq,
  169. at::Tensor dk,
  170. at::Tensor dv,
  171. void *cu_seqlens_q_d,
  172. void *cu_seqlens_k_d,
  173. void *seqused_q,
  174. void *seqused_k,
  175. void *dq_accum_d,
  176. void *dk_accum_d,
  177. void *dv_accum_d,
  178. void *softmax_lse_d,
  179. void *dsoftmax_sum_d,
  180. float p_dropout,
  181. float softmax_scale,
  182. int window_size_left,
  183. int window_size_right,
  184. const float softcap=0.f,
  185. bool deterministic=false,
  186. int const sm_margin=0) {
  187. set_params_fprop(params,
  188. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  189. q, k, v, out,
  190. cu_seqlens_q_d,
  191. cu_seqlens_k_d,
  192. seqused_q,
  193. seqused_k,
  194. softmax_lse_d,
  195. p_dropout,
  196. softmax_scale,
  197. window_size_left,
  198. window_size_right,
  199. softcap,
  200. sm_margin);
  201. // Set the pointers and strides.
  202. params.do_ptr = dout.data_ptr();
  203. params.do_row_stride = dout.stride(-3);
  204. params.do_head_stride = dout.stride(-2);
  205. params.dq_ptr = dq.data_ptr();
  206. params.dk_ptr = dk.data_ptr();
  207. params.dv_ptr = dv.data_ptr();
  208. params.dq_row_stride = dq.stride(-3);
  209. params.dk_row_stride = dk.stride(-3);
  210. params.dv_row_stride = dv.stride(-3);
  211. params.dq_head_stride = dq.stride(-2);
  212. params.dk_head_stride = dk.stride(-2);
  213. params.dv_head_stride = dv.stride(-2);
  214. if (cu_seqlens_q_d == nullptr) {
  215. params.do_batch_stride = dout.stride(0);
  216. params.dq_batch_stride = dq.stride(0);
  217. params.dk_batch_stride = dk.stride(0);
  218. params.dv_batch_stride = dv.stride(0);
  219. }
  220. params.dq_accum_ptr = dq_accum_d;
  221. params.dk_accum_ptr = dk_accum_d;
  222. params.dv_accum_ptr = dv_accum_d;
  223. // Softmax sum
  224. params.dsoftmax_sum = dsoftmax_sum_d;
  225. params.deterministic = deterministic;
  226. }
  227. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  228. // HEADDIM_SWITCH(params.d, [&] {
  229. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  230. // });
  231. TORCH_CHECK(params.num_splits >= 1);
  232. ARCH_SWITCH(params.arch, Arch, [&] {
  233. SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
  234. PAGEDKV_SWITCH(params.page_table, PagedKV, [&] {
  235. PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
  236. // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation
  237. static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKV || Split;
  238. SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
  239. if (!params.is_e4m3) {
  240. if (params.is_bf16) {
  241. #ifndef FLASHATTENTION_DISABLE_HDIM64
  242. if (params.d <= 64) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  243. #endif
  244. #ifndef FLASHATTENTION_DISABLE_HDIM96
  245. if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  246. #endif
  247. #ifndef FLASHATTENTION_DISABLE_HDIM128
  248. if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  249. #endif
  250. #ifndef FLASHATTENTION_DISABLE_HDIM192
  251. if (params.d <= 192) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  252. #endif
  253. #ifndef FLASHATTENTION_DISABLE_HDIM256
  254. if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  255. #endif
  256. } else {
  257. #ifndef FLASHATTENTION_DISABLE_FP16
  258. #ifndef FLASHATTENTION_DISABLE_HDIM64
  259. if (params.d <= 64) { return run_mha_fwd_<Arch, cutlass::half_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  260. #endif
  261. #ifndef FLASHATTENTION_DISABLE_HDIM96
  262. if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  263. #endif
  264. #ifndef FLASHATTENTION_DISABLE_HDIM128
  265. if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  266. #endif
  267. #ifndef FLASHATTENTION_DISABLE_HDIM192
  268. if (params.d <= 192) { return run_mha_fwd_<Arch, cutlass::half_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  269. #endif
  270. #ifndef FLASHATTENTION_DISABLE_HDIM256
  271. if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  272. #endif
  273. #else
  274. TORCH_CHECK(false, "This flash attention build does not support FP16.");
  275. #endif
  276. }
  277. } else {
  278. #ifndef FLASHATTENTION_DISABLE_FP8
  279. #ifndef FLASHATTENTION_DISABLE_HDIM64
  280. if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  281. #endif
  282. #ifndef FLASHATTENTION_DISABLE_HDIM96
  283. if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  284. #endif
  285. #ifndef FLASHATTENTION_DISABLE_HDIM128
  286. if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  287. #endif
  288. #ifndef FLASHATTENTION_DISABLE_HDIM192
  289. if (params.d <= 192) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  290. #endif
  291. #ifndef FLASHATTENTION_DISABLE_HDIM256
  292. if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  293. #endif
  294. #else
  295. TORCH_CHECK(false, "This flash attention build does not support FP8.");
  296. #endif
  297. }
  298. });
  299. });
  300. });
  301. });
  302. });
  303. }
  304. void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream) {
  305. #ifndef FLASHATTENTION_DISABLE_SPLIT
  306. // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
  307. // so that kBlockM is smaller and we have more parallelism.
  308. if (params.is_fp32) {
  309. if (params.d <= 64) {
  310. run_mha_fwd_combine_<float, float, 64>(params, stream);
  311. } else if (params.d <= 128) {
  312. run_mha_fwd_combine_<float, float, 128>(params, stream);
  313. } else {
  314. run_mha_fwd_combine_<float, float, 256>(params, stream);
  315. }
  316. } else if (params.is_bf16) {
  317. if (params.d <= 64) {
  318. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream);
  319. } else if (params.d <= 128) {
  320. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream);
  321. } else {
  322. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 256>(params, stream);
  323. }
  324. } else {
  325. if (params.d <= 64) {
  326. run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream);
  327. } else if (params.d <= 128) {
  328. run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream);
  329. } else {
  330. run_mha_fwd_combine_<cutlass::half_t, float, 256>(params, stream);
  331. }
  332. }
  333. #else
  334. TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
  335. #endif
  336. }
  337. inline bool get_pack_gqa(Flash_fwd_params const& params) {
  338. // Always enable PackGQA for Sm8x or PagedKV or Split to reduce compilation and binary size.
  339. // Has little effect on speed.
  340. if (params.arch < 90 || params.page_table || params.num_splits > 1) { return true; }
  341. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  342. return false;
  343. #else
  344. // params.page_table must already be set
  345. if (params.h == params.h_k) { return false; }
  346. // This needs to match the kernel configs
  347. auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f);
  348. int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
  349. return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
  350. #endif
  351. }
  352. inline int get_num_splits(Flash_fwd_params const& params) {
  353. #ifdef FLASHATTENTION_DISABLE_SPLIT
  354. return 1;
  355. #else
  356. // Always enable PackGQA for Split
  357. // params.page_table must already be set
  358. // This needs to match the kernel configs
  359. bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
  360. auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f);
  361. // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
  362. // has not been set here. It's OK though because we might just underestimate kBlockN a bit
  363. auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
  364. int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
  365. int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
  366. int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
  367. // If is_local, we're not going to load all of seqlen_k
  368. int const seqlen_k_loaded = !params.is_local
  369. ? params.seqlen_k
  370. : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
  371. int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
  372. int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
  373. return num_splits_heuristic(params.b * (!params.pack_gqa ? params.h : params.h_k) * num_m_blocks, params.num_sm, num_n_blocks, 128);
  374. // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k,
  375. // params.num_sm, num_n_blocks, 128, params.d_rounded);
  376. #endif
  377. }
  378. inline int get_max_headdim() {
  379. #ifndef FLASHATTENTION_DISABLE_HDIM256
  380. return 256;
  381. #endif
  382. #ifndef FLASHATTENTION_DISABLE_HDIM192
  383. return 192;
  384. #endif
  385. #ifndef FLASHATTENTION_DISABLE_HDIM128
  386. return 128;
  387. #endif
  388. #ifndef FLASHATTENTION_DISABLE_HDIM96
  389. return 96;
  390. #endif
  391. #ifndef FLASHATTENTION_DISABLE_HDIM64
  392. return 64;
  393. #endif
  394. return 0;
  395. }
  396. inline int round_up_headdim(int head_size) {
  397. #ifndef FLASHATTENTION_DISABLE_HDIM64
  398. if (head_size <= 64) { return 64; }
  399. #endif
  400. #ifndef FLASHATTENTION_DISABLE_HDIM96
  401. if (head_size <= 96) { return 96; }
  402. #endif
  403. #ifndef FLASHATTENTION_DISABLE_HDIM128
  404. if (head_size <= 128) { return 128; }
  405. #endif
  406. #ifndef FLASHATTENTION_DISABLE_HDIM192
  407. if (head_size <= 192) { return 192; }
  408. #endif
  409. #ifndef FLASHATTENTION_DISABLE_HDIM256
  410. if (head_size <= 256) { return 256; }
  411. #endif
  412. return 256;
  413. }
  414. // b: batch_size
  415. // b_k: batch_size_k
  416. // s_q: seqlen_q
  417. // s_k: seqlen_k
  418. // s_k_new: seqlen_k_new
  419. // h: num_heads
  420. // h_k: num_heads_k
  421. // d: head_size
  422. std::vector<at::Tensor>
  423. mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  424. const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
  425. const at::Tensor &v, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
  426. std::optional<const at::Tensor> &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
  427. std::optional<const at::Tensor> &v_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
  428. std::optional<at::Tensor> &out_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  429. std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
  430. std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
  431. std::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
  432. std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  433. std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  434. std::optional<int> max_seqlen_q_,
  435. // TODO: check if we need max_seqlen_k
  436. std::optional<int> max_seqlen_k_,
  437. std::optional<const at::Tensor> &page_table_, // (b_k, max_num_pages_per_seq)
  438. std::optional<const at::Tensor> &kv_batch_idx_, // b. indices to index into the KV cache
  439. std::optional<const at::Tensor> &leftpad_k_, // b
  440. std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  441. std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  442. std::optional<at::Tensor> &q_descale_, // (b, h_k), not (b, h)
  443. std::optional<at::Tensor> &k_descale_, // (b, h_k)
  444. std::optional<at::Tensor> &v_descale_, // (b, h_k)
  445. float const softmax_scale,
  446. bool is_causal,
  447. int window_size_left,
  448. int window_size_right,
  449. int sink_token_length,
  450. float const softcap,
  451. bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  452. int num_splits,
  453. std::optional<bool> pack_gqa_,
  454. int const sm_margin
  455. ) {
  456. auto dprops = at::cuda::getCurrentDeviceProperties();
  457. bool is_sm8x = dprops->major >= 8;
  458. TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  459. auto q_type = q.scalar_type();
  460. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
  461. "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
  462. if (dprops->major < 9) {
  463. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
  464. "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
  465. }
  466. TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
  467. TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
  468. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  469. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  470. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  471. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  472. at::Tensor page_table;
  473. const bool paged_KV = page_table_.has_value();
  474. if (paged_KV) {
  475. page_table = page_table_.value();
  476. CHECK_DEVICE(page_table);
  477. TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
  478. TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
  479. }
  480. at::Tensor cu_seqlens_q;
  481. bool const is_varlen_q = cu_seqlens_q_.has_value();
  482. if (is_varlen_q) {
  483. cu_seqlens_q = cu_seqlens_q_.value();
  484. CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
  485. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
  486. TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
  487. }
  488. at::Tensor cu_seqlens_k;
  489. bool const is_varlen_k = cu_seqlens_k_.has_value();
  490. if (is_varlen_k) {
  491. cu_seqlens_k = cu_seqlens_k_.value();
  492. CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
  493. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
  494. TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
  495. TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
  496. TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
  497. }
  498. // This is what we will template on
  499. bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
  500. #ifdef FLASHATTENTION_DISABLE_VARLEN
  501. TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
  502. #endif
  503. auto const sizes = q.sizes();
  504. const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
  505. int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
  506. int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
  507. int num_heads = q.size(-2);
  508. int const head_size = q.size(-1);
  509. int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
  510. int const num_pages = !paged_KV ? 0 : k.size(0);
  511. int const page_size = !paged_KV ? 1 : k.size(1);
  512. int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
  513. int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
  514. int const num_heads_k = k.size(-2);
  515. int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
  516. if (!kv_batch_idx_.has_value()) {
  517. TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
  518. }
  519. int const max_headdim = get_max_headdim();
  520. TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
  521. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  522. // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
  523. // TODO: check this
  524. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  525. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  526. if (is_causal) { window_size_right = 0; }
  527. // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true.
  528. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM.
  529. is_causal = window_size_left < 0 && window_size_right == 0;
  530. if (!is_varlen_q) {
  531. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  532. } else {
  533. CHECK_SHAPE(q, total_q, num_heads, head_size);
  534. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  535. }
  536. if (!paged_KV) {
  537. if (!is_varlen_k) {
  538. CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
  539. CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size);
  540. } else {
  541. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  542. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  543. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  544. }
  545. } else {
  546. CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
  547. CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size);
  548. CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
  549. }
  550. if (seqused_q_.has_value()){
  551. auto seqused_q = seqused_q_.value();
  552. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  553. CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
  554. CHECK_SHAPE(seqused_q, batch_size);
  555. }
  556. if (seqused_k_.has_value()) {
  557. auto seqused_k = seqused_k_.value();
  558. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  559. CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
  560. CHECK_SHAPE(seqused_k, batch_size);
  561. }
  562. int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
  563. TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
  564. auto opts = q.options();
  565. auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
  566. at::Tensor out;
  567. if (out_.has_value()) {
  568. out = out_.value();
  569. TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
  570. CHECK_DEVICE(out);
  571. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  572. if (!is_varlen_q) {
  573. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  574. } else {
  575. CHECK_SHAPE(out, total_q, num_heads, head_size);
  576. }
  577. } else {
  578. out = torch::empty_like(q, opts.dtype(out_type));
  579. }
  580. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  581. int const head_size_rounded = round_up_headdim(head_size);
  582. int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
  583. int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
  584. // Otherwise the kernel will be launched from cuda:0 device
  585. // Cast to char to avoid compiler warning about narrowing
  586. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  587. at::Tensor softmax_lse;
  588. if (!is_varlen_q) {
  589. softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  590. } else {
  591. softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  592. }
  593. Flash_fwd_params params;
  594. set_params_fprop(params,
  595. batch_size,
  596. seqlen_q, seqlen_k,
  597. seqlen_q_rounded, seqlen_k_rounded,
  598. num_heads, num_heads_k,
  599. head_size, head_size_rounded,
  600. q, k, v, out,
  601. !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
  602. !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
  603. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  604. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  605. softmax_lse.data_ptr(),
  606. /*p_dropout=*/0.f,
  607. softmax_scale,
  608. window_size_left,
  609. window_size_right,
  610. softcap,
  611. sm_margin);
  612. params.total_q = total_q;
  613. params.total_k = total_k;
  614. params.sink_token_length = sink_token_length;
  615. params.b_k = batch_size_k;
  616. if (paged_KV) {
  617. params.page_table = page_table.data_ptr<int>();
  618. params.page_table_batch_stride = page_table.stride(0);
  619. }
  620. params.page_size = page_size;
  621. params.num_pages = num_pages;
  622. params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
  623. params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
  624. if (k_new_.has_value()) {
  625. at::Tensor k_new, v_new;
  626. TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
  627. TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
  628. TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
  629. at::Tensor cu_seqlens_k_new;
  630. bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
  631. if (is_varlen_k_new) {
  632. cu_seqlens_k_new = cu_seqlens_k_new_.value();
  633. CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
  634. TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
  635. }
  636. k_new = k_new_.value();
  637. v_new = v_new_.value();
  638. TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
  639. TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
  640. CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
  641. TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
  642. TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
  643. // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
  644. int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
  645. int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
  646. if (!is_varlen_k_new) {
  647. CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
  648. CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size);
  649. } else {
  650. CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
  651. CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size);
  652. CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
  653. }
  654. params.seqlen_knew = seqlen_k_new;
  655. params.total_knew = total_k_new;
  656. params.knew_ptr = k_new.data_ptr();
  657. params.vnew_ptr = v_new.data_ptr();
  658. // All stride are in elements, not bytes.
  659. params.knew_row_stride = k_new.stride(-3);
  660. params.vnew_row_stride = v_new.stride(-3);
  661. params.knew_head_stride = k_new.stride(-2);
  662. params.vnew_head_stride = v_new.stride(-2);
  663. if (!is_varlen_k_new) {
  664. params.knew_batch_stride = k_new.stride(0);
  665. params.vnew_batch_stride = v_new.stride(0);
  666. }
  667. if (is_varlen_k_new) {
  668. params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
  669. }
  670. }
  671. if (leftpad_k_.has_value()) {
  672. auto leftpad_k = leftpad_k_.value();
  673. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  674. CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
  675. CHECK_SHAPE(leftpad_k, batch_size);
  676. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  677. }
  678. if (rotary_cos_.has_value()) {
  679. TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  680. auto rotary_cos = rotary_cos_.value();
  681. CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
  682. params.rotary_dim = rotary_cos.size(1) * 2;
  683. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  684. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  685. const int seqlen_ro = rotary_cos.size(0);
  686. if (paged_KV) {
  687. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  688. }
  689. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  690. TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
  691. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  692. auto rotary_sin = rotary_sin_.value();
  693. CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
  694. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  695. TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
  696. params.rotary_cos_ptr = rotary_cos.data_ptr();
  697. params.rotary_sin_ptr = rotary_sin.data_ptr();
  698. params.is_rotary_interleaved = is_rotary_interleaved;
  699. } else {
  700. params.rotary_dim = 0;
  701. }
  702. if (kv_batch_idx_.has_value()) {
  703. auto kv_batch_idx = kv_batch_idx_.value();
  704. CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
  705. TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
  706. params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
  707. }
  708. at::Tensor out_accum, softmax_lse_accum;
  709. auto outaccum_type = at::ScalarType::Float;
  710. if (params.num_splits > 1) {
  711. TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
  712. if (!is_varlen_q) {
  713. out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type));
  714. softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  715. params.oaccum_batch_stride = out_accum.stride(1);
  716. params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
  717. } else {
  718. out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type));
  719. softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
  720. }
  721. params.is_fp32 = false;
  722. params.oaccum_ptr = out_accum.data_ptr();
  723. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  724. params.oaccum_split_stride = out_accum.stride(0);
  725. params.oaccum_row_stride = out_accum.stride(-2);
  726. params.oaccum_head_stride = out_accum.stride(-3);
  727. params.lseaccum_split_stride = softmax_lse_accum.stride(0);
  728. params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
  729. }
  730. at::Tensor tile_count_semaphore;
  731. // We don't use the persistent scheduler if Split and not Varlen
  732. bool const persistent_scheduler = params.arch >= 90
  733. ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
  734. : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
  735. if (persistent_scheduler) {
  736. tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32));
  737. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  738. } else {
  739. params.tile_count_semaphore = nullptr;
  740. }
  741. if (q_type == at::ScalarType::Float8_e4m3fn) {
  742. if (q_descale_.has_value()) {
  743. auto q_descale = q_descale_.value();
  744. CHECK_DEVICE(q_descale);
  745. CHECK_SHAPE(q_descale, batch_size, num_heads_k);
  746. params.q_descale_ptr = q_descale.data_ptr<float>();
  747. params.q_descale_batch_stride = q_descale.stride(0);
  748. params.q_descale_head_stride = q_descale.stride(1);
  749. } else {
  750. params.q_descale_ptr = nullptr;
  751. }
  752. if (k_descale_.has_value()) {
  753. auto k_descale = k_descale_.value();
  754. CHECK_DEVICE(k_descale);
  755. CHECK_SHAPE(k_descale, batch_size, num_heads_k);
  756. params.k_descale_ptr = k_descale.data_ptr<float>();
  757. params.k_descale_batch_stride = k_descale.stride(0);
  758. params.k_descale_head_stride = k_descale.stride(1);
  759. } else {
  760. params.k_descale_ptr = nullptr;
  761. }
  762. if (v_descale_.has_value()) {
  763. auto v_descale = v_descale_.value();
  764. CHECK_DEVICE(v_descale);
  765. CHECK_SHAPE(v_descale, batch_size, num_heads_k);
  766. params.v_descale_ptr = v_descale.data_ptr<float>();
  767. params.v_descale_batch_stride = v_descale.stride(0);
  768. params.v_descale_head_stride = v_descale.stride(1);
  769. } else {
  770. params.v_descale_ptr = nullptr;
  771. }
  772. }
  773. #ifdef FLASHATTENTION_DISABLE_LOCAL
  774. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  775. #endif
  776. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  777. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  778. #endif
  779. #ifdef FLASHATTENTION_DISABLE_SPLIT
  780. TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
  781. #endif
  782. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  783. TORCH_CHECK(!params.pack_gqa || params.arch < 90 || params.page_table || params.num_splits > 1, "This flash attention build does not support pack_gqa.");
  784. #endif
  785. #ifdef FLASHATTENTION_DISABLE_PAGEDKV
  786. TORCH_CHECK(!paged_KV, "This flash attention build does not support paged KV.");
  787. #endif
  788. #ifdef FLASHATTENTION_DISABLE_APPENDKV
  789. TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
  790. #endif
  791. if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
  792. auto stream = at::cuda::getCurrentCUDAStream().stream();
  793. run_mha_fwd(params, stream);
  794. if (params.num_splits > 1) {
  795. if (out_type == at::ScalarType::BFloat16) {
  796. // Since we want output in BF16. Otherwise fwd_combine will output to FP16
  797. params.is_bf16 = true;
  798. }
  799. // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
  800. // and seqlen = total_q, and don't need to dispatch to Varlen there.
  801. // if (is_varlen_q && !seqused_q_.has_value()) {
  802. if (is_varlen_q) {
  803. params.b = 1;
  804. params.seqlen_q = total_q;
  805. }
  806. run_mha_fwd_combine(params, stream);
  807. }
  808. } else if (total_q > 0 && num_heads_k > 0) {
  809. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  810. out.zero_();
  811. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  812. }
  813. // return {out, softmax_lse};
  814. return {out, softmax_lse, out_accum, softmax_lse_accum};
  815. }
  816. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  817. #ifndef FLASHATTENTION_DISABLE_BACKWARD
  818. // FP16_SWITCH(!params.is_bf16, [&] {
  819. // HEADDIM_SWITCH(params.d, [&] {
  820. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  821. // });
  822. // });
  823. ARCH_SWITCH(params.arch, Arch, [&] {
  824. SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
  825. if (!params.is_bf16) {
  826. #ifndef FLASHATTENTION_DISABLE_FP16
  827. #ifndef FLASHATTENTION_DISABLE_HDIM64
  828. if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }
  829. #endif
  830. #ifndef FLASHATTENTION_DISABLE_HDIM96
  831. if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }
  832. #endif
  833. #ifndef FLASHATTENTION_DISABLE_HDIM128
  834. if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }
  835. #endif
  836. #ifndef FLASHATTENTION_DISABLE_HDIM192
  837. if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }
  838. #endif
  839. #ifndef FLASHATTENTION_DISABLE_HDIM256
  840. if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }
  841. #endif
  842. #else
  843. TORCH_CHECK(false, "This flash attention build does not support FP16.");
  844. #endif
  845. } else {
  846. #ifndef FLASHATTENTION_DISABLE_HDIM64
  847. if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }
  848. #endif
  849. #ifndef FLASHATTENTION_DISABLE_HDIM96
  850. if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }
  851. #endif
  852. #ifndef FLASHATTENTION_DISABLE_HDIM128
  853. if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }
  854. #endif
  855. #ifndef FLASHATTENTION_DISABLE_HDIM192
  856. if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }
  857. #endif
  858. #ifndef FLASHATTENTION_DISABLE_HDIM256
  859. if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }
  860. #endif
  861. }
  862. });
  863. });
  864. #endif
  865. }
  866. // b: batch_size
  867. // s_q: seqlen_q
  868. // s_k: seqlen_k
  869. // h: num_heads
  870. // h_k: num_heads_k
  871. // d: head_size
  872. std::vector<at::Tensor> mha_bwd(
  873. const at::Tensor &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  874. const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  875. const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  876. const at::Tensor &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  877. const at::Tensor &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  878. const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
  879. std::optional<at::Tensor> &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  880. std::optional<at::Tensor> &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  881. std::optional<at::Tensor> &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  882. std::optional<const at::Tensor> &cu_seqlens_q_, // b+1
  883. std::optional<const at::Tensor> &cu_seqlens_k_, // b+1
  884. std::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  885. std::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  886. std::optional<int> max_seqlen_q_,
  887. std::optional<int> max_seqlen_k_,
  888. float const softmax_scale,
  889. bool is_causal,
  890. int window_size_left,
  891. int window_size_right,
  892. int const sink_token_length,
  893. float const softcap,
  894. bool const deterministic,
  895. int const sm_margin) {
  896. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  897. TORCH_CHECK(false, "This flash attention build does not support backward.");
  898. #endif
  899. auto dprops = at::cuda::getCurrentDeviceProperties();
  900. bool is_sm8x = dprops->major >= 8;
  901. TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  902. auto q_type = q.dtype();
  903. TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
  904. "FlashAttention only support fp16 and bf16 data type");
  905. TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
  906. TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
  907. TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
  908. TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
  909. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  910. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  911. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  912. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  913. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  914. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  915. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  916. at::Tensor cu_seqlens_q;
  917. bool const is_varlen_q = cu_seqlens_q_.has_value();
  918. if (is_varlen_q) {
  919. cu_seqlens_q = cu_seqlens_q_.value();
  920. CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
  921. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
  922. TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
  923. }
  924. at::Tensor cu_seqlens_k;
  925. bool const is_varlen_k = cu_seqlens_k_.has_value();
  926. if (is_varlen_k) {
  927. cu_seqlens_k = cu_seqlens_k_.value();
  928. CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
  929. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
  930. TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
  931. }
  932. // This is what we will template on
  933. bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
  934. #ifdef FLASHATTENTION_DISABLE_VARLEN
  935. TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
  936. #endif
  937. auto const sizes = q.sizes();
  938. int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
  939. int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
  940. int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
  941. int const num_heads = q.size(-2);
  942. int const head_size = q.size(-1);
  943. int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
  944. int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
  945. int const num_heads_k = k.size(-2);
  946. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  947. int const max_headdim = get_max_headdim();
  948. TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
  949. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  950. // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
  951. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  952. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  953. if (is_causal) { window_size_right = 0; }
  954. // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
  955. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
  956. is_causal = window_size_left < 0 && window_size_right == 0;
  957. int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
  958. int const head_size_rounded = round_up_headdim(head_size);
  959. // Very important that these match the kernel configs
  960. bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
  961. int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
  962. : (head_size_rounded <= 96 ? 64
  963. : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
  964. : 64));
  965. int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
  966. int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
  967. int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
  968. int const kBlockN_sm90 = head_size_rounded <= 128
  969. ? 128
  970. : (head_size_rounded <= 192 ? 96 : 80);
  971. int const kBlockN_sm80 = head_size_rounded <= 128
  972. ? 128
  973. : (head_size_rounded <= 192 ? 80 : 64);
  974. int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
  975. : (head_size_rounded <= 96 ? 128
  976. : (head_size_rounded <= 128 ? 96
  977. : (head_size_rounded <= 192 ? 64 : 64)));
  978. int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
  979. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  980. int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
  981. int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
  982. int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
  983. int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
  984. if (!is_varlen_q) {
  985. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  986. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  987. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
  988. } else {
  989. CHECK_SHAPE(q, total_q, num_heads, head_size);
  990. CHECK_SHAPE(out, total_q, num_heads, head_size);
  991. CHECK_SHAPE(dout, total_q, num_heads, head_size);
  992. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  993. }
  994. if (!is_varlen_k) {
  995. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  996. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  997. } else {
  998. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  999. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  1000. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  1001. }
  1002. if (seqused_q_.has_value()){
  1003. auto seqused_q = seqused_q_.value();
  1004. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  1005. CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
  1006. CHECK_SHAPE(seqused_q, batch_size);
  1007. }
  1008. if (seqused_k_.has_value()){
  1009. auto seqused_k = seqused_k_.value();
  1010. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  1011. CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
  1012. CHECK_SHAPE(seqused_k, batch_size);
  1013. }
  1014. at::Tensor dq, dk, dv;
  1015. if (dq_.has_value()) {
  1016. dq = dq_.value();
  1017. TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
  1018. CHECK_DEVICE(dq);
  1019. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  1020. if (!is_varlen_q) {
  1021. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  1022. } else {
  1023. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  1024. }
  1025. } else {
  1026. dq = torch::empty_like(q);
  1027. }
  1028. if (dk_.has_value()) {
  1029. dk = dk_.value();
  1030. TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
  1031. CHECK_DEVICE(dk);
  1032. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  1033. if (!is_varlen_k) {
  1034. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  1035. } else {
  1036. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  1037. }
  1038. } else {
  1039. dk = torch::empty_like(k);
  1040. }
  1041. if (dv_.has_value()) {
  1042. dv = dv_.value();
  1043. TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
  1044. CHECK_DEVICE(dv);
  1045. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  1046. if (!is_varlen_k) {
  1047. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  1048. } else {
  1049. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  1050. }
  1051. } else {
  1052. dv = torch::empty_like(v);
  1053. }
  1054. // Otherwise the kernel will be launched from cuda:0 device
  1055. // Cast to char to avoid compiler warning about narrowing
  1056. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1057. auto opts = q.options();
  1058. // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  1059. at::Tensor softmax_d, softmax_lse_log2;
  1060. if (!is_varlen) {
  1061. // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  1062. softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  1063. softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  1064. } else {
  1065. softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1066. softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1067. }
  1068. at::Tensor dq_accum, dk_accum, dv_accum;
  1069. if (!is_varlen) {
  1070. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1071. } else {
  1072. dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1073. }
  1074. if (num_heads_k != num_heads) { // MQA / GQA
  1075. if (!is_varlen) {
  1076. dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1077. dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1078. } else {
  1079. dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1080. dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1081. }
  1082. }
  1083. Flash_bwd_params params;
  1084. set_params_dgrad(params,
  1085. batch_size,
  1086. seqlen_q, seqlen_k,
  1087. seqlen_q_rounded, seqlen_k_rounded,
  1088. num_heads, num_heads_k,
  1089. head_size, head_size_rounded,
  1090. q, k, v, out,
  1091. dout, dq, dk, dv,
  1092. !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
  1093. !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
  1094. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  1095. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  1096. dq_accum.data_ptr(),
  1097. num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
  1098. num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
  1099. softmax_lse.data_ptr(),
  1100. softmax_d.data_ptr(),
  1101. /*p_dropout=*/0.f,
  1102. softmax_scale,
  1103. window_size_left,
  1104. window_size_right,
  1105. softcap,
  1106. deterministic,
  1107. sm_margin);
  1108. params.total_q = total_q;
  1109. params.total_k = total_k;
  1110. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  1111. params.sink_token_length = sink_token_length;
  1112. // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  1113. // params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  1114. // Will be zero'ed out in the backward preprocess kernel
  1115. at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  1116. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  1117. if (num_heads_k != num_heads && params.deterministic) {
  1118. // TODO: do we need to zero them out?
  1119. at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  1120. at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  1121. params.dk_semaphore = dk_semaphore.data_ptr<int>();
  1122. params.dv_semaphore = dv_semaphore.data_ptr<int>();
  1123. }
  1124. #ifdef FLASHATTENTION_DISABLE_LOCAL
  1125. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  1126. #endif
  1127. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  1128. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  1129. #endif
  1130. if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
  1131. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1132. run_mha_bwd(params, stream);
  1133. } else if (total_k > 0 && num_heads_k > 0) {
  1134. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  1135. dk.zero_();
  1136. dv.zero_();
  1137. softmax_d.zero_();
  1138. } else if (total_q > 0 && num_heads_k > 0) {
  1139. dq.zero_();
  1140. softmax_d.zero_();
  1141. }
  1142. return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
  1143. }
  1144. std::vector<at::Tensor>
  1145. mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
  1146. const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads
  1147. std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
  1148. std::optional<at::ScalarType> out_dtype_
  1149. ) {
  1150. auto dprops = at::cuda::getCurrentDeviceProperties();
  1151. bool is_sm8x = dprops->major >= 8;
  1152. TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer.");
  1153. auto out_partial_type = out_partial.scalar_type();
  1154. TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
  1155. TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
  1156. CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
  1157. TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1158. TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
  1159. const auto sizes = out_partial.sizes();
  1160. const int num_splits = sizes[0];
  1161. const int batch_size = sizes[1];
  1162. const int seqlen = sizes[2];
  1163. const int num_heads = sizes[3];
  1164. const int head_size_og = sizes[4];
  1165. TORCH_CHECK(head_size_og <= 256, "FlashAttention combine only supports head dimension at most 256");
  1166. TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
  1167. CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
  1168. CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
  1169. int const alignment = 4;
  1170. at::Tensor out_partial_padded;
  1171. auto pad = [](at::Tensor x, int alignment) {
  1172. return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
  1173. };
  1174. out_partial_padded = pad(out_partial, alignment);
  1175. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1176. const int head_size = round_multiple(head_size_og, alignment);
  1177. auto opts = out_partial.options();
  1178. at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
  1179. TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
  1180. at::Tensor out;
  1181. if (out_.has_value()) {
  1182. out = out_.value();
  1183. TORCH_CHECK(out.scalar_type() == out_type);
  1184. CHECK_DEVICE(out);
  1185. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1186. CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
  1187. if (head_size_og % alignment != 0) {
  1188. out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
  1189. }
  1190. } else {
  1191. out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
  1192. }
  1193. // Otherwise the kernel will be launched from cuda:0 device
  1194. // Cast to char to avoid compiler warning about narrowing
  1195. at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
  1196. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
  1197. Flash_fwd_params params {}; // Need to reset the params to set everything to zero
  1198. params.is_fp32 = out_type == at::ScalarType::Float;
  1199. params.is_bf16 = out_type == at::ScalarType::BFloat16;
  1200. params.oaccum_ptr = out_partial_padded.data_ptr();
  1201. params.softmax_lseaccum_ptr = lse_partial.data_ptr();
  1202. params.o_ptr = out.data_ptr();
  1203. params.softmax_lse_ptr = softmax_lse.data_ptr();
  1204. params.b = batch_size;
  1205. params.h = num_heads;
  1206. params.seqlen_q = seqlen;
  1207. params.d = head_size;
  1208. params.num_splits = num_splits;
  1209. params.oaccum_split_stride = out_partial_padded.stride(0);
  1210. params.oaccum_row_stride = out_partial_padded.stride(2);
  1211. params.oaccum_head_stride = out_partial_padded.stride(3);
  1212. params.oaccum_batch_stride = out_partial_padded.stride(1);
  1213. params.lseaccum_split_stride = lse_partial.stride(0);
  1214. params.lseaccum_head_stride = lse_partial.stride(3);
  1215. params.lseaccum_batch_stride = lse_partial.stride(1);
  1216. params.o_row_stride = out.stride(1);
  1217. params.o_head_stride = out.stride(2);
  1218. params.o_batch_stride = out.stride(0);
  1219. if (seqlen > 0 && batch_size > 0) {
  1220. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1221. run_mha_fwd_combine(params, stream);
  1222. }
  1223. at::Tensor out_padded = out;
  1224. if (head_size_og % alignment != 0) {
  1225. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1226. // if (out_.has_value()) { out_.value().copy_(out); }
  1227. }
  1228. return {out, softmax_lse};
  1229. }
  1230. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1231. m.doc() = "FlashAttention";
  1232. m.def("fwd", &mha_fwd, "Forward pass");
  1233. m.def("bwd", &mha_bwd, "Backward pass");
  1234. m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
  1235. }