mha_varlen_fwd.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. #include "fmha_fwd.hpp"
  6. #include "mask.hpp"
  7. fmha_fwd_traits get_ck_fmha_varlen_fwd_traits(const mask_info &mask,
  8. std::string dtype,
  9. int head_size,
  10. bool has_dropout,
  11. bool has_lse,
  12. bool enable_alibi)
  13. {
  14. return fmha_fwd_traits{head_size,
  15. head_size,
  16. dtype,
  17. true, // is_group_mode
  18. true, // is_v_rowmajor
  19. mask.type,
  20. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  21. has_lse,
  22. has_dropout,
  23. false}; // do_fp8_static_quant
  24. }
  25. fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse,
  26. bool has_dropout_randval,
  27. const mask_info &mask,
  28. // sizes
  29. const int b,
  30. const int max_seqlen_q,
  31. const int h,
  32. const int h_k,
  33. const int d,
  34. // device pointers
  35. const at::Tensor q,
  36. const at::Tensor k,
  37. const at::Tensor v,
  38. const at::Tensor seqlens_q,
  39. const at::Tensor seqlens_k,
  40. c10::optional<at::Tensor> &alibi_slopes_,
  41. at::Tensor out,
  42. at::Tensor softmax_lse,
  43. at::Tensor dropout_randval,
  44. float softmax_scale,
  45. float p_dropout,
  46. std::pair<uint64_t*, uint64_t*> drop_seed_offset)
  47. {
  48. // q: (total_q, nheads, d)
  49. // k: (total_k, nheads_k, d)
  50. // v: (total_k, nheads_k, d)
  51. // o: (total_q, nheads, d)
  52. // alibi_slopes:(batch, nheads) or (nhead)
  53. // lse: (nheads, total_q)
  54. // randval: (nheads, total_q, max_seqlen_k)
  55. ck_tile::index_t total_q = q.size(0);
  56. ck_tile::index_t total_k = k.size(0);
  57. ck_tile::index_t stride_q = q.stride(0);
  58. ck_tile::index_t stride_k = k.stride(0);
  59. ck_tile::index_t stride_v = v.stride(0);
  60. ck_tile::index_t stride_o = out.stride(0);
  61. ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
  62. ck_tile::index_t nhead_stride_q = q.stride(1);
  63. ck_tile::index_t nhead_stride_k = k.stride(1);
  64. ck_tile::index_t nhead_stride_v = v.stride(1);
  65. ck_tile::index_t nhead_stride_o = out.stride(1);
  66. ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
  67. ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
  68. ck_tile::index_t batch_stride_q = 0;
  69. ck_tile::index_t batch_stride_k = 0;
  70. ck_tile::index_t batch_stride_v = 0;
  71. ck_tile::index_t batch_stride_o = 0;
  72. ck_tile::index_t batch_stride_lse = 0;
  73. ck_tile::index_t batch_stride_randval = 0;
  74. void *alibi_slopes_ptr = nullptr;
  75. ck_tile::index_t stride_alibi_slopes = 0;
  76. if (alibi_slopes_.has_value()) {
  77. auto alibi_slopes = alibi_slopes_.value();
  78. CHECK_DEVICE(alibi_slopes);
  79. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  80. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  81. alibi_slopes_ptr = alibi_slopes.data_ptr();
  82. stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  83. }
  84. return fmha_fwd_args{q.data_ptr(),
  85. k.data_ptr(),
  86. v.data_ptr(),
  87. alibi_slopes_ptr, // bias
  88. has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
  89. has_lse ? softmax_lse.data_ptr() : nullptr,
  90. out.data_ptr(),
  91. seqlens_q.data_ptr(), // seqstart_q
  92. seqlens_k.data_ptr(), // seqstart_k
  93. nullptr, // seqlen_kpads
  94. total_q,
  95. total_k,
  96. b,
  97. max_seqlen_q,
  98. d, // hdim_q
  99. d, // hdim_v
  100. h, // nhead
  101. h_k, // nhead_k
  102. softmax_scale, // scale_s
  103. 1, // scale_p
  104. 1, // scale_o
  105. stride_q,
  106. stride_k,
  107. stride_v,
  108. stride_alibi_slopes,
  109. stride_randval,
  110. stride_o,
  111. nhead_stride_q,
  112. nhead_stride_k,
  113. nhead_stride_v,
  114. 0, // nhead_stride_bias, FA without bias
  115. nhead_stride_randval,
  116. nhead_stride_lse,
  117. nhead_stride_o,
  118. batch_stride_q,
  119. batch_stride_k,
  120. batch_stride_v,
  121. 0, // batch_stride_bias, FA without bias
  122. batch_stride_randval,
  123. batch_stride_lse,
  124. batch_stride_o,
  125. mask.left,
  126. mask.right,
  127. static_cast<ck_tile::index_t>(mask.type),
  128. p_dropout,
  129. has_dropout_randval,
  130. drop_seed_offset};
  131. }
  132. std::vector<at::Tensor>
  133. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  134. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  135. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  136. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  137. const at::Tensor &cu_seqlens_q, // b+1
  138. const at::Tensor &cu_seqlens_k, // b+1
  139. c10::optional<at::Tensor> & /*seqused_k*/,
  140. c10::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size
  141. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  142. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  143. int max_seqlen_q,
  144. const int max_seqlen_k,
  145. const float p_dropout,
  146. const float softmax_scale,
  147. const bool zero_tensors,
  148. bool is_causal,
  149. int window_size_left,
  150. int window_size_right,
  151. const float /*softcap*/,
  152. const bool return_dropout_randval,
  153. c10::optional<at::Generator> gen_)
  154. {
  155. auto q_dtype = q.dtype();
  156. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  157. "FlashAttention only support fp16 and bf16 data type");
  158. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  159. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  160. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  161. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  162. std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
  163. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  164. CHECK_DEVICE(cu_seqlens_q);
  165. CHECK_DEVICE(cu_seqlens_k);
  166. // TODO - Support paged_KV
  167. const bool paged_KV = block_table_.has_value();
  168. TORCH_CHECK(!paged_KV, "CK does not support paged_KV yet");
  169. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  170. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  171. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  172. CHECK_CONTIGUOUS(cu_seqlens_q);
  173. CHECK_CONTIGUOUS(cu_seqlens_k);
  174. const auto sizes = q.sizes();
  175. const int batch_size = cu_seqlens_q.numel() - 1;
  176. int num_heads = sizes[1];
  177. const int head_size = sizes[2];
  178. const int num_heads_k = k.size(1);
  179. const int max_num_blocks_per_seq = 0;
  180. const int num_blocks = 0;
  181. if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
  182. // TODO
  183. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  184. // H/t Daniel Haziza
  185. const int total_q = q.size(0);
  186. const int total_k = k.size(0);
  187. TORCH_CHECK(batch_size > 0, "batch size must be postive");
  188. TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256");
  189. TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
  190. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  191. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  192. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  193. mask_info mask;
  194. if (is_causal) {
  195. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  196. window_size_right = 0;
  197. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
  198. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
  199. }
  200. else if (window_size_left == -1 && window_size_right == -1) {
  201. mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
  202. }
  203. else {
  204. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  205. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
  206. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
  207. }
  208. CHECK_SHAPE(q, total_q, num_heads, head_size);
  209. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  210. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  211. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  212. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  213. at::Tensor out;
  214. if (out_.has_value()) {
  215. out = out_.value();
  216. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  217. CHECK_DEVICE(out);
  218. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  219. CHECK_SHAPE(out, total_q, num_heads, head_size);
  220. }
  221. else {
  222. out = torch::empty_like(q);
  223. }
  224. // Otherwise the kernel will be launched from cuda:0 device
  225. at::cuda::CUDAGuard device_guard{q.device()};
  226. auto opts = q.options();
  227. bool has_lse = true;
  228. bool has_dropout = p_dropout > 0.0f;
  229. at::Tensor softmax_lse;
  230. // TODO - check gradient, only training require lse
  231. softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(torch::kFloat32));
  232. at::Tensor p;
  233. if (return_dropout_randval) {
  234. TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
  235. p = torch::empty({num_heads, total_q, max_seqlen_k}, opts.dtype(torch::kUInt8));
  236. }
  237. else {
  238. p = torch::empty({ 0 }, opts);
  239. }
  240. if (zero_tensors)
  241. {
  242. out.zero_();
  243. softmax_lse.fill_(-std::numeric_limits<float>::infinity());
  244. if (return_dropout_randval) {p.zero_();}
  245. }
  246. int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
  247. auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
  248. auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  249. if (p_dropout > 0.0) {
  250. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  251. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  252. // See Note [Acquire lock when using random generators]
  253. std::lock_guard<std::mutex> lock(gen->mutex_);
  254. auto philox_args = gen->philox_cuda_state(counter_offset);
  255. hipLaunchKernelGGL(
  256. flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
  257. }
  258. if (max_seqlen_k > 0) {
  259. auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
  260. auto stream = at::cuda::getCurrentHIPStream().stream();
  261. ck_tile::stream_config stream_config{stream};
  262. auto traits =
  263. get_ck_fmha_varlen_fwd_traits(
  264. mask,
  265. q_dtype_str,
  266. head_size,
  267. has_dropout,
  268. has_lse,
  269. alibi_slopes_.has_value());
  270. auto args =
  271. get_ck_fmha_varlen_fwd_args(
  272. has_lse,
  273. return_dropout_randval,
  274. mask,
  275. batch_size,
  276. max_seqlen_q,
  277. num_heads,
  278. num_heads_k,
  279. head_size,
  280. q,
  281. k,
  282. v,
  283. cu_seqlens_q,
  284. cu_seqlens_k,
  285. alibi_slopes_,
  286. out,
  287. softmax_lse,
  288. p,
  289. softmax_scale,
  290. p_dropout,
  291. drop_seed_offset);
  292. float t = fmha_fwd(traits, args, stream_config);
  293. TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
  294. }
  295. else {
  296. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  297. out.zero_();
  298. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  299. }
  300. return {out, softmax_lse, p, rng_state};
  301. }