mha_varlen_fwd.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. #include "fmha_fwd.hpp"
  6. #include "mask.hpp"
  7. fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
  8. std::string dtype,
  9. int head_size,
  10. bool has_dropout,
  11. bool has_lse,
  12. bool enable_alibi)
  13. {
  14. return fmha_fwd_traits{head_size,
  15. head_size,
  16. dtype,
  17. true, // is_group_mode
  18. true, // is_v_rowmajor
  19. mask.type,
  20. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  21. has_lse,
  22. has_dropout,
  23. false}; // do_fp8_static_quant
  24. }
  25. fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &mask,
  26. std::string dtype,
  27. int head_size,
  28. bool has_lse,
  29. bool enable_alibi)
  30. {
  31. return fmha_fwd_splitkv_traits{head_size,
  32. head_size,
  33. dtype,
  34. true, // is_group_mode
  35. true, // is_v_rowmajor
  36. mask.type,
  37. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  38. has_lse,
  39. false}; // do_fp8_static_quant
  40. }
  41. fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
  42. bool has_dropout_randval,
  43. const mask_info &mask,
  44. // sizes
  45. const int b,
  46. const int max_seqlen_q,
  47. const int h,
  48. const int h_k,
  49. const int d,
  50. // device pointers
  51. const at::Tensor q,
  52. const at::Tensor k,
  53. const at::Tensor v,
  54. const at::Tensor seqlens_q,
  55. const at::Tensor seqlens_k,
  56. std::optional<at::Tensor> &alibi_slopes_,
  57. at::Tensor out,
  58. at::Tensor softmax_lse,
  59. at::Tensor dropout_randval,
  60. float softmax_scale,
  61. float p_dropout,
  62. std::pair<uint64_t*, uint64_t*> drop_seed_offset)
  63. {
  64. // q: (total_q, nheads, d)
  65. // k: (total_k, nheads_k, d)
  66. // v: (total_k, nheads_k, d)
  67. // o: (total_q, nheads, d)
  68. // alibi_slopes:(batch, nheads) or (nhead)
  69. // lse: (nheads, total_q)
  70. // randval: (nheads, total_q, max_seqlen_k)
  71. ck_tile::index_t total_q = q.size(0);
  72. ck_tile::index_t total_k = k.size(0);
  73. ck_tile::index_t stride_q = q.stride(0);
  74. ck_tile::index_t stride_k = k.stride(0);
  75. ck_tile::index_t stride_v = v.stride(0);
  76. ck_tile::index_t stride_o = out.stride(0);
  77. ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
  78. ck_tile::index_t nhead_stride_q = q.stride(1);
  79. ck_tile::index_t nhead_stride_k = k.stride(1);
  80. ck_tile::index_t nhead_stride_v = v.stride(1);
  81. ck_tile::index_t nhead_stride_o = out.stride(1);
  82. ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
  83. ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
  84. ck_tile::index_t batch_stride_q = 0;
  85. ck_tile::index_t batch_stride_k = 0;
  86. ck_tile::index_t batch_stride_v = 0;
  87. ck_tile::index_t batch_stride_o = 0;
  88. ck_tile::index_t batch_stride_lse = 0;
  89. ck_tile::index_t batch_stride_randval = 0;
  90. void *alibi_slopes_ptr = nullptr;
  91. ck_tile::index_t stride_alibi_slopes = 0;
  92. if (alibi_slopes_.has_value()) {
  93. auto alibi_slopes = alibi_slopes_.value();
  94. CHECK_DEVICE(alibi_slopes);
  95. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  96. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  97. alibi_slopes_ptr = alibi_slopes.data_ptr();
  98. stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  99. }
  100. return fmha_fwd_args{q.data_ptr(),
  101. k.data_ptr(),
  102. v.data_ptr(),
  103. alibi_slopes_ptr, // bias
  104. has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
  105. has_lse ? softmax_lse.data_ptr() : nullptr,
  106. out.data_ptr(),
  107. seqlens_q.data_ptr(), // seqstart_q
  108. seqlens_k.data_ptr(), // seqstart_k
  109. nullptr, // seqlen_kpads
  110. total_q,
  111. total_k,
  112. b,
  113. max_seqlen_q,
  114. d, // hdim_q
  115. d, // hdim_v
  116. h, // nhead
  117. h_k, // nhead_k
  118. softmax_scale, // scale_s
  119. 1, // scale_p
  120. 1, // scale_o
  121. stride_q,
  122. stride_k,
  123. stride_v,
  124. stride_alibi_slopes,
  125. stride_randval,
  126. stride_o,
  127. nhead_stride_q,
  128. nhead_stride_k,
  129. nhead_stride_v,
  130. 0, // nhead_stride_bias, FA without bias
  131. nhead_stride_randval,
  132. nhead_stride_lse,
  133. nhead_stride_o,
  134. batch_stride_q,
  135. batch_stride_k,
  136. batch_stride_v,
  137. 0, // batch_stride_bias, FA without bias
  138. batch_stride_randval,
  139. batch_stride_lse,
  140. batch_stride_o,
  141. mask.left,
  142. mask.right,
  143. static_cast<ck_tile::index_t>(mask.type),
  144. p_dropout,
  145. has_dropout_randval,
  146. drop_seed_offset};
  147. }
  148. fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
  149. const mask_info &mask,
  150. const int b,
  151. const int max_seqlen_q,
  152. const int h,
  153. const int h_k,
  154. const int d,
  155. const int page_block_size,
  156. const int num_splits,
  157. float softmax_scale,
  158. // device pointers
  159. const at::Tensor q,
  160. const at::Tensor k,
  161. const at::Tensor v,
  162. const at::Tensor seqlens_q,
  163. const at::Tensor seqlens_k,
  164. std::optional<at::Tensor> &block_table_,
  165. std::optional<at::Tensor> &alibi_slopes_,
  166. at::Tensor out,
  167. at::Tensor lse,
  168. at::Tensor lse_acc,
  169. at::Tensor out_acc)
  170. {
  171. // q: (total_q, nheads, d)
  172. // k: (num_blocks, page_block_size, num_heads_k, d)
  173. // v: (num_blocks, page_block_size, num_heads_k, d)
  174. // o: (total_q, nheads, d)
  175. // alibi_slopes:(batch_size, nheads) or (nhead)
  176. // lse: (nheads, total_q)
  177. // lse_acc: (nheads, split, total_q)
  178. // o_acc: (nheads, split, total_q, d)
  179. // block_table: (batch_size, max_num_blocks_per_seq)
  180. fmha_fwd_splitkv_args args;
  181. args.q_ptr = q.data_ptr();
  182. args.k_ptr = k.data_ptr();
  183. args.v_ptr = v.data_ptr();
  184. args.bias_ptr = nullptr;
  185. args.lse_acc_ptr = lse_acc.data_ptr();
  186. args.o_acc_ptr = out_acc.data_ptr();
  187. args.lse_ptr = nullptr;
  188. args.o_ptr = out.data_ptr();
  189. if (block_table_.has_value())
  190. {
  191. auto block_table = block_table_.value();
  192. args.block_table_ptr = block_table.data_ptr();
  193. args.batch_stride_block_table = block_table.stride(0);
  194. args.page_block_size = page_block_size;
  195. }
  196. else
  197. {
  198. args.block_table_ptr = nullptr;
  199. args.batch_stride_block_table = 0;
  200. args.page_block_size = 0;
  201. }
  202. args.is_gappy = false;
  203. args.cache_batch_idx = nullptr;
  204. args.seqstart_q_ptr = seqlens_q.data_ptr();
  205. args.seqstart_k_ptr = seqlens_k.data_ptr();
  206. args.seqlen_k_ptr = nullptr;
  207. args.batch = b;
  208. args.max_seqlen_q = max_seqlen_q;
  209. args.hdim_q = d;
  210. args.hdim_v = d;
  211. args.nhead_q = h;
  212. args.nhead_k = h_k;
  213. args.num_splits = num_splits;
  214. args.scale_s = softmax_scale;
  215. args.scale_p = 1;
  216. args.scale_o = 1;
  217. args.batch_stride_q = 0;
  218. args.stride_q = q.stride(0);
  219. args.nhead_stride_q = q.stride(1);
  220. args.batch_stride_k = k.stride(0);
  221. args.stride_k = k.stride(1);
  222. args.nhead_stride_k = k.stride(2);
  223. args.batch_stride_v = v.stride(0);
  224. args.stride_v = v.stride(1);
  225. args.nhead_stride_v = v.stride(2);
  226. args.batch_stride_o = 0;
  227. args.stride_o = out.stride(0);
  228. args.nhead_stride_o = out.stride(1);
  229. args.batch_stride_bias = 0;
  230. args.stride_bias = 0;
  231. args.nhead_stride_bias = 0;
  232. args.batch_stride_lse = 0;
  233. args.nhead_stride_lse = 0;
  234. args.batch_stride_lse_acc = 0;
  235. args.nhead_stride_lse_acc = lse_acc.stride(0);
  236. args.split_stride_lse_acc = lse_acc.stride(1);
  237. args.batch_stride_o_acc = 0;
  238. args.nhead_stride_o_acc = out_acc.stride(0);
  239. args.split_stride_o_acc = out_acc.stride(1);
  240. args.stride_o_acc = out_acc.stride(2);
  241. if (has_lse) {
  242. args.lse_ptr = lse.data_ptr();
  243. args.batch_stride_lse = 0;
  244. args.nhead_stride_lse = lse.stride(0);
  245. }
  246. if (alibi_slopes_.has_value()) {
  247. auto alibi_slopes = alibi_slopes_.value();
  248. CHECK_DEVICE(alibi_slopes);
  249. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  250. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  251. args.bias_ptr = alibi_slopes.data_ptr();
  252. args.stride_bias = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  253. }
  254. args.window_size_left = mask.left;
  255. args.window_size_right = mask.right;
  256. args.mask_type = static_cast<ck_tile::index_t>(mask.type);
  257. return args;
  258. }
  259. std::vector<at::Tensor>
  260. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  261. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  262. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  263. std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  264. const at::Tensor &cu_seqlens_q, // b+1
  265. const at::Tensor &cu_seqlens_k, // b+1
  266. std::optional<at::Tensor> & /*seqused_k*/,
  267. std::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size
  268. std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  269. std::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  270. int max_seqlen_q,
  271. const int max_seqlen_k,
  272. const float p_dropout,
  273. const float softmax_scale,
  274. const bool zero_tensors,
  275. bool is_causal,
  276. int window_size_left,
  277. int window_size_right,
  278. const float /*softcap*/,
  279. const bool return_dropout_randval,
  280. std::optional<at::Generator> gen_)
  281. {
  282. auto q_dtype = q.dtype();
  283. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  284. "FlashAttention only support fp16 and bf16 data type");
  285. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  286. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  287. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  288. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  289. std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
  290. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  291. CHECK_DEVICE(cu_seqlens_q);
  292. CHECK_DEVICE(cu_seqlens_k);
  293. at::Tensor block_table;
  294. const bool paged_KV = block_table_.has_value();
  295. if (paged_KV) {
  296. block_table = block_table_.value();
  297. CHECK_DEVICE(block_table);
  298. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  299. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  300. }
  301. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  302. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  303. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  304. CHECK_CONTIGUOUS(cu_seqlens_q);
  305. CHECK_CONTIGUOUS(cu_seqlens_k);
  306. const auto sizes = q.sizes();
  307. const int batch_size = cu_seqlens_q.numel() - 1;
  308. int num_heads = sizes[1];
  309. const int head_size = sizes[2];
  310. const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
  311. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  312. const int num_blocks = !paged_KV ? 0 : k.size(0);
  313. const int page_block_size = !paged_KV ? 1 : k.size(1);
  314. TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128");
  315. if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
  316. // TODO
  317. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  318. // H/t Daniel Haziza
  319. const int total_q = q.size(0);
  320. TORCH_CHECK(batch_size > 0, "batch size must be postive");
  321. TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256");
  322. TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
  323. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  324. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  325. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  326. mask_info mask;
  327. if (is_causal) {
  328. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  329. window_size_right = 0;
  330. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
  331. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
  332. }
  333. else if (window_size_left == -1 && window_size_right == -1) {
  334. mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
  335. }
  336. else {
  337. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  338. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
  339. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
  340. }
  341. CHECK_SHAPE(q, total_q, num_heads, head_size);
  342. if (!paged_KV) {
  343. const int total_k = k.size(0);
  344. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  345. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  346. } else {
  347. CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
  348. CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
  349. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  350. }
  351. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  352. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  353. at::Tensor out;
  354. if (out_.has_value()) {
  355. out = out_.value();
  356. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  357. CHECK_DEVICE(out);
  358. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  359. CHECK_SHAPE(out, total_q, num_heads, head_size);
  360. }
  361. else {
  362. out = torch::empty_like(q);
  363. }
  364. // Otherwise the kernel will be launched from cuda:0 device
  365. at::cuda::CUDAGuard device_guard{q.device()};
  366. auto opts = q.options();
  367. bool has_lse = true;
  368. bool has_dropout = p_dropout > 0.0f;
  369. if (has_dropout)
  370. TORCH_CHECK(!paged_KV, "Paged KV does not support dropout");
  371. at::Tensor softmax_lse;
  372. // TODO - check gradient, only training require lse
  373. softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(torch::kFloat32));
  374. at::Tensor p;
  375. if (return_dropout_randval) {
  376. TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
  377. p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));
  378. }
  379. else {
  380. p = torch::empty({ 0 }, opts);
  381. }
  382. if (zero_tensors)
  383. {
  384. out.zero_();
  385. softmax_lse.fill_(-std::numeric_limits<float>::infinity());
  386. if (return_dropout_randval) {p.zero_();}
  387. }
  388. int num_splits = 0;
  389. num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits);
  390. TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
  391. TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");
  392. auto softmax_lse_accum = torch::empty({num_heads, num_splits, total_q}, opts.dtype(at::kFloat));
  393. auto out_accum = torch::empty({num_heads, num_splits, total_q, head_size}, opts.dtype(at::kFloat));
  394. int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
  395. auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
  396. auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  397. if (p_dropout > 0.0) {
  398. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  399. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  400. // See Note [Acquire lock when using random generators]
  401. std::lock_guard<std::mutex> lock(gen->mutex_);
  402. auto philox_args = gen->philox_cuda_state(counter_offset);
  403. hipLaunchKernelGGL(
  404. flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
  405. }
  406. if (max_seqlen_k > 0) {
  407. auto stream = at::cuda::getCurrentHIPStream().stream();
  408. ck_tile::stream_config stream_config{stream};
  409. if (paged_KV)
  410. {
  411. auto traits =
  412. get_ck_fmha_varlen_fwd_splitkv_traits(
  413. mask,
  414. q_dtype_str,
  415. head_size,
  416. has_lse,
  417. alibi_slopes_.has_value());
  418. auto args =
  419. get_ck_fmha_varlen_fwd_splitkv_args(
  420. has_lse,
  421. mask,
  422. batch_size,
  423. max_seqlen_q,
  424. num_heads,
  425. num_heads_k,
  426. head_size,
  427. page_block_size,
  428. num_splits,
  429. softmax_scale,
  430. q,
  431. k,
  432. v,
  433. cu_seqlens_q,
  434. cu_seqlens_k,
  435. block_table_,
  436. alibi_slopes_,
  437. out,
  438. softmax_lse,
  439. softmax_lse_accum,
  440. out_accum);
  441. float t = fmha_fwd_splitkv(traits, args, stream_config);
  442. TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd_splitkv");
  443. }
  444. else
  445. {
  446. auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
  447. auto traits =
  448. get_ck_fmha_varlen_fwd_traits(
  449. mask,
  450. q_dtype_str,
  451. head_size,
  452. has_dropout,
  453. has_lse,
  454. alibi_slopes_.has_value());
  455. auto args =
  456. get_ck_fmha_varlen_fwd_args(
  457. has_lse,
  458. return_dropout_randval,
  459. mask,
  460. batch_size,
  461. max_seqlen_q,
  462. num_heads,
  463. num_heads_k,
  464. head_size,
  465. q,
  466. k,
  467. v,
  468. cu_seqlens_q,
  469. cu_seqlens_k,
  470. alibi_slopes_,
  471. out,
  472. softmax_lse,
  473. p,
  474. softmax_scale,
  475. p_dropout,
  476. drop_seed_offset);
  477. float t = fmha_fwd(traits, args, stream_config);
  478. TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
  479. }
  480. }
  481. else {
  482. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  483. out.zero_();
  484. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  485. }
  486. return {out, softmax_lse, p, rng_state};
  487. }