mha_bwd.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. #include "fmha_bwd.hpp"
  6. #include "mask.hpp"
  7. fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
  8. std::string dtype,
  9. int head_size,
  10. bool has_dropout,
  11. bool enable_alibi,
  12. bool deterministic)
  13. {
  14. return fmha_bwd_traits{head_size,
  15. head_size,
  16. dtype,
  17. false, // is_group_mode
  18. mask.type,
  19. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  20. false, // has_dbias
  21. has_dropout,
  22. false, // s_randval
  23. deterministic};
  24. }
  25. fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
  26. // sizes
  27. const int b,
  28. const int seqlen_q,
  29. const int seqlen_k,
  30. const int h,
  31. const int h_k,
  32. const int hdim,
  33. // device pointers
  34. const at::Tensor q,
  35. const at::Tensor k,
  36. const at::Tensor v,
  37. c10::optional<at::Tensor> &alibi_slopes_,
  38. const at::Tensor out,
  39. const at::Tensor softmax_lse,
  40. const at::Tensor dout,
  41. at::Tensor dq_acc,
  42. at::Tensor d,
  43. at::Tensor dq,
  44. at::Tensor dk,
  45. at::Tensor dv,
  46. float softmax_scale,
  47. float p_dropout,
  48. std::pair<uint64_t*, uint64_t*> drop_seed_offset)
  49. {
  50. // q: (batch_size, seqlen_q, nheads, hdim)
  51. ck_tile::index_t batch_stride_q = q.stride(0);
  52. ck_tile::index_t stride_q = q.stride(1);
  53. ck_tile::index_t nhead_stride_q = q.stride(2);
  54. // k: (batch_size, seqlen_k, nheads_k, hdim)
  55. ck_tile::index_t batch_stride_k = k.stride(0);
  56. ck_tile::index_t stride_k = k.stride(1);
  57. ck_tile::index_t nhead_stride_k = k.stride(2);
  58. // v: (batch_size, seqlen_k, nheads_k, hdim)
  59. ck_tile::index_t batch_stride_v = v.stride(0);
  60. ck_tile::index_t stride_v = v.stride(1);
  61. ck_tile::index_t nhead_stride_v = v.stride(2);
  62. // o: (batch_size, seqlen_q, nheads, hdim)
  63. ck_tile::index_t batch_stride_o = out.stride(0);
  64. ck_tile::index_t stride_o = out.stride(1);
  65. ck_tile::index_t nhead_stride_o = out.stride(2);
  66. // lse: (batch_size, nheads, seqlen_q)
  67. ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);
  68. ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
  69. // do: (batch_size, seqlen_q, nheads, hdim)
  70. ck_tile::index_t batch_stride_do = dout.stride(0);
  71. ck_tile::index_t stride_do = dout.stride(1);
  72. ck_tile::index_t nhead_stride_do = dout.stride(2);
  73. // d: (batch_size, nheads, seqlen_q)
  74. // CK assume d share the same stride with lse
  75. // dq: (batch_size, seqlen_q, nheads, hdim)
  76. ck_tile::index_t batch_stride_dq = dq.stride(0);
  77. ck_tile::index_t stride_dq = dq.stride(1);
  78. ck_tile::index_t nhead_stride_dq = dq.stride(2);
  79. // dk_expanded: (batch_size, seqlen_k, nheads, hdim)
  80. ck_tile::index_t batch_stride_dk = dk.stride(0);
  81. ck_tile::index_t stride_dk = dk.stride(1);
  82. ck_tile::index_t nhead_stride_dk = dk.stride(2);
  83. // dv_expanded: (batch_size, seqlen_k, nheads, hdim)
  84. ck_tile::index_t batch_stride_dv = dv.stride(0);
  85. ck_tile::index_t stride_dv = dv.stride(1);
  86. ck_tile::index_t nhead_stride_dv = dv.stride(2);
  87. // dq_acc: (split, batch_size, seqlen_q, nheads, hdim)
  88. ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
  89. ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1);
  90. ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
  91. ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);
  92. float p_undrop = 1.0 - p_dropout;
  93. void *alibi_slopes_ptr = nullptr;
  94. ck_tile::index_t stride_alibi_slopes = 0;
  95. if (alibi_slopes_.has_value()) {
  96. auto alibi_slopes = alibi_slopes_.value();
  97. CHECK_DEVICE(alibi_slopes);
  98. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  99. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  100. alibi_slopes_ptr = alibi_slopes.data_ptr();
  101. // alibi_slopes:(batch_size, nheads) or (nhead)
  102. stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  103. }
  104. return fmha_bwd_args{q.data_ptr(),
  105. k.data_ptr(),
  106. v.data_ptr(),
  107. alibi_slopes_ptr, // bias
  108. out.data_ptr(),
  109. softmax_lse.data_ptr(),
  110. dout.data_ptr(),
  111. d.data_ptr(),
  112. nullptr, // rand_val
  113. dq.data_ptr(),
  114. dk.data_ptr(),
  115. dv.data_ptr(),
  116. nullptr, // dbias
  117. dq_acc.data_ptr(), // dq_acc
  118. nullptr, // seqstart_q
  119. nullptr, // seqstart_k
  120. nullptr, // seqlen_k_ptr
  121. seqlen_q,
  122. seqlen_k,
  123. b,
  124. seqlen_q, // max_seqlen_q
  125. seqlen_k, // max_seqlen_k
  126. hdim, // hdim_q
  127. hdim, // hdim_v
  128. h, // nhead
  129. h_k, // nhead_k
  130. softmax_scale,
  131. stride_q,
  132. stride_k,
  133. stride_v,
  134. stride_alibi_slopes,
  135. stride_o,
  136. 0, // stride_randval
  137. stride_do,
  138. stride_dq_acc,
  139. stride_dq,
  140. stride_dk,
  141. stride_dv,
  142. 0, // stride_dbias, FA without bias
  143. nhead_stride_q,
  144. nhead_stride_k,
  145. nhead_stride_v,
  146. 0, // nhead_stride_bias, FA without bias
  147. nhead_stride_o,
  148. 0, // nhead_stride_randval
  149. nhead_stride_do,
  150. nhead_stride_lse,
  151. nhead_stride_dq_acc,
  152. nhead_stride_dq,
  153. nhead_stride_dk,
  154. nhead_stride_dv,
  155. 0, // nhead_stride_dbias, FA without dbias
  156. batch_stride_q,
  157. batch_stride_k,
  158. batch_stride_v,
  159. 0 , // batch_stride_bias, FA without bias
  160. batch_stride_o,
  161. 0, // batch_stride_randval
  162. batch_stride_do,
  163. batch_stride_lse,
  164. batch_stride_dq_acc,
  165. batch_stride_dq,
  166. batch_stride_dk,
  167. batch_stride_dv,
  168. 0 , // batch_stride_dbias, FA without dbias
  169. split_stride_dq_acc,
  170. mask.left,
  171. mask.right,
  172. static_cast<ck_tile::index_t>(mask.type),
  173. p_dropout,
  174. p_undrop,
  175. drop_seed_offset};
  176. }
  177. std::vector<at::Tensor>
  178. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size, 8)
  179. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  180. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  181. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  182. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  183. const at::Tensor &softmax_lse, // b x h x seqlen_q
  184. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  185. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  186. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  187. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  188. const float p_dropout, // probability to drop
  189. const float softmax_scale,
  190. const bool is_causal,
  191. int window_size_left,
  192. int window_size_right,
  193. const float /*softcap*/,
  194. const bool deterministic,
  195. c10::optional<at::Generator> gen_,
  196. c10::optional<at::Tensor> &rng_state_)
  197. {
  198. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  199. TORCH_CHECK(false, "This flash attention build does not support backward.");
  200. #endif
  201. if (is_causal) { window_size_right = 0; }
  202. bool is_dropout = p_dropout > 0.0;
  203. auto stream = at::cuda::getCurrentHIPStream().stream();
  204. auto q_dtype = q.dtype();
  205. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  206. "FlashAttention only support fp16 and bf16 data type");
  207. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  208. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  209. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  210. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  211. std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
  212. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  213. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  214. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  215. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  216. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  217. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  218. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  219. const auto sizes = q.sizes();
  220. const int batch_size = sizes[0];
  221. const int seqlen_q = sizes[1];
  222. const int num_heads = sizes[2];
  223. const int head_size = sizes[3];
  224. const int seqlen_k = k.size(1);
  225. const int num_heads_k = k.size(2);
  226. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  227. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  228. TORCH_CHECK(head_size <= 256, "CK FlashAttention backward only supports head dimension at most 256");
  229. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  230. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  231. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  232. mask_info mask;
  233. if (is_causal) {
  234. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
  235. mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // casual
  236. }
  237. else if (window_size_left == -1 && window_size_right == -1) {
  238. mask = mask_info::decode("0", seqlen_q, seqlen_k); // no mask
  239. }
  240. else {
  241. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  242. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
  243. mask = mask_info::decode(mask_identify, seqlen_q, seqlen_k); // local
  244. }
  245. // q, k, v, out had been padded in mha_fwd
  246. // dq_, dk_, dv_ are also padded tensor
  247. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  248. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  249. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  250. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  251. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size);
  252. at::Tensor dq, dk, dv;
  253. if (dq_.has_value()) {
  254. dq = dq_.value();
  255. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  256. CHECK_DEVICE(dq);
  257. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  258. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  259. } else {
  260. dq = torch::empty_like(q);
  261. }
  262. if (dk_.has_value()) {
  263. dk = dk_.value();
  264. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  265. CHECK_DEVICE(dk);
  266. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  267. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  268. } else {
  269. dk = torch::empty_like(k);
  270. }
  271. if (dv_.has_value()) {
  272. dv = dv_.value();
  273. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  274. CHECK_DEVICE(dv);
  275. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  276. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  277. } else {
  278. dv = torch::empty_like(v);
  279. }
  280. at::cuda::CUDAGuard device_guard{q.device()};
  281. auto opts = q.options();
  282. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  283. at::Tensor dq_accum;
  284. if (!deterministic) {
  285. dq_accum = torch::zeros({1, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
  286. } else {
  287. const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64;
  288. const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
  289. dq_accum = torch::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size}, opts.dtype(at::kFloat));
  290. }
  291. at::Tensor dk_expanded, dv_expanded;
  292. if (num_heads_k != num_heads) { // MQA / GQA
  293. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  294. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  295. } else {
  296. dk_expanded = dk;
  297. dv_expanded = dv;
  298. }
  299. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  300. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  301. int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
  302. at::Tensor rng_state;
  303. if (rng_state_.has_value()) {
  304. rng_state = rng_state_.value();
  305. } else if(is_dropout) {
  306. rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
  307. // See Note [Acquire lock when using random generators]
  308. std::lock_guard<std::mutex> lock(gen->mutex_);
  309. auto philox_args = gen->philox_cuda_state(counter_offset);
  310. hipLaunchKernelGGL(
  311. flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,
  312. philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
  313. }
  314. if (seqlen_q > 0) {
  315. auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  316. auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
  317. ck_tile::stream_config stream_config{stream};
  318. auto traits =
  319. get_ck_fmha_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic);
  320. auto args =
  321. get_ck_fmha_bwd_args(
  322. mask,
  323. batch_size,
  324. seqlen_q,
  325. seqlen_k,
  326. num_heads,
  327. num_heads_k,
  328. head_size,
  329. q,
  330. k,
  331. v,
  332. alibi_slopes_,
  333. out,
  334. softmax_lse,
  335. dout,
  336. dq_accum,
  337. softmax_d,
  338. dq,
  339. dk_expanded,
  340. dv_expanded,
  341. softmax_scale,
  342. p_dropout,
  343. drop_seed_offset);
  344. float t = fmha_bwd(traits, args, stream_config);
  345. TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
  346. } else {
  347. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  348. dk_expanded.zero_();
  349. dv_expanded.zero_();
  350. softmax_d.zero_();
  351. }
  352. // For MQA/GQA we need to sum dK and dV across the groups
  353. if (num_heads_k != num_heads) {
  354. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  355. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  356. }
  357. return { dq, dk, dv, softmax_d };
  358. }