mha_fwd_kvcache.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. #include "fmha_fwd.hpp"
  6. #include "rotary.hpp"
  7. fmha_fwd_appendkv_traits get_ck_fmha_fwd_appendkv_traits(std::string dtype,
  8. int head_size,
  9. int rotary_dim,
  10. bool is_rotary_interleaved)
  11. {
  12. rope_enum rope_type = (0 < rotary_dim ? (is_rotary_interleaved ? rope_enum::interleaved
  13. : rope_enum::half_rotated)
  14. : rope_enum::none);
  15. return fmha_fwd_appendkv_traits{head_size,
  16. head_size,
  17. dtype,
  18. true, // is_v_rowmajor
  19. rope_type};
  20. }
  21. fmha_fwd_splitkv_traits get_ck_fmha_fwd_splitkv_traits(const mask_info &mask,
  22. std::string dtype,
  23. int head_size,
  24. bool has_lse,
  25. bool enable_alibi)
  26. {
  27. return fmha_fwd_splitkv_traits{head_size,
  28. head_size,
  29. dtype,
  30. false, // is_group_mode
  31. true, // is_v_rowmajor
  32. mask.type,
  33. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  34. has_lse,
  35. false}; // do_fp8_static_quant
  36. }
  37. fmha_fwd_appendkv_args get_ck_fmha_fwd_appendkv_args(const int b,
  38. const int seqlen_q,
  39. const int seqlen_knew,
  40. const int h,
  41. const int h_k,
  42. const int d,
  43. const int rotary_dim,
  44. const bool has_mask,
  45. const int page_block_size,
  46. // device pointers
  47. const at::Tensor q,
  48. const at::Tensor kcache,
  49. const at::Tensor vcache,
  50. const at::Tensor knew,
  51. const at::Tensor vnew,
  52. std::optional<const at::Tensor> &seqlens_k_,
  53. std::optional<const at::Tensor> &rotary_cos_,
  54. std::optional<const at::Tensor> &rotary_sin_,
  55. std::optional<const at::Tensor> &cache_batch_idx_,
  56. std::optional<at::Tensor> &block_table_)
  57. {
  58. // q: (batch_size, seqlen_q, nheads, d)
  59. // kcache: (batch_size_c, seqlen_k, nheads_k, d) or (num_blocks, page_block_size, nheads_k, d)
  60. // vcache: (batch_size_c, seqlen_k, nheads_k, d) or (num_blocks, page_block_size, nheads_k, d)
  61. // knew: (batch_size, seqlen_knew, nheads_k, d)
  62. // vnew: (batch_size, seqlen_knew, nheads_k, d)
  63. // seqlens_k: (batch_size)
  64. // rotary_cos: (seqlen_ro, rotary_dim / 2)
  65. // rotary_sin: (seqlen_ro, rotary_dim / 2)
  66. // block_table: (batch_size, max_num_blocks_per_seq)
  67. fmha_fwd_appendkv_args args;
  68. args.q_ptr = q.data_ptr();
  69. args.k_ptr = kcache.data_ptr();
  70. args.knew_ptr = knew.data_ptr();
  71. args.v_ptr = vcache.data_ptr();
  72. args.vnew_ptr = vnew.data_ptr();
  73. args.seqlen_k_ptr = seqlens_k_.has_value() ? seqlens_k_.value().data_ptr() : nullptr;
  74. args.seqlen_q = seqlen_q;
  75. args.seqlen_knew = seqlen_knew;
  76. args.batch = b;
  77. args.hdim_q = d;
  78. args.hdim_v = d;
  79. args.nhead_q = h;
  80. args.nhead_k = h_k;
  81. args.rotary_cos_ptr = rotary_cos_.has_value() ? rotary_cos_.value().data_ptr() : nullptr;
  82. args.rotary_sin_ptr = rotary_sin_.has_value() ? rotary_sin_.value().data_ptr() : nullptr;
  83. args.rotary_dim = rotary_dim;
  84. args.has_mask = has_mask;
  85. if (block_table_.has_value())
  86. {
  87. auto block_table = block_table_.value();
  88. args.block_table_ptr = block_table.data_ptr();
  89. args.batch_stride_block_table = block_table.stride(0);
  90. args.page_block_size = page_block_size;
  91. }
  92. else
  93. {
  94. args.block_table_ptr = nullptr;
  95. args.batch_stride_block_table = 0;
  96. args.page_block_size = 0;
  97. }
  98. args.cache_batch_idx = cache_batch_idx_.has_value() ?
  99. reinterpret_cast<int *>(cache_batch_idx_.value().data_ptr()) : nullptr;
  100. args.batch_stride_q = q.stride(0);
  101. args.stride_q = q.stride(1);
  102. args.nhead_stride_q = q.stride(2);
  103. args.batch_stride_k = kcache.stride(0);
  104. args.stride_k = kcache.stride(1);
  105. args.nhead_stride_k = kcache.stride(2);
  106. args.batch_stride_knew = knew.stride(0);
  107. args.stride_knew = knew.stride(1);
  108. args.nhead_stride_knew = knew.stride(2);
  109. args.batch_stride_v = vcache.stride(0);
  110. args.stride_v = vcache.stride(1);
  111. args.nhead_stride_v = vcache.stride(2);
  112. args.batch_stride_vnew = vnew.stride(0);
  113. args.stride_vnew = vnew.stride(1);
  114. args.nhead_stride_vnew = vnew.stride(2);
  115. return args;
  116. }
  117. fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,
  118. const mask_info &mask,
  119. const int b,
  120. const int seqlen_q,
  121. const int seqlen_k,
  122. const int h,
  123. const int h_k,
  124. const int d,
  125. const int page_block_size,
  126. const int num_splits,
  127. float softmax_scale,
  128. // device pointers
  129. const at::Tensor q,
  130. const at::Tensor k,
  131. const at::Tensor v,
  132. const at::Tensor seqlens_k,
  133. std::optional<const at::Tensor> &cache_batch_idx_,
  134. std::optional<at::Tensor> &block_table_,
  135. std::optional<at::Tensor> &alibi_slopes_,
  136. at::Tensor out,
  137. at::Tensor lse,
  138. at::Tensor lse_acc,
  139. at::Tensor out_acc)
  140. {
  141. // q: (batch_size, seqlen_q, nheads, d)
  142. // k: (batch_size, seqlen_k, nheads_k, d)
  143. // v: (batch_size, seqlen_k, nheads_k, d)
  144. // o: (batch_size, seqlen_q, nheads, d)
  145. // alibi_slopes:(batch_size, nheads) or (nhead)
  146. // lse: (batch_size, nheads, seqlen_q)
  147. // lse_acc: (split, batch_size, nheads, seqlen_q)
  148. // o_acc: (split, batch_size, nheads, seqlen_q, d)
  149. fmha_fwd_splitkv_args args;
  150. args.q_ptr = q.data_ptr();
  151. args.k_ptr = k.data_ptr();
  152. args.v_ptr = v.data_ptr();
  153. args.bias_ptr = nullptr;
  154. args.lse_acc_ptr = lse_acc.data_ptr();
  155. args.o_acc_ptr = out_acc.data_ptr();
  156. args.lse_ptr = nullptr;
  157. args.o_ptr = out.data_ptr();
  158. if (block_table_.has_value())
  159. {
  160. auto block_table = block_table_.value();
  161. args.block_table_ptr = block_table.data_ptr();
  162. args.batch_stride_block_table = block_table.stride(0);
  163. args.page_block_size = page_block_size;
  164. }
  165. else
  166. {
  167. args.block_table_ptr = nullptr;
  168. args.batch_stride_block_table = 0;
  169. args.page_block_size = 0;
  170. }
  171. args.cache_batch_idx = cache_batch_idx_.has_value() ? cache_batch_idx_.value().data_ptr() : nullptr;
  172. args.seqstart_q_ptr = nullptr;
  173. args.seqstart_k_ptr = nullptr;
  174. args.seqlen_k_ptr = seqlens_k.data_ptr();
  175. args.seqlen_q = seqlen_q;
  176. args.seqlen_k = seqlen_k;
  177. args.batch = b;
  178. args.max_seqlen_q = seqlen_q;
  179. args.hdim_q = d;
  180. args.hdim_v = d;
  181. args.nhead_q = h;
  182. args.nhead_k = h_k;
  183. args.num_splits = num_splits;
  184. args.scale_s = softmax_scale;
  185. args.scale_p = 1;
  186. args.scale_o = 1;
  187. args.batch_stride_q = q.stride(0);
  188. args.stride_q = q.stride(1);
  189. args.nhead_stride_q = q.stride(2);
  190. args.batch_stride_k = k.stride(0);
  191. args.stride_k = k.stride(1);
  192. args.nhead_stride_k = k.stride(2);
  193. args.batch_stride_v = v.stride(0);
  194. args.stride_v = v.stride(1);
  195. args.nhead_stride_v = v.stride(2);
  196. args.batch_stride_o = out.stride(0);
  197. args.stride_o = out.stride(1);
  198. args.nhead_stride_o = out.stride(2);
  199. args.batch_stride_bias = 0;
  200. args.stride_bias = 0;
  201. args.nhead_stride_bias = 0;
  202. args.batch_stride_lse = 0;
  203. args.nhead_stride_lse = 0;
  204. args.split_stride_lse_acc = lse_acc.stride(0);
  205. args.batch_stride_lse_acc = lse_acc.stride(1);
  206. args.nhead_stride_lse_acc = lse_acc.stride(2);
  207. args.split_stride_o_acc = out_acc.stride(0);
  208. args.batch_stride_o_acc = out_acc.stride(1);
  209. args.nhead_stride_o_acc = out_acc.stride(2);
  210. args.stride_o_acc = out_acc.stride(3);
  211. if (has_lse) {
  212. args.lse_ptr = lse.data_ptr();
  213. args.batch_stride_lse = lse.stride(0);
  214. args.nhead_stride_lse = lse.stride(1);
  215. }
  216. if (alibi_slopes_.has_value()) {
  217. auto alibi_slopes = alibi_slopes_.value();
  218. CHECK_DEVICE(alibi_slopes);
  219. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  220. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  221. args.bias_ptr = alibi_slopes.data_ptr();
  222. args.stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  223. }
  224. args.window_size_left = mask.left;
  225. args.window_size_right = mask.right;
  226. args.mask_type = static_cast<ck_tile::index_t>(mask.type);
  227. return args;
  228. }
  229. std::vector<at::Tensor>
  230. mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  231. const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  232. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  233. std::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
  234. std::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
  235. std::optional<const at::Tensor> &seqlens_k_, // batch_size
  236. std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  237. std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  238. std::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
  239. std::optional<const at::Tensor> & /*leftpad_k_*/, // batch_size
  240. std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  241. std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  242. std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  243. const float softmax_scale,
  244. bool is_causal,
  245. int window_size_left,
  246. int window_size_right,
  247. const float /*softcap*/,
  248. bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  249. int num_splits)
  250. {
  251. auto q_dtype = q.dtype();
  252. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  253. "FlashAttention only support fp16 and bf16 data type");
  254. TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
  255. TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
  256. std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
  257. CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
  258. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  259. TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  260. TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  261. at::Tensor block_table;
  262. const bool paged_KV = block_table_.has_value();
  263. if (paged_KV) {
  264. TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
  265. block_table = block_table_.value();
  266. CHECK_DEVICE(block_table);
  267. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  268. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  269. }
  270. const auto sizes = q.sizes();
  271. const int batch_size = sizes[0];
  272. int seqlen_q = sizes[1];
  273. int num_heads = sizes[2];
  274. const int head_size_og = sizes[3];
  275. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  276. const int num_blocks = !paged_KV ? 0 : kcache.size(0);
  277. const int page_block_size = !paged_KV ? 1 : kcache.size(1);
  278. TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128");
  279. const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
  280. const int num_heads_k = kcache.size(2);
  281. const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
  282. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  283. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  284. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  285. // causal=true is the same as causal=false in this case
  286. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  287. if (is_causal) { window_size_right = 0; }
  288. mask_info mask;
  289. if (is_causal) {
  290. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  291. window_size_right = 0;
  292. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
  293. mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
  294. }
  295. else if (window_size_left == -1 && window_size_right == -1) {
  296. mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
  297. }
  298. else {
  299. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  300. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
  301. mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
  302. }
  303. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  304. // H/t Daniel Haziza
  305. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
  306. if (seqlenq_ngroups_swapped) {
  307. const int ngroups = num_heads / num_heads_k;
  308. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  309. seqlen_q = ngroups;
  310. num_heads = num_heads_k;
  311. }
  312. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  313. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  314. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  315. if (!paged_KV) {
  316. CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  317. CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  318. } else {
  319. CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  320. CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  321. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  322. }
  323. at::Tensor q_padded, kcache_padded, vcache_padded;
  324. if (head_size_og % 8 != 0) {
  325. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  326. kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  327. vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  328. } else {
  329. q_padded = q;
  330. kcache_padded = kcache;
  331. vcache_padded = vcache;
  332. }
  333. at::Tensor out;
  334. if (out_.has_value()) {
  335. out = out_.value();
  336. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  337. CHECK_DEVICE(out);
  338. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  339. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  340. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  341. } else {
  342. out = torch::empty_like(q_padded);
  343. }
  344. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  345. const int head_size_8x = round_multiple(head_size_og, 8);
  346. // Otherwise the kernel will be launched from cuda:0 device
  347. at::cuda::CUDAGuard device_guard{q.device()};
  348. auto opts = q.options();
  349. // TODO - check gradient, only training require lse
  350. bool has_lse = true;
  351. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  352. int seqlen_knew = 0;
  353. at::Tensor k, v, k_padded, v_padded;
  354. if (k_.has_value()) {
  355. TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
  356. TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
  357. TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
  358. k = k_.value();
  359. v = v_.value();
  360. TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
  361. TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
  362. CHECK_DEVICE(k); CHECK_DEVICE(v);
  363. TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
  364. TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
  365. seqlen_knew = k.size(1);
  366. CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
  367. CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
  368. if (head_size_og % 8 != 0) {
  369. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  370. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  371. } else {
  372. k_padded = k;
  373. v_padded = v;
  374. }
  375. }
  376. if (seqlens_k_.has_value()) {
  377. auto seqlens_k = seqlens_k_.value();
  378. TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
  379. CHECK_DEVICE(seqlens_k);
  380. CHECK_CONTIGUOUS(seqlens_k);
  381. CHECK_SHAPE(seqlens_k, batch_size);
  382. }
  383. int rotary_dim = 0;
  384. if (rotary_cos_.has_value()) {
  385. TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  386. auto rotary_cos = rotary_cos_.value();
  387. CHECK_DEVICE(rotary_cos);
  388. rotary_dim = rotary_cos.size(1) * 2;
  389. TORCH_CHECK(rotary_dim <= head_size_og, "rotary_dim must be <= headdim");
  390. TORCH_CHECK(rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  391. const int seqlen_ro = rotary_cos.size(0);
  392. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  393. CHECK_SHAPE(rotary_cos, seqlen_ro, rotary_dim / 2);
  394. CHECK_CONTIGUOUS(rotary_cos);
  395. TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  396. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  397. auto rotary_sin = rotary_sin_.value();
  398. CHECK_DEVICE(rotary_sin);
  399. CHECK_SHAPE(rotary_sin, seqlen_ro, rotary_dim / 2);
  400. CHECK_CONTIGUOUS(rotary_sin);
  401. TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  402. }
  403. if (cache_batch_idx_.has_value()) {
  404. auto cache_batch_idx = cache_batch_idx_.value();
  405. CHECK_DEVICE(cache_batch_idx);
  406. CHECK_CONTIGUOUS(cache_batch_idx);
  407. TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
  408. }
  409. num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, 0, num_splits);
  410. TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
  411. TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");
  412. // Keep references to these tensors to extend their lifetime
  413. auto softmax_lse_accum = torch::empty({num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  414. auto out_accum = torch::empty({num_splits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
  415. auto stream = at::cuda::getCurrentCUDAStream().stream();
  416. ck_tile::stream_config stream_config{stream};
  417. if (seqlen_knew > 0 || rotary_dim > 0) {
  418. auto appendkv_traits =
  419. get_ck_fmha_fwd_appendkv_traits(q_dtype_str, head_size_8x, rotary_dim, is_rotary_interleaved);
  420. auto appendkv_args =
  421. get_ck_fmha_fwd_appendkv_args(
  422. batch_size,
  423. seqlen_q,
  424. seqlen_knew,
  425. num_heads,
  426. num_heads_k,
  427. head_size_8x,
  428. rotary_dim,
  429. mask.type != mask_enum::no_mask,
  430. page_block_size,
  431. q_padded,
  432. kcache_padded,
  433. vcache_padded,
  434. k_padded,
  435. v_padded,
  436. seqlens_k_,
  437. rotary_cos_,
  438. rotary_sin_,
  439. cache_batch_idx_,
  440. block_table_);
  441. fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
  442. }
  443. // seqlens_k_ is the seqlen of kvcache. We need to add seqlen_knew for before attention
  444. auto append_seqlens_k = torch::empty({batch_size}, opts.dtype(torch::kInt32));
  445. if (seqlens_k_.has_value())
  446. append_seqlens_k = seqlens_k_.value() + seqlen_knew;
  447. else
  448. append_seqlens_k.fill_(seqlen_knew);
  449. // we use splitkv even num_splits == 1, because fmha_fwd() does not support seqlen_k_ in batch mode
  450. auto splitkv_traits =
  451. get_ck_fmha_fwd_splitkv_traits(mask, q_dtype_str, head_size_8x, has_lse, alibi_slopes_.has_value());
  452. auto splitkv_args =
  453. get_ck_fmha_fwd_splitkv_args(
  454. has_lse,
  455. mask,
  456. batch_size,
  457. seqlen_q,
  458. seqlen_k,
  459. num_heads,
  460. num_heads_k,
  461. head_size_8x,
  462. page_block_size,
  463. num_splits,
  464. softmax_scale,
  465. q_padded,
  466. kcache_padded,
  467. vcache_padded,
  468. append_seqlens_k,
  469. cache_batch_idx_,
  470. block_table_,
  471. alibi_slopes_,
  472. out,
  473. softmax_lse,
  474. softmax_lse_accum,
  475. out_accum);
  476. fmha_fwd_splitkv(splitkv_traits, splitkv_args, stream_config);
  477. if (head_size_og % 8 != 0) {
  478. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  479. if (out_.has_value()) { out_.value().copy_(out); }
  480. if (k_.has_value()) {
  481. // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
  482. // but we don't expect to get this case in practice. This is just so that the code works for that case.
  483. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  484. vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  485. }
  486. }
  487. if (seqlenq_ngroups_swapped) {
  488. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  489. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  490. }
  491. return {out, softmax_lse};
  492. }