mha_bwd.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. #include "fmha_bwd.hpp"
  6. #include "mask.hpp"
  7. fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
  8. std::string dtype,
  9. int head_size,
  10. bool has_dropout,
  11. bool enable_alibi)
  12. {
  13. return fmha_bwd_traits{head_size,
  14. head_size,
  15. dtype,
  16. false, // is_group_mode
  17. mask.type,
  18. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  19. false, // has_dbias
  20. has_dropout};
  21. }
  22. fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
  23. // sizes
  24. const int b,
  25. const int seqlen_q,
  26. const int seqlen_k,
  27. const int h,
  28. const int h_k,
  29. const int hdim,
  30. // device pointers
  31. const at::Tensor q,
  32. const at::Tensor k,
  33. const at::Tensor v,
  34. c10::optional<at::Tensor> &alibi_slopes_,
  35. const at::Tensor out,
  36. const at::Tensor softmax_lse,
  37. const at::Tensor dout,
  38. at::Tensor d,
  39. at::Tensor dq,
  40. at::Tensor dk,
  41. at::Tensor dv,
  42. float softmax_scale,
  43. float p_dropout,
  44. uint64_t drop_seed,
  45. uint64_t drop_offset)
  46. {
  47. // q: (batch_size, seqlen_q, nheads, hdim)
  48. // k: (batch_size, seqlen_k, nheads_k, hdim)
  49. // v: (batch_size, seqlen_k, nheads_k, hdim)
  50. // o: (batch_size, seqlen_q, nheads, hdim)
  51. // dq: (batch_size, seqlen_q, nheads, hdim)
  52. // dk_expanded: (batch_size, seqlen_k, nheads, hdim)
  53. // dv_expanded: (batch_size, seqlen_k, nheads, hdim)
  54. // do: (batch_size, seqlen_q, nheads, hdim)
  55. // alibi_slopes:(batch_size, nheads) or (nhead)
  56. // lse: (batch_size, nheads, seqlen_q)
  57. // d: (batch_size, nheads, seqlen_q)
  58. ck_tile::index_t stride_q = q.stride(1);
  59. ck_tile::index_t stride_k = k.stride(1);
  60. ck_tile::index_t stride_v = v.stride(1);
  61. ck_tile::index_t stride_o = out.stride(1);
  62. ck_tile::index_t stride_do = dout.stride(1);
  63. ck_tile::index_t stride_dk = dk.stride(1);
  64. ck_tile::index_t stride_dv = dv.stride(1);
  65. ck_tile::index_t nhead_stride_q = q.stride(2);
  66. ck_tile::index_t nhead_stride_k = k.stride(2);
  67. ck_tile::index_t nhead_stride_v = v.stride(2);
  68. ck_tile::index_t nhead_stride_o = out.stride(2);
  69. ck_tile::index_t nhead_stride_do = dout.stride(2);
  70. ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
  71. ck_tile::index_t batch_stride_q = q.stride(0);
  72. ck_tile::index_t batch_stride_k = k.stride(0);
  73. ck_tile::index_t batch_stride_v = v.stride(0);
  74. ck_tile::index_t batch_stride_o = out.stride(0);
  75. ck_tile::index_t batch_stride_do = dout.stride(0);
  76. ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
  77. ck_tile::index_t batch_stride_dk = dk.stride(0);
  78. ck_tile::index_t batch_stride_dv = dv.stride(0);
  79. float p_undrop = 1.0 - p_dropout;
  80. void *alibi_slopes_ptr = nullptr;
  81. ck_tile::index_t stride_alibi_slopes = 0;
  82. if (alibi_slopes_.has_value()) {
  83. auto alibi_slopes = alibi_slopes_.value();
  84. CHECK_DEVICE(alibi_slopes);
  85. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  86. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  87. alibi_slopes_ptr = alibi_slopes.data_ptr();
  88. stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  89. }
  90. return fmha_bwd_args{q.data_ptr(),
  91. k.data_ptr(),
  92. v.data_ptr(),
  93. alibi_slopes_ptr, // bias
  94. out.data_ptr(),
  95. softmax_lse.data_ptr(),
  96. dout.data_ptr(),
  97. d.data_ptr(),
  98. nullptr, // rand_val
  99. dq.data_ptr(),
  100. dk.data_ptr(),
  101. dv.data_ptr(),
  102. nullptr, // dbias
  103. nullptr, // seqstart_q
  104. nullptr, // seqstart_k
  105. nullptr, // seqlen_k_ptr
  106. seqlen_q,
  107. seqlen_k,
  108. b,
  109. seqlen_q, // max_seqlen_q
  110. seqlen_k, // max_seqlen_k
  111. hdim, // hdim_q
  112. hdim, // hdim_v
  113. h, // nhead
  114. h_k, // nhead_k
  115. softmax_scale,
  116. stride_q,
  117. stride_k,
  118. stride_v,
  119. stride_alibi_slopes,
  120. stride_o,
  121. 0, // stride_randval
  122. stride_do,
  123. stride_dk,
  124. stride_dv,
  125. 0, // stride_dbias, FA without bias
  126. nhead_stride_q,
  127. nhead_stride_k,
  128. nhead_stride_v,
  129. 0, // nhead_stride_bias, FA without bias
  130. nhead_stride_o,
  131. 0, // nhead_stride_randval
  132. nhead_stride_do,
  133. nhead_stride_lse,
  134. 0, // nhead_stride_dbias, FA without dbias
  135. batch_stride_q,
  136. batch_stride_k,
  137. batch_stride_v,
  138. 0 , // batch_stride_bias, FA without bias
  139. batch_stride_o,
  140. 0, // batch_stride_randval
  141. batch_stride_do,
  142. batch_stride_lse,
  143. batch_stride_dk,
  144. batch_stride_dv,
  145. 0 , // batch_stride_dbias, FA without dbias
  146. mask.left,
  147. mask.right,
  148. static_cast<ck_tile::index_t>(mask.type),
  149. p_dropout,
  150. p_undrop,
  151. false, // s_randval
  152. {drop_seed, drop_offset}};
  153. }
  154. std::vector<at::Tensor>
  155. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  156. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  157. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  158. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  159. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  160. const at::Tensor &softmax_lse, // b x h x seqlen_q
  161. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  162. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  163. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  164. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  165. const float p_dropout, // probability to drop
  166. const float softmax_scale,
  167. const bool is_causal,
  168. int window_size_left,
  169. int window_size_right,
  170. const float /*softcap*/,
  171. const bool deterministic,
  172. c10::optional<at::Generator> gen_,
  173. c10::optional<at::Tensor> &rng_state)
  174. {
  175. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  176. TORCH_CHECK(false, "This flash attention build does not support backward.");
  177. #endif
  178. if (is_causal) { window_size_right = 0; }
  179. bool is_dropout = p_dropout > 0.0;
  180. auto stream = at::cuda::getCurrentHIPStream().stream();
  181. auto q_dtype = q.dtype();
  182. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  183. "FlashAttention only support fp16 and bf16 data type");
  184. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  185. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  186. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  187. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  188. std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
  189. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  190. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  191. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  192. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  193. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  194. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  195. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  196. const auto sizes = q.sizes();
  197. const int batch_size = sizes[0];
  198. const int seqlen_q = sizes[1];
  199. const int num_heads = sizes[2];
  200. const int head_size_og = dout.size(3); // unpadded hdim
  201. const int head_size_8x = sizes[3];
  202. const int seqlen_k = k.size(1);
  203. const int num_heads_k = k.size(2);
  204. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  205. TORCH_CHECK(head_size_8x % 8 == 0, "head_size_8x should be a multiple of 8");
  206. TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
  207. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  208. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  209. TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
  210. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  211. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  212. mask_info mask;
  213. if (is_causal) {
  214. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
  215. mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
  216. }
  217. else if (window_size_left == -1 && window_size_right == -1) {
  218. mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
  219. }
  220. else {
  221. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  222. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
  223. mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
  224. }
  225. // q, k, v, out had been padded in mha_fwd
  226. // dq_, dk_, dv_ are also padded tensor
  227. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_8x);
  228. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_8x);
  229. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_8x);
  230. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_8x);
  231. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  232. at::Tensor dq, dk, dv;
  233. if (dq_.has_value()) {
  234. dq = dq_.value();
  235. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  236. CHECK_DEVICE(dq);
  237. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  238. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size_8x);
  239. } else {
  240. dq = torch::empty_like(q);
  241. }
  242. if (dk_.has_value()) {
  243. dk = dk_.value();
  244. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  245. CHECK_DEVICE(dk);
  246. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  247. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size_8x);
  248. } else {
  249. dk = torch::empty_like(k);
  250. }
  251. if (dv_.has_value()) {
  252. dv = dv_.value();
  253. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  254. CHECK_DEVICE(dv);
  255. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  256. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_8x);
  257. } else {
  258. dv = torch::empty_like(v);
  259. }
  260. at::Tensor dout_padded;
  261. if (head_size_og % 8 != 0) {
  262. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  263. } else {
  264. dout_padded = dout;
  265. }
  266. // Cast to char to avoid compiler warning about narrowing
  267. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  268. auto opts = q.options();
  269. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  270. // TODO - CK does not support dq_accum
  271. at::Tensor dk_expanded, dv_expanded;
  272. if (num_heads_k != num_heads) { // MQA / GQA
  273. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
  274. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size_8x}, opts);
  275. } else {
  276. dk_expanded = dk;
  277. dv_expanded = dv;
  278. }
  279. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  280. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  281. uint64_t drop_seed = 1, drop_offset = 0;
  282. int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
  283. if (rng_state.has_value()) {
  284. uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  285. drop_seed = d[0];
  286. drop_offset = d[1];
  287. } else if(is_dropout) {
  288. // See Note [Acquire lock when using random generators]
  289. std::lock_guard<std::mutex> lock(gen->mutex_);
  290. auto philox_args = gen->philox_cuda_state(counter_offset);
  291. std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
  292. }
  293. if (seqlen_q > 0) {
  294. ck_tile::stream_config stream_config{stream};
  295. dq.zero_(); // ck use atomic operation on dq
  296. auto traits =
  297. get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
  298. auto args =
  299. get_ck_fmha_bwd_args(
  300. mask,
  301. batch_size,
  302. seqlen_q,
  303. seqlen_k,
  304. num_heads,
  305. num_heads_k,
  306. head_size_8x,
  307. q,
  308. k,
  309. v,
  310. alibi_slopes_,
  311. out,
  312. softmax_lse,
  313. dout_padded,
  314. softmax_d,
  315. dq,
  316. dk_expanded,
  317. dv_expanded,
  318. softmax_scale,
  319. p_dropout,
  320. drop_seed,
  321. drop_offset);
  322. fmha_bwd(traits, args, stream_config);
  323. } else {
  324. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  325. dk_expanded.zero_();
  326. dv_expanded.zero_();
  327. softmax_d.zero_();
  328. }
  329. // For MQA/GQA we need to sum dK and dV across the groups
  330. if (num_heads_k != num_heads) {
  331. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
  332. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {3});
  333. }
  334. if (head_size_og % 8 != 0) {
  335. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  336. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  337. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  338. }
  339. return { dq, dk, dv, softmax_d };
  340. }