flash_api.h 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all
  5. // of the torch headers.
  6. #include "registration.h"
  7. #include <torch/library.h>
  8. #include <torch/nn/functional.h>
  9. #include <ATen/cuda/CUDAContext.h>
  10. #include <c10/cuda/CUDAGuard.h>
  11. #include <cutlass/numeric_types.h>
  12. #include "flash.h"
  13. #include "static_switch.h"
  14. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  15. #define CHECK_SHAPE(x, ...) \
  16. TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
  17. #x " must have shape (" #__VA_ARGS__ ")")
  18. #define CHECK_CONTIGUOUS(x) \
  19. TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  20. void set_params_fprop(Flash_fwd_params& params,
  21. // sizes
  22. const size_t b, const size_t seqlen_q,
  23. const size_t seqlen_k, const size_t seqlen_q_rounded,
  24. const size_t seqlen_k_rounded, const size_t h,
  25. const size_t h_k, const size_t d, const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q, const at::Tensor k,
  28. const at::Tensor v, at::Tensor out, void* cu_seqlens_q_d,
  29. void* cu_seqlens_k_d, void* seqused_k, void* p_d,
  30. void* softmax_lse_d, float p_dropout, float softmax_scale,
  31. int window_size_left, int window_size_right,
  32. const float softcap, bool seqlenq_ngroups_swapped = false,
  33. const bool unpadded_lse = false) {
  34. // Reset the parameters
  35. params = {};
  36. params.is_bf16 = q.dtype() == torch::kBFloat16;
  37. // Set the pointers and strides.
  38. params.q_ptr = q.data_ptr();
  39. params.k_ptr = k.data_ptr();
  40. params.v_ptr = v.data_ptr();
  41. // All stride are in elements, not bytes.
  42. params.q_row_stride = q.stride(-3);
  43. params.k_row_stride = k.stride(-3);
  44. params.v_row_stride = v.stride(-3);
  45. params.q_head_stride = q.stride(-2);
  46. params.k_head_stride = k.stride(-2);
  47. params.v_head_stride = v.stride(-2);
  48. params.o_ptr = out.data_ptr();
  49. params.o_row_stride = out.stride(-3);
  50. params.o_head_stride = out.stride(-2);
  51. if (cu_seqlens_q_d == nullptr) {
  52. params.q_batch_stride = q.stride(0);
  53. params.k_batch_stride = k.stride(0);
  54. params.v_batch_stride = v.stride(0);
  55. params.o_batch_stride = out.stride(0);
  56. if (seqlenq_ngroups_swapped) {
  57. params.q_batch_stride *= seqlen_q;
  58. params.o_batch_stride *= seqlen_q;
  59. }
  60. }
  61. params.cu_seqlens_q = static_cast<int*>(cu_seqlens_q_d);
  62. params.cu_seqlens_k = static_cast<int*>(cu_seqlens_k_d);
  63. params.seqused_k = static_cast<int*>(seqused_k);
  64. // P = softmax(QK^T)
  65. params.p_ptr = p_d;
  66. // Softmax sum
  67. params.softmax_lse_ptr = softmax_lse_d;
  68. // Set the dimensions.
  69. params.b = b;
  70. params.h = h;
  71. params.h_k = h_k;
  72. params.h_h_k_ratio = h / h_k;
  73. params.seqlen_q = seqlen_q;
  74. params.seqlen_k = seqlen_k;
  75. params.seqlen_q_rounded = seqlen_q_rounded;
  76. params.seqlen_k_rounded = seqlen_k_rounded;
  77. params.d = d;
  78. params.d_rounded = d_rounded;
  79. // Set the different scale values.
  80. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  81. TORCH_CHECK(softcap <= 0.0,
  82. "This flash attention build does not support softcap.");
  83. #endif
  84. if (softcap > 0.0) {
  85. params.softcap = softmax_scale / softcap;
  86. params.scale_softmax = softcap;
  87. params.scale_softmax_log2 = softcap * M_LOG2E;
  88. } else {
  89. // Remove potential NaN
  90. params.softcap = 0.0;
  91. params.scale_softmax = softmax_scale;
  92. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  93. }
  94. // Set this to probability of keeping an element to simplify things.
  95. params.p_dropout = 1.f - p_dropout;
  96. // Convert p from float to int so we don't have to convert the random uint to
  97. // float to compare. [Minor] We want to round down since when we do the
  98. // comparison we use <= instead of < params.p_dropout_in_uint =
  99. // uint32_t(std::floor(params.p_dropout * 4294967295.0));
  100. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout *
  101. // 65535.0));
  102. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  103. params.rp_dropout = 1.f / params.p_dropout;
  104. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  105. TORCH_CHECK(p_dropout < 1.f);
  106. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  107. TORCH_CHECK(p_dropout == 0.0f,
  108. "This flash attention build does not support dropout.");
  109. #endif
  110. // Causal is the special case where window_size_right == 0 and
  111. // window_size_left < 0. Local is the more general case where
  112. // window_size_right >= 0 or window_size_left >= 0.
  113. params.is_causal = window_size_left < 0 && window_size_right == 0;
  114. if (window_size_left < 0 && window_size_right >= 0) {
  115. window_size_left = seqlen_k;
  116. }
  117. if (window_size_left >= 0 && window_size_right < 0) {
  118. window_size_right = seqlen_k;
  119. }
  120. params.window_size_left = window_size_left;
  121. params.window_size_right = window_size_right;
  122. #ifdef FLASHATTENTION_DISABLE_LOCAL
  123. TORCH_CHECK(
  124. params.is_causal || (window_size_left < 0 && window_size_right < 0),
  125. "This flash attention build does not support local attention.");
  126. #endif
  127. params.is_seqlens_k_cumulative = true;
  128. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  129. TORCH_CHECK(d == d_rounded,
  130. "This flash attention build does not support headdim not being a "
  131. "multiple of 32.");
  132. #endif
  133. params.unpadded_lse = unpadded_lse;
  134. params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
  135. }
  136. void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream,
  137. bool force_split_kernel = false) {
  138. FP16_SWITCH(!params.is_bf16, [&] {
  139. HEADDIM_SWITCH(params.d, [&] {
  140. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  141. if (params.num_splits <= 1 &&
  142. !force_split_kernel) { // If we don't set it num_splits == 0
  143. run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
  144. } else {
  145. run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params,
  146. stream);
  147. }
  148. });
  149. });
  150. });
  151. }
  152. // Find the number of splits that maximizes the occupancy. For example, if we
  153. // have batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency =
  154. // 0.89) is better than having 3 splits (efficiency = 0.67). However, we also
  155. // don't want too many splits as that would incur more HBM reads/writes. So we
  156. // find the best efficiency, then find the smallest number of splits that gets
  157. // 85% of the best efficiency.
  158. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs,
  159. int num_n_blocks, int max_splits) {
  160. // If we have enough to almost fill the SMs, then just use 1 split
  161. if (batch_nheads_mblocks >= 0.8f * num_SMs) {
  162. return 1;
  163. }
  164. max_splits = std::min({max_splits, num_SMs, num_n_blocks});
  165. float max_efficiency = 0.f;
  166. std::vector<float> efficiency;
  167. efficiency.reserve(max_splits);
  168. auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
  169. // Some splits are not eligible. For example, if we have 64 blocks and choose
  170. // 11 splits, we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have
  171. // 6 * 11 + (-2) blocks (i.e. it's 11 splits anyway). So we check if the
  172. // number of blocks per split is the same as the previous num_splits.
  173. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
  174. return num_splits == 1 || ceildiv(num_n_blocks, num_splits) !=
  175. ceildiv(num_n_blocks, num_splits - 1);
  176. };
  177. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  178. if (!is_split_eligible(num_splits)) {
  179. efficiency.push_back(0.f);
  180. } else {
  181. float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
  182. float eff = n_waves / ceil(n_waves);
  183. // printf("num_splits = %d, eff = %f\n", num_splits, eff);
  184. if (eff > max_efficiency) {
  185. max_efficiency = eff;
  186. }
  187. efficiency.push_back(eff);
  188. }
  189. }
  190. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  191. if (!is_split_eligible(num_splits)) {
  192. continue;
  193. }
  194. if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
  195. // printf("num_splits chosen = %d\n", num_splits);
  196. return num_splits;
  197. }
  198. }
  199. return 1;
  200. }
  201. std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
  202. Flash_fwd_params& params, const int batch_size, const int num_heads,
  203. const int head_size, const int max_seqlen_k, const int max_seqlen_q,
  204. const int head_size_rounded, const float p_dropout, const int num_splits,
  205. cudaDeviceProp* dprops, struct c10::TensorOptions opts) {
  206. // This needs to match with run_mha_fwd_splitkv_dispatch
  207. const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
  208. const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
  209. // Technically kBlockM = 64 only for the splitKV kernels, not the standard
  210. // kernel. In any case we don't expect seqlen_q to be larger than 64 for
  211. // inference.
  212. const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
  213. params.num_splits = num_splits;
  214. at::Tensor softmax_lse_accum;
  215. at::Tensor out_accum;
  216. if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
  217. if (num_splits < 1) {
  218. // We multiply number of SMs by 2 to hard-code the fact that we're using
  219. // 128 threads per block.
  220. params.num_splits = num_splits_heuristic(
  221. batch_size * num_heads * num_m_blocks,
  222. dprops->multiProcessorCount * 2, num_n_blocks, 128);
  223. }
  224. if (params.num_splits > 1) {
  225. softmax_lse_accum =
  226. torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q},
  227. opts.dtype(at::kFloat));
  228. out_accum = torch::empty({params.num_splits, batch_size, num_heads,
  229. max_seqlen_q, head_size_rounded},
  230. opts.dtype(at::kFloat));
  231. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  232. params.oaccum_ptr = out_accum.data_ptr();
  233. }
  234. TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
  235. }
  236. return std::make_tuple(softmax_lse_accum, out_accum);
  237. }
  238. void set_params_alibi(Flash_fwd_params& params,
  239. const c10::optional<at::Tensor>& alibi_slopes_,
  240. int batch_size, int num_heads) {
  241. #ifdef FLASHATTENTION_DISABLE_ALIBI
  242. TORCH_CHECK(!alibi_slopes_.has_value(),
  243. "This flash attention build does not support alibi.");
  244. params.alibi_slopes_ptr = nullptr;
  245. #else
  246. if (alibi_slopes_.has_value()) {
  247. auto alibi_slopes = alibi_slopes_.value();
  248. TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32,
  249. "ALiBi slopes must have dtype fp32");
  250. CHECK_DEVICE(alibi_slopes);
  251. TORCH_CHECK(alibi_slopes.stride(-1) == 1,
  252. "ALiBi slopes tensor must have contiguous last dimension");
  253. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) ||
  254. alibi_slopes.sizes() ==
  255. torch::IntArrayRef({batch_size, num_heads}));
  256. params.alibi_slopes_ptr = alibi_slopes.data_ptr();
  257. params.alibi_slopes_batch_stride =
  258. alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  259. } else {
  260. params.alibi_slopes_ptr = nullptr;
  261. }
  262. #endif
  263. }
  264. std::vector<at::Tensor> mha_fwd(
  265. at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
  266. const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
  267. const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
  268. const c10::optional<at::Tensor>&
  269. out_, // batch_size x seqlen_q x num_heads x head_size
  270. const c10::optional<at::Tensor>&
  271. alibi_slopes_, // num_heads or batch_size x num_heads
  272. const double p_dropout, const double softmax_scale, bool is_causal,
  273. int64_t window_size_left, int64_t window_size_right, const double softcap,
  274. const bool return_softmax, c10::optional<at::Generator> gen_) {
  275. auto dprops = at::cuda::getCurrentDeviceProperties();
  276. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  277. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  278. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  279. TORCH_CHECK(is_sm90 || is_sm8x,
  280. "FlashAttention only supports Ampere GPUs or newer.");
  281. // We will support Turing in the near future
  282. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports
  283. // Turing GPUs or newer.");
  284. auto q_dtype = q.dtype();
  285. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  286. "FlashAttention only support fp16 and bf16 data type");
  287. if (q_dtype == torch::kBFloat16) {
  288. TORCH_CHECK(is_sm90 || is_sm8x,
  289. "bfloat16 is only supported on Ampere GPUs or newer");
  290. }
  291. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  292. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  293. CHECK_DEVICE(q);
  294. CHECK_DEVICE(k);
  295. CHECK_DEVICE(v);
  296. TORCH_CHECK(q.stride(-1) == 1,
  297. "Input tensor must have contiguous last dimension");
  298. TORCH_CHECK(k.stride(-1) == 1,
  299. "Input tensor must have contiguous last dimension");
  300. TORCH_CHECK(v.stride(-1) == 1,
  301. "Input tensor must have contiguous last dimension");
  302. const auto sizes = q.sizes();
  303. const int batch_size = sizes[0];
  304. int seqlen_q = sizes[1];
  305. int num_heads = sizes[2];
  306. const int head_size_og = sizes[3];
  307. const int seqlen_k = k.size(1);
  308. const int num_heads_k = k.size(2);
  309. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  310. TORCH_CHECK(
  311. head_size_og <= 256,
  312. "FlashAttention forward only supports head dimension at most 256");
  313. TORCH_CHECK(
  314. num_heads % num_heads_k == 0,
  315. "Number of heads in key/value must divide number of heads in query");
  316. if (softcap > 0.f) {
  317. TORCH_CHECK(p_dropout == 0.f,
  318. "Softcapping does not support dropout for now");
  319. }
  320. if (window_size_left >= seqlen_k) {
  321. window_size_left = -1;
  322. }
  323. if (window_size_right >= seqlen_k) {
  324. window_size_right = -1;
  325. }
  326. // causal=true is the same as causal=false in this case
  327. if (seqlen_q == 1 && !alibi_slopes_.has_value()) {
  328. is_causal = false;
  329. }
  330. if (is_causal) {
  331. window_size_right = 0;
  332. }
  333. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups,
  334. // nheads_kv, d) in this case H/t Daniel Haziza
  335. const int seqlenq_ngroups_swapped =
  336. seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 &&
  337. window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 &&
  338. !alibi_slopes_.has_value();
  339. const int ngroups = num_heads / num_heads_k;
  340. if (seqlenq_ngroups_swapped) {
  341. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og})
  342. .transpose(1, 2);
  343. seqlen_q = ngroups;
  344. num_heads = num_heads_k;
  345. }
  346. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  347. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  348. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  349. at::Tensor q_padded, k_padded, v_padded;
  350. if (head_size_og % 8 != 0) {
  351. q_padded = torch::nn::functional::pad(
  352. q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  353. k_padded = torch::nn::functional::pad(
  354. k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  355. v_padded = torch::nn::functional::pad(
  356. v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  357. } else {
  358. q_padded = q;
  359. k_padded = k;
  360. v_padded = v;
  361. }
  362. at::Tensor out;
  363. if (out_.has_value()) {
  364. out = out_.value();
  365. TORCH_CHECK(out.dtype() == q_dtype,
  366. "Output must have the same dtype as inputs");
  367. CHECK_DEVICE(out);
  368. TORCH_CHECK(out.stride(-1) == 1,
  369. "Output tensor must have contiguous last dimension");
  370. CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
  371. if (seqlenq_ngroups_swapped) {
  372. out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og})
  373. .transpose(1, 2);
  374. }
  375. if (head_size_og % 8 != 0) {
  376. out = torch::empty_like(q_padded);
  377. }
  378. } else {
  379. out = torch::empty_like(q_padded);
  380. }
  381. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  382. const int head_size = round_multiple(head_size_og, 8);
  383. const int head_size_rounded = round_multiple(head_size, 32);
  384. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  385. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  386. // Otherwise the kernel will be launched from cuda:0 device
  387. // Cast to char to avoid compiler warning about narrowing
  388. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  389. auto opts = q.options();
  390. auto softmax_lse =
  391. torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  392. at::Tensor p;
  393. // Only return softmax if there's dropout to reduce compilation time
  394. if (return_softmax) {
  395. TORCH_CHECK(p_dropout > 0.0f,
  396. "return_softmax is only supported when p_dropout > 0.0");
  397. p = torch::empty(
  398. {batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
  399. }
  400. Flash_fwd_params params;
  401. set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded,
  402. seqlen_k_rounded, num_heads, num_heads_k, head_size,
  403. head_size_rounded, q_padded, k_padded, v_padded, out,
  404. /*cu_seqlens_q_d=*/nullptr,
  405. /*cu_seqlens_k_d=*/nullptr,
  406. /*seqused_k=*/nullptr,
  407. return_softmax ? p.data_ptr() : nullptr,
  408. softmax_lse.data_ptr(), p_dropout, softmax_scale,
  409. window_size_left, window_size_right, softcap);
  410. // Keep references to these tensors to extend their lifetime
  411. at::Tensor softmax_lse_accum, out_accum;
  412. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  413. params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
  414. head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
  415. // number of times random will be generated per thread, to offset philox
  416. // counter in thc random state We use a custom RNG that increases the offset
  417. // by batch_size * nheads * 32.
  418. int64_t counter_offset = params.b * params.h * 32;
  419. auto options =
  420. torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  421. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  422. // Forward kernel will populate memory with the seed and offset.
  423. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  424. if (p_dropout > 0.0) {
  425. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  426. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  427. // See Note [Acquire lock when using random generators]
  428. std::lock_guard<std::mutex> lock(gen->mutex_);
  429. params.philox_args = gen->philox_cuda_state(counter_offset);
  430. }
  431. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  432. if (seqlen_k > 0) {
  433. auto stream = at::cuda::getCurrentCUDAStream().stream();
  434. run_mha_fwd(params, stream);
  435. } else {
  436. // If seqlen_k == 0, then we have an empty tensor. We need to set the output
  437. // to 0.
  438. out.zero_();
  439. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  440. }
  441. at::Tensor out_padded = out;
  442. if (head_size_og % 8 != 0) {
  443. out = out.index(
  444. {"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  445. if (out_.has_value()) {
  446. out_.value().copy_(out);
  447. }
  448. }
  449. if (seqlenq_ngroups_swapped) {
  450. out = out.transpose(1, 2).reshape(
  451. {batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  452. out_padded = out_padded.transpose(1, 2).reshape(
  453. {batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  454. q_padded = q_padded.transpose(1, 2).reshape(
  455. {batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  456. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  457. }
  458. return {out, q_padded, k_padded, v_padded,
  459. out_padded, softmax_lse, p, rng_state};
  460. }
  461. std::vector<at::Tensor> mha_varlen_fwd(
  462. at::Tensor&
  463. q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  464. const at::Tensor& k, // total_k x num_heads_k x head_size, total_k :=
  465. // \sum_{i=0}^{b} s_i or num_blocks x page_block_size
  466. // x num_heads_k x head_size if there's a block_table.
  467. const at::Tensor& v, // total_k x num_heads_k x head_size, total_k :=
  468. // \sum_{i=0}^{b} s_i or num_blocks x page_block_size
  469. // x num_heads_k x head_size if there's a block_table.
  470. const c10::optional<at::Tensor>&
  471. out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  472. const at::Tensor& cu_seqlens_q, // b+1
  473. const at::Tensor& cu_seqlens_k, // b+1
  474. const c10::optional<at::Tensor>&
  475. seqused_k, // b. If given, only this many elements of each batch
  476. // element's keys are used.
  477. const c10::optional<at::Tensor>&
  478. block_table_, // batch_size x max_num_blocks_per_seq
  479. const c10::optional<at::Tensor>&
  480. alibi_slopes_, // num_heads or b x num_heads
  481. int64_t max_seqlen_q, const int64_t max_seqlen_k, const double p_dropout,
  482. const double softmax_scale, const bool zero_tensors, bool is_causal,
  483. int64_t window_size_left, int64_t window_size_right, const double softcap,
  484. const bool return_softmax, c10::optional<at::Generator> gen_) {
  485. auto dprops = at::cuda::getCurrentDeviceProperties();
  486. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  487. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  488. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  489. TORCH_CHECK(is_sm90 || is_sm8x,
  490. "FlashAttention only supports Ampere GPUs or newer.");
  491. // We will support Turing in the near future
  492. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports
  493. // Turing GPUs or newer.");
  494. auto q_dtype = q.dtype();
  495. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  496. "FlashAttention only support fp16 and bf16 data type");
  497. if (q_dtype == torch::kBFloat16) {
  498. TORCH_CHECK(is_sm90 || is_sm8x,
  499. "bfloat16 is only supported on Ampere GPUs or newer");
  500. }
  501. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  502. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  503. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32,
  504. "cu_seqlens_q must have dtype int32");
  505. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32,
  506. "cu_seqlens_k must have dtype int32");
  507. CHECK_DEVICE(q);
  508. CHECK_DEVICE(k);
  509. CHECK_DEVICE(v);
  510. CHECK_DEVICE(cu_seqlens_q);
  511. CHECK_DEVICE(cu_seqlens_k);
  512. at::Tensor block_table;
  513. const bool paged_KV = block_table_.has_value();
  514. if (paged_KV) {
  515. block_table = block_table_.value();
  516. CHECK_DEVICE(block_table);
  517. TORCH_CHECK(block_table.dtype() == torch::kInt32,
  518. "block_table must have dtype torch.int32");
  519. TORCH_CHECK(block_table.stride(-1) == 1,
  520. "block_table must have contiguous last dimension");
  521. }
  522. TORCH_CHECK(q.stride(-1) == 1,
  523. "Input tensor must have contiguous last dimension");
  524. TORCH_CHECK(k.stride(-1) == 1,
  525. "Input tensor must have contiguous last dimension");
  526. TORCH_CHECK(v.stride(-1) == 1,
  527. "Input tensor must have contiguous last dimension");
  528. CHECK_CONTIGUOUS(cu_seqlens_q);
  529. CHECK_CONTIGUOUS(cu_seqlens_k);
  530. const auto sizes = q.sizes();
  531. const int batch_size = cu_seqlens_q.numel() - 1;
  532. int num_heads = sizes[1];
  533. const int head_size_og = sizes[2];
  534. const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
  535. if (softcap > 0.f) {
  536. TORCH_CHECK(p_dropout == 0.f,
  537. "Softcapping does not support dropout for now");
  538. }
  539. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  540. const int num_blocks = !paged_KV ? 0 : k.size(0);
  541. const int page_block_size = !paged_KV ? 1 : k.size(1);
  542. TORCH_CHECK(!paged_KV || page_block_size % 16 == 0,
  543. "Paged KV cache block size must be divisible by 16");
  544. if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) {
  545. is_causal = false;
  546. } // causal=true is the same as causal=false in this case
  547. if (is_causal) {
  548. window_size_right = 0;
  549. }
  550. void* cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  551. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups,
  552. // nheads_kv, d) in this case H/t Daniel Haziza
  553. const int seqlenq_ngroups_swapped =
  554. max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 &&
  555. window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 &&
  556. !alibi_slopes_.has_value();
  557. const int ngroups = num_heads / num_heads_k;
  558. if (seqlenq_ngroups_swapped) {
  559. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og})
  560. .transpose(1, 2)
  561. .reshape({batch_size * ngroups, num_heads_k, head_size_og});
  562. max_seqlen_q = ngroups;
  563. num_heads = num_heads_k;
  564. cu_seqlens_q_d = nullptr;
  565. }
  566. const int total_q = q.sizes()[0];
  567. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  568. TORCH_CHECK(
  569. head_size_og <= 256,
  570. "FlashAttention forward only supports head dimension at most 256");
  571. TORCH_CHECK(
  572. num_heads % num_heads_k == 0,
  573. "Number of heads in key/value must divide number of heads in query");
  574. if (window_size_left >= max_seqlen_k) {
  575. window_size_left = -1;
  576. }
  577. if (window_size_right >= max_seqlen_k) {
  578. window_size_right = -1;
  579. }
  580. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  581. if (!paged_KV) {
  582. const int total_k = k.size(0);
  583. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  584. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  585. } else {
  586. CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
  587. CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
  588. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  589. }
  590. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  591. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  592. if (seqused_k.has_value()) {
  593. auto seqused_k_ = seqused_k.value();
  594. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32,
  595. "seqused_k must have dtype int32");
  596. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  597. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  598. CHECK_SHAPE(seqused_k_, batch_size);
  599. }
  600. at::Tensor q_padded, k_padded, v_padded;
  601. if (head_size_og % 8 != 0) {
  602. q_padded = torch::nn::functional::pad(
  603. q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  604. k_padded = torch::nn::functional::pad(
  605. k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  606. v_padded = torch::nn::functional::pad(
  607. v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  608. } else {
  609. q_padded = q;
  610. k_padded = k;
  611. v_padded = v;
  612. }
  613. at::Tensor out;
  614. if (out_.has_value()) {
  615. out = out_.value();
  616. TORCH_CHECK(out.dtype() == q_dtype,
  617. "Output must have the same dtype as inputs");
  618. CHECK_DEVICE(out);
  619. TORCH_CHECK(out.stride(-1) == 1,
  620. "Output tensor must have contiguous last dimension");
  621. CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
  622. if (seqlenq_ngroups_swapped) {
  623. out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og})
  624. .transpose(1, 2)
  625. .reshape({batch_size * ngroups, num_heads_k, head_size_og});
  626. }
  627. if (head_size_og % 8 != 0) {
  628. out = torch::empty_like(q_padded);
  629. }
  630. } else {
  631. out = torch::empty_like(q_padded);
  632. }
  633. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  634. const int head_size = round_multiple(head_size_og, 8);
  635. const int head_size_rounded = round_multiple(head_size, 32);
  636. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  637. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  638. // Otherwise the kernel will be launched from cuda:0 device
  639. // Cast to char to avoid compiler warning about narrowing
  640. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  641. auto opts = q.options();
  642. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  643. at::Tensor p;
  644. // Only return softmax if there's dropout to reduce compilation time
  645. if (return_softmax) {
  646. TORCH_CHECK(p_dropout > 0.0f,
  647. "return_softmax is only supported when p_dropout > 0.0");
  648. p = torch::empty(
  649. {batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}, opts);
  650. }
  651. if (zero_tensors) {
  652. out.zero_();
  653. softmax_lse.fill_(-std::numeric_limits<float>::infinity());
  654. if (return_softmax) {
  655. p.zero_();
  656. }
  657. }
  658. Flash_fwd_params params;
  659. set_params_fprop(
  660. params, batch_size, max_seqlen_q, max_seqlen_k, seqlen_q_rounded,
  661. seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded,
  662. q_padded, k_padded, v_padded, out, cu_seqlens_q_d,
  663. cu_seqlens_k.data_ptr(),
  664. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  665. return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(),
  666. p_dropout, softmax_scale, window_size_left, window_size_right, softcap,
  667. seqlenq_ngroups_swapped,
  668. /*unpadded_lse*/ true);
  669. params.total_q = total_q;
  670. if (paged_KV) {
  671. params.block_table = block_table.data_ptr<int>();
  672. params.block_table_batch_stride = block_table.stride(0);
  673. params.k_batch_stride = k_padded.stride(0);
  674. params.v_batch_stride = v_padded.stride(0);
  675. }
  676. params.page_block_size = page_block_size;
  677. // Keep references to these tensors to extend their lifetime
  678. at::Tensor softmax_lse_accum, out_accum;
  679. if (seqlenq_ngroups_swapped) {
  680. // Only apply split-k for decoding
  681. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  682. params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q,
  683. head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
  684. }
  685. // number of times random will be generated per thread, to offset philox
  686. // counter in thc random state We use a custom RNG that increases the offset
  687. // by batch_size * nheads * 32.
  688. int64_t counter_offset = params.b * params.h * 32;
  689. auto options =
  690. torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  691. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  692. // Forward kernel will populate memory with the seed and offset.
  693. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  694. if (p_dropout > 0.0) {
  695. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  696. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  697. // See Note [Acquire lock when using random generators]
  698. std::lock_guard<std::mutex> lock(gen->mutex_);
  699. params.philox_args = gen->philox_cuda_state(counter_offset);
  700. }
  701. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  702. if (max_seqlen_k > 0) {
  703. auto stream = at::cuda::getCurrentCUDAStream().stream();
  704. run_mha_fwd(params, stream, paged_KV);
  705. } else {
  706. // If seqlen_k == 0, then we have an empty tensor. We need to set the output
  707. // to 0.
  708. out.zero_();
  709. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  710. }
  711. at::Tensor out_padded = out;
  712. if (head_size_og % 8 != 0) {
  713. out = out.index(
  714. {"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  715. if (out_.has_value()) {
  716. out_.value().copy_(out);
  717. }
  718. }
  719. if (seqlenq_ngroups_swapped) {
  720. int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k,
  721. head_size_og};
  722. int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q,
  723. head_size_og};
  724. out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
  725. out_padded =
  726. out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
  727. q_padded =
  728. q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
  729. softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
  730. }
  731. return {out, q_padded, k_padded, v_padded,
  732. out_padded, softmax_lse, p, rng_state};
  733. }
  734. std::vector<at::Tensor> mha_fwd_kvcache(
  735. at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
  736. const at::Tensor&
  737. kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or
  738. // num_blocks x page_block_size x num_heads_k x head_size if
  739. // there's a block_table.
  740. const at::Tensor&
  741. vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or
  742. // num_blocks x page_block_size x num_heads_k x head_size if
  743. // there's a block_table.
  744. const c10::optional<at::Tensor>&
  745. k_, // batch_size x seqlen_knew x num_heads_k x head_size
  746. const c10::optional<at::Tensor>&
  747. v_, // batch_size x seqlen_knew x num_heads_k x head_size
  748. const c10::optional<at::Tensor>& seqlens_k_, // batch_size
  749. const c10::optional<at::Tensor>&
  750. rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  751. const c10::optional<at::Tensor>&
  752. rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  753. const c10::optional<at::Tensor>&
  754. cache_batch_idx_, // indices to index into the KV cache
  755. const c10::optional<at::Tensor>&
  756. block_table_, // batch_size x max_num_blocks_per_seq
  757. const c10::optional<at::Tensor>&
  758. alibi_slopes_, // num_heads or batch_size x num_heads
  759. const c10::optional<at::Tensor>&
  760. out_, // batch_size x seqlen_q x num_heads x head_size
  761. const double softmax_scale, bool is_causal, int64_t window_size_left,
  762. int64_t window_size_right, const double softcap,
  763. bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else
  764. // indices 0 & rotary_dim / 2
  765. int64_t num_splits) {
  766. auto dprops = at::cuda::getCurrentDeviceProperties();
  767. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  768. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  769. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  770. TORCH_CHECK(is_sm90 || is_sm8x,
  771. "FlashAttention only supports Ampere GPUs or newer.");
  772. // We will support Turing in the near future
  773. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports
  774. // Turing GPUs or newer.");
  775. auto q_dtype = q.dtype();
  776. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  777. "FlashAttention only support fp16 and bf16 data type");
  778. if (q_dtype == torch::kBFloat16) {
  779. TORCH_CHECK(is_sm90 || is_sm8x,
  780. "bfloat16 is only supported on Ampere GPUs or newer");
  781. }
  782. TORCH_CHECK(kcache.dtype() == q_dtype,
  783. "query and key must have the same dtype");
  784. TORCH_CHECK(vcache.dtype() == q_dtype,
  785. "query and value must have the same dtype");
  786. CHECK_DEVICE(q);
  787. CHECK_DEVICE(kcache);
  788. CHECK_DEVICE(vcache);
  789. TORCH_CHECK(q.stride(-1) == 1,
  790. "Input tensor must have contiguous last dimension");
  791. TORCH_CHECK(kcache.stride(-1) == 1,
  792. "Input tensor must have contiguous last dimension");
  793. TORCH_CHECK(vcache.stride(-1) == 1,
  794. "Input tensor must have contiguous last dimension");
  795. at::Tensor block_table;
  796. const bool paged_KV = block_table_.has_value();
  797. if (paged_KV) {
  798. TORCH_CHECK(!cache_batch_idx_.has_value(),
  799. "Paged KVcache does not support cache_batch_idx");
  800. block_table = block_table_.value();
  801. CHECK_DEVICE(block_table);
  802. TORCH_CHECK(block_table.dtype() == torch::kInt32,
  803. "block_table must have dtype torch.int32");
  804. TORCH_CHECK(block_table.stride(-1) == 1,
  805. "block_table must have contiguous last dimension");
  806. }
  807. const auto sizes = q.sizes();
  808. const int batch_size = sizes[0];
  809. int seqlen_q = sizes[1];
  810. const int seqlen_q_og = seqlen_q;
  811. int num_heads = sizes[2];
  812. const int num_heads_og = num_heads;
  813. const int head_size_og = sizes[3];
  814. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  815. const int num_blocks = !paged_KV ? 0 : kcache.size(0);
  816. const int page_block_size = !paged_KV ? 1 : kcache.size(1);
  817. TORCH_CHECK(!paged_KV || page_block_size % 16 == 0,
  818. "Paged KV cache block size must be divisible by 16");
  819. const int seqlen_k =
  820. !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
  821. const int num_heads_k = kcache.size(2);
  822. const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
  823. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  824. TORCH_CHECK(
  825. head_size_og <= 256,
  826. "FlashAttention forward only supports head dimension at most 256");
  827. TORCH_CHECK(
  828. num_heads % num_heads_k == 0,
  829. "Number of heads in key/value must divide number of heads in query");
  830. // causal=true is the same as causal=false in this case
  831. if (seqlen_q == 1 && !alibi_slopes_.has_value()) {
  832. is_causal = false;
  833. }
  834. if (is_causal) {
  835. window_size_right = 0;
  836. }
  837. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups,
  838. // nheads_kv, d) in this case H/t Daniel Haziza
  839. const int seqlenq_ngroups_swapped =
  840. seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 &&
  841. window_size_right < 0 && head_size_og % 8 == 0 &&
  842. !alibi_slopes_.has_value();
  843. if (seqlenq_ngroups_swapped) {
  844. const int ngroups = num_heads / num_heads_k;
  845. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og})
  846. .transpose(1, 2);
  847. seqlen_q = ngroups;
  848. num_heads = num_heads_k;
  849. }
  850. if (window_size_left >= seqlen_k) {
  851. window_size_left = -1;
  852. }
  853. if (window_size_right >= seqlen_k) {
  854. window_size_right = -1;
  855. }
  856. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  857. if (!paged_KV) {
  858. CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  859. CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  860. } else {
  861. CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  862. CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  863. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  864. }
  865. at::Tensor q_padded, kcache_padded, vcache_padded;
  866. if (head_size_og % 8 != 0) {
  867. q_padded = torch::nn::functional::pad(
  868. q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  869. kcache_padded = torch::nn::functional::pad(
  870. kcache,
  871. torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  872. vcache_padded = torch::nn::functional::pad(
  873. vcache,
  874. torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  875. } else {
  876. q_padded = q;
  877. kcache_padded = kcache;
  878. vcache_padded = vcache;
  879. }
  880. at::Tensor out;
  881. if (out_.has_value()) {
  882. out = out_.value();
  883. TORCH_CHECK(out.dtype() == q_dtype,
  884. "Output must have the same dtype as inputs");
  885. CHECK_DEVICE(out);
  886. TORCH_CHECK(out.stride(-1) == 1,
  887. "Output tensor must have contiguous last dimension");
  888. CHECK_SHAPE(out, batch_size, seqlen_q_og, num_heads_og, head_size_og);
  889. if (head_size_og % 8 != 0) {
  890. out = torch::empty_like(q_padded);
  891. } else if (seqlenq_ngroups_swapped) {
  892. out = out.reshape({batch_size, num_heads, seqlen_q, head_size_og})
  893. .transpose(1, 2);
  894. }
  895. } else {
  896. out = torch::empty_like(q_padded);
  897. }
  898. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  899. const int head_size = round_multiple(head_size_og, 8);
  900. const int head_size_rounded = round_multiple(head_size, 32);
  901. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  902. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  903. // Otherwise the kernel will be launched from cuda:0 device
  904. // Cast to char to avoid compiler warning about narrowing
  905. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  906. auto opts = q.options();
  907. auto softmax_lse =
  908. torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  909. Flash_fwd_params params;
  910. set_params_fprop(params, batch_size, seqlen_q, seqlen_k, seqlen_q_rounded,
  911. seqlen_k_rounded, num_heads, num_heads_k, head_size,
  912. head_size_rounded, q_padded, kcache_padded, vcache_padded,
  913. out,
  914. /*cu_seqlens_q_d=*/nullptr,
  915. /*cu_seqlens_k_d=*/nullptr,
  916. /*seqused_k=*/nullptr,
  917. /*p_ptr=*/nullptr, softmax_lse.data_ptr(),
  918. /*p_dropout=*/0.f, softmax_scale, window_size_left,
  919. window_size_right, softcap);
  920. at::Tensor k, v, k_padded, v_padded;
  921. if (k_.has_value()) {
  922. TORCH_CHECK(v_.has_value(),
  923. "If key is supplied, value must also be passed in");
  924. TORCH_CHECK(seqlens_k_.has_value(),
  925. "If key is supplied, seqlens_k must also be passed in");
  926. TORCH_CHECK(seqlen_q <= seqlen_k,
  927. "If key is supplied, it must have seqlen <= the seqlen of the "
  928. "KV cache");
  929. k = k_.value();
  930. v = v_.value();
  931. TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
  932. TORCH_CHECK(v.dtype() == q_dtype,
  933. "Value must have the same dtype as query");
  934. CHECK_DEVICE(k);
  935. CHECK_DEVICE(v);
  936. TORCH_CHECK(k.stride(-1) == 1,
  937. "Key tensor must have contiguous last dimension");
  938. TORCH_CHECK(v.stride(-1) == 1,
  939. "Value tensor must have contiguous last dimension");
  940. int seqlen_knew = k.size(1);
  941. CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
  942. CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
  943. if (head_size_og % 8 != 0) {
  944. k_padded = torch::nn::functional::pad(
  945. k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  946. v_padded = torch::nn::functional::pad(
  947. v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  948. } else {
  949. k_padded = k;
  950. v_padded = v;
  951. }
  952. params.seqlen_knew = seqlen_knew;
  953. params.knew_ptr = k_padded.data_ptr();
  954. params.vnew_ptr = v_padded.data_ptr();
  955. // All stride are in elements, not bytes.
  956. params.knew_batch_stride = k_padded.stride(0);
  957. params.vnew_batch_stride = v_padded.stride(0);
  958. params.knew_row_stride = k_padded.stride(-3);
  959. params.vnew_row_stride = v_padded.stride(-3);
  960. params.knew_head_stride = k_padded.stride(-2);
  961. params.vnew_head_stride = v_padded.stride(-2);
  962. }
  963. if (seqlens_k_.has_value()) {
  964. auto seqlens_k = seqlens_k_.value();
  965. TORCH_CHECK(seqlens_k.dtype() == torch::kInt32,
  966. "seqlens_k must have dtype int32");
  967. CHECK_DEVICE(seqlens_k);
  968. CHECK_CONTIGUOUS(seqlens_k);
  969. CHECK_SHAPE(seqlens_k, batch_size);
  970. params.cu_seqlens_k = static_cast<int*>(seqlens_k.data_ptr());
  971. }
  972. params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
  973. if (rotary_cos_.has_value()) {
  974. TORCH_CHECK(k_.has_value(),
  975. "If rotary cos/sin are provided, new key / value to be "
  976. "appended to KV cache must also be provided");
  977. auto rotary_cos = rotary_cos_.value();
  978. CHECK_DEVICE(rotary_cos);
  979. params.rotary_dim = rotary_cos.size(1) * 2;
  980. TORCH_CHECK(params.rotary_dim <= head_size,
  981. "rotary_dim must be <= headdim");
  982. TORCH_CHECK(
  983. params.rotary_dim % 16 == 0,
  984. "Only rotary dimensions divisible by 16 are currently supported");
  985. const int seqlen_ro = rotary_cos.size(0);
  986. TORCH_CHECK(seqlen_ro >= seqlen_k,
  987. "cos/sin seqlen must be at least the seqlen of KV cache");
  988. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  989. CHECK_CONTIGUOUS(rotary_cos);
  990. TORCH_CHECK(rotary_cos.scalar_type() == q_dtype,
  991. "rotary_cos must have the same dtype as query");
  992. TORCH_CHECK(rotary_sin_.has_value(),
  993. "If rotary cos is provided, rotary sin must also be provided");
  994. auto rotary_sin = rotary_sin_.value();
  995. CHECK_DEVICE(rotary_sin);
  996. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  997. CHECK_CONTIGUOUS(rotary_sin);
  998. TORCH_CHECK(rotary_sin.scalar_type() == q_dtype,
  999. "rotary_cos must have the same dtype as query");
  1000. params.rotary_cos_ptr = rotary_cos.data_ptr();
  1001. params.rotary_sin_ptr = rotary_sin.data_ptr();
  1002. params.is_rotary_interleaved = is_rotary_interleaved;
  1003. } else {
  1004. params.rotary_dim = 0;
  1005. }
  1006. if (cache_batch_idx_.has_value()) {
  1007. auto cache_batch_idx = cache_batch_idx_.value();
  1008. CHECK_DEVICE(cache_batch_idx);
  1009. CHECK_CONTIGUOUS(cache_batch_idx);
  1010. TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32,
  1011. "cache_batch_idx must have dtype int32");
  1012. params.cache_batch_idx = reinterpret_cast<int*>(cache_batch_idx.data_ptr());
  1013. }
  1014. // Keep references to these tensors to extend their lifetime
  1015. at::Tensor softmax_lse_accum, out_accum;
  1016. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  1017. params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
  1018. head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts);
  1019. if (paged_KV) {
  1020. params.block_table = block_table.data_ptr<int>();
  1021. params.block_table_batch_stride = block_table.stride(0);
  1022. }
  1023. params.page_block_size = page_block_size;
  1024. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1025. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1026. // Only split kernel supports appending to KV cache, or indexing to the cache
  1027. // with cache_batch_idx, or paged KV cache
  1028. run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() ||
  1029. cache_batch_idx_.has_value() || paged_KV);
  1030. if (head_size_og % 8 != 0) {
  1031. out = out.index(
  1032. {"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1033. if (out_.has_value()) {
  1034. out_.value().copy_(out);
  1035. }
  1036. if (k_.has_value()) {
  1037. // It's expensive to copy the KV cache here for the case where head size
  1038. // not divisible by 8, but we don't expect to get this case in practice.
  1039. // This is just so that the code works for that case.
  1040. kcache.copy_(kcache_padded.index(
  1041. {"...",
  1042. torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1043. vcache.copy_(vcache_padded.index(
  1044. {"...",
  1045. torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1046. }
  1047. }
  1048. if (seqlenq_ngroups_swapped) {
  1049. out = out.transpose(1, 2).reshape(
  1050. {batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  1051. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  1052. }
  1053. return {out, softmax_lse};
  1054. }