flash_api.cpp 64 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <torch/version.h> // For TORCH_VERSION* macros
  8. #include <ATen/cuda/CUDAContext.h>
  9. #include <c10/cuda/CUDAGuard.h>
  10. #include <cutlass/numeric_types.h>
  11. #include "flash.h"
  12. #include "static_switch.h"
  13. #include "tile_size.h"
  14. #include "heuristics.h"
  15. // Copied from https://github.com/pytorch/pytorch/commit/7931eee5c5ebcdf468bff4d308510b03355cd909
  16. // This is so that we can pass in torch.dtype as a parameter to the function.
  17. #if TORCH_VERSION_MAJOR < 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR < 4)
  18. #include <pybind11/pybind11.h>
  19. #include <pybind11/stl.h>
  20. namespace pybind11::detail {
  21. template <>
  22. struct type_caster<at::ScalarType> {
  23. public:
  24. // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  25. PYBIND11_TYPE_CASTER(at::ScalarType, _("torch.dtype"));
  26. // PYBIND11_TYPE_CASTER defines a member field called value. at::ScalarType
  27. // cannot be default-initialized, we provide this constructor to explicitly
  28. // initialize that field. The value doesn't matter as it will be overwritten
  29. // after a successful call to load.
  30. type_caster() : value(at::kFloat) {}
  31. bool load(handle src, bool) {
  32. PyObject* obj = src.ptr();
  33. if (THPDtype_Check(obj)) {
  34. value = reinterpret_cast<THPDtype*>(obj)->scalar_type;
  35. return true;
  36. }
  37. return false;
  38. }
  39. static handle cast(
  40. const at::ScalarType& src,
  41. return_value_policy /* policy */,
  42. handle /* parent */) {
  43. return Py_NewRef(torch::getTHPDtype(src));
  44. }
  45. };
  46. } // namespace pybind11::detail
  47. #endif
  48. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  49. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  50. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  51. void set_params_fprop(Flash_fwd_params &params,
  52. // sizes
  53. const size_t b,
  54. const size_t seqlen_q,
  55. const size_t seqlen_k,
  56. const size_t seqlen_q_rounded,
  57. const size_t seqlen_k_rounded,
  58. const size_t h,
  59. const size_t h_k,
  60. const size_t d,
  61. const size_t d_rounded,
  62. // device pointers
  63. const at::Tensor q,
  64. const at::Tensor k,
  65. const at::Tensor v,
  66. at::Tensor out,
  67. void *cu_seqlens_q_d,
  68. void *cu_seqlens_k_d,
  69. void *seqused_q,
  70. void *seqused_k,
  71. void *softmax_lse_d,
  72. float p_dropout,
  73. float softmax_scale,
  74. int window_size_left,
  75. int window_size_right,
  76. const float softcap=0.f,
  77. const int sm_margin=0) {
  78. // Reset the parameters
  79. params = {};
  80. params.is_bf16 = q.dtype() == torch::kBFloat16;
  81. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  82. // Set the pointers and strides.
  83. params.q_ptr = q.data_ptr();
  84. params.k_ptr = k.data_ptr();
  85. params.v_ptr = v.data_ptr();
  86. // All stride are in elements, not bytes.
  87. params.q_row_stride = q.stride(-3);
  88. params.k_row_stride = k.stride(-3);
  89. params.v_row_stride = v.stride(-3);
  90. params.q_head_stride = q.stride(-2);
  91. params.k_head_stride = k.stride(-2);
  92. params.v_head_stride = v.stride(-2);
  93. params.v_dim_stride = v.stride(-1);
  94. params.o_ptr = out.data_ptr();
  95. params.o_row_stride = out.stride(-3);
  96. params.o_head_stride = out.stride(-2);
  97. if (cu_seqlens_q_d == nullptr) {
  98. params.q_batch_stride = q.stride(0);
  99. params.o_batch_stride = out.stride(0);
  100. }
  101. if (cu_seqlens_k_d == nullptr) {
  102. params.k_batch_stride = k.stride(0);
  103. params.v_batch_stride = v.stride(0);
  104. }
  105. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  106. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  107. params.seqused_q = static_cast<int *>(seqused_q);
  108. params.seqused_k = static_cast<int *>(seqused_k);
  109. // Softmax sum
  110. params.softmax_lse_ptr = softmax_lse_d;
  111. // Set the dimensions.
  112. params.b = b;
  113. params.h = h;
  114. params.h_k = h_k;
  115. params.seqlen_q = seqlen_q;
  116. params.seqlen_k = seqlen_k;
  117. params.seqlen_q_rounded = seqlen_q_rounded;
  118. params.seqlen_k_rounded = seqlen_k_rounded;
  119. params.d = d;
  120. params.d_rounded = d_rounded;
  121. // Set the different scale values.
  122. params.scale_softmax = softmax_scale;
  123. params.softcap = softcap;
  124. // Set this to probability of keeping an element to simplify things.
  125. params.p_dropout = 1.f - p_dropout;
  126. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  127. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  128. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  129. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  130. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  131. params.rp_dropout = 1.f / params.p_dropout;
  132. TORCH_CHECK(p_dropout < 1.f);
  133. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  134. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  135. #endif
  136. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  137. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  138. params.is_causal = window_size_left < 0 && window_size_right == 0;
  139. params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
  140. // TODO: check this
  141. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; }
  142. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; }
  143. params.window_size_left = window_size_left;
  144. params.window_size_right = window_size_right;
  145. params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
  146. params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
  147. #ifdef FLASHATTENTION_DISABLE_LOCAL
  148. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  149. #endif
  150. }
  151. void set_params_dgrad(Flash_bwd_params &params,
  152. // sizes
  153. const size_t b,
  154. const size_t seqlen_q,
  155. const size_t seqlen_k,
  156. const size_t seqlen_q_rounded,
  157. const size_t seqlen_k_rounded,
  158. const size_t h,
  159. const size_t h_k,
  160. const size_t d,
  161. const size_t d_rounded,
  162. // device pointers
  163. const at::Tensor q,
  164. const at::Tensor k,
  165. const at::Tensor v,
  166. const at::Tensor out,
  167. const at::Tensor dout,
  168. at::Tensor dq,
  169. at::Tensor dk,
  170. at::Tensor dv,
  171. void *cu_seqlens_q_d,
  172. void *cu_seqlens_k_d,
  173. void *seqused_q,
  174. void *seqused_k,
  175. void *dq_accum_d,
  176. void *dk_accum_d,
  177. void *dv_accum_d,
  178. void *softmax_lse_d,
  179. void *dsoftmax_sum_d,
  180. float p_dropout,
  181. float softmax_scale,
  182. int window_size_left,
  183. int window_size_right,
  184. const float softcap=0.f,
  185. bool deterministic=false,
  186. int const sm_margin=0) {
  187. set_params_fprop(params,
  188. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  189. q, k, v, out,
  190. cu_seqlens_q_d,
  191. cu_seqlens_k_d,
  192. seqused_q,
  193. seqused_k,
  194. softmax_lse_d,
  195. p_dropout,
  196. softmax_scale,
  197. window_size_left,
  198. window_size_right,
  199. softcap,
  200. sm_margin);
  201. // Set the pointers and strides.
  202. params.do_ptr = dout.data_ptr();
  203. params.do_row_stride = dout.stride(-3);
  204. params.do_head_stride = dout.stride(-2);
  205. params.dq_ptr = dq.data_ptr();
  206. params.dk_ptr = dk.data_ptr();
  207. params.dv_ptr = dv.data_ptr();
  208. params.dq_row_stride = dq.stride(-3);
  209. params.dk_row_stride = dk.stride(-3);
  210. params.dv_row_stride = dv.stride(-3);
  211. params.dq_head_stride = dq.stride(-2);
  212. params.dk_head_stride = dk.stride(-2);
  213. params.dv_head_stride = dv.stride(-2);
  214. if (cu_seqlens_q_d == nullptr) {
  215. params.do_batch_stride = dout.stride(0);
  216. params.dq_batch_stride = dq.stride(0);
  217. params.dk_batch_stride = dk.stride(0);
  218. params.dv_batch_stride = dv.stride(0);
  219. }
  220. params.dq_accum_ptr = dq_accum_d;
  221. params.dk_accum_ptr = dk_accum_d;
  222. params.dv_accum_ptr = dv_accum_d;
  223. // Softmax sum
  224. params.dsoftmax_sum = dsoftmax_sum_d;
  225. params.deterministic = deterministic;
  226. }
  227. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
  228. // HEADDIM_SWITCH(params.d, [&] {
  229. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  230. // });
  231. TORCH_CHECK(params.num_splits >= 1);
  232. ARCH_SWITCH(params.arch, Arch, [&] {
  233. SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
  234. PAGEDKV_SWITCH(params.page_table, PagedKV, [&] {
  235. PACKGQA_SWITCH(params.pack_gqa, PackGQA, [&] {
  236. SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
  237. if (!params.is_e4m3) {
  238. if (params.is_bf16) {
  239. #ifndef FLASHATTENTION_DISABLE_HDIM64
  240. if (params.d <= 64) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  241. #endif
  242. #ifndef FLASHATTENTION_DISABLE_HDIM96
  243. if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  244. #endif
  245. #ifndef FLASHATTENTION_DISABLE_HDIM128
  246. if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  247. #endif
  248. #ifndef FLASHATTENTION_DISABLE_HDIM192
  249. if (params.d <= 192) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  250. #endif
  251. #ifndef FLASHATTENTION_DISABLE_HDIM256
  252. if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  253. #endif
  254. } else {
  255. #ifndef FLASHATTENTION_DISABLE_FP16
  256. #ifndef FLASHATTENTION_DISABLE_HDIM64
  257. if (params.d <= 64) { return run_mha_fwd_<Arch, cutlass::half_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  258. #endif
  259. #ifndef FLASHATTENTION_DISABLE_HDIM96
  260. if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  261. #endif
  262. #ifndef FLASHATTENTION_DISABLE_HDIM128
  263. if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  264. #endif
  265. #ifndef FLASHATTENTION_DISABLE_HDIM192
  266. if (params.d <= 192) { return run_mha_fwd_<Arch, cutlass::half_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  267. #endif
  268. #ifndef FLASHATTENTION_DISABLE_HDIM256
  269. if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  270. #endif
  271. #else
  272. TORCH_CHECK(false, "This flash attention build does not support FP16.");
  273. #endif
  274. }
  275. } else {
  276. #ifndef FLASHATTENTION_DISABLE_FP8
  277. #ifndef FLASHATTENTION_DISABLE_HDIM64
  278. if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  279. #endif
  280. #ifndef FLASHATTENTION_DISABLE_HDIM96
  281. if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  282. #endif
  283. #ifndef FLASHATTENTION_DISABLE_HDIM128
  284. if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  285. #endif
  286. #ifndef FLASHATTENTION_DISABLE_HDIM192
  287. if (params.d <= 192) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  288. #endif
  289. #ifndef FLASHATTENTION_DISABLE_HDIM256
  290. if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, Split, PagedKV, Has_softcap, PackGQA>(params, stream); }
  291. #endif
  292. #else
  293. TORCH_CHECK(false, "This flash attention build does not support FP8.");
  294. #endif
  295. }
  296. });
  297. });
  298. });
  299. });
  300. });
  301. }
  302. void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream) {
  303. #ifndef FLASHATTENTION_DISABLE_SPLIT
  304. // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
  305. // so that kBlockM is smaller and we have more parallelism.
  306. if (params.is_fp32) {
  307. if (params.d <= 64) {
  308. run_mha_fwd_combine_<float, float, 64>(params, stream);
  309. } else if (params.d <= 128) {
  310. run_mha_fwd_combine_<float, float, 128>(params, stream);
  311. } else {
  312. run_mha_fwd_combine_<float, float, 256>(params, stream);
  313. }
  314. } else if (params.is_bf16) {
  315. if (params.d <= 64) {
  316. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream);
  317. } else if (params.d <= 128) {
  318. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream);
  319. } else {
  320. run_mha_fwd_combine_<cutlass::bfloat16_t, float, 256>(params, stream);
  321. }
  322. } else {
  323. if (params.d <= 64) {
  324. run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream);
  325. } else if (params.d <= 128) {
  326. run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream);
  327. } else {
  328. run_mha_fwd_combine_<cutlass::half_t, float, 256>(params, stream);
  329. }
  330. }
  331. #else
  332. TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
  333. #endif
  334. }
  335. inline bool get_pack_gqa(Flash_fwd_params const& params) {
  336. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  337. return false;
  338. #else
  339. // params.page_table must already be set
  340. if (params.h == params.h_k) { return false; }
  341. // This needs to match the kernel configs
  342. bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
  343. auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f);
  344. // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
  345. // has not been set here. It's OK though because for Sm80 kBlockM = 128 always.
  346. auto kBlockMN_kernel_args_sm80 = tile_size_fwd_sm80(params.arch == 86 || params.arch == 89, params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, varlen, params.softcap > 0.f);
  347. int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm80);
  348. return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
  349. #endif
  350. }
  351. inline int get_num_splits(Flash_fwd_params const& params) {
  352. #ifdef FLASHATTENTION_DISABLE_SPLIT
  353. return 1;
  354. #else
  355. // params.pack_gqa must already be set
  356. // params.page_table must already be set
  357. // This needs to match the kernel configs
  358. bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
  359. auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table, params.softcap > 0.f);
  360. // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
  361. // has not been set here. It's OK though because we might just underestimate kBlockN a bit
  362. auto kBlockMN_kernel_args_sm80 = tile_size_fwd_sm80(params.arch == 86 || params.arch == 89, params.d_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, varlen, params.softcap > 0.f);
  363. int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm80);
  364. int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm80);
  365. int seqlen_q_packgqa = params.seqlen_q * (params.pack_gqa ? params.h / params.h_k : 1);
  366. // If is_local, we're not going to load all of seqlen_k
  367. int const seqlen_k_loaded = !params.is_local
  368. ? params.seqlen_k
  369. : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
  370. int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
  371. int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
  372. return num_splits_heuristic(params.b * (!params.pack_gqa ? params.h : params.h_k) * num_m_blocks, params.num_sm, num_n_blocks, 128);
  373. // return num_splits_heuristic(params.b * params.h_k * num_m_blocks, params.b * params.h_k,
  374. // params.num_sm, num_n_blocks, 128, params.d_rounded);
  375. #endif
  376. }
  377. inline int get_max_headdim() {
  378. #ifndef FLASHATTENTION_DISABLE_HDIM256
  379. return 256;
  380. #endif
  381. #ifndef FLASHATTENTION_DISABLE_HDIM192
  382. return 192;
  383. #endif
  384. #ifndef FLASHATTENTION_DISABLE_HDIM128
  385. return 128;
  386. #endif
  387. #ifndef FLASHATTENTION_DISABLE_HDIM96
  388. return 96;
  389. #endif
  390. #ifndef FLASHATTENTION_DISABLE_HDIM64
  391. return 64;
  392. #endif
  393. return 0;
  394. }
  395. inline int round_up_headdim(int head_size) {
  396. #ifndef FLASHATTENTION_DISABLE_HDIM64
  397. if (head_size <= 64) { return 64; }
  398. #endif
  399. #ifndef FLASHATTENTION_DISABLE_HDIM96
  400. if (head_size <= 96) { return 96; }
  401. #endif
  402. #ifndef FLASHATTENTION_DISABLE_HDIM128
  403. if (head_size <= 128) { return 128; }
  404. #endif
  405. #ifndef FLASHATTENTION_DISABLE_HDIM192
  406. if (head_size <= 192) { return 192; }
  407. #endif
  408. #ifndef FLASHATTENTION_DISABLE_HDIM256
  409. if (head_size <= 256) { return 256; }
  410. #endif
  411. return 256;
  412. }
  413. // b: batch_size
  414. // b_k: batch_size_k
  415. // s_q: seqlen_q
  416. // s_k: seqlen_k
  417. // s_k_new: seqlen_k_new
  418. // h: num_heads
  419. // h_k: num_heads_k
  420. // d: head_size
  421. std::vector<at::Tensor>
  422. mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  423. const at::Tensor &k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
  424. const at::Tensor &v, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
  425. c10::optional<const at::Tensor> &k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
  426. c10::optional<const at::Tensor> &v_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
  427. c10::optional<at::Tensor> &out_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  428. c10::optional<const at::Tensor> &cu_seqlens_q_, // b+1
  429. c10::optional<const at::Tensor> &cu_seqlens_k_, // b+1
  430. c10::optional<const at::Tensor> &cu_seqlens_k_new_, // b+1
  431. c10::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  432. c10::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  433. c10::optional<int> max_seqlen_q_,
  434. // TODO: check if we need max_seqlen_k
  435. c10::optional<int> max_seqlen_k_,
  436. c10::optional<const at::Tensor> &page_table_, // (b_k, max_num_pages_per_seq)
  437. c10::optional<const at::Tensor> &kv_batch_idx_, // b. indices to index into the KV cache
  438. c10::optional<const at::Tensor> &leftpad_k_, // b
  439. c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  440. c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  441. c10::optional<at::Tensor> &q_descale_, // (b, h_k), not (b, h)
  442. c10::optional<at::Tensor> &k_descale_, // (b, h_k)
  443. c10::optional<at::Tensor> &v_descale_, // (b, h_k)
  444. float const softmax_scale,
  445. bool is_causal,
  446. int window_size_left,
  447. int window_size_right,
  448. int sink_token_length,
  449. float const softcap,
  450. bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  451. int num_splits,
  452. c10::optional<bool> pack_gqa_,
  453. int const sm_margin
  454. ) {
  455. auto dprops = at::cuda::getCurrentDeviceProperties();
  456. bool is_sm8x = dprops->major >= 8;
  457. TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  458. auto q_type = q.scalar_type();
  459. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
  460. "FlashAttention only support fp16, bf16, and fp8_e4m3 data type");
  461. TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
  462. TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
  463. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  464. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  465. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  466. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  467. at::Tensor page_table;
  468. const bool paged_KV = page_table_.has_value();
  469. if (paged_KV) {
  470. page_table = page_table_.value();
  471. CHECK_DEVICE(page_table);
  472. TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
  473. TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
  474. }
  475. at::Tensor cu_seqlens_q;
  476. bool const is_varlen_q = cu_seqlens_q_.has_value();
  477. if (is_varlen_q) {
  478. cu_seqlens_q = cu_seqlens_q_.value();
  479. CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
  480. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
  481. TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
  482. }
  483. at::Tensor cu_seqlens_k;
  484. bool const is_varlen_k = cu_seqlens_k_.has_value();
  485. if (is_varlen_k) {
  486. cu_seqlens_k = cu_seqlens_k_.value();
  487. CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
  488. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
  489. TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
  490. TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
  491. TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
  492. }
  493. // This is what we will template on
  494. bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
  495. #ifdef FLASHATTENTION_DISABLE_VARLEN
  496. TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
  497. #endif
  498. auto const sizes = q.sizes();
  499. const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
  500. int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
  501. int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
  502. int num_heads = q.size(-2);
  503. int const head_size = q.size(-1);
  504. int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
  505. int const num_pages = !paged_KV ? 0 : k.size(0);
  506. int const page_size = !paged_KV ? 1 : k.size(1);
  507. int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
  508. int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
  509. int const num_heads_k = k.size(-2);
  510. int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
  511. if (!kv_batch_idx_.has_value()) {
  512. TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
  513. }
  514. int const max_headdim = get_max_headdim();
  515. TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
  516. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  517. // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
  518. // TODO: check this
  519. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  520. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  521. if (is_causal) { window_size_right = 0; }
  522. // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_fprop will set params.is_causal=true.
  523. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM.
  524. is_causal = window_size_left < 0 && window_size_right == 0;
  525. if (!is_varlen_q) {
  526. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  527. } else {
  528. CHECK_SHAPE(q, total_q, num_heads, head_size);
  529. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  530. }
  531. if (!paged_KV) {
  532. if (!is_varlen_k) {
  533. CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
  534. CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size);
  535. } else {
  536. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  537. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  538. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  539. }
  540. } else {
  541. CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
  542. CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size);
  543. CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
  544. }
  545. if (seqused_q_.has_value()){
  546. auto seqused_q = seqused_q_.value();
  547. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  548. CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
  549. CHECK_SHAPE(seqused_q, batch_size);
  550. }
  551. if (seqused_k_.has_value()) {
  552. auto seqused_k = seqused_k_.value();
  553. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  554. CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
  555. CHECK_SHAPE(seqused_k, batch_size);
  556. }
  557. int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
  558. TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
  559. auto opts = q.options();
  560. auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
  561. at::Tensor out;
  562. if (out_.has_value()) {
  563. out = out_.value();
  564. TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
  565. CHECK_DEVICE(out);
  566. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  567. if (!is_varlen_q) {
  568. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  569. } else {
  570. CHECK_SHAPE(out, total_q, num_heads, head_size);
  571. }
  572. } else {
  573. out = torch::empty_like(q, opts.dtype(out_type));
  574. }
  575. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  576. int const head_size_rounded = round_up_headdim(head_size);
  577. int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
  578. int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
  579. // Otherwise the kernel will be launched from cuda:0 device
  580. // Cast to char to avoid compiler warning about narrowing
  581. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  582. at::Tensor softmax_lse;
  583. if (!is_varlen_q) {
  584. softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  585. } else {
  586. softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  587. }
  588. Flash_fwd_params params;
  589. set_params_fprop(params,
  590. batch_size,
  591. seqlen_q, seqlen_k,
  592. seqlen_q_rounded, seqlen_k_rounded,
  593. num_heads, num_heads_k,
  594. head_size, head_size_rounded,
  595. q, k, v, out,
  596. !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
  597. !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
  598. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  599. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  600. softmax_lse.data_ptr(),
  601. /*p_dropout=*/0.f,
  602. softmax_scale,
  603. window_size_left,
  604. window_size_right,
  605. softcap,
  606. sm_margin);
  607. params.total_q = total_q;
  608. params.total_k = total_k;
  609. params.sink_token_length = sink_token_length;
  610. params.b_k = batch_size_k;
  611. if (paged_KV) {
  612. params.page_table = page_table.data_ptr<int>();
  613. params.page_table_batch_stride = page_table.stride(0);
  614. }
  615. params.page_size = page_size;
  616. params.num_pages = num_pages;
  617. params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
  618. params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
  619. if (k_new_.has_value()) {
  620. at::Tensor k_new, v_new;
  621. TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
  622. TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
  623. TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
  624. at::Tensor cu_seqlens_k_new;
  625. bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
  626. if (is_varlen_k_new) {
  627. cu_seqlens_k_new = cu_seqlens_k_new_.value();
  628. CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
  629. TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
  630. }
  631. k_new = k_new_.value();
  632. v_new = v_new_.value();
  633. TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
  634. TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
  635. CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
  636. TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
  637. TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
  638. // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
  639. int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
  640. int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
  641. if (!is_varlen_k_new) {
  642. CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
  643. CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size);
  644. } else {
  645. CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
  646. CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size);
  647. CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
  648. }
  649. params.seqlen_knew = seqlen_k_new;
  650. params.total_knew = total_k_new;
  651. params.knew_ptr = k_new.data_ptr();
  652. params.vnew_ptr = v_new.data_ptr();
  653. // All stride are in elements, not bytes.
  654. params.knew_row_stride = k_new.stride(-3);
  655. params.vnew_row_stride = v_new.stride(-3);
  656. params.knew_head_stride = k_new.stride(-2);
  657. params.vnew_head_stride = v_new.stride(-2);
  658. if (!is_varlen_k_new) {
  659. params.knew_batch_stride = k_new.stride(0);
  660. params.vnew_batch_stride = v_new.stride(0);
  661. }
  662. if (is_varlen_k_new) {
  663. params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
  664. }
  665. }
  666. if (leftpad_k_.has_value()) {
  667. auto leftpad_k = leftpad_k_.value();
  668. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  669. CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
  670. CHECK_SHAPE(leftpad_k, batch_size);
  671. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  672. }
  673. if (rotary_cos_.has_value()) {
  674. TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  675. auto rotary_cos = rotary_cos_.value();
  676. CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
  677. params.rotary_dim = rotary_cos.size(1) * 2;
  678. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  679. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  680. const int seqlen_ro = rotary_cos.size(0);
  681. if (paged_KV) {
  682. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  683. }
  684. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  685. TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
  686. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  687. auto rotary_sin = rotary_sin_.value();
  688. CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
  689. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  690. TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
  691. params.rotary_cos_ptr = rotary_cos.data_ptr();
  692. params.rotary_sin_ptr = rotary_sin.data_ptr();
  693. params.is_rotary_interleaved = is_rotary_interleaved;
  694. } else {
  695. params.rotary_dim = 0;
  696. }
  697. if (kv_batch_idx_.has_value()) {
  698. auto kv_batch_idx = kv_batch_idx_.value();
  699. CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
  700. TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
  701. params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
  702. }
  703. at::Tensor out_accum, softmax_lse_accum;
  704. auto outaccum_type = at::ScalarType::Float;
  705. if (params.num_splits > 1) {
  706. TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
  707. if (!is_varlen_q) {
  708. out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size}, opts.dtype(outaccum_type));
  709. softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  710. params.oaccum_batch_stride = out_accum.stride(1);
  711. params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
  712. } else {
  713. out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size}, opts.dtype(outaccum_type));
  714. softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
  715. }
  716. params.is_fp32 = false;
  717. params.oaccum_ptr = out_accum.data_ptr();
  718. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  719. params.oaccum_split_stride = out_accum.stride(0);
  720. params.oaccum_row_stride = out_accum.stride(-2);
  721. params.oaccum_head_stride = out_accum.stride(-3);
  722. params.lseaccum_split_stride = softmax_lse_accum.stride(0);
  723. params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
  724. }
  725. at::Tensor tile_count_semaphore;
  726. // We don't use the persistent scheduler if Split and not Varlen
  727. if (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) {
  728. tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32));
  729. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  730. } else {
  731. params.tile_count_semaphore = nullptr;
  732. }
  733. if (q_type == at::ScalarType::Float8_e4m3fn) {
  734. if (q_descale_.has_value()) {
  735. auto q_descale = q_descale_.value();
  736. CHECK_DEVICE(q_descale);
  737. CHECK_SHAPE(q_descale, batch_size, num_heads_k);
  738. params.q_descale_ptr = q_descale.data_ptr<float>();
  739. params.q_descale_batch_stride = q_descale.stride(0);
  740. params.q_descale_head_stride = q_descale.stride(1);
  741. } else {
  742. params.q_descale_ptr = nullptr;
  743. }
  744. if (k_descale_.has_value()) {
  745. auto k_descale = k_descale_.value();
  746. CHECK_DEVICE(k_descale);
  747. CHECK_SHAPE(k_descale, batch_size, num_heads_k);
  748. params.k_descale_ptr = k_descale.data_ptr<float>();
  749. params.k_descale_batch_stride = k_descale.stride(0);
  750. params.k_descale_head_stride = k_descale.stride(1);
  751. } else {
  752. params.k_descale_ptr = nullptr;
  753. }
  754. if (v_descale_.has_value()) {
  755. auto v_descale = v_descale_.value();
  756. CHECK_DEVICE(v_descale);
  757. CHECK_SHAPE(v_descale, batch_size, num_heads_k);
  758. params.v_descale_ptr = v_descale.data_ptr<float>();
  759. params.v_descale_batch_stride = v_descale.stride(0);
  760. params.v_descale_head_stride = v_descale.stride(1);
  761. } else {
  762. params.v_descale_ptr = nullptr;
  763. }
  764. }
  765. #ifdef FLASHATTENTION_DISABLE_LOCAL
  766. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  767. #endif
  768. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  769. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  770. #endif
  771. #ifdef FLASHATTENTION_DISABLE_SPLIT
  772. TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
  773. #endif
  774. #ifdef FLASHATTENTION_DISABLE_PACKGQA
  775. TORCH_CHECK(!params.pack_gqa, "This flash attention build does not support pack_gqa.");
  776. #endif
  777. #ifdef FLASHATTENTION_DISABLE_PAGEDKV
  778. TORCH_CHECK(!paged_KV, "This flash attention build does not support paged KV.");
  779. #endif
  780. #ifdef FLASHATTENTION_DISABLE_APPENDKV
  781. TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
  782. #endif
  783. if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
  784. auto stream = at::cuda::getCurrentCUDAStream().stream();
  785. run_mha_fwd(params, stream);
  786. if (params.num_splits > 1) {
  787. if (out_type == at::ScalarType::BFloat16) {
  788. // Since we want output in BF16. Otherwise fwd_combine will output to FP16
  789. params.is_bf16 = true;
  790. }
  791. // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
  792. // and seqlen = total_q, and don't need to dispatch to Varlen there.
  793. // if (is_varlen_q && !seqused_q_.has_value()) {
  794. if (is_varlen_q) {
  795. params.b = 1;
  796. params.seqlen_q = total_q;
  797. }
  798. run_mha_fwd_combine(params, stream);
  799. }
  800. } else if (total_q > 0 && num_heads_k > 0) {
  801. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  802. out.zero_();
  803. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  804. }
  805. // return {out, softmax_lse};
  806. return {out, softmax_lse, out_accum, softmax_lse_accum};
  807. }
  808. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  809. #ifndef FLASHATTENTION_DISABLE_BACKWARD
  810. // FP16_SWITCH(!params.is_bf16, [&] {
  811. // HEADDIM_SWITCH(params.d, [&] {
  812. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  813. // });
  814. // });
  815. ARCH_SWITCH(params.arch, Arch, [&] {
  816. if (!params.is_bf16) {
  817. #ifndef FLASHATTENTION_DISABLE_FP16
  818. #ifndef FLASHATTENTION_DISABLE_HDIM64
  819. if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64>(params, stream); }
  820. #endif
  821. #ifndef FLASHATTENTION_DISABLE_HDIM96
  822. if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96>(params, stream); }
  823. #endif
  824. #ifndef FLASHATTENTION_DISABLE_HDIM128
  825. if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128>(params, stream); }
  826. #endif
  827. #ifndef FLASHATTENTION_DISABLE_HDIM192
  828. if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192>(params, stream); }
  829. #endif
  830. #ifndef FLASHATTENTION_DISABLE_HDIM256
  831. if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256>(params, stream); }
  832. #endif
  833. #else
  834. TORCH_CHECK(false, "This flash attention build does not support FP16.");
  835. #endif
  836. } else {
  837. #ifndef FLASHATTENTION_DISABLE_HDIM64
  838. if (params.d <= 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64>(params, stream); }
  839. #endif
  840. #ifndef FLASHATTENTION_DISABLE_HDIM96
  841. if (params.d <= 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96>(params, stream); }
  842. #endif
  843. #ifndef FLASHATTENTION_DISABLE_HDIM128
  844. if (params.d <= 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128>(params, stream); }
  845. #endif
  846. #ifndef FLASHATTENTION_DISABLE_HDIM192
  847. if (params.d <= 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192>(params, stream); }
  848. #endif
  849. #ifndef FLASHATTENTION_DISABLE_HDIM256
  850. if (params.d <= 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256>(params, stream); }
  851. #endif
  852. }
  853. });
  854. #endif
  855. }
  856. // b: batch_size
  857. // s_q: seqlen_q
  858. // s_k: seqlen_k
  859. // h: num_heads
  860. // h_k: num_heads_k
  861. // d: head_size
  862. std::vector<at::Tensor> mha_bwd(
  863. const at::Tensor &dout, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  864. const at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  865. const at::Tensor &k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  866. const at::Tensor &v, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  867. const at::Tensor &out, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  868. const at::Tensor &softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
  869. c10::optional<at::Tensor> &dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
  870. c10::optional<at::Tensor> &dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  871. c10::optional<at::Tensor> &dv_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
  872. c10::optional<const at::Tensor> &cu_seqlens_q_, // b+1
  873. c10::optional<const at::Tensor> &cu_seqlens_k_, // b+1
  874. c10::optional<const at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  875. c10::optional<const at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  876. c10::optional<int> max_seqlen_q_,
  877. c10::optional<int> max_seqlen_k_,
  878. float const softmax_scale,
  879. bool is_causal,
  880. int window_size_left,
  881. int window_size_right,
  882. int const sink_token_length,
  883. float const softcap,
  884. bool const deterministic,
  885. int const sm_margin) {
  886. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  887. TORCH_CHECK(false, "This flash attention build does not support backward.");
  888. #endif
  889. auto dprops = at::cuda::getCurrentDeviceProperties();
  890. bool is_sm8x = dprops->major >= 8;
  891. TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  892. auto q_type = q.dtype();
  893. TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
  894. "FlashAttention only support fp16 and bf16 data type");
  895. TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
  896. TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
  897. TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
  898. TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
  899. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  900. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  901. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  902. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  903. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  904. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  905. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  906. at::Tensor cu_seqlens_q;
  907. bool const is_varlen_q = cu_seqlens_q_.has_value();
  908. if (is_varlen_q) {
  909. cu_seqlens_q = cu_seqlens_q_.value();
  910. CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
  911. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
  912. TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
  913. }
  914. at::Tensor cu_seqlens_k;
  915. bool const is_varlen_k = cu_seqlens_k_.has_value();
  916. if (is_varlen_k) {
  917. cu_seqlens_k = cu_seqlens_k_.value();
  918. CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
  919. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
  920. TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
  921. }
  922. // This is what we will template on
  923. bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
  924. #ifdef FLASHATTENTION_DISABLE_VARLEN
  925. TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
  926. #endif
  927. auto const sizes = q.sizes();
  928. int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
  929. int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
  930. int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
  931. int const num_heads = q.size(-2);
  932. int const head_size = q.size(-1);
  933. int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
  934. int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
  935. int const num_heads_k = k.size(-2);
  936. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  937. int const max_headdim = get_max_headdim();
  938. TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
  939. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  940. // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
  941. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  942. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  943. if (is_causal) { window_size_right = 0; }
  944. // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
  945. // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
  946. is_causal = window_size_left < 0 && window_size_right == 0;
  947. int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
  948. int const head_size_rounded = round_up_headdim(head_size);
  949. // Very important that these match the kernel configs
  950. bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
  951. int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
  952. : (head_size_rounded <= 96 ? 64
  953. : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
  954. : 64));
  955. int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
  956. int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
  957. int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
  958. int const kBlockN_sm90_sm80 = head_size_rounded <= 128
  959. ? 128
  960. : (arch >= 90
  961. ? (head_size_rounded <= 192 ? 96 : 80)
  962. : (head_size_rounded <= 192 ? 80 : 64));
  963. int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
  964. : (head_size_rounded <= 96 ? 128
  965. : (head_size_rounded <= 128 ? 96
  966. : (head_size_rounded <= 192 ? 64 : 64)));
  967. int const kBlockN = arch >= 90 ? kBlockN_sm90_sm80 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm90_sm80);
  968. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  969. int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
  970. int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
  971. int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
  972. int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
  973. if (!is_varlen_q) {
  974. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  975. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  976. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
  977. } else {
  978. CHECK_SHAPE(q, total_q, num_heads, head_size);
  979. CHECK_SHAPE(out, total_q, num_heads, head_size);
  980. CHECK_SHAPE(dout, total_q, num_heads, head_size);
  981. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  982. }
  983. if (!is_varlen_k) {
  984. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  985. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  986. } else {
  987. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  988. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  989. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  990. }
  991. if (seqused_q_.has_value()){
  992. auto seqused_q = seqused_q_.value();
  993. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  994. CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
  995. CHECK_SHAPE(seqused_q, batch_size);
  996. }
  997. if (seqused_k_.has_value()){
  998. auto seqused_k = seqused_k_.value();
  999. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  1000. CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
  1001. CHECK_SHAPE(seqused_k, batch_size);
  1002. }
  1003. at::Tensor dq, dk, dv;
  1004. if (dq_.has_value()) {
  1005. dq = dq_.value();
  1006. TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
  1007. CHECK_DEVICE(dq);
  1008. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  1009. if (!is_varlen_q) {
  1010. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  1011. } else {
  1012. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  1013. }
  1014. } else {
  1015. dq = torch::empty_like(q);
  1016. }
  1017. if (dk_.has_value()) {
  1018. dk = dk_.value();
  1019. TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
  1020. CHECK_DEVICE(dk);
  1021. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  1022. if (!is_varlen_k) {
  1023. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  1024. } else {
  1025. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  1026. }
  1027. } else {
  1028. dk = torch::empty_like(k);
  1029. }
  1030. if (dv_.has_value()) {
  1031. dv = dv_.value();
  1032. TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
  1033. CHECK_DEVICE(dv);
  1034. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  1035. if (!is_varlen_k) {
  1036. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  1037. } else {
  1038. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  1039. }
  1040. } else {
  1041. dv = torch::empty_like(v);
  1042. }
  1043. // Otherwise the kernel will be launched from cuda:0 device
  1044. // Cast to char to avoid compiler warning about narrowing
  1045. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1046. auto opts = q.options();
  1047. // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  1048. at::Tensor softmax_d, softmax_lse_log2;
  1049. if (!is_varlen) {
  1050. // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  1051. softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  1052. softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  1053. } else {
  1054. softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1055. softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  1056. }
  1057. at::Tensor dq_accum, dk_accum, dv_accum;
  1058. if (!is_varlen) {
  1059. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1060. } else {
  1061. dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1062. }
  1063. if (num_heads_k != num_heads) { // MQA / GQA
  1064. if (!is_varlen) {
  1065. dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1066. dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
  1067. } else {
  1068. dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1069. dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  1070. }
  1071. }
  1072. Flash_bwd_params params;
  1073. set_params_dgrad(params,
  1074. batch_size,
  1075. seqlen_q, seqlen_k,
  1076. seqlen_q_rounded, seqlen_k_rounded,
  1077. num_heads, num_heads_k,
  1078. head_size, head_size_rounded,
  1079. q, k, v, out,
  1080. dout, dq, dk, dv,
  1081. !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
  1082. !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
  1083. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  1084. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  1085. dq_accum.data_ptr(),
  1086. num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
  1087. num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
  1088. softmax_lse.data_ptr(),
  1089. softmax_d.data_ptr(),
  1090. /*p_dropout=*/0.f,
  1091. softmax_scale,
  1092. window_size_left,
  1093. window_size_right,
  1094. softcap,
  1095. deterministic,
  1096. sm_margin);
  1097. params.total_q = total_q;
  1098. params.total_k = total_k;
  1099. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  1100. params.sink_token_length = sink_token_length;
  1101. // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  1102. // params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  1103. // Will be zero'ed out in the backward preprocess kernel
  1104. at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  1105. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  1106. if (num_heads_k != num_heads && params.deterministic) {
  1107. // TODO: do we need to zero them out?
  1108. at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  1109. at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  1110. params.dk_semaphore = dk_semaphore.data_ptr<int>();
  1111. params.dv_semaphore = dv_semaphore.data_ptr<int>();
  1112. }
  1113. #ifdef FLASHATTENTION_DISABLE_LOCAL
  1114. TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
  1115. #endif
  1116. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  1117. TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
  1118. #endif
  1119. if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
  1120. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1121. run_mha_bwd(params, stream);
  1122. } else if (total_k > 0 && num_heads_k > 0) {
  1123. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  1124. dk.zero_();
  1125. dv.zero_();
  1126. softmax_d.zero_();
  1127. } else if (total_q > 0 && num_heads_k > 0) {
  1128. dq.zero_();
  1129. softmax_d.zero_();
  1130. }
  1131. return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
  1132. }
  1133. std::vector<at::Tensor>
  1134. mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
  1135. const at::Tensor &lse_partial, // num_splits x batch_size x seqlen x num_heads
  1136. std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
  1137. std::optional<at::ScalarType> out_dtype_
  1138. ) {
  1139. auto dprops = at::cuda::getCurrentDeviceProperties();
  1140. bool is_sm80 = dprops->major >= 8;
  1141. TORCH_CHECK(is_sm80, "Attention combine function only supports Ampere GPUs or newer.");
  1142. auto out_partial_type = out_partial.scalar_type();
  1143. TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
  1144. TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
  1145. CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
  1146. TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1147. TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
  1148. const auto sizes = out_partial.sizes();
  1149. const int num_splits = sizes[0];
  1150. const int batch_size = sizes[1];
  1151. const int seqlen = sizes[2];
  1152. const int num_heads = sizes[3];
  1153. const int head_size_og = sizes[4];
  1154. TORCH_CHECK(head_size_og <= 256, "FlashAttention combine only supports head dimension at most 256");
  1155. TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
  1156. CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
  1157. CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
  1158. int const alignment = 4;
  1159. at::Tensor out_partial_padded;
  1160. auto pad = [](at::Tensor x, int alignment) {
  1161. return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
  1162. };
  1163. out_partial_padded = pad(out_partial, alignment);
  1164. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1165. const int head_size = round_multiple(head_size_og, alignment);
  1166. auto opts = out_partial.options();
  1167. at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
  1168. TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
  1169. at::Tensor out;
  1170. if (out_.has_value()) {
  1171. out = out_.value();
  1172. TORCH_CHECK(out.scalar_type() == out_type);
  1173. CHECK_DEVICE(out);
  1174. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1175. CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
  1176. if (head_size_og % alignment != 0) {
  1177. out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
  1178. }
  1179. } else {
  1180. out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
  1181. }
  1182. // Otherwise the kernel will be launched from cuda:0 device
  1183. // Cast to char to avoid compiler warning about narrowing
  1184. at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
  1185. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
  1186. Flash_fwd_params params {}; // Need to reset the params to set everything to zero
  1187. params.is_fp32 = out_type == at::ScalarType::Float;
  1188. params.is_bf16 = out_type == at::ScalarType::BFloat16;
  1189. params.oaccum_ptr = out_partial_padded.data_ptr();
  1190. params.softmax_lseaccum_ptr = lse_partial.data_ptr();
  1191. params.o_ptr = out.data_ptr();
  1192. params.softmax_lse_ptr = softmax_lse.data_ptr();
  1193. params.b = batch_size;
  1194. params.h = num_heads;
  1195. params.seqlen_q = seqlen;
  1196. params.d = head_size;
  1197. params.num_splits = num_splits;
  1198. params.oaccum_split_stride = out_partial_padded.stride(0);
  1199. params.oaccum_row_stride = out_partial_padded.stride(2);
  1200. params.oaccum_head_stride = out_partial_padded.stride(3);
  1201. params.oaccum_batch_stride = out_partial_padded.stride(1);
  1202. params.lseaccum_split_stride = lse_partial.stride(0);
  1203. params.lseaccum_head_stride = lse_partial.stride(3);
  1204. params.lseaccum_batch_stride = lse_partial.stride(1);
  1205. params.o_row_stride = out.stride(1);
  1206. params.o_head_stride = out.stride(2);
  1207. params.o_batch_stride = out.stride(0);
  1208. if (seqlen > 0 && batch_size > 0) {
  1209. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1210. run_mha_fwd_combine(params, stream);
  1211. }
  1212. at::Tensor out_padded = out;
  1213. if (head_size_og % alignment != 0) {
  1214. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1215. // if (out_.has_value()) { out_.value().copy_(out); }
  1216. }
  1217. return {out, softmax_lse};
  1218. }
  1219. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1220. m.doc() = "FlashAttention";
  1221. m.def("fwd", &mha_fwd, "Forward pass");
  1222. m.def("bwd", &mha_bwd, "Backward pass");
  1223. m.def("fwd_combine", &mha_combine, "Combine partial attention outputs");
  1224. }