mha_fwd.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. #include "fmha_fwd.hpp"
  6. #include "mask.hpp"
  7. fmha_fwd_traits get_ck_fmha_fwd_traits(const mask_info &mask,
  8. std::string dtype,
  9. int head_size,
  10. bool has_dropout,
  11. bool has_lse,
  12. bool enable_alibi)
  13. {
  14. return fmha_fwd_traits{head_size,
  15. head_size,
  16. dtype,
  17. false, // is_group_mode
  18. true, // is_v_rowmajor
  19. mask.type,
  20. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  21. has_lse,
  22. has_dropout,
  23. false}; // do_fp8_static_quant
  24. }
  25. fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse,
  26. bool has_dropout_randval,
  27. const mask_info &mask,
  28. // sizes
  29. const int b,
  30. const int seqlen_q,
  31. const int seqlen_k,
  32. const int h,
  33. const int h_k,
  34. const int d,
  35. // device pointers
  36. const at::Tensor q,
  37. const at::Tensor k,
  38. const at::Tensor v,
  39. std::optional<at::Tensor> &alibi_slopes_,
  40. at::Tensor out,
  41. at::Tensor softmax_lse,
  42. at::Tensor dropout_randval,
  43. float softmax_scale,
  44. float p_dropout,
  45. std::pair<uint64_t*, uint64_t*> drop_seed_offset)
  46. {
  47. // q: (batch_size, seqlen_q, nheads, d)
  48. // k: (batch_size, seqlen_k, nheads_k, d)
  49. // v: (batch_size, seqlen_k, nheads_k, d)
  50. // o: (batch_size, seqlen_q, nheads, d)
  51. // alibi_slopes:(batch_size, nheads) or (nhead)
  52. // lse: (batch_size, nheads, seqlen_q)
  53. // randval: (batch_size, nheads, seqlen_q, seqlen_k)
  54. ck_tile::index_t stride_q = q.stride(1);
  55. ck_tile::index_t stride_k = k.stride(1);
  56. ck_tile::index_t stride_v = v.stride(1);
  57. ck_tile::index_t stride_o = out.stride(1);
  58. ck_tile::index_t stride_randval = has_dropout_randval ? dropout_randval.stride(2) : 0;
  59. ck_tile::index_t nhead_stride_q = q.stride(2);
  60. ck_tile::index_t nhead_stride_k = k.stride(2);
  61. ck_tile::index_t nhead_stride_v = v.stride(2);
  62. ck_tile::index_t nhead_stride_o = out.stride(2);
  63. ck_tile::index_t nhead_stride_lse = has_lse ? softmax_lse.stride(1) : 0;
  64. ck_tile::index_t nhead_stride_randval = has_dropout_randval ? dropout_randval.stride(1) : 0;
  65. ck_tile::index_t batch_stride_q = q.stride(0);
  66. ck_tile::index_t batch_stride_k = k.stride(0);
  67. ck_tile::index_t batch_stride_v = v.stride(0);
  68. ck_tile::index_t batch_stride_o = out.stride(0);
  69. ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0;
  70. ck_tile::index_t batch_stride_randval = has_dropout_randval ? dropout_randval.stride(0) : 0;
  71. void *alibi_slopes_ptr = nullptr;
  72. ck_tile::index_t stride_alibi_slopes = 0;
  73. if (alibi_slopes_.has_value()) {
  74. auto alibi_slopes = alibi_slopes_.value();
  75. CHECK_DEVICE(alibi_slopes);
  76. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  77. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  78. alibi_slopes_ptr = alibi_slopes.data_ptr();
  79. stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  80. }
  81. return fmha_fwd_args{q.data_ptr(),
  82. k.data_ptr(),
  83. v.data_ptr(),
  84. alibi_slopes_ptr, // bias
  85. has_dropout_randval ? dropout_randval.data_ptr() : nullptr,
  86. has_lse ? softmax_lse.data_ptr() : nullptr,
  87. out.data_ptr(),
  88. nullptr, // seqstart_q
  89. nullptr, // seqstart_k
  90. nullptr,
  91. seqlen_q,
  92. seqlen_k,
  93. b,
  94. seqlen_q, // max_seqlen_q
  95. d, // hdim_q
  96. d, // hdim_v
  97. h, // nhead
  98. h_k, // nhead_k
  99. softmax_scale, // scale_s
  100. 1, // scale_p
  101. 1, // scale_o
  102. stride_q,
  103. stride_k,
  104. stride_v,
  105. stride_alibi_slopes,
  106. stride_randval,
  107. stride_o,
  108. nhead_stride_q,
  109. nhead_stride_k,
  110. nhead_stride_v,
  111. 0, // nhead_stride_bias, FA without bias
  112. nhead_stride_randval,
  113. nhead_stride_lse,
  114. nhead_stride_o,
  115. batch_stride_q,
  116. batch_stride_k,
  117. batch_stride_v,
  118. 0, // batch_stride_bias, FA without bias
  119. batch_stride_randval,
  120. batch_stride_lse,
  121. batch_stride_o,
  122. mask.left,
  123. mask.right,
  124. static_cast<ck_tile::index_t>(mask.type),
  125. p_dropout,
  126. has_dropout_randval,
  127. drop_seed_offset};
  128. }
  129. std::vector<at::Tensor>
  130. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
  131. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
  132. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
  133. std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
  134. std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  135. const float p_dropout,
  136. const float softmax_scale,
  137. bool is_causal,
  138. int window_size_left,
  139. int window_size_right,
  140. const float /*softcap*/,
  141. const bool return_dropout_randval,
  142. std::optional<at::Generator> gen_)
  143. {
  144. auto q_dtype = q.dtype();
  145. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  146. "FlashAttention only support fp16 and bf16 data type");
  147. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  148. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  149. std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
  150. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  151. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  152. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  153. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  154. const auto sizes = q.sizes();
  155. const int batch_size = sizes[0];
  156. int seqlen_q = sizes[1];
  157. int num_heads = sizes[2];
  158. const int head_size = sizes[3];
  159. const int seqlen_k = k.size(1);
  160. const int num_heads_k = k.size(2);
  161. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  162. TORCH_CHECK(head_size <= 256, "CK only supports head dimension at most 256");
  163. TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
  164. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  165. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  166. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  167. // causal=true is the same as causal=false in this case
  168. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  169. mask_info mask;
  170. if (is_causal) {
  171. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  172. window_size_right = 0;
  173. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
  174. mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
  175. }
  176. else if (window_size_left == -1 && window_size_right == -1) {
  177. mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
  178. }
  179. else {
  180. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  181. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
  182. mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
  183. }
  184. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  185. // H/t Daniel Haziza
  186. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size % 8 == 0 && !alibi_slopes_.has_value();
  187. const int ngroups = num_heads / num_heads_k;
  188. if (seqlenq_ngroups_swapped) {
  189. q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
  190. seqlen_q = ngroups;
  191. num_heads = num_heads_k;
  192. }
  193. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  194. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  195. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  196. at::Tensor out;
  197. if (out_.has_value()) {
  198. out = out_.value();
  199. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  200. CHECK_DEVICE(out);
  201. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  202. CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size);
  203. if (seqlenq_ngroups_swapped) {
  204. out = out.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
  205. }
  206. }
  207. else {
  208. out = torch::empty_like(q);
  209. }
  210. // Otherwise the kernel will be launched from cuda:0 device
  211. at::cuda::CUDAGuard device_guard{q.device()};
  212. auto opts = q.options();
  213. bool has_lse = true;
  214. bool has_dropout = p_dropout > 0.0f;
  215. at::Tensor softmax_lse;
  216. // TODO - check gradient, only training require lse
  217. softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(torch::kFloat32));
  218. at::Tensor p;
  219. if (return_dropout_randval) {
  220. TORCH_CHECK(has_dropout, "return_dropout_randval require p_dropout > 0");
  221. p = torch::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(torch::kUInt8));
  222. }
  223. else {
  224. p = torch::empty({ 0 }, opts);
  225. }
  226. int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
  227. auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
  228. auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  229. if (p_dropout > 0.0) {
  230. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  231. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  232. // See Note [Acquire lock when using random generators]
  233. std::lock_guard<std::mutex> lock(gen->mutex_);
  234. auto philox_args = gen->philox_cuda_state(counter_offset);
  235. hipLaunchKernelGGL(
  236. flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr);
  237. }
  238. if (seqlen_k > 0) {
  239. auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
  240. auto stream = at::cuda::getCurrentHIPStream().stream();
  241. ck_tile::stream_config stream_config{stream};
  242. auto traits =
  243. get_ck_fmha_fwd_traits(
  244. mask,
  245. q_dtype_str,
  246. head_size,
  247. has_dropout,
  248. has_lse,
  249. alibi_slopes_.has_value());
  250. auto args =
  251. get_ck_fmha_fwd_args(
  252. has_lse,
  253. return_dropout_randval,
  254. mask,
  255. batch_size,
  256. seqlen_q,
  257. seqlen_k,
  258. num_heads,
  259. num_heads_k,
  260. head_size,
  261. q,
  262. k,
  263. v,
  264. alibi_slopes_,
  265. out,
  266. softmax_lse,
  267. p,
  268. softmax_scale,
  269. p_dropout,
  270. drop_seed_offset);
  271. float t = fmha_fwd(traits, args, stream_config);
  272. TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd");
  273. }
  274. else {
  275. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  276. out.zero_();
  277. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  278. }
  279. if (seqlenq_ngroups_swapped) {
  280. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
  281. q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
  282. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  283. }
  284. return {out, softmax_lse, p, rng_state};
  285. }