flash_api.cpp 71 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *seqused_k,
  34. void *p_d,
  35. void *softmax_lse_d,
  36. float p_dropout,
  37. float softmax_scale,
  38. int window_size_left,
  39. int window_size_right,
  40. const float softcap,
  41. bool seqlenq_ngroups_swapped=false,
  42. const bool unpadded_lse=false) {
  43. // Reset the parameters
  44. params = {};
  45. params.is_bf16 = q.dtype() == torch::kBFloat16;
  46. // Set the pointers and strides.
  47. params.q_ptr = q.data_ptr();
  48. params.k_ptr = k.data_ptr();
  49. params.v_ptr = v.data_ptr();
  50. // All stride are in elements, not bytes.
  51. params.q_row_stride = q.stride(-3);
  52. params.k_row_stride = k.stride(-3);
  53. params.v_row_stride = v.stride(-3);
  54. params.q_head_stride = q.stride(-2);
  55. params.k_head_stride = k.stride(-2);
  56. params.v_head_stride = v.stride(-2);
  57. params.o_ptr = out.data_ptr();
  58. params.o_row_stride = out.stride(-3);
  59. params.o_head_stride = out.stride(-2);
  60. if (cu_seqlens_q_d == nullptr) {
  61. params.q_batch_stride = q.stride(0);
  62. params.k_batch_stride = k.stride(0);
  63. params.v_batch_stride = v.stride(0);
  64. params.o_batch_stride = out.stride(0);
  65. if (seqlenq_ngroups_swapped) {
  66. params.q_batch_stride *= seqlen_q;
  67. params.o_batch_stride *= seqlen_q;
  68. }
  69. }
  70. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  71. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  72. params.seqused_k = static_cast<int *>(seqused_k);
  73. // P = softmax(QK^T)
  74. params.p_ptr = p_d;
  75. // Softmax sum
  76. params.softmax_lse_ptr = softmax_lse_d;
  77. // Set the dimensions.
  78. params.b = b;
  79. params.h = h;
  80. params.h_k = h_k;
  81. params.h_h_k_ratio = h / h_k;
  82. params.seqlen_q = seqlen_q;
  83. params.seqlen_k = seqlen_k;
  84. params.seqlen_q_rounded = seqlen_q_rounded;
  85. params.seqlen_k_rounded = seqlen_k_rounded;
  86. params.d = d;
  87. params.d_rounded = d_rounded;
  88. // Set the different scale values.
  89. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  90. TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
  91. #endif
  92. if (softcap > 0.0) {
  93. params.softcap = softmax_scale / softcap;
  94. params.scale_softmax = softcap;
  95. params.scale_softmax_log2 = softcap * M_LOG2E;
  96. } else{
  97. // Remove potential NaN
  98. params.softcap = 0.0;
  99. params.scale_softmax = softmax_scale;
  100. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  101. }
  102. // Set this to probability of keeping an element to simplify things.
  103. params.p_dropout = 1.f - p_dropout;
  104. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  105. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  106. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  107. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  108. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  109. params.rp_dropout = 1.f / params.p_dropout;
  110. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  111. TORCH_CHECK(p_dropout < 1.f);
  112. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  113. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  114. #endif
  115. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  116. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  117. params.is_causal = window_size_left < 0 && window_size_right == 0;
  118. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
  119. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
  120. params.window_size_left = window_size_left;
  121. params.window_size_right = window_size_right;
  122. #ifdef FLASHATTENTION_DISABLE_LOCAL
  123. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  124. "This flash attention build does not support local attention.");
  125. #endif
  126. params.is_seqlens_k_cumulative = true;
  127. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  128. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  129. #endif
  130. params.unpadded_lse = unpadded_lse;
  131. params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
  132. }
  133. void set_params_dgrad(Flash_bwd_params &params,
  134. // sizes
  135. const size_t b,
  136. const size_t seqlen_q,
  137. const size_t seqlen_k,
  138. const size_t seqlen_q_rounded,
  139. const size_t seqlen_k_rounded,
  140. const size_t h,
  141. const size_t h_k,
  142. const size_t d,
  143. const size_t d_rounded,
  144. // device pointers
  145. const at::Tensor q,
  146. const at::Tensor k,
  147. const at::Tensor v,
  148. const at::Tensor out,
  149. const at::Tensor dout,
  150. at::Tensor dq,
  151. at::Tensor dk,
  152. at::Tensor dv,
  153. void *cu_seqlens_q_d,
  154. void *cu_seqlens_k_d,
  155. void *dq_accum_d,
  156. void *dk_accum_d,
  157. void *dv_accum_d,
  158. void *softmax_lse_d,
  159. void *dsoftmax_sum_d,
  160. float p_dropout,
  161. float softmax_scale,
  162. int window_size_left,
  163. int window_size_right,
  164. const float softcap,
  165. bool deterministic,
  166. const bool unpadded_lse) {
  167. set_params_fprop(params,
  168. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  169. q, k, v, out,
  170. cu_seqlens_q_d,
  171. cu_seqlens_k_d,
  172. nullptr,
  173. nullptr,
  174. softmax_lse_d,
  175. p_dropout,
  176. softmax_scale,
  177. window_size_left,
  178. window_size_right,
  179. softcap,
  180. false, // seqlenq_ngroups_swapped
  181. unpadded_lse);
  182. // Set the pointers and strides.
  183. params.do_ptr = dout.data_ptr();
  184. params.do_row_stride = dout.stride(-3);
  185. params.do_head_stride = dout.stride(-2);
  186. params.dq_ptr = dq.data_ptr();
  187. params.dk_ptr = dk.data_ptr();
  188. params.dv_ptr = dv.data_ptr();
  189. params.dq_row_stride = dq.stride(-3);
  190. params.dk_row_stride = dk.stride(-3);
  191. params.dv_row_stride = dv.stride(-3);
  192. params.dq_head_stride = dq.stride(-2);
  193. params.dk_head_stride = dk.stride(-2);
  194. params.dv_head_stride = dv.stride(-2);
  195. if (cu_seqlens_q_d == nullptr) {
  196. params.do_batch_stride = dout.stride(0);
  197. params.dq_batch_stride = dq.stride(0);
  198. params.dk_batch_stride = dk.stride(0);
  199. params.dv_batch_stride = dv.stride(0);
  200. }
  201. params.dq_accum_ptr = dq_accum_d;
  202. params.dk_accum_ptr = dk_accum_d;
  203. params.dv_accum_ptr = dv_accum_d;
  204. // Softmax sum
  205. params.dsoftmax_sum = dsoftmax_sum_d;
  206. params.deterministic = deterministic;
  207. }
  208. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  209. FP16_SWITCH(!params.is_bf16, [&] {
  210. HEADDIM_SWITCH(params.d, [&] {
  211. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  212. if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
  213. run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
  214. } else {
  215. run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
  216. }
  217. });
  218. });
  219. });
  220. }
  221. // Find the number of splits that maximizes the occupancy. For example, if we have
  222. // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
  223. // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
  224. // splits as that would incur more HBM reads/writes.
  225. // So we find the best efficiency, then find the smallest number of splits that gets 85%
  226. // of the best efficiency.
  227. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
  228. // If we have enough to almost fill the SMs, then just use 1 split
  229. if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
  230. max_splits = std::min({max_splits, num_SMs, num_n_blocks});
  231. float max_efficiency = 0.f;
  232. std::vector<float> efficiency;
  233. efficiency.reserve(max_splits);
  234. auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
  235. // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
  236. // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
  237. // (i.e. it's 11 splits anyway).
  238. // So we check if the number of blocks per split is the same as the previous num_splits.
  239. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
  240. return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
  241. };
  242. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  243. if (!is_split_eligible(num_splits)) {
  244. efficiency.push_back(0.f);
  245. } else {
  246. float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
  247. float eff = n_waves / ceil(n_waves);
  248. // printf("num_splits = %d, eff = %f\n", num_splits, eff);
  249. if (eff > max_efficiency) { max_efficiency = eff; }
  250. efficiency.push_back(eff);
  251. }
  252. }
  253. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  254. if (!is_split_eligible(num_splits)) { continue; }
  255. if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
  256. // printf("num_splits chosen = %d\n", num_splits);
  257. return num_splits;
  258. }
  259. }
  260. return 1;
  261. }
  262. std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
  263. const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
  264. const int head_size_rounded, const float p_dropout,
  265. const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
  266. // This needs to match with run_mha_fwd_splitkv_dispatch
  267. const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
  268. const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
  269. // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
  270. // In any case we don't expect seqlen_q to be larger than 64 for inference.
  271. const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
  272. params.num_splits = num_splits;
  273. at::Tensor softmax_lse_accum;
  274. at::Tensor out_accum;
  275. if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
  276. if (num_splits < 1) {
  277. // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
  278. params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
  279. }
  280. if (params.num_splits > 1) {
  281. softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
  282. out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
  283. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  284. params.oaccum_ptr = out_accum.data_ptr();
  285. }
  286. TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
  287. }
  288. return std::make_tuple(softmax_lse_accum, out_accum);
  289. }
  290. void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
  291. #ifdef FLASHATTENTION_DISABLE_ALIBI
  292. TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
  293. params.alibi_slopes_ptr = nullptr;
  294. #else
  295. if (alibi_slopes_.has_value()) {
  296. auto alibi_slopes = alibi_slopes_.value();
  297. TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
  298. CHECK_DEVICE(alibi_slopes);
  299. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  300. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
  301. params.alibi_slopes_ptr = alibi_slopes.data_ptr();
  302. params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  303. } else {
  304. params.alibi_slopes_ptr = nullptr;
  305. }
  306. #endif
  307. }
  308. std::vector<at::Tensor>
  309. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
  310. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
  311. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
  312. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
  313. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  314. const float p_dropout,
  315. const float softmax_scale,
  316. bool is_causal,
  317. int window_size_left,
  318. int window_size_right,
  319. const float softcap,
  320. const bool return_softmax,
  321. c10::optional<at::Generator> gen_) {
  322. auto dprops = at::cuda::getCurrentDeviceProperties();
  323. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  324. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  325. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  326. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  327. // We will support Turing in the near future
  328. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  329. auto q_dtype = q.dtype();
  330. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  331. "FlashAttention only support fp16 and bf16 data type");
  332. if (q_dtype == torch::kBFloat16) {
  333. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  334. }
  335. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  336. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  337. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  338. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  339. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  340. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  341. const auto sizes = q.sizes();
  342. const int batch_size = sizes[0];
  343. int seqlen_q = sizes[1];
  344. int num_heads = sizes[2];
  345. const int head_size = sizes[3];
  346. const int seqlen_k = k.size(1);
  347. const int num_heads_k = k.size(2);
  348. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  349. TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
  350. TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
  351. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  352. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  353. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  354. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  355. // causal=true is the same as causal=false in this case
  356. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  357. if (is_causal) { window_size_right = 0; }
  358. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  359. // H/t Daniel Haziza
  360. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
  361. const int ngroups = num_heads / num_heads_k;
  362. if (seqlenq_ngroups_swapped) {
  363. q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
  364. seqlen_q = ngroups;
  365. num_heads = num_heads_k;
  366. }
  367. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  368. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  369. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  370. at::Tensor out;
  371. if (out_.has_value()) {
  372. out = out_.value();
  373. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  374. CHECK_DEVICE(out);
  375. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  376. CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);
  377. if (seqlenq_ngroups_swapped) {
  378. out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
  379. }
  380. } else {
  381. out = torch::empty_like(q);
  382. }
  383. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  384. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  385. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  386. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  387. // Otherwise the kernel will be launched from cuda:0 device
  388. // Cast to char to avoid compiler warning about narrowing
  389. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  390. auto opts = q.options();
  391. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  392. at::Tensor p;
  393. // Only return softmax if there's dropout to reduce compilation time
  394. if (return_softmax) {
  395. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  396. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  397. }
  398. else {
  399. p = torch::empty({ 0 }, opts);
  400. }
  401. Flash_fwd_params params;
  402. set_params_fprop(params,
  403. batch_size,
  404. seqlen_q, seqlen_k,
  405. seqlen_q_rounded, seqlen_k_rounded,
  406. num_heads, num_heads_k,
  407. head_size, head_size_rounded,
  408. q, k, v, out,
  409. /*cu_seqlens_q_d=*/nullptr,
  410. /*cu_seqlens_k_d=*/nullptr,
  411. /*seqused_k=*/nullptr,
  412. return_softmax ? p.data_ptr() : nullptr,
  413. softmax_lse.data_ptr(),
  414. p_dropout,
  415. softmax_scale,
  416. window_size_left,
  417. window_size_right,
  418. softcap
  419. );
  420. // Keep references to these tensors to extend their lifetime
  421. at::Tensor softmax_lse_accum, out_accum;
  422. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  423. params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
  424. head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
  425. // number of times random will be generated per thread, to offset philox counter in thc random
  426. // state
  427. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  428. int64_t counter_offset = params.b * params.h * 32;
  429. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  430. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  431. // Forward kernel will populate memory with the seed and offset.
  432. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  433. if (p_dropout > 0.0) {
  434. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  435. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  436. // See Note [Acquire lock when using random generators]
  437. std::lock_guard<std::mutex> lock(gen->mutex_);
  438. params.philox_args = gen->philox_cuda_state(counter_offset);
  439. }
  440. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  441. if (seqlen_k > 0) {
  442. auto stream = at::cuda::getCurrentCUDAStream().stream();
  443. run_mha_fwd(params, stream);
  444. } else {
  445. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  446. out.zero_();
  447. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  448. }
  449. if (seqlenq_ngroups_swapped) {
  450. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
  451. q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
  452. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  453. }
  454. return {out, softmax_lse, p, rng_state};
  455. }
  456. std::vector<at::Tensor>
  457. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  458. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  459. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  460. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  461. const at::Tensor &cu_seqlens_q, // b+1
  462. const at::Tensor &cu_seqlens_k, // b+1
  463. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  464. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  465. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  466. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  467. int max_seqlen_q,
  468. const int max_seqlen_k,
  469. const float p_dropout,
  470. const float softmax_scale,
  471. const bool zero_tensors,
  472. bool is_causal,
  473. int window_size_left,
  474. int window_size_right,
  475. const float softcap,
  476. const bool return_softmax,
  477. c10::optional<at::Generator> gen_) {
  478. auto dprops = at::cuda::getCurrentDeviceProperties();
  479. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  480. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  481. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  482. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  483. // We will support Turing in the near future
  484. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  485. auto q_dtype = q.dtype();
  486. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  487. "FlashAttention only support fp16 and bf16 data type");
  488. if (q_dtype == torch::kBFloat16) {
  489. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  490. }
  491. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  492. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  493. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  494. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  495. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  496. CHECK_DEVICE(cu_seqlens_q);
  497. CHECK_DEVICE(cu_seqlens_k);
  498. at::Tensor block_table;
  499. const bool paged_KV = block_table_.has_value();
  500. if (paged_KV) {
  501. block_table = block_table_.value();
  502. CHECK_DEVICE(block_table);
  503. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  504. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  505. }
  506. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  507. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  508. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  509. CHECK_CONTIGUOUS(cu_seqlens_q);
  510. CHECK_CONTIGUOUS(cu_seqlens_k);
  511. const auto sizes = q.sizes();
  512. const int batch_size = cu_seqlens_q.numel() - 1;
  513. int num_heads = sizes[1];
  514. const int head_size = sizes[2];
  515. const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
  516. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  517. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  518. const int num_blocks = !paged_KV ? 0 : k.size(0);
  519. const int page_block_size = !paged_KV ? 1 : k.size(1);
  520. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  521. if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
  522. if (is_causal) { window_size_right = 0; }
  523. void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  524. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  525. // H/t Daniel Haziza
  526. const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
  527. const int ngroups = num_heads / num_heads_k;
  528. if (seqlenq_ngroups_swapped) {
  529. q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
  530. max_seqlen_q = ngroups;
  531. num_heads = num_heads_k;
  532. cu_seqlens_q_d = nullptr;
  533. }
  534. const int total_q = q.sizes()[0];
  535. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  536. TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
  537. TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
  538. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  539. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  540. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  541. CHECK_SHAPE(q, total_q, num_heads, head_size);
  542. if (!paged_KV) {
  543. const int total_k = k.size(0);
  544. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  545. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  546. } else {
  547. CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
  548. CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
  549. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  550. }
  551. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  552. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  553. if (seqused_k.has_value()){
  554. auto seqused_k_ = seqused_k.value();
  555. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  556. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  557. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  558. CHECK_SHAPE(seqused_k_, batch_size);
  559. }
  560. at::Tensor out;
  561. if (out_.has_value()) {
  562. out = out_.value();
  563. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  564. CHECK_DEVICE(out);
  565. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  566. CHECK_SHAPE(out, sizes[0], sizes[1], head_size);
  567. if (seqlenq_ngroups_swapped) {
  568. out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
  569. }
  570. } else {
  571. out = torch::empty_like(q);
  572. }
  573. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  574. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  575. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  576. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  577. // Otherwise the kernel will be launched from cuda:0 device
  578. // Cast to char to avoid compiler warning about narrowing
  579. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  580. auto opts = q.options();
  581. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  582. at::Tensor p;
  583. // Only return softmax if there's dropout to reduce compilation time
  584. if (return_softmax) {
  585. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  586. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  587. }
  588. else {
  589. p = torch::empty({ 0 }, opts);
  590. }
  591. if (zero_tensors) {
  592. out.zero_();
  593. softmax_lse.fill_(-std::numeric_limits<float>::infinity());
  594. if (return_softmax) {p.zero_();}
  595. }
  596. Flash_fwd_params params;
  597. set_params_fprop(params,
  598. batch_size,
  599. max_seqlen_q, max_seqlen_k,
  600. seqlen_q_rounded, seqlen_k_rounded,
  601. num_heads, num_heads_k,
  602. head_size, head_size_rounded,
  603. q, k, v, out,
  604. cu_seqlens_q_d,
  605. cu_seqlens_k.data_ptr(),
  606. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  607. return_softmax ? p.data_ptr() : nullptr,
  608. softmax_lse.data_ptr(),
  609. p_dropout,
  610. softmax_scale,
  611. window_size_left,
  612. window_size_right,
  613. softcap,
  614. seqlenq_ngroups_swapped,
  615. /*unpadded_lse*/true);
  616. params.total_q = total_q;
  617. if (paged_KV) {
  618. params.block_table = block_table.data_ptr<int>();
  619. params.block_table_batch_stride = block_table.stride(0);
  620. params.k_batch_stride = k.stride(0);
  621. params.v_batch_stride = v.stride(0);
  622. }
  623. params.page_block_size = page_block_size;
  624. // Keep references to these tensors to extend their lifetime
  625. at::Tensor softmax_lse_accum, out_accum;
  626. if (seqlenq_ngroups_swapped) {
  627. // Only apply split-k for decoding
  628. std::tie(softmax_lse_accum, out_accum) =
  629. set_params_splitkv(params, batch_size, num_heads, head_size,
  630. max_seqlen_k, max_seqlen_q, head_size_rounded,
  631. p_dropout, /*num_splits*/ 0, dprops, opts);
  632. }
  633. if (leftpad_k_.has_value()) {
  634. auto leftpad_k = leftpad_k_.value();
  635. TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
  636. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  637. CHECK_DEVICE(leftpad_k);
  638. CHECK_CONTIGUOUS(leftpad_k);
  639. CHECK_SHAPE(leftpad_k, batch_size);
  640. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  641. }
  642. // number of times random will be generated per thread, to offset philox counter in thc random
  643. // state
  644. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  645. int64_t counter_offset = params.b * params.h * 32;
  646. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  647. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  648. // Forward kernel will populate memory with the seed and offset.
  649. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  650. if (p_dropout > 0.0) {
  651. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  652. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  653. // See Note [Acquire lock when using random generators]
  654. std::lock_guard<std::mutex> lock(gen->mutex_);
  655. params.philox_args = gen->philox_cuda_state(counter_offset);
  656. }
  657. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  658. if (max_seqlen_k > 0) {
  659. auto stream = at::cuda::getCurrentCUDAStream().stream();
  660. run_mha_fwd(params, stream, paged_KV);
  661. } else {
  662. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  663. out.zero_();
  664. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  665. }
  666. if (seqlenq_ngroups_swapped) {
  667. int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
  668. int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
  669. out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
  670. q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
  671. softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
  672. }
  673. return {out, softmax_lse, p, rng_state};
  674. }
  675. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  676. FP16_SWITCH(!params.is_bf16, [&] {
  677. HEADDIM_SWITCH(params.d, [&] {
  678. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  679. run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
  680. });
  681. });
  682. });
  683. }
  684. std::vector<at::Tensor>
  685. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
  686. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  687. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  688. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  689. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  690. const at::Tensor &softmax_lse, // b x h x seqlen_q
  691. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  692. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  693. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  694. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  695. const float p_dropout, // probability to drop
  696. const float softmax_scale,
  697. const bool is_causal,
  698. int window_size_left,
  699. int window_size_right,
  700. const float softcap,
  701. const bool deterministic,
  702. c10::optional<at::Generator> gen_,
  703. c10::optional<at::Tensor> &rng_state) {
  704. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  705. TORCH_CHECK(false, "This flash attention build does not support backward.");
  706. #endif
  707. if (is_causal) { window_size_right = 0; }
  708. auto dprops = at::cuda::getCurrentDeviceProperties();
  709. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  710. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  711. bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
  712. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  713. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  714. // We will support Turing in the near future
  715. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  716. bool is_dropout = p_dropout > 0.0;
  717. auto stream = at::cuda::getCurrentCUDAStream().stream();
  718. auto q_dtype = q.dtype();
  719. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  720. "FlashAttention only support fp16 and bf16 data type");
  721. if (q_dtype == torch::kBFloat16) {
  722. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  723. }
  724. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  725. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  726. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  727. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  728. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  729. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  730. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  731. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  732. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  733. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  734. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  735. const auto sizes = q.sizes();
  736. const int batch_size = sizes[0];
  737. const int seqlen_q = sizes[1];
  738. const int num_heads = sizes[2];
  739. const int head_size = sizes[3];
  740. const int seqlen_k = k.size(1);
  741. const int num_heads_k = k.size(2);
  742. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  743. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  744. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  745. if (head_size > 192 && is_dropout) {
  746. TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
  747. }
  748. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  749. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  750. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  751. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  752. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  753. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  754. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  755. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  756. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  757. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  758. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  759. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  760. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
  761. at::Tensor dq, dk, dv;
  762. if (dq_.has_value()) {
  763. dq = dq_.value();
  764. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  765. CHECK_DEVICE(dq);
  766. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  767. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  768. } else {
  769. dq = torch::empty_like(q);
  770. }
  771. if (dk_.has_value()) {
  772. dk = dk_.value();
  773. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  774. CHECK_DEVICE(dk);
  775. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  776. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  777. } else {
  778. dk = torch::empty_like(k);
  779. }
  780. if (dv_.has_value()) {
  781. dv = dv_.value();
  782. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  783. CHECK_DEVICE(dv);
  784. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  785. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  786. } else {
  787. dv = torch::empty_like(v);
  788. }
  789. // bool loop = seqlen_k > blocksize_c;
  790. // TODO: change later, for now set to true for simplicity
  791. bool loop = true;
  792. // Otherwise the kernel will be launched from cuda:0 device
  793. // Cast to char to avoid compiler warning about narrowing
  794. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  795. auto opts = q.options();
  796. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  797. at::Tensor dq_accum;
  798. at::Tensor dk_accum, dv_accum;
  799. if (loop) {
  800. if (!deterministic) {
  801. dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  802. } else {
  803. const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
  804. dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  805. }
  806. // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  807. // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  808. }
  809. at::Tensor dk_expanded, dv_expanded;
  810. if (num_heads_k != num_heads) { // MQA / GQA
  811. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  812. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  813. } else {
  814. dk_expanded = dk;
  815. dv_expanded = dv;
  816. }
  817. Flash_bwd_params params;
  818. set_params_dgrad(params,
  819. batch_size,
  820. seqlen_q, seqlen_k,
  821. seqlen_q_rounded, seqlen_k_rounded,
  822. num_heads, num_heads_k,
  823. head_size, head_size_rounded,
  824. q, k, v, out,
  825. dout, dq, dk_expanded, dv_expanded,
  826. nullptr,
  827. nullptr,
  828. loop ? dq_accum.data_ptr() : nullptr,
  829. // loop ? dk_accum.data_ptr() : nullptr,
  830. // loop ? dv_accum.data_ptr() : nullptr,
  831. nullptr,
  832. nullptr,
  833. softmax_lse.data_ptr(),
  834. softmax_d.data_ptr(),
  835. p_dropout,
  836. softmax_scale,
  837. window_size_left,
  838. window_size_right,
  839. softcap,
  840. deterministic,
  841. /*unpadded_lse*/false);
  842. params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
  843. auto launch = &run_mha_bwd;
  844. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  845. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  846. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  847. int64_t counter_offset = params.b * params.h * 32;
  848. if ( rng_state.has_value() ) {
  849. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  850. } else if( is_dropout ) {
  851. // See Note [Acquire lock when using random generators]
  852. std::lock_guard<std::mutex> lock(gen->mutex_);
  853. params.philox_args = gen->philox_cuda_state(counter_offset);
  854. auto seeds = at::cuda::philox::unpack(params.philox_args);
  855. params.rng_state[0] = std::get<0>(seeds);
  856. params.rng_state[1] = std::get<1>(seeds);
  857. }
  858. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  859. if (seqlen_q > 0) {
  860. launch(params, stream);
  861. } else {
  862. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  863. dk_expanded.zero_();
  864. dv_expanded.zero_();
  865. softmax_d.zero_();
  866. }
  867. // For MQA/GQA we need to sum dK and dV across the groups
  868. if (num_heads_k != num_heads) {
  869. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  870. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  871. }
  872. return { dq, dk, dv, softmax_d };
  873. }
  874. std::vector<at::Tensor>
  875. mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
  876. const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  877. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  878. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  879. const at::Tensor &out, // total_q x num_heads x head_size
  880. const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
  881. c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  882. c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  883. c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  884. const at::Tensor &cu_seqlens_q, // b+1
  885. const at::Tensor &cu_seqlens_k, // b+1
  886. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  887. const int max_seqlen_q,
  888. const int max_seqlen_k, // max sequence length to choose the kernel
  889. const float p_dropout, // probability to drop
  890. const float softmax_scale,
  891. const bool zero_tensors,
  892. const bool is_causal,
  893. int window_size_left,
  894. int window_size_right,
  895. const float softcap,
  896. const bool deterministic,
  897. c10::optional<at::Generator> gen_,
  898. c10::optional<at::Tensor> &rng_state) {
  899. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  900. TORCH_CHECK(false, "This flash attention build does not support backward.");
  901. #endif
  902. if (is_causal) { window_size_right = 0; }
  903. auto dprops = at::cuda::getCurrentDeviceProperties();
  904. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  905. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  906. bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
  907. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  908. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  909. // We will support Turing in the near future
  910. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  911. bool is_dropout = p_dropout > 0.0;
  912. auto stream = at::cuda::getCurrentCUDAStream().stream();
  913. auto q_dtype = q.dtype();
  914. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  915. "FlashAttention only support fp16 and bf16 data type");
  916. if (q_dtype == torch::kBFloat16) {
  917. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  918. }
  919. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  920. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  921. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  922. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  923. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  924. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  925. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  926. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  927. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  928. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  929. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  930. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  931. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  932. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  933. CHECK_CONTIGUOUS(cu_seqlens_q);
  934. CHECK_CONTIGUOUS(cu_seqlens_k);
  935. const auto sizes = q.sizes();
  936. const int total_q = sizes[0];
  937. const int batch_size = cu_seqlens_q.numel() - 1;
  938. const int num_heads = sizes[1];
  939. const int head_size = sizes[2];
  940. const int total_k = k.size(0);
  941. const int num_heads_k = k.size(1);
  942. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  943. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  944. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  945. if (head_size > 192 && is_dropout) {
  946. TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
  947. }
  948. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  949. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  950. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  951. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  952. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  953. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  954. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  955. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  956. CHECK_SHAPE(q, total_q, num_heads, head_size);
  957. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  958. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  959. CHECK_SHAPE(out, total_q, num_heads, head_size);
  960. CHECK_SHAPE(dout, total_q, num_heads, head_size);
  961. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  962. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  963. at::Tensor dq, dk, dv;
  964. if (dq_.has_value()) {
  965. dq = dq_.value();
  966. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  967. CHECK_DEVICE(dq);
  968. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  969. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  970. } else {
  971. dq = torch::empty_like(q);
  972. }
  973. if (dk_.has_value()) {
  974. dk = dk_.value();
  975. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  976. CHECK_DEVICE(dk);
  977. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  978. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  979. } else {
  980. dk = torch::empty_like(k);
  981. }
  982. if (dv_.has_value()) {
  983. dv = dv_.value();
  984. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  985. CHECK_DEVICE(dv);
  986. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  987. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  988. } else {
  989. dv = torch::empty_like(v);
  990. }
  991. // bool loop = max_seqlen_k > blocksize_c;
  992. // TODO: change later, for now set to true for simplicity
  993. bool loop = true;
  994. // Otherwise the kernel will be launched from cuda:0 device
  995. // Cast to char to avoid compiler warning about narrowing
  996. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  997. auto opts = q.options();
  998. auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
  999. at::Tensor dq_accum;
  1000. if (loop) {
  1001. // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
  1002. // because that would be too large if there is a very long sequence and the rest of the sequences are short.
  1003. // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
  1004. // Note that 128 is the max block size on the seqlen_q dimension.
  1005. // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
  1006. // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
  1007. // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
  1008. // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
  1009. // Same holds for softmax_d, since LSE is stored in unpadded format.
  1010. if (!deterministic) {
  1011. dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  1012. } else {
  1013. const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
  1014. dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  1015. }
  1016. }
  1017. at::Tensor dk_expanded, dv_expanded;
  1018. if (num_heads_k != num_heads) { // MQA / GQA
  1019. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  1020. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  1021. } else {
  1022. dk_expanded = dk;
  1023. dv_expanded = dv;
  1024. }
  1025. if( zero_tensors ) {
  1026. dq.zero_();
  1027. dk_expanded.zero_();
  1028. dv_expanded.zero_();
  1029. softmax_d.zero_();
  1030. }
  1031. Flash_bwd_params params;
  1032. set_params_dgrad(params,
  1033. batch_size,
  1034. max_seqlen_q, max_seqlen_k,
  1035. seqlen_q_rounded, seqlen_k_rounded,
  1036. num_heads, num_heads_k,
  1037. head_size, head_size_rounded,
  1038. q, k, v, out,
  1039. dout, dq, dk_expanded, dv_expanded,
  1040. cu_seqlens_q.data_ptr(),
  1041. cu_seqlens_k.data_ptr(),
  1042. loop ? dq_accum.data_ptr() : nullptr,
  1043. nullptr,
  1044. nullptr,
  1045. softmax_lse.data_ptr(),
  1046. softmax_d.data_ptr(),
  1047. p_dropout,
  1048. softmax_scale,
  1049. window_size_left,
  1050. window_size_right,
  1051. softcap,
  1052. deterministic,
  1053. /*unpadded_lse*/true);
  1054. params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
  1055. params.total_q = total_q;
  1056. auto launch = &run_mha_bwd;
  1057. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  1058. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  1059. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  1060. int64_t counter_offset = params.b * params.h * 32;
  1061. if ( rng_state.has_value() ) {
  1062. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  1063. } else if( is_dropout ) {
  1064. // See Note [Acquire lock when using random generators]
  1065. std::lock_guard<std::mutex> lock(gen->mutex_);
  1066. params.philox_args = gen->philox_cuda_state(counter_offset);
  1067. auto seeds = at::cuda::philox::unpack(params.philox_args);
  1068. params.rng_state[0] = std::get<0>(seeds);
  1069. params.rng_state[1] = std::get<1>(seeds);
  1070. }
  1071. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1072. if (max_seqlen_q > 0) {
  1073. launch(params, stream);
  1074. } else {
  1075. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  1076. dk_expanded.zero_();
  1077. dv_expanded.zero_();
  1078. softmax_d.zero_();
  1079. }
  1080. // For MQA/GQA we need to sum dK and dV across the groups
  1081. if (num_heads_k != num_heads) {
  1082. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1083. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1084. }
  1085. return { dq, dk, dv, softmax_d };
  1086. }
  1087. std::vector<at::Tensor>
  1088. mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  1089. const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1090. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1091. c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
  1092. c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
  1093. c10::optional<const at::Tensor> &seqlens_k_, // batch_size
  1094. c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  1095. c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  1096. c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
  1097. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  1098. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  1099. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  1100. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  1101. const float softmax_scale,
  1102. bool is_causal,
  1103. int window_size_left,
  1104. int window_size_right,
  1105. const float softcap,
  1106. bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  1107. int num_splits
  1108. ) {
  1109. auto dprops = at::cuda::getCurrentDeviceProperties();
  1110. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  1111. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  1112. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  1113. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  1114. // We will support Turing in the near future
  1115. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  1116. auto q_dtype = q.dtype();
  1117. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  1118. "FlashAttention only support fp16 and bf16 data type");
  1119. if (q_dtype == torch::kBFloat16) {
  1120. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  1121. }
  1122. TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
  1123. TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
  1124. CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
  1125. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1126. TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1127. TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1128. at::Tensor block_table;
  1129. const bool paged_KV = block_table_.has_value();
  1130. if (paged_KV) {
  1131. TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
  1132. block_table = block_table_.value();
  1133. CHECK_DEVICE(block_table);
  1134. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  1135. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  1136. }
  1137. const auto sizes = q.sizes();
  1138. const int batch_size = sizes[0];
  1139. int seqlen_q = sizes[1];
  1140. int num_heads = sizes[2];
  1141. const int head_size_og = sizes[3];
  1142. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  1143. const int num_blocks = !paged_KV ? 0 : kcache.size(0);
  1144. const int page_block_size = !paged_KV ? 1 : kcache.size(1);
  1145. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  1146. const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
  1147. const int num_heads_k = kcache.size(2);
  1148. const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
  1149. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  1150. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  1151. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  1152. // causal=true is the same as causal=false in this case
  1153. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  1154. if (is_causal) { window_size_right = 0; }
  1155. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  1156. // H/t Daniel Haziza
  1157. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
  1158. if (seqlenq_ngroups_swapped) {
  1159. const int ngroups = num_heads / num_heads_k;
  1160. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  1161. seqlen_q = ngroups;
  1162. num_heads = num_heads_k;
  1163. }
  1164. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  1165. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  1166. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  1167. if (!paged_KV) {
  1168. CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1169. CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1170. } else {
  1171. CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1172. CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1173. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  1174. }
  1175. at::Tensor q_padded, kcache_padded, vcache_padded;
  1176. if (head_size_og % 8 != 0) {
  1177. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1178. kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1179. vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1180. } else {
  1181. q_padded = q;
  1182. kcache_padded = kcache;
  1183. vcache_padded = vcache;
  1184. }
  1185. at::Tensor out;
  1186. if (out_.has_value()) {
  1187. out = out_.value();
  1188. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  1189. CHECK_DEVICE(out);
  1190. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1191. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  1192. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  1193. } else {
  1194. out = torch::empty_like(q_padded);
  1195. }
  1196. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1197. const int head_size = round_multiple(head_size_og, 8);
  1198. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  1199. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  1200. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  1201. // Otherwise the kernel will be launched from cuda:0 device
  1202. // Cast to char to avoid compiler warning about narrowing
  1203. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1204. auto opts = q.options();
  1205. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  1206. Flash_fwd_params params;
  1207. set_params_fprop(params,
  1208. batch_size,
  1209. seqlen_q, seqlen_k,
  1210. seqlen_q_rounded, seqlen_k_rounded,
  1211. num_heads, num_heads_k,
  1212. head_size, head_size_rounded,
  1213. q_padded, kcache_padded, vcache_padded, out,
  1214. /*cu_seqlens_q_d=*/nullptr,
  1215. /*cu_seqlens_k_d=*/nullptr,
  1216. /*seqused_k=*/nullptr,
  1217. /*p_ptr=*/nullptr,
  1218. softmax_lse.data_ptr(),
  1219. /*p_dropout=*/0.f,
  1220. softmax_scale,
  1221. window_size_left,
  1222. window_size_right,
  1223. softcap
  1224. );
  1225. at::Tensor k, v, k_padded, v_padded;
  1226. if (k_.has_value()) {
  1227. TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
  1228. TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
  1229. TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
  1230. k = k_.value();
  1231. v = v_.value();
  1232. TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
  1233. TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
  1234. CHECK_DEVICE(k); CHECK_DEVICE(v);
  1235. TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
  1236. TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
  1237. int seqlen_knew = k.size(1);
  1238. CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1239. CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1240. if (head_size_og % 8 != 0) {
  1241. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1242. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1243. } else {
  1244. k_padded = k;
  1245. v_padded = v;
  1246. }
  1247. params.seqlen_knew = seqlen_knew;
  1248. params.knew_ptr = k_padded.data_ptr();
  1249. params.vnew_ptr = v_padded.data_ptr();
  1250. // All stride are in elements, not bytes.
  1251. params.knew_batch_stride = k_padded.stride(0);
  1252. params.vnew_batch_stride = v_padded.stride(0);
  1253. params.knew_row_stride = k_padded.stride(-3);
  1254. params.vnew_row_stride = v_padded.stride(-3);
  1255. params.knew_head_stride = k_padded.stride(-2);
  1256. params.vnew_head_stride = v_padded.stride(-2);
  1257. }
  1258. if (seqlens_k_.has_value()) {
  1259. auto seqlens_k = seqlens_k_.value();
  1260. TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
  1261. CHECK_DEVICE(seqlens_k);
  1262. CHECK_CONTIGUOUS(seqlens_k);
  1263. CHECK_SHAPE(seqlens_k, batch_size);
  1264. params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
  1265. }
  1266. params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
  1267. if (leftpad_k_.has_value()) {
  1268. TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
  1269. auto leftpad_k = leftpad_k_.value();
  1270. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  1271. CHECK_DEVICE(leftpad_k);
  1272. CHECK_CONTIGUOUS(leftpad_k);
  1273. CHECK_SHAPE(leftpad_k, batch_size);
  1274. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  1275. }
  1276. if (rotary_cos_.has_value()) {
  1277. TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  1278. auto rotary_cos = rotary_cos_.value();
  1279. CHECK_DEVICE(rotary_cos);
  1280. params.rotary_dim = rotary_cos.size(1) * 2;
  1281. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  1282. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  1283. const int seqlen_ro = rotary_cos.size(0);
  1284. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  1285. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  1286. CHECK_CONTIGUOUS(rotary_cos);
  1287. TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1288. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  1289. auto rotary_sin = rotary_sin_.value();
  1290. CHECK_DEVICE(rotary_sin);
  1291. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  1292. CHECK_CONTIGUOUS(rotary_sin);
  1293. TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1294. params.rotary_cos_ptr = rotary_cos.data_ptr();
  1295. params.rotary_sin_ptr = rotary_sin.data_ptr();
  1296. params.is_rotary_interleaved = is_rotary_interleaved;
  1297. } else {
  1298. params.rotary_dim = 0;
  1299. }
  1300. if (cache_batch_idx_.has_value()) {
  1301. auto cache_batch_idx = cache_batch_idx_.value();
  1302. CHECK_DEVICE(cache_batch_idx);
  1303. CHECK_CONTIGUOUS(cache_batch_idx);
  1304. TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
  1305. params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
  1306. }
  1307. // Keep references to these tensors to extend their lifetime
  1308. at::Tensor softmax_lse_accum, out_accum;
  1309. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  1310. params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
  1311. head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts);
  1312. if (paged_KV) {
  1313. params.block_table = block_table.data_ptr<int>();
  1314. params.block_table_batch_stride = block_table.stride(0);
  1315. }
  1316. params.page_block_size = page_block_size;
  1317. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1318. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1319. // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
  1320. // or paged KV cache
  1321. run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
  1322. if (head_size_og % 8 != 0) {
  1323. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1324. if (out_.has_value()) { out_.value().copy_(out); }
  1325. if (k_.has_value()) {
  1326. // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
  1327. // but we don't expect to get this case in practice. This is just so that the code works for that case.
  1328. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1329. vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1330. }
  1331. }
  1332. if (seqlenq_ngroups_swapped) {
  1333. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  1334. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  1335. }
  1336. return {out, softmax_lse};
  1337. }
  1338. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1339. m.doc() = "FlashAttention";
  1340. m.def("fwd", &mha_fwd, "Forward pass");
  1341. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  1342. m.def("bwd", &mha_bwd, "Backward pass");
  1343. m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
  1344. m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
  1345. }