flash_api.cpp 69 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <c10/cuda/CUDAGuard.h>
  8. #include <c10/cuda/CUDAStream.h>
  9. #include <ATen/cuda/CUDAGeneratorImpl.h> // For at::Generator and at::PhiloxCudaState
  10. #include "philox_unpack.cuh" // For at::cuda::philox::unpack
  11. #include <cutlass/numeric_types.h>
  12. #include "namespace_config.h"
  13. #include "hardware_info.h"
  14. #include "flash.h"
  15. #include "static_switch.h"
  16. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  17. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  18. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  19. namespace FLASH_NAMESPACE {
  20. void set_params_fprop(Flash_fwd_params &params,
  21. // sizes
  22. const size_t b,
  23. const size_t seqlen_q,
  24. const size_t seqlen_k,
  25. const size_t seqlen_q_rounded,
  26. const size_t seqlen_k_rounded,
  27. const size_t h,
  28. const size_t h_k,
  29. const size_t d,
  30. const size_t d_rounded,
  31. // device pointers
  32. const at::Tensor q,
  33. const at::Tensor k,
  34. const at::Tensor v,
  35. at::Tensor out,
  36. void *cu_seqlens_q_d,
  37. void *cu_seqlens_k_d,
  38. void *seqused_k,
  39. void *p_d,
  40. void *softmax_lse_d,
  41. float p_dropout,
  42. float softmax_scale,
  43. int window_size_left,
  44. int window_size_right,
  45. const float softcap,
  46. bool seqlenq_ngroups_swapped=false,
  47. const bool unpadded_lse=false) {
  48. // Reset the parameters
  49. params = {};
  50. params.is_bf16 = q.dtype() == torch::kBFloat16;
  51. // Set the pointers and strides.
  52. params.q_ptr = q.data_ptr();
  53. params.k_ptr = k.data_ptr();
  54. params.v_ptr = v.data_ptr();
  55. // All stride are in elements, not bytes.
  56. params.q_row_stride = q.stride(-3);
  57. params.k_row_stride = k.stride(-3);
  58. params.v_row_stride = v.stride(-3);
  59. params.q_head_stride = q.stride(-2);
  60. params.k_head_stride = k.stride(-2);
  61. params.v_head_stride = v.stride(-2);
  62. params.o_ptr = out.data_ptr();
  63. params.o_row_stride = out.stride(-3);
  64. params.o_head_stride = out.stride(-2);
  65. if (cu_seqlens_q_d == nullptr) {
  66. params.q_batch_stride = q.stride(0);
  67. params.k_batch_stride = k.stride(0);
  68. params.v_batch_stride = v.stride(0);
  69. params.o_batch_stride = out.stride(0);
  70. if (seqlenq_ngroups_swapped) {
  71. params.q_batch_stride *= seqlen_q;
  72. params.o_batch_stride *= seqlen_q;
  73. }
  74. }
  75. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  76. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  77. params.seqused_k = static_cast<int *>(seqused_k);
  78. // P = softmax(QK^T)
  79. params.p_ptr = p_d;
  80. // Softmax sum
  81. params.softmax_lse_ptr = softmax_lse_d;
  82. // Set the dimensions.
  83. params.b = b;
  84. params.h = h;
  85. params.h_k = h_k;
  86. params.h_h_k_ratio = h / h_k;
  87. params.seqlen_q = seqlen_q;
  88. params.seqlen_k = seqlen_k;
  89. params.seqlen_q_rounded = seqlen_q_rounded;
  90. params.seqlen_k_rounded = seqlen_k_rounded;
  91. params.d = d;
  92. params.d_rounded = d_rounded;
  93. // Set the different scale values.
  94. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  95. TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
  96. #endif
  97. if (softcap > 0.0) {
  98. params.softcap = softmax_scale / softcap;
  99. params.scale_softmax = softcap;
  100. params.scale_softmax_log2 = softcap * M_LOG2E;
  101. } else{
  102. // Remove potential NaN
  103. params.softcap = 0.0;
  104. params.scale_softmax = softmax_scale;
  105. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  106. }
  107. // Set this to probability of keeping an element to simplify things.
  108. params.p_dropout = 1.f - p_dropout;
  109. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  110. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  111. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  112. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  113. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  114. params.rp_dropout = 1.f / params.p_dropout;
  115. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  116. TORCH_CHECK(p_dropout < 1.f);
  117. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  118. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  119. #endif
  120. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  121. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  122. params.is_causal = window_size_left < 0 && window_size_right == 0;
  123. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
  124. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
  125. params.window_size_left = window_size_left;
  126. params.window_size_right = window_size_right;
  127. #ifdef FLASHATTENTION_DISABLE_LOCAL
  128. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  129. "This flash attention build does not support local attention.");
  130. #endif
  131. params.is_seqlens_k_cumulative = true;
  132. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  133. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  134. #endif
  135. params.unpadded_lse = unpadded_lse;
  136. params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
  137. }
  138. void set_params_dgrad(Flash_bwd_params &params,
  139. // sizes
  140. const size_t b,
  141. const size_t seqlen_q,
  142. const size_t seqlen_k,
  143. const size_t seqlen_q_rounded,
  144. const size_t seqlen_k_rounded,
  145. const size_t h,
  146. const size_t h_k,
  147. const size_t d,
  148. const size_t d_rounded,
  149. // device pointers
  150. const at::Tensor q,
  151. const at::Tensor k,
  152. const at::Tensor v,
  153. const at::Tensor out,
  154. const at::Tensor dout,
  155. at::Tensor dq,
  156. at::Tensor dk,
  157. at::Tensor dv,
  158. void *cu_seqlens_q_d,
  159. void *cu_seqlens_k_d,
  160. void *dq_accum_d,
  161. void *dk_accum_d,
  162. void *dv_accum_d,
  163. void *softmax_lse_d,
  164. void *dsoftmax_sum_d,
  165. float p_dropout,
  166. float softmax_scale,
  167. int window_size_left,
  168. int window_size_right,
  169. const float softcap,
  170. bool deterministic,
  171. const bool unpadded_lse) {
  172. set_params_fprop(params,
  173. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  174. q, k, v, out,
  175. cu_seqlens_q_d,
  176. cu_seqlens_k_d,
  177. nullptr,
  178. nullptr,
  179. softmax_lse_d,
  180. p_dropout,
  181. softmax_scale,
  182. window_size_left,
  183. window_size_right,
  184. softcap,
  185. false, // seqlenq_ngroups_swapped
  186. unpadded_lse);
  187. // Set the pointers and strides.
  188. params.do_ptr = dout.data_ptr();
  189. params.do_row_stride = dout.stride(-3);
  190. params.do_head_stride = dout.stride(-2);
  191. params.dq_ptr = dq.data_ptr();
  192. params.dk_ptr = dk.data_ptr();
  193. params.dv_ptr = dv.data_ptr();
  194. params.dq_row_stride = dq.stride(-3);
  195. params.dk_row_stride = dk.stride(-3);
  196. params.dv_row_stride = dv.stride(-3);
  197. params.dq_head_stride = dq.stride(-2);
  198. params.dk_head_stride = dk.stride(-2);
  199. params.dv_head_stride = dv.stride(-2);
  200. if (cu_seqlens_q_d == nullptr) {
  201. params.do_batch_stride = dout.stride(0);
  202. params.dq_batch_stride = dq.stride(0);
  203. params.dk_batch_stride = dk.stride(0);
  204. params.dv_batch_stride = dv.stride(0);
  205. }
  206. params.dq_accum_ptr = dq_accum_d;
  207. params.dk_accum_ptr = dk_accum_d;
  208. params.dv_accum_ptr = dv_accum_d;
  209. // Softmax sum
  210. params.dsoftmax_sum = dsoftmax_sum_d;
  211. params.deterministic = deterministic;
  212. }
  213. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  214. FP16_SWITCH(!params.is_bf16, [&] {
  215. HEADDIM_SWITCH(params.d, [&] {
  216. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  217. if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
  218. run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
  219. } else {
  220. run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
  221. }
  222. });
  223. });
  224. });
  225. }
  226. // Find the number of splits that maximizes the occupancy. For example, if we have
  227. // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
  228. // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
  229. // splits as that would incur more HBM reads/writes.
  230. // So we find the best efficiency, then find the smallest number of splits that gets 85%
  231. // of the best efficiency.
  232. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
  233. // If we have enough to almost fill the SMs, then just use 1 split
  234. if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
  235. max_splits = std::min({max_splits, num_SMs, num_n_blocks});
  236. float max_efficiency = 0.f;
  237. std::vector<float> efficiency;
  238. efficiency.reserve(max_splits);
  239. auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
  240. // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
  241. // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
  242. // (i.e. it's 11 splits anyway).
  243. // So we check if the number of blocks per split is the same as the previous num_splits.
  244. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
  245. return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
  246. };
  247. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  248. if (!is_split_eligible(num_splits)) {
  249. efficiency.push_back(0.f);
  250. } else {
  251. float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
  252. float eff = n_waves / ceil(n_waves);
  253. // printf("num_splits = %d, eff = %f\n", num_splits, eff);
  254. if (eff > max_efficiency) { max_efficiency = eff; }
  255. efficiency.push_back(eff);
  256. }
  257. }
  258. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  259. if (!is_split_eligible(num_splits)) { continue; }
  260. if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
  261. // printf("num_splits chosen = %d\n", num_splits);
  262. return num_splits;
  263. }
  264. }
  265. return 1;
  266. }
  267. std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
  268. const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
  269. const int head_size_rounded, const float p_dropout,
  270. const int num_splits, const int num_sm, struct c10::TensorOptions opts) {
  271. // This needs to match with run_mha_fwd_splitkv_dispatch
  272. const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
  273. const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
  274. // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
  275. // In any case we don't expect seqlen_q to be larger than 64 for inference.
  276. const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
  277. params.num_splits = num_splits;
  278. at::Tensor softmax_lse_accum;
  279. at::Tensor out_accum;
  280. if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
  281. if (num_splits < 1) {
  282. // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
  283. params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
  284. }
  285. if (params.num_splits > 1) {
  286. softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
  287. out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
  288. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  289. params.oaccum_ptr = out_accum.data_ptr();
  290. }
  291. TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
  292. }
  293. return std::make_tuple(softmax_lse_accum, out_accum);
  294. }
  295. void set_params_alibi(Flash_fwd_params &params, std::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
  296. #ifdef FLASHATTENTION_DISABLE_ALIBI
  297. TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
  298. params.alibi_slopes_ptr = nullptr;
  299. #else
  300. if (alibi_slopes_.has_value()) {
  301. auto alibi_slopes = alibi_slopes_.value();
  302. TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
  303. CHECK_DEVICE(alibi_slopes);
  304. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  305. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
  306. params.alibi_slopes_ptr = alibi_slopes.data_ptr();
  307. params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  308. } else {
  309. params.alibi_slopes_ptr = nullptr;
  310. }
  311. #endif
  312. }
  313. std::vector<at::Tensor>
  314. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
  315. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
  316. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
  317. std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
  318. std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  319. const float p_dropout,
  320. const float softmax_scale,
  321. bool is_causal,
  322. int window_size_left,
  323. int window_size_right,
  324. const float softcap,
  325. const bool return_softmax,
  326. std::optional<at::Generator> gen_) {
  327. // Otherwise the kernel will be launched from cuda:0 device
  328. at::cuda::CUDAGuard device_guard{q.device()};
  329. auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
  330. bool is_sm8x_min = cc_major >= 8;
  331. TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
  332. auto q_dtype = q.dtype();
  333. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  334. "FlashAttention only support fp16 and bf16 data type");
  335. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  336. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  337. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  338. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  339. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  340. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  341. const auto sizes = q.sizes();
  342. const int batch_size = sizes[0];
  343. int seqlen_q = sizes[1];
  344. int num_heads = sizes[2];
  345. const int head_size = sizes[3];
  346. const int seqlen_k = k.size(1);
  347. const int num_heads_k = k.size(2);
  348. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  349. TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
  350. TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
  351. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  352. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  353. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  354. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  355. // causal=true is the same as causal=false in this case
  356. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  357. if (is_causal) { window_size_right = 0; }
  358. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  359. // H/t Daniel Haziza
  360. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
  361. const int ngroups = num_heads / num_heads_k;
  362. if (seqlenq_ngroups_swapped) {
  363. q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
  364. seqlen_q = ngroups;
  365. num_heads = num_heads_k;
  366. }
  367. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  368. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  369. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  370. at::Tensor out;
  371. if (out_.has_value()) {
  372. out = out_.value();
  373. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  374. CHECK_DEVICE(out);
  375. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  376. CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);
  377. if (seqlenq_ngroups_swapped) {
  378. out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
  379. }
  380. } else {
  381. out = torch::empty_like(q);
  382. }
  383. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  384. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  385. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  386. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  387. auto opts = q.options();
  388. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  389. at::Tensor p;
  390. // Only return softmax if there's dropout to reduce compilation time
  391. if (return_softmax) {
  392. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  393. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  394. }
  395. else {
  396. p = torch::empty({ 0 }, opts);
  397. }
  398. Flash_fwd_params params;
  399. set_params_fprop(params,
  400. batch_size,
  401. seqlen_q, seqlen_k,
  402. seqlen_q_rounded, seqlen_k_rounded,
  403. num_heads, num_heads_k,
  404. head_size, head_size_rounded,
  405. q, k, v, out,
  406. /*cu_seqlens_q_d=*/nullptr,
  407. /*cu_seqlens_k_d=*/nullptr,
  408. /*seqused_k=*/nullptr,
  409. return_softmax ? p.data_ptr() : nullptr,
  410. softmax_lse.data_ptr(),
  411. p_dropout,
  412. softmax_scale,
  413. window_size_left,
  414. window_size_right,
  415. softcap
  416. );
  417. // Keep references to these tensors to extend their lifetime
  418. at::Tensor softmax_lse_accum, out_accum;
  419. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  420. params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
  421. head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
  422. // number of times random will be generated per thread, to offset philox counter in thc random
  423. // state
  424. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  425. int64_t counter_offset = params.b * params.h * 32;
  426. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  427. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  428. // Forward kernel will populate memory with the seed and offset.
  429. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  430. if (p_dropout > 0.0) {
  431. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  432. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  433. // See Note [Acquire lock when using random generators]
  434. std::lock_guard<std::mutex> lock(gen->mutex_);
  435. params.philox_args = gen->philox_cuda_state(counter_offset);
  436. }
  437. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  438. if (seqlen_k > 0) {
  439. auto stream = at::cuda::getCurrentCUDAStream().stream();
  440. run_mha_fwd(params, stream);
  441. } else {
  442. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  443. out.zero_();
  444. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  445. }
  446. if (seqlenq_ngroups_swapped) {
  447. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
  448. q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
  449. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  450. }
  451. return {out, softmax_lse, p, rng_state};
  452. }
  453. std::vector<at::Tensor>
  454. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  455. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  456. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  457. std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  458. const at::Tensor &cu_seqlens_q, // b+1
  459. const at::Tensor &cu_seqlens_k, // b+1
  460. std::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  461. std::optional<const at::Tensor> &leftpad_k_, // batch_size
  462. std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  463. std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  464. int max_seqlen_q,
  465. const int max_seqlen_k,
  466. const float p_dropout,
  467. const float softmax_scale,
  468. const bool zero_tensors,
  469. bool is_causal,
  470. int window_size_left,
  471. int window_size_right,
  472. const float softcap,
  473. const bool return_softmax,
  474. std::optional<at::Generator> gen_) {
  475. // Otherwise the kernel will be launched from cuda:0 device
  476. at::cuda::CUDAGuard device_guard{q.device()};
  477. auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
  478. bool is_sm8x_min = cc_major >= 8;
  479. TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
  480. auto q_dtype = q.dtype();
  481. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  482. "FlashAttention only support fp16 and bf16 data type");
  483. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  484. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  485. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  486. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  487. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  488. CHECK_DEVICE(cu_seqlens_q);
  489. CHECK_DEVICE(cu_seqlens_k);
  490. at::Tensor block_table;
  491. const bool paged_KV = block_table_.has_value();
  492. if (paged_KV) {
  493. block_table = block_table_.value();
  494. CHECK_DEVICE(block_table);
  495. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  496. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  497. }
  498. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  499. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  500. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  501. CHECK_CONTIGUOUS(cu_seqlens_q);
  502. CHECK_CONTIGUOUS(cu_seqlens_k);
  503. const auto sizes = q.sizes();
  504. const int batch_size = cu_seqlens_q.numel() - 1;
  505. int num_heads = sizes[1];
  506. const int head_size = sizes[2];
  507. const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
  508. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  509. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  510. const int num_blocks = !paged_KV ? 0 : k.size(0);
  511. const int page_block_size = !paged_KV ? 1 : k.size(1);
  512. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  513. if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
  514. if (is_causal) { window_size_right = 0; }
  515. void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  516. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  517. // H/t Daniel Haziza
  518. const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
  519. const int ngroups = num_heads / num_heads_k;
  520. if (seqlenq_ngroups_swapped) {
  521. q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
  522. max_seqlen_q = ngroups;
  523. num_heads = num_heads_k;
  524. cu_seqlens_q_d = nullptr;
  525. }
  526. const int total_q = q.sizes()[0];
  527. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  528. TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
  529. TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
  530. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  531. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  532. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  533. CHECK_SHAPE(q, total_q, num_heads, head_size);
  534. if (!paged_KV) {
  535. const int total_k = k.size(0);
  536. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  537. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  538. } else {
  539. CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
  540. CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
  541. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  542. }
  543. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  544. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  545. if (seqused_k.has_value()){
  546. auto seqused_k_ = seqused_k.value();
  547. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  548. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  549. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  550. CHECK_SHAPE(seqused_k_, batch_size);
  551. }
  552. at::Tensor out;
  553. if (out_.has_value()) {
  554. out = out_.value();
  555. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  556. CHECK_DEVICE(out);
  557. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  558. CHECK_SHAPE(out, sizes[0], sizes[1], head_size);
  559. if (seqlenq_ngroups_swapped) {
  560. out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
  561. }
  562. } else {
  563. out = torch::empty_like(q);
  564. }
  565. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  566. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  567. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  568. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  569. auto opts = q.options();
  570. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  571. at::Tensor p;
  572. // Only return softmax if there's dropout to reduce compilation time
  573. if (return_softmax) {
  574. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  575. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  576. }
  577. else {
  578. p = torch::empty({ 0 }, opts);
  579. }
  580. if (zero_tensors) {
  581. out.zero_();
  582. softmax_lse.fill_(-std::numeric_limits<float>::infinity());
  583. if (return_softmax) {p.zero_();}
  584. }
  585. Flash_fwd_params params;
  586. set_params_fprop(params,
  587. batch_size,
  588. max_seqlen_q, max_seqlen_k,
  589. seqlen_q_rounded, seqlen_k_rounded,
  590. num_heads, num_heads_k,
  591. head_size, head_size_rounded,
  592. q, k, v, out,
  593. cu_seqlens_q_d,
  594. cu_seqlens_k.data_ptr(),
  595. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  596. return_softmax ? p.data_ptr() : nullptr,
  597. softmax_lse.data_ptr(),
  598. p_dropout,
  599. softmax_scale,
  600. window_size_left,
  601. window_size_right,
  602. softcap,
  603. seqlenq_ngroups_swapped,
  604. /*unpadded_lse*/true);
  605. params.total_q = total_q;
  606. if (paged_KV) {
  607. params.block_table = block_table.data_ptr<int>();
  608. params.block_table_batch_stride = block_table.stride(0);
  609. params.k_batch_stride = k.stride(0);
  610. params.v_batch_stride = v.stride(0);
  611. }
  612. params.page_block_size = page_block_size;
  613. // Keep references to these tensors to extend their lifetime
  614. at::Tensor softmax_lse_accum, out_accum;
  615. if (seqlenq_ngroups_swapped) {
  616. // Only apply split-k for decoding
  617. std::tie(softmax_lse_accum, out_accum) =
  618. set_params_splitkv(params, batch_size, num_heads, head_size,
  619. max_seqlen_k, max_seqlen_q, head_size_rounded,
  620. p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts);
  621. }
  622. if (leftpad_k_.has_value()) {
  623. auto leftpad_k = leftpad_k_.value();
  624. TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
  625. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  626. CHECK_DEVICE(leftpad_k);
  627. CHECK_CONTIGUOUS(leftpad_k);
  628. CHECK_SHAPE(leftpad_k, batch_size);
  629. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  630. }
  631. // number of times random will be generated per thread, to offset philox counter in thc random
  632. // state
  633. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  634. int64_t counter_offset = params.b * params.h * 32;
  635. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  636. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  637. // Forward kernel will populate memory with the seed and offset.
  638. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  639. if (p_dropout > 0.0) {
  640. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  641. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  642. // See Note [Acquire lock when using random generators]
  643. std::lock_guard<std::mutex> lock(gen->mutex_);
  644. params.philox_args = gen->philox_cuda_state(counter_offset);
  645. }
  646. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  647. if (max_seqlen_k > 0) {
  648. auto stream = at::cuda::getCurrentCUDAStream().stream();
  649. run_mha_fwd(params, stream, paged_KV);
  650. } else {
  651. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  652. out.zero_();
  653. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  654. }
  655. if (seqlenq_ngroups_swapped) {
  656. int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size};
  657. int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size};
  658. out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
  659. q = q.reshape(size_before).transpose(1, 2).reshape(size_after);
  660. softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
  661. }
  662. return {out, softmax_lse, p, rng_state};
  663. }
  664. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  665. FP16_SWITCH(!params.is_bf16, [&] {
  666. HEADDIM_SWITCH(params.d, [&] {
  667. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  668. run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
  669. });
  670. });
  671. });
  672. }
  673. std::vector<at::Tensor>
  674. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
  675. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  676. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  677. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  678. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  679. const at::Tensor &softmax_lse, // b x h x seqlen_q
  680. std::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  681. std::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  682. std::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  683. std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  684. const float p_dropout, // probability to drop
  685. const float softmax_scale,
  686. const bool is_causal,
  687. int window_size_left,
  688. int window_size_right,
  689. const float softcap,
  690. const bool deterministic,
  691. std::optional<at::Generator> gen_,
  692. std::optional<at::Tensor> &rng_state) {
  693. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  694. TORCH_CHECK(false, "This flash attention build does not support backward.");
  695. #endif
  696. if (is_causal) { window_size_right = 0; }
  697. // Otherwise the kernel will be launched from cuda:0 device
  698. at::cuda::CUDAGuard device_guard{q.device()};
  699. auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
  700. bool is_sm8x_min = cc_major >= 8;
  701. TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
  702. bool is_dropout = p_dropout > 0.0;
  703. auto stream = at::cuda::getCurrentCUDAStream().stream();
  704. auto q_dtype = q.dtype();
  705. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  706. "FlashAttention only support fp16 and bf16 data type");
  707. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  708. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  709. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  710. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  711. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  712. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  713. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  714. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  715. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  716. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  717. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  718. const auto sizes = q.sizes();
  719. const int batch_size = sizes[0];
  720. const int seqlen_q = sizes[1];
  721. const int num_heads = sizes[2];
  722. const int head_size = sizes[3];
  723. const int seqlen_k = k.size(1);
  724. const int num_heads_k = k.size(2);
  725. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  726. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  727. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  728. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  729. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  730. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  731. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  732. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  733. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  734. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  735. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  736. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  737. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  738. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  739. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  740. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
  741. at::Tensor dq, dk, dv;
  742. if (dq_.has_value()) {
  743. dq = dq_.value();
  744. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  745. CHECK_DEVICE(dq);
  746. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  747. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  748. } else {
  749. dq = torch::empty_like(q);
  750. }
  751. if (dk_.has_value()) {
  752. dk = dk_.value();
  753. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  754. CHECK_DEVICE(dk);
  755. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  756. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  757. } else {
  758. dk = torch::empty_like(k);
  759. }
  760. if (dv_.has_value()) {
  761. dv = dv_.value();
  762. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  763. CHECK_DEVICE(dv);
  764. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  765. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  766. } else {
  767. dv = torch::empty_like(v);
  768. }
  769. // bool loop = seqlen_k > blocksize_c;
  770. // TODO: change later, for now set to true for simplicity
  771. bool loop = true;
  772. auto opts = q.options();
  773. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  774. at::Tensor dq_accum;
  775. at::Tensor dk_accum, dv_accum;
  776. if (loop) {
  777. if (!deterministic) {
  778. dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  779. } else {
  780. const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
  781. dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  782. }
  783. // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  784. // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  785. }
  786. at::Tensor dk_expanded, dv_expanded;
  787. if (num_heads_k != num_heads) { // MQA / GQA
  788. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  789. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  790. } else {
  791. dk_expanded = dk;
  792. dv_expanded = dv;
  793. }
  794. Flash_bwd_params params;
  795. set_params_dgrad(params,
  796. batch_size,
  797. seqlen_q, seqlen_k,
  798. seqlen_q_rounded, seqlen_k_rounded,
  799. num_heads, num_heads_k,
  800. head_size, head_size_rounded,
  801. q, k, v, out,
  802. dout, dq, dk_expanded, dv_expanded,
  803. nullptr,
  804. nullptr,
  805. loop ? dq_accum.data_ptr() : nullptr,
  806. // loop ? dk_accum.data_ptr() : nullptr,
  807. // loop ? dv_accum.data_ptr() : nullptr,
  808. nullptr,
  809. nullptr,
  810. softmax_lse.data_ptr(),
  811. softmax_d.data_ptr(),
  812. p_dropout,
  813. softmax_scale,
  814. window_size_left,
  815. window_size_right,
  816. softcap,
  817. deterministic,
  818. /*unpadded_lse*/false);
  819. params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
  820. auto launch = &run_mha_bwd;
  821. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  822. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  823. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  824. int64_t counter_offset = params.b * params.h * 32;
  825. if ( rng_state.has_value() ) {
  826. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  827. } else if( is_dropout ) {
  828. // See Note [Acquire lock when using random generators]
  829. std::lock_guard<std::mutex> lock(gen->mutex_);
  830. params.philox_args = gen->philox_cuda_state(counter_offset);
  831. auto seeds = at::cuda::philox::unpack(params.philox_args);
  832. params.rng_state[0] = std::get<0>(seeds);
  833. params.rng_state[1] = std::get<1>(seeds);
  834. }
  835. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  836. if (seqlen_q > 0) {
  837. launch(params, stream);
  838. } else {
  839. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  840. dk_expanded.zero_();
  841. dv_expanded.zero_();
  842. softmax_d.zero_();
  843. }
  844. // For MQA/GQA we need to sum dK and dV across the groups
  845. if (num_heads_k != num_heads) {
  846. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  847. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  848. }
  849. return { dq, dk, dv, softmax_d };
  850. }
  851. std::vector<at::Tensor>
  852. mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
  853. const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  854. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  855. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  856. const at::Tensor &out, // total_q x num_heads x head_size
  857. const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
  858. std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  859. std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  860. std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  861. const at::Tensor &cu_seqlens_q, // b+1
  862. const at::Tensor &cu_seqlens_k, // b+1
  863. std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  864. const int max_seqlen_q,
  865. const int max_seqlen_k, // max sequence length to choose the kernel
  866. const float p_dropout, // probability to drop
  867. const float softmax_scale,
  868. const bool zero_tensors,
  869. const bool is_causal,
  870. int window_size_left,
  871. int window_size_right,
  872. const float softcap,
  873. const bool deterministic,
  874. std::optional<at::Generator> gen_,
  875. std::optional<at::Tensor> &rng_state) {
  876. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  877. TORCH_CHECK(false, "This flash attention build does not support backward.");
  878. #endif
  879. if (is_causal) { window_size_right = 0; }
  880. // Otherwise the kernel will be launched from cuda:0 device
  881. at::cuda::CUDAGuard device_guard{q.device()};
  882. auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
  883. bool is_sm8x_min = cc_major >= 8;
  884. TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
  885. bool is_dropout = p_dropout > 0.0;
  886. auto stream = at::cuda::getCurrentCUDAStream().stream();
  887. auto q_dtype = q.dtype();
  888. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  889. "FlashAttention only support fp16 and bf16 data type");
  890. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  891. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  892. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  893. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  894. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  895. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  896. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  897. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  898. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  899. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  900. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  901. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  902. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  903. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  904. CHECK_CONTIGUOUS(cu_seqlens_q);
  905. CHECK_CONTIGUOUS(cu_seqlens_k);
  906. const auto sizes = q.sizes();
  907. const int total_q = sizes[0];
  908. const int batch_size = cu_seqlens_q.numel() - 1;
  909. const int num_heads = sizes[1];
  910. const int head_size = sizes[2];
  911. const int total_k = k.size(0);
  912. const int num_heads_k = k.size(1);
  913. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  914. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  915. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  916. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  917. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  918. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  919. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  920. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  921. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  922. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  923. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  924. CHECK_SHAPE(q, total_q, num_heads, head_size);
  925. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  926. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  927. CHECK_SHAPE(out, total_q, num_heads, head_size);
  928. CHECK_SHAPE(dout, total_q, num_heads, head_size);
  929. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  930. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  931. at::Tensor dq, dk, dv;
  932. if (dq_.has_value()) {
  933. dq = dq_.value();
  934. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  935. CHECK_DEVICE(dq);
  936. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  937. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  938. } else {
  939. dq = torch::empty_like(q);
  940. }
  941. if (dk_.has_value()) {
  942. dk = dk_.value();
  943. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  944. CHECK_DEVICE(dk);
  945. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  946. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  947. } else {
  948. dk = torch::empty_like(k);
  949. }
  950. if (dv_.has_value()) {
  951. dv = dv_.value();
  952. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  953. CHECK_DEVICE(dv);
  954. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  955. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  956. } else {
  957. dv = torch::empty_like(v);
  958. }
  959. // bool loop = max_seqlen_k > blocksize_c;
  960. // TODO: change later, for now set to true for simplicity
  961. bool loop = true;
  962. auto opts = q.options();
  963. auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
  964. at::Tensor dq_accum;
  965. if (loop) {
  966. // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
  967. // because that would be too large if there is a very long sequence and the rest of the sequences are short.
  968. // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
  969. // Note that 128 is the max block size on the seqlen_q dimension.
  970. // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
  971. // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
  972. // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
  973. // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
  974. // Same holds for softmax_d, since LSE is stored in unpadded format.
  975. if (!deterministic) {
  976. dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  977. } else {
  978. const int nsplits = (get_num_sm(get_current_device()) + batch_size * num_heads - 1) / (batch_size * num_heads);
  979. dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  980. }
  981. }
  982. at::Tensor dk_expanded, dv_expanded;
  983. if (num_heads_k != num_heads) { // MQA / GQA
  984. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  985. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  986. } else {
  987. dk_expanded = dk;
  988. dv_expanded = dv;
  989. }
  990. if( zero_tensors ) {
  991. dq.zero_();
  992. dk_expanded.zero_();
  993. dv_expanded.zero_();
  994. softmax_d.zero_();
  995. }
  996. Flash_bwd_params params;
  997. set_params_dgrad(params,
  998. batch_size,
  999. max_seqlen_q, max_seqlen_k,
  1000. seqlen_q_rounded, seqlen_k_rounded,
  1001. num_heads, num_heads_k,
  1002. head_size, head_size_rounded,
  1003. q, k, v, out,
  1004. dout, dq, dk_expanded, dv_expanded,
  1005. cu_seqlens_q.data_ptr(),
  1006. cu_seqlens_k.data_ptr(),
  1007. loop ? dq_accum.data_ptr() : nullptr,
  1008. nullptr,
  1009. nullptr,
  1010. softmax_lse.data_ptr(),
  1011. softmax_d.data_ptr(),
  1012. p_dropout,
  1013. softmax_scale,
  1014. window_size_left,
  1015. window_size_right,
  1016. softcap,
  1017. deterministic,
  1018. /*unpadded_lse*/true);
  1019. params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
  1020. params.total_q = total_q;
  1021. auto launch = &run_mha_bwd;
  1022. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  1023. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  1024. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  1025. int64_t counter_offset = params.b * params.h * 32;
  1026. if ( rng_state.has_value() ) {
  1027. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  1028. } else if( is_dropout ) {
  1029. // See Note [Acquire lock when using random generators]
  1030. std::lock_guard<std::mutex> lock(gen->mutex_);
  1031. params.philox_args = gen->philox_cuda_state(counter_offset);
  1032. auto seeds = at::cuda::philox::unpack(params.philox_args);
  1033. params.rng_state[0] = std::get<0>(seeds);
  1034. params.rng_state[1] = std::get<1>(seeds);
  1035. }
  1036. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1037. if (max_seqlen_q > 0) {
  1038. launch(params, stream);
  1039. } else {
  1040. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  1041. dk_expanded.zero_();
  1042. dv_expanded.zero_();
  1043. softmax_d.zero_();
  1044. }
  1045. // For MQA/GQA we need to sum dK and dV across the groups
  1046. if (num_heads_k != num_heads) {
  1047. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1048. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1049. }
  1050. return { dq, dk, dv, softmax_d };
  1051. }
  1052. std::vector<at::Tensor>
  1053. mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  1054. const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1055. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1056. std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
  1057. std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
  1058. std::optional<const at::Tensor> &seqlens_k_, // batch_size
  1059. std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  1060. std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  1061. std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
  1062. std::optional<const at::Tensor> &leftpad_k_, // batch_size
  1063. std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  1064. std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  1065. std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  1066. const float softmax_scale,
  1067. bool is_causal,
  1068. int window_size_left,
  1069. int window_size_right,
  1070. const float softcap,
  1071. bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  1072. int num_splits
  1073. ) {
  1074. // Otherwise the kernel will be launched from cuda:0 device
  1075. at::cuda::CUDAGuard device_guard{q.device()};
  1076. auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
  1077. bool is_sm8x_min = cc_major >= 8;
  1078. TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
  1079. auto q_dtype = q.dtype();
  1080. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  1081. "FlashAttention only support fp16 and bf16 data type");
  1082. TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
  1083. TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
  1084. CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
  1085. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1086. TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1087. TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1088. at::Tensor block_table;
  1089. const bool paged_KV = block_table_.has_value();
  1090. if (paged_KV) {
  1091. TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
  1092. block_table = block_table_.value();
  1093. CHECK_DEVICE(block_table);
  1094. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  1095. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  1096. }
  1097. const auto sizes = q.sizes();
  1098. const int batch_size = sizes[0];
  1099. int seqlen_q = sizes[1];
  1100. int num_heads = sizes[2];
  1101. const int head_size_og = sizes[3];
  1102. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  1103. const int num_blocks = !paged_KV ? 0 : kcache.size(0);
  1104. const int page_block_size = !paged_KV ? 1 : kcache.size(1);
  1105. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  1106. const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
  1107. const int num_heads_k = kcache.size(2);
  1108. const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
  1109. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  1110. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  1111. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  1112. // causal=true is the same as causal=false in this case
  1113. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  1114. if (is_causal) { window_size_right = 0; }
  1115. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  1116. // H/t Daniel Haziza
  1117. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
  1118. if (seqlenq_ngroups_swapped) {
  1119. const int ngroups = num_heads / num_heads_k;
  1120. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  1121. seqlen_q = ngroups;
  1122. num_heads = num_heads_k;
  1123. }
  1124. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  1125. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  1126. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  1127. if (!paged_KV) {
  1128. CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1129. CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1130. } else {
  1131. CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1132. CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1133. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  1134. }
  1135. at::Tensor q_padded, kcache_padded, vcache_padded;
  1136. if (head_size_og % 8 != 0) {
  1137. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1138. kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1139. vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1140. } else {
  1141. q_padded = q;
  1142. kcache_padded = kcache;
  1143. vcache_padded = vcache;
  1144. }
  1145. at::Tensor out;
  1146. if (out_.has_value()) {
  1147. out = out_.value();
  1148. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  1149. CHECK_DEVICE(out);
  1150. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1151. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  1152. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  1153. } else {
  1154. out = torch::empty_like(q_padded);
  1155. }
  1156. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1157. const int head_size = round_multiple(head_size_og, 8);
  1158. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  1159. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  1160. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  1161. auto opts = q.options();
  1162. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  1163. Flash_fwd_params params;
  1164. set_params_fprop(params,
  1165. batch_size,
  1166. seqlen_q, seqlen_k,
  1167. seqlen_q_rounded, seqlen_k_rounded,
  1168. num_heads, num_heads_k,
  1169. head_size, head_size_rounded,
  1170. q_padded, kcache_padded, vcache_padded, out,
  1171. /*cu_seqlens_q_d=*/nullptr,
  1172. /*cu_seqlens_k_d=*/nullptr,
  1173. /*seqused_k=*/nullptr,
  1174. /*p_ptr=*/nullptr,
  1175. softmax_lse.data_ptr(),
  1176. /*p_dropout=*/0.f,
  1177. softmax_scale,
  1178. window_size_left,
  1179. window_size_right,
  1180. softcap
  1181. );
  1182. at::Tensor k, v, k_padded, v_padded;
  1183. if (k_.has_value()) {
  1184. TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
  1185. TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
  1186. TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
  1187. k = k_.value();
  1188. v = v_.value();
  1189. TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
  1190. TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
  1191. CHECK_DEVICE(k); CHECK_DEVICE(v);
  1192. TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
  1193. TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
  1194. int seqlen_knew = k.size(1);
  1195. CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1196. CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1197. if (head_size_og % 8 != 0) {
  1198. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1199. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1200. } else {
  1201. k_padded = k;
  1202. v_padded = v;
  1203. }
  1204. params.seqlen_knew = seqlen_knew;
  1205. params.knew_ptr = k_padded.data_ptr();
  1206. params.vnew_ptr = v_padded.data_ptr();
  1207. // All stride are in elements, not bytes.
  1208. params.knew_batch_stride = k_padded.stride(0);
  1209. params.vnew_batch_stride = v_padded.stride(0);
  1210. params.knew_row_stride = k_padded.stride(-3);
  1211. params.vnew_row_stride = v_padded.stride(-3);
  1212. params.knew_head_stride = k_padded.stride(-2);
  1213. params.vnew_head_stride = v_padded.stride(-2);
  1214. }
  1215. if (seqlens_k_.has_value()) {
  1216. auto seqlens_k = seqlens_k_.value();
  1217. TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
  1218. CHECK_DEVICE(seqlens_k);
  1219. CHECK_CONTIGUOUS(seqlens_k);
  1220. CHECK_SHAPE(seqlens_k, batch_size);
  1221. params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
  1222. }
  1223. params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
  1224. if (leftpad_k_.has_value()) {
  1225. TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
  1226. auto leftpad_k = leftpad_k_.value();
  1227. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  1228. CHECK_DEVICE(leftpad_k);
  1229. CHECK_CONTIGUOUS(leftpad_k);
  1230. CHECK_SHAPE(leftpad_k, batch_size);
  1231. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  1232. }
  1233. if (rotary_cos_.has_value()) {
  1234. TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  1235. auto rotary_cos = rotary_cos_.value();
  1236. CHECK_DEVICE(rotary_cos);
  1237. params.rotary_dim = rotary_cos.size(1) * 2;
  1238. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  1239. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  1240. const int seqlen_ro = rotary_cos.size(0);
  1241. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  1242. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  1243. CHECK_CONTIGUOUS(rotary_cos);
  1244. TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1245. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  1246. auto rotary_sin = rotary_sin_.value();
  1247. CHECK_DEVICE(rotary_sin);
  1248. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  1249. CHECK_CONTIGUOUS(rotary_sin);
  1250. TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1251. params.rotary_cos_ptr = rotary_cos.data_ptr();
  1252. params.rotary_sin_ptr = rotary_sin.data_ptr();
  1253. params.is_rotary_interleaved = is_rotary_interleaved;
  1254. } else {
  1255. params.rotary_dim = 0;
  1256. }
  1257. if (cache_batch_idx_.has_value()) {
  1258. auto cache_batch_idx = cache_batch_idx_.value();
  1259. CHECK_DEVICE(cache_batch_idx);
  1260. CHECK_CONTIGUOUS(cache_batch_idx);
  1261. TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
  1262. params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
  1263. }
  1264. // Keep references to these tensors to extend their lifetime
  1265. at::Tensor softmax_lse_accum, out_accum;
  1266. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  1267. params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
  1268. head_size_rounded, /*dropout*/ 0.f, num_splits, get_num_sm(get_current_device()), opts);
  1269. if (paged_KV) {
  1270. params.block_table = block_table.data_ptr<int>();
  1271. params.block_table_batch_stride = block_table.stride(0);
  1272. }
  1273. params.page_block_size = page_block_size;
  1274. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1275. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1276. // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
  1277. // or paged KV cache
  1278. run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
  1279. if (head_size_og % 8 != 0) {
  1280. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1281. if (out_.has_value()) { out_.value().copy_(out); }
  1282. if (k_.has_value()) {
  1283. // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
  1284. // but we don't expect to get this case in practice. This is just so that the code works for that case.
  1285. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1286. vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1287. }
  1288. }
  1289. if (seqlenq_ngroups_swapped) {
  1290. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  1291. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  1292. }
  1293. return {out, softmax_lse};
  1294. }
  1295. } // namespace FLASH_NAMESPACE
  1296. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1297. m.doc() = "FlashAttention";
  1298. m.def("fwd", &FLASH_NAMESPACE::mha_fwd, "Forward pass");
  1299. m.def("varlen_fwd", &FLASH_NAMESPACE::mha_varlen_fwd, "Forward pass (variable length)");
  1300. m.def("bwd", &FLASH_NAMESPACE::mha_bwd, "Backward pass");
  1301. m.def("varlen_bwd", &FLASH_NAMESPACE::mha_varlen_bwd, "Backward pass (variable length)");
  1302. m.def("fwd_kvcache", &FLASH_NAMESPACE::mha_fwd_kvcache, "Forward pass, with KV-cache");
  1303. }