mha_varlen_bwd.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. #include "fmha_bwd.hpp"
  6. #include "mask.hpp"
  7. fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
  8. std::string dtype,
  9. int head_size,
  10. bool has_dropout,
  11. bool enable_alibi,
  12. bool deterministic)
  13. {
  14. return fmha_bwd_traits{head_size,
  15. head_size,
  16. dtype,
  17. true, // is_group_mode
  18. mask.type,
  19. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  20. false, // has_dbias
  21. has_dropout,
  22. false, // s_randval
  23. deterministic};
  24. }
  25. fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
  26. // sizes
  27. const int b,
  28. const int max_seqlen_q,
  29. const int max_seqlen_k,
  30. const int h,
  31. const int h_k,
  32. const int hdim,
  33. // device pointers
  34. const at::Tensor q,
  35. const at::Tensor k,
  36. const at::Tensor v,
  37. const at::Tensor seqlens_q,
  38. const at::Tensor seqlens_k,
  39. c10::optional<at::Tensor> &alibi_slopes_,
  40. const at::Tensor out,
  41. const at::Tensor softmax_lse,
  42. const at::Tensor dout,
  43. at::Tensor dq_acc,
  44. at::Tensor d,
  45. at::Tensor dq,
  46. at::Tensor dk,
  47. at::Tensor dv,
  48. float softmax_scale,
  49. float p_dropout,
  50. std::pair<uint64_t*, uint64_t*> drop_seed_offset)
  51. {
  52. ck_tile::index_t total_q = q.size(0);
  53. ck_tile::index_t total_k = k.size(0);
  54. // q: (total_q, nheads, hdim)
  55. ck_tile::index_t batch_stride_q = 0;
  56. ck_tile::index_t stride_q = q.stride(0);
  57. ck_tile::index_t nhead_stride_q = q.stride(1);
  58. // k: (total_k, nheads_k, hdim)
  59. ck_tile::index_t batch_stride_k = 0;
  60. ck_tile::index_t stride_k = k.stride(0);
  61. ck_tile::index_t nhead_stride_k = k.stride(1);
  62. // v: (total_k, nheads_k, hdim)
  63. ck_tile::index_t batch_stride_v = 0;
  64. ck_tile::index_t stride_v = v.stride(0);
  65. ck_tile::index_t nhead_stride_v = v.stride(1);
  66. // o: (total_q, nheads, hdim)
  67. ck_tile::index_t batch_stride_o = 0;
  68. ck_tile::index_t stride_o = out.stride(0);
  69. ck_tile::index_t nhead_stride_o = out.stride(1);
  70. // lse: (nheads, total_q)
  71. ck_tile::index_t batch_stride_lse = 0;
  72. ck_tile::index_t nhead_stride_lse = softmax_lse.stride(0);
  73. // do: (total_q, nheads, hdim)
  74. ck_tile::index_t batch_stride_do = 0;
  75. ck_tile::index_t stride_do = dout.stride(0);
  76. ck_tile::index_t nhead_stride_do = dout.stride(1);
  77. // d: (batch_size, nheads, max_seqlen_q)
  78. // CK assume d share the same stride with lse
  79. // dq: (total_q, nheads, hdim)
  80. ck_tile::index_t batch_stride_dq = 0;
  81. ck_tile::index_t stride_dq = dq.stride(0);
  82. ck_tile::index_t nhead_stride_dq = dq.stride(1);
  83. // dk_expanded: (total_k, nheads, hdim)
  84. ck_tile::index_t batch_stride_dk = 0;
  85. ck_tile::index_t stride_dk = dk.stride(0);
  86. ck_tile::index_t nhead_stride_dk = dk.stride(1);
  87. // dv_expanded: (total_k, nheads, hdim)
  88. ck_tile::index_t batch_stride_dv = 0;
  89. ck_tile::index_t stride_dv = dv.stride(0);
  90. ck_tile::index_t nhead_stride_dv = dv.stride(1);
  91. // dq_acc: (split, total_q, nheads, hdim)
  92. ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
  93. ck_tile::index_t batch_stride_dq_acc = 0;
  94. ck_tile::index_t stride_dq_acc = dq_acc.stride(1);
  95. ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2);
  96. float p_undrop = 1.0 - p_dropout;
  97. void *alibi_slopes_ptr = nullptr;
  98. ck_tile::index_t stride_alibi_slopes = 0;
  99. if (alibi_slopes_.has_value()) {
  100. auto alibi_slopes = alibi_slopes_.value();
  101. CHECK_DEVICE(alibi_slopes);
  102. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  103. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  104. alibi_slopes_ptr = alibi_slopes.data_ptr();
  105. // alibi_slopes:(batch_size, nheads) or (nhead)
  106. stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  107. }
  108. return fmha_bwd_args{q.data_ptr(),
  109. k.data_ptr(),
  110. v.data_ptr(),
  111. alibi_slopes_ptr, // bias
  112. out.data_ptr(),
  113. softmax_lse.data_ptr(),
  114. dout.data_ptr(),
  115. d.data_ptr(),
  116. nullptr, // rand_val
  117. dq.data_ptr(),
  118. dk.data_ptr(),
  119. dv.data_ptr(),
  120. nullptr, // dbias
  121. dq_acc.data_ptr(), // dq_acc
  122. seqlens_q.data_ptr(), // seqstart_q
  123. seqlens_k.data_ptr(), // seqstart_k
  124. nullptr, // seqlen_k_ptr
  125. total_q,
  126. total_k,
  127. b,
  128. max_seqlen_q, // max_seqlen_q
  129. max_seqlen_k, // max_seqlen_k
  130. hdim, // hdim_q
  131. hdim, // hdim_v
  132. h, // nhead
  133. h_k, // nhead_k
  134. softmax_scale,
  135. stride_q,
  136. stride_k,
  137. stride_v,
  138. stride_alibi_slopes,
  139. stride_o,
  140. 0, // stride_randval
  141. stride_do,
  142. stride_dq_acc,
  143. stride_dq,
  144. stride_dk,
  145. stride_dv,
  146. 0, // stride_dbias, FA without bias
  147. nhead_stride_q,
  148. nhead_stride_k,
  149. nhead_stride_v,
  150. 0, // nhead_stride_bias, FA without bias
  151. nhead_stride_o,
  152. 0, // nhead_stride_randval
  153. nhead_stride_do,
  154. nhead_stride_lse,
  155. nhead_stride_dq_acc,
  156. nhead_stride_dq,
  157. nhead_stride_dk,
  158. nhead_stride_dv,
  159. 0, // nhead_stride_dbias, FA without dbias
  160. batch_stride_q,
  161. batch_stride_k,
  162. batch_stride_v,
  163. 0 , // batch_stride_bias, FA without bias
  164. batch_stride_o,
  165. 0, // batch_stride_randval
  166. batch_stride_do,
  167. batch_stride_lse,
  168. batch_stride_dq_acc,
  169. batch_stride_dq,
  170. batch_stride_dk,
  171. batch_stride_dv,
  172. 0 , // batch_stride_dbias, FA without dbias
  173. split_stride_dq_acc,
  174. mask.left,
  175. mask.right,
  176. static_cast<ck_tile::index_t>(mask.type),
  177. p_dropout,
  178. p_undrop,
  179. drop_seed_offset};
  180. }
  181. std::vector<at::Tensor>
  182. mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
  183. const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  184. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  185. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  186. const at::Tensor &out, // total_q x num_heads x head_size
  187. const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
  188. c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  189. c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  190. c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  191. const at::Tensor &cu_seqlens_q, // b+1
  192. const at::Tensor &cu_seqlens_k, // b+1
  193. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  194. const int max_seqlen_q,
  195. const int max_seqlen_k, // max sequence length to choose the kernel
  196. const float p_dropout, // probability to drop
  197. const float softmax_scale,
  198. const bool zero_tensors,
  199. const bool is_causal,
  200. int window_size_left,
  201. int window_size_right,
  202. const float /*softcap*/,
  203. const bool deterministic,
  204. c10::optional<at::Generator> gen_,
  205. c10::optional<at::Tensor> &rng_state_)
  206. {
  207. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  208. TORCH_CHECK(false, "This flash attention build does not support backward.");
  209. #endif
  210. if (is_causal) { window_size_right = 0; }
  211. bool is_dropout = p_dropout > 0.0;
  212. auto stream = at::cuda::getCurrentCUDAStream().stream();
  213. auto q_dtype = q.dtype();
  214. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  215. "FlashAttention only support fp16 and bf16 data type");
  216. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  217. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  218. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  219. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  220. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  221. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  222. std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
  223. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  224. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  225. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  226. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  227. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  228. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  229. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  230. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  231. CHECK_CONTIGUOUS(cu_seqlens_q);
  232. CHECK_CONTIGUOUS(cu_seqlens_k);
  233. const auto sizes = q.sizes();
  234. const int total_q = sizes[0];
  235. const int batch_size = cu_seqlens_q.numel() - 1;
  236. const int num_heads = sizes[1];
  237. const int head_size = sizes[2];
  238. const int total_k = k.size(0);
  239. const int num_heads_k = k.size(1);
  240. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  241. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  242. TORCH_CHECK(head_size <= 256, "CK FlashAttention backward only supports head dimension at most 256");
  243. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  244. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  245. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  246. mask_info mask;
  247. if (is_causal) {
  248. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
  249. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
  250. }
  251. else if (window_size_left == -1 && window_size_right == -1) {
  252. mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
  253. }
  254. else {
  255. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  256. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
  257. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
  258. }
  259. // q, k, v, out had been padded in mha_fwd
  260. // dq_, dk_, dv_ are also padded tensor
  261. CHECK_SHAPE(q, total_q, num_heads, head_size);
  262. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  263. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  264. CHECK_SHAPE(out, total_q, num_heads, head_size);
  265. CHECK_SHAPE(dout, total_q, num_heads, head_size);
  266. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  267. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  268. at::Tensor dq, dk, dv;
  269. if (dq_.has_value()) {
  270. dq = dq_.value();
  271. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  272. CHECK_DEVICE(dq);
  273. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  274. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  275. } else {
  276. dq = torch::empty_like(q);
  277. }
  278. if (dk_.has_value()) {
  279. dk = dk_.value();
  280. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  281. CHECK_DEVICE(dk);
  282. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  283. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  284. } else {
  285. dk = torch::empty_like(k);
  286. }
  287. if (dv_.has_value()) {
  288. dv = dv_.value();
  289. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  290. CHECK_DEVICE(dv);
  291. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  292. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  293. } else {
  294. dv = torch::empty_like(v);
  295. }
  296. at::cuda::CUDAGuard device_guard{q.device()};
  297. auto opts = q.options();
  298. auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
  299. at::Tensor dq_accum;
  300. if (!deterministic) {
  301. dq_accum = torch::zeros({1, total_q, num_heads, head_size}, opts.dtype(at::kFloat));
  302. } else {
  303. const ck_tile::index_t kN0 = head_size <= 128 ? 128 : 64;
  304. const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(max_seqlen_k, kN0);
  305. dq_accum = torch::zeros({nsplits, total_q, num_heads, head_size}, opts.dtype(at::kFloat));
  306. }
  307. at::Tensor dk_expanded, dv_expanded;
  308. if (num_heads_k != num_heads) { // MQA / GQA
  309. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  310. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  311. } else {
  312. dk_expanded = dk;
  313. dv_expanded = dv;
  314. }
  315. if(zero_tensors) {
  316. dq.zero_();
  317. dk_expanded.zero_();
  318. dv_expanded.zero_();
  319. softmax_d.zero_();
  320. }
  321. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  322. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  323. int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
  324. at::Tensor rng_state;
  325. if (rng_state_.has_value()) {
  326. rng_state = rng_state_.value();
  327. } else if(is_dropout) {
  328. rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
  329. // See Note [Acquire lock when using random generators]
  330. std::lock_guard<std::mutex> lock(gen->mutex_);
  331. auto philox_args = gen->philox_cuda_state(counter_offset);
  332. hipLaunchKernelGGL(
  333. flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0,
  334. philox_args, reinterpret_cast<uint64_t*>(rng_state.data_ptr()));
  335. } else {
  336. rng_state = torch::empty({2}, opts.dtype(torch::kInt64));
  337. }
  338. if (max_seqlen_q > 0) {
  339. auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  340. auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
  341. ck_tile::stream_config stream_config{stream};
  342. auto traits =
  343. get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size, is_dropout, alibi_slopes_.has_value(), deterministic);
  344. auto args =
  345. get_ck_fmha_varlen_bwd_args(
  346. mask,
  347. batch_size,
  348. max_seqlen_q,
  349. max_seqlen_k,
  350. num_heads,
  351. num_heads_k,
  352. head_size,
  353. q,
  354. k,
  355. v,
  356. cu_seqlens_q,
  357. cu_seqlens_k,
  358. alibi_slopes_,
  359. out,
  360. softmax_lse,
  361. dout,
  362. dq_accum,
  363. softmax_d,
  364. dq,
  365. dk_expanded,
  366. dv_expanded,
  367. softmax_scale,
  368. p_dropout,
  369. drop_seed_offset);
  370. float t = fmha_bwd(traits, args, stream_config);
  371. TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
  372. } else {
  373. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  374. dk_expanded.zero_();
  375. dv_expanded.zero_();
  376. softmax_d.zero_();
  377. }
  378. // For MQA/GQA we need to sum dK and dV across the groups
  379. if (num_heads_k != num_heads) {
  380. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  381. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  382. }
  383. return { dq, dk, dv, softmax_d };
  384. }