mha_varlen_fwd.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. #include "fmha_fwd.hpp"
  6. #include "mask.hpp"
  7. fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
  8. std::string dtype,
  9. int head_size,
  10. bool has_dropout,
  11. bool has_lse,
  12. bool enable_alibi)
  13. {
  14. return fmha_fwd_traits{head_size,
  15. head_size,
  16. dtype,
  17. true, // is_group_mode
  18. true, // is_v_rowmajor
  19. mask.type,
  20. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  21. has_lse,
  22. has_dropout,
  23. false}; // do_fp8_static_quant
  24. }
  25. fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
  26. bool has_dropout_randval,
  27. const mask_info &mask,
  28. // sizes
  29. const int b,
  30. const int max_seqlen_q,
  31. const int h,
  32. const int h_k,
  33. const int d,
  34. // device pointers
  35. const at::Tensor q,
  36. const at::Tensor k,
  37. const at::Tensor v,
  38. const at::Tensor seqlens_q,
  39. const at::Tensor seqlens_k,
  40. c10::optional<at::Tensor> &alibi_slopes_,
  41. at::Tensor out,
  42. at::Tensor softmax_lse,
  43. at::Tensor dropout_randval,
  44. float softmax_scale,
  45. float p_dropout,
  46. uint64_t drop_seed,
  47. uint64_t drop_offset)
  48. {
  49. // q: (total_q, nheads, d)
  50. // k: (total_k, nheads_k, d)
  51. // v: (total_k, nheads_k, d)
  52. // o: (total_q, nheads, d)
  53. // alibi_slopes:(batch, nheads) or (nhead)
  54. // lse: (batch, nheads, max_seqlen_q)
  55. // randval: (nheads, total_q, max_seqlen_k)
  56. ck_tile::index_t total_q = q.size(0);
  57. ck_tile::index_t total_k = k.size(0);
  58. ck_tile::index_t stride_q = q.stride(0);
  59. ck_tile::index_t stride_k = k.stride(0);
  60. ck_tile::index_t stride_v = v.stride(0);
  61. ck_tile::index_t stride_o = out.stride(0);
  62. ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
  63. ck_tile::index_t nhead_stride_q = q.stride(1);
  64. ck_tile::index_t nhead_stride_k = k.stride(1);
  65. ck_tile::index_t nhead_stride_v = v.stride(1);
  66. ck_tile::index_t nhead_stride_o = out.stride(1);
  67. ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
  68. ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
  69. ck_tile::index_t batch_stride_q = 0;
  70. ck_tile::index_t batch_stride_k = 0;
  71. ck_tile::index_t batch_stride_v = 0;
  72. ck_tile::index_t batch_stride_o = 0;
  73. ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
  74. ck_tile::index_t batch_stride_randval = 0;
  75. void *alibi_slopes_ptr = nullptr;
  76. ck_tile::index_t stride_alibi_slopes = 0;
  77. if (alibi_slopes_.has_value()) {
  78. auto alibi_slopes = alibi_slopes_.value();
  79. CHECK_DEVICE(alibi_slopes);
  80. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  81. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  82. alibi_slopes_ptr = alibi_slopes.data_ptr();
  83. stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  84. }
  85. return fmha_fwd_args{q.data_ptr(),
  86. k.data_ptr(),
  87. v.data_ptr(),
  88. alibi_slopes_ptr, // bias
  89. has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
  90. nullptr, // lse_acc
  91. nullptr, // o_acc
  92. has_lse ? softmax_lse.data_ptr() : nullptr,
  93. out.data_ptr(),
  94. seqlens_q.data_ptr(), // seqstart_q
  95. seqlens_k.data_ptr(), // seqstart_k
  96. nullptr, // seqlen_kpads
  97. total_q,
  98. total_k,
  99. b,
  100. max_seqlen_q,
  101. d, // hdim_q
  102. d, // hdim_v
  103. h, // nhead
  104. h_k, // nhead_k
  105. 1, // num_splits
  106. softmax_scale, // scale_s
  107. 1, // scale_p
  108. 1, // scale_o
  109. stride_q,
  110. stride_k,
  111. stride_v,
  112. stride_alibi_slopes,
  113. stride_randval,
  114. 0, // stride_o_acc,
  115. stride_o,
  116. nhead_stride_q,
  117. nhead_stride_k,
  118. nhead_stride_v,
  119. 0, // nhead_stride_bias, FA without bias
  120. nhead_stride_randval,
  121. nhead_stride_lse,
  122. 0, // nhead_stride_lse_acc
  123. 0, // nhead_stride_o_acc
  124. nhead_stride_o,
  125. batch_stride_q,
  126. batch_stride_k,
  127. batch_stride_v,
  128. 0, // batch_stride_bias, FA without bias
  129. batch_stride_randval,
  130. batch_stride_lse,
  131. 0, // batch_stride_lse_acc
  132. 0, // batch_stride_o_acc
  133. batch_stride_o,
  134. 0, // split_stride_lse_acc
  135. 0, // split_stride_o_acc
  136. mask.left,
  137. mask.right,
  138. static_cast<ck_tile::index_t>(mask.type),
  139. p_dropout,
  140. has_dropout_randval,
  141. {drop_seed, drop_offset}};
  142. }
  143. std::vector<at::Tensor>
  144. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  145. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  146. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  147. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  148. const at::Tensor &cu_seqlens_q, // b+1
  149. const at::Tensor &cu_seqlens_k, // b+1
  150. c10::optional<at::Tensor> & /*seqused_k*/,
  151. c10::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size
  152. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  153. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  154. int max_seqlen_q,
  155. const int max_seqlen_k,
  156. const float p_dropout,
  157. const float softmax_scale,
  158. const bool zero_tensors,
  159. bool is_causal,
  160. int window_size_left,
  161. int window_size_right,
  162. const float /*softcap*/,
  163. const bool return_dropout_randval,
  164. c10::optional<at::Generator> gen_)
  165. {
  166. auto q_dtype = q.dtype();
  167. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  168. "FlashAttention only support fp16 and bf16 data type");
  169. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  170. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  171. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  172. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  173. std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
  174. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  175. CHECK_DEVICE(cu_seqlens_q);
  176. CHECK_DEVICE(cu_seqlens_k);
  177. // TODO - Support paged_KV
  178. const bool paged_KV = block_table_.has_value();
  179. TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet");
  180. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  181. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  182. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  183. CHECK_CONTIGUOUS(cu_seqlens_q);
  184. CHECK_CONTIGUOUS(cu_seqlens_k);
  185. const auto sizes = q.sizes();
  186. const int batch_size = cu_seqlens_q.numel() - 1;
  187. int num_heads = sizes[1];
  188. const int head_size_og = sizes[2];
  189. const int num_heads_k = k.size(1);
  190. const int max_num_blocks_per_seq = 0;
  191. const int num_blocks = 0;
  192. if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
  193. // TODO
  194. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  195. // H/t Daniel Haziza
  196. const int total_q = q.size(0);
  197. const int total_k = k.size(0);
  198. TORCH_CHECK(batch_size > 0, "batch size must be postive");
  199. TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
  200. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  201. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  202. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  203. mask_info mask;
  204. if (is_causal) {
  205. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  206. window_size_right = 0;
  207. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
  208. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
  209. }
  210. else if (window_size_left == -1 && window_size_right == -1) {
  211. mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
  212. }
  213. else {
  214. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  215. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
  216. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
  217. }
  218. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  219. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  220. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  221. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  222. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  223. at::Tensor q_padded, k_padded, v_padded;
  224. if (head_size_og % 8 != 0) {
  225. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  226. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  227. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  228. }
  229. else {
  230. q_padded = q;
  231. k_padded = k;
  232. v_padded = v;
  233. }
  234. at::Tensor out;
  235. if (out_.has_value()) {
  236. out = out_.value();
  237. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  238. CHECK_DEVICE(out);
  239. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  240. CHECK_SHAPE(out, total_q, num_heads, head_size_og);
  241. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  242. }
  243. else {
  244. out = torch::empty_like(q_padded);
  245. }
  246. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  247. const int head_size_8x = round_multiple(head_size_og, 8);
  248. // Otherwise the kernel will be launched from cuda:0 device
  249. // Cast to char to avoid compiler warning about narrowing
  250. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  251. auto opts = q.options();
  252. bool has_lse = true;
  253. bool has_dropout = p_dropout > 0.0f;
  254. at::Tensor softmax_lse;
  255. // TODO - check gradient, only training require lse
  256. softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(torch::kFloat32));
  257. at::Tensor p;
  258. if (return_dropout_randval) {
  259. TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
  260. p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));
  261. }
  262. if (zero_tensors)
  263. {
  264. out.zero_();
  265. softmax_lse.fill_(-std::numeric_limits<float>::infinity());
  266. if (return_dropout_randval) {p.zero_();}
  267. }
  268. uint64_t drop_seed = 1, drop_offset = 0;
  269. int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
  270. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  271. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  272. if (p_dropout > 0.0) {
  273. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  274. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  275. // See Note [Acquire lock when using random generators]
  276. std::lock_guard<std::mutex> lock(gen->mutex_);
  277. auto philox_args = gen->philox_cuda_state(counter_offset);
  278. std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
  279. }
  280. rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
  281. rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
  282. if (max_seqlen_k > 0) {
  283. auto stream = at::cuda::getCurrentHIPStream().stream();
  284. ck_tile::stream_config stream_config{stream};
  285. auto traits =
  286. get_ck_fmha_varlen_fwd_traits(mask, q_dtype_str, head_size_8x, has_dropout, has_lse, alibi_slopes_.has_value());
  287. auto args =
  288. get_ck_fmha_varlen_fwd_args(
  289. has_lse,
  290. return_dropout_randval,
  291. mask,
  292. batch_size,
  293. max_seqlen_q,
  294. num_heads,
  295. num_heads_k,
  296. head_size_8x,
  297. q_padded,
  298. k_padded,
  299. v_padded,
  300. cu_seqlens_q,
  301. cu_seqlens_k,
  302. alibi_slopes_,
  303. out,
  304. softmax_lse,
  305. p,
  306. softmax_scale,
  307. p_dropout,
  308. drop_seed,
  309. drop_offset);
  310. fmha_fwd(traits, args, stream_config);
  311. }
  312. else {
  313. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  314. out.zero_();
  315. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  316. }
  317. at::Tensor out_padded = out;
  318. if (head_size_og % 8 != 0) {
  319. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  320. if (out_.has_value()) { out_.value().copy_(out); }
  321. }
  322. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
  323. }