mha_varlen_bwd.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include "flash_common.hpp"
  5. #include "fmha_bwd.hpp"
  6. #include "mask.hpp"
  7. fmha_bwd_traits get_ck_fmha_varlen_bwd_traits(const mask_info &mask,
  8. std::string dtype,
  9. int head_size,
  10. bool has_dropout,
  11. bool enable_alibi)
  12. {
  13. return fmha_bwd_traits{head_size,
  14. head_size,
  15. dtype,
  16. true, // is_group_mode
  17. mask.type,
  18. enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
  19. false, // has_dbias
  20. has_dropout};
  21. }
  22. fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask,
  23. // sizes
  24. const int b,
  25. const int max_seqlen_q,
  26. const int max_seqlen_k,
  27. const int h,
  28. const int h_k,
  29. const int hdim,
  30. // device pointers
  31. const at::Tensor q,
  32. const at::Tensor k,
  33. const at::Tensor v,
  34. const at::Tensor seqlens_q,
  35. const at::Tensor seqlens_k,
  36. c10::optional<at::Tensor> &alibi_slopes_,
  37. const at::Tensor out,
  38. const at::Tensor softmax_lse,
  39. const at::Tensor dout,
  40. at::Tensor d,
  41. at::Tensor dq,
  42. at::Tensor dk,
  43. at::Tensor dv,
  44. float softmax_scale,
  45. float p_dropout,
  46. uint64_t drop_seed,
  47. uint64_t drop_offset)
  48. {
  49. // q: (total_q, nheads, hdim)
  50. // k: (total_k, nheads_k, hdim)
  51. // v: (total_k, nheads_k, hdim)
  52. // o: (total_q, nheads, hdim)
  53. // dq: (total_q, nheads, hdim)
  54. // dk_expanded: (total_k, nheads, hdim)
  55. // dv_expanded: (total_k, nheads, hdim)
  56. // do: (total_q, nheads, hdim)
  57. // alibi_slopes:(batch_size, nheads) or (nhead)
  58. // lse: (batch_size, nheads, max_seqlen_q)
  59. // d: (batch_size, nheads, max_seqlen_q)
  60. ck_tile::index_t total_q = q.size(0);
  61. ck_tile::index_t total_k = k.size(0);
  62. ck_tile::index_t stride_q = q.stride(0);
  63. ck_tile::index_t stride_k = k.stride(0);
  64. ck_tile::index_t stride_v = v.stride(0);
  65. ck_tile::index_t stride_o = out.stride(0);
  66. ck_tile::index_t stride_do = dout.stride(0);
  67. ck_tile::index_t stride_dk = dk.stride(0);
  68. ck_tile::index_t stride_dv = dv.stride(0);
  69. ck_tile::index_t nhead_stride_q = q.stride(1);
  70. ck_tile::index_t nhead_stride_k = k.stride(1);
  71. ck_tile::index_t nhead_stride_v = v.stride(1);
  72. ck_tile::index_t nhead_stride_o = out.stride(1);
  73. ck_tile::index_t nhead_stride_do = dout.stride(1);
  74. ck_tile::index_t nhead_stride_lse = softmax_lse.stride(1);
  75. ck_tile::index_t batch_stride_q = 0;
  76. ck_tile::index_t batch_stride_k = 0;
  77. ck_tile::index_t batch_stride_v = 0;
  78. ck_tile::index_t batch_stride_o = 0;
  79. ck_tile::index_t batch_stride_do = 0;
  80. ck_tile::index_t batch_stride_lse = softmax_lse.stride(0);;
  81. ck_tile::index_t batch_stride_dk = 0;
  82. ck_tile::index_t batch_stride_dv = 0;
  83. float p_undrop = 1.0 - p_dropout;
  84. void *alibi_slopes_ptr = nullptr;
  85. ck_tile::index_t stride_alibi_slopes = 0;
  86. if (alibi_slopes_.has_value()) {
  87. auto alibi_slopes = alibi_slopes_.value();
  88. CHECK_DEVICE(alibi_slopes);
  89. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  90. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({h}) || alibi_slopes.sizes() == torch::IntArrayRef({b, h}));
  91. alibi_slopes_ptr = alibi_slopes.data_ptr();
  92. stride_alibi_slopes = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  93. }
  94. return fmha_bwd_args{q.data_ptr(),
  95. k.data_ptr(),
  96. v.data_ptr(),
  97. alibi_slopes_ptr, // bias
  98. out.data_ptr(),
  99. softmax_lse.data_ptr(),
  100. dout.data_ptr(),
  101. d.data_ptr(),
  102. nullptr, // rand_val
  103. dq.data_ptr(),
  104. dk.data_ptr(),
  105. dv.data_ptr(),
  106. nullptr, // dbias
  107. seqlens_q.data_ptr(), // seqstart_q
  108. seqlens_k.data_ptr(), // seqstart_k
  109. nullptr, // seqlen_k_ptr
  110. total_q,
  111. total_k,
  112. b,
  113. max_seqlen_q, // max_seqlen_q
  114. max_seqlen_k, // max_seqlen_k
  115. hdim, // hdim_q
  116. hdim, // hdim_v
  117. h, // nhead
  118. h_k, // nhead_k
  119. softmax_scale,
  120. stride_q,
  121. stride_k,
  122. stride_v,
  123. stride_alibi_slopes,
  124. stride_o,
  125. 0, // stride_randval
  126. stride_do,
  127. stride_dk,
  128. stride_dv,
  129. 0, // stride_dbias, FA without bias
  130. nhead_stride_q,
  131. nhead_stride_k,
  132. nhead_stride_v,
  133. 0, // nhead_stride_bias, FA without bias
  134. nhead_stride_o,
  135. 0, // nhead_stride_randval
  136. nhead_stride_do,
  137. nhead_stride_lse,
  138. 0, // nhead_stride_dbias, FA without dbias
  139. batch_stride_q,
  140. batch_stride_k,
  141. batch_stride_v,
  142. 0 , // batch_stride_bias, FA without bias
  143. batch_stride_o,
  144. 0, // batch_stride_randval
  145. batch_stride_do,
  146. batch_stride_lse,
  147. batch_stride_dk,
  148. batch_stride_dv,
  149. 0 , // batch_stride_dbias, FA without dbias
  150. mask.left,
  151. mask.right,
  152. static_cast<ck_tile::index_t>(mask.type),
  153. p_dropout,
  154. p_undrop,
  155. false, // s_randval
  156. {drop_seed, drop_offset}};
  157. }
  158. std::vector<at::Tensor>
  159. mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads x head_size
  160. const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  161. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  162. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  163. const at::Tensor &out, // total_q x num_heads x head_size
  164. const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
  165. c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  166. c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  167. c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  168. const at::Tensor &cu_seqlens_q, // b+1
  169. const at::Tensor &cu_seqlens_k, // b+1
  170. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  171. const int max_seqlen_q,
  172. const int max_seqlen_k, // max sequence length to choose the kernel
  173. const float p_dropout, // probability to drop
  174. const float softmax_scale,
  175. const bool zero_tensors,
  176. const bool is_causal,
  177. int window_size_left,
  178. int window_size_right,
  179. const float /*softcap*/,
  180. const bool deterministic,
  181. c10::optional<at::Generator> gen_,
  182. c10::optional<at::Tensor> &rng_state)
  183. {
  184. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  185. TORCH_CHECK(false, "This flash attention build does not support backward.");
  186. #endif
  187. if (is_causal) { window_size_right = 0; }
  188. bool is_dropout = p_dropout > 0.0;
  189. auto stream = at::cuda::getCurrentCUDAStream().stream();
  190. auto q_dtype = q.dtype();
  191. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  192. "FlashAttention only support fp16 and bf16 data type");
  193. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  194. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  195. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  196. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  197. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  198. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  199. std::string q_dtype_str = q_dtype == torch::kFloat16 ? "fp16" : "bf16";
  200. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  201. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  202. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  203. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  204. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  205. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  206. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  207. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  208. CHECK_CONTIGUOUS(cu_seqlens_q);
  209. CHECK_CONTIGUOUS(cu_seqlens_k);
  210. const auto sizes = q.sizes();
  211. const int total_q = sizes[0];
  212. const int batch_size = cu_seqlens_q.numel() - 1;
  213. const int num_heads = sizes[1];
  214. const int head_size_og = dout.size(2);
  215. const int head_size_8x = sizes[2];
  216. const int total_k = k.size(0);
  217. const int num_heads_k = k.size(1);
  218. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  219. TORCH_CHECK(head_size_8x % 8 == 0, "head_size should be a multiple of 8");
  220. TORCH_CHECK(head_size_8x <= 128, "CK FlashAttention backward only supports head dimension at most 128");
  221. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  222. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  223. TORCH_CHECK(head_size_8x == round_multiple(head_size_og, 8), "head_size_8x must be head_size_og rounded to a multiple of 8");
  224. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  225. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  226. mask_info mask;
  227. if (is_causal) {
  228. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + "0";
  229. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // casual
  230. }
  231. else if (window_size_left == -1 && window_size_right == -1) {
  232. mask = mask_info::decode("0", max_seqlen_q, max_seqlen_k); // no mask
  233. }
  234. else {
  235. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  236. std::string mask_identify = "b:" + std::to_string(window_size_left) + "," + std::to_string(window_size_right);
  237. mask = mask_info::decode(mask_identify, max_seqlen_q, max_seqlen_k); // local
  238. }
  239. // q, k, v, out had been padded in mha_fwd
  240. // dq_, dk_, dv_ are also padded tensor
  241. CHECK_SHAPE(q, total_q, num_heads, head_size_8x);
  242. CHECK_SHAPE(k, total_k, num_heads_k, head_size_8x);
  243. CHECK_SHAPE(v, total_k, num_heads_k, head_size_8x);
  244. CHECK_SHAPE(out, total_q, num_heads, head_size_8x);
  245. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  246. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  247. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  248. at::Tensor dq, dk, dv;
  249. if (dq_.has_value()) {
  250. dq = dq_.value();
  251. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  252. CHECK_DEVICE(dq);
  253. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  254. CHECK_SHAPE(dq, total_q, num_heads, head_size_8x);
  255. } else {
  256. dq = torch::empty_like(q);
  257. }
  258. if (dk_.has_value()) {
  259. dk = dk_.value();
  260. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  261. CHECK_DEVICE(dk);
  262. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  263. CHECK_SHAPE(dk, total_k, num_heads_k, head_size_8x);
  264. } else {
  265. dk = torch::empty_like(k);
  266. }
  267. if (dv_.has_value()) {
  268. dv = dv_.value();
  269. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  270. CHECK_DEVICE(dv);
  271. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  272. CHECK_SHAPE(dv, total_k, num_heads_k, head_size_8x);
  273. } else {
  274. dv = torch::empty_like(v);
  275. }
  276. at::Tensor dout_padded;
  277. if (head_size_og % 8 != 0) {
  278. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  279. } else {
  280. dout_padded = dout;
  281. }
  282. // Cast to char to avoid compiler warning about narrowing
  283. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  284. auto opts = q.options();
  285. auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
  286. // TODO - CK does not support dq_accum
  287. at::Tensor dk_expanded, dv_expanded;
  288. if (num_heads_k != num_heads) { // MQA / GQA
  289. dk_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
  290. dv_expanded = torch::empty({total_k, num_heads, head_size_8x}, opts);
  291. } else {
  292. dk_expanded = dk;
  293. dv_expanded = dv;
  294. }
  295. if(zero_tensors) {
  296. dq.zero_();
  297. dk_expanded.zero_();
  298. dv_expanded.zero_();
  299. softmax_d.zero_();
  300. }
  301. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  302. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  303. uint64_t drop_seed = 1, drop_offset = 0;
  304. int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
  305. if (rng_state.has_value()) {
  306. uint64_t* d = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  307. drop_seed = d[0];
  308. drop_offset = d[1];
  309. } else if(is_dropout) {
  310. // See Note [Acquire lock when using random generators]
  311. std::lock_guard<std::mutex> lock(gen->mutex_);
  312. auto philox_args = gen->philox_cuda_state(counter_offset);
  313. std::tie(drop_seed, drop_offset) = flash::unpack(philox_args);
  314. }
  315. if (max_seqlen_q > 0) {
  316. ck_tile::stream_config stream_config{stream};
  317. dq.zero_(); // ck use atomic operation on dq
  318. auto traits =
  319. get_ck_fmha_varlen_bwd_traits(mask, q_dtype_str, head_size_8x, is_dropout, alibi_slopes_.has_value());
  320. auto args =
  321. get_ck_fmha_varlen_bwd_args(
  322. mask,
  323. batch_size,
  324. max_seqlen_q,
  325. max_seqlen_k,
  326. num_heads,
  327. num_heads_k,
  328. head_size_8x,
  329. q,
  330. k,
  331. v,
  332. cu_seqlens_q,
  333. cu_seqlens_k,
  334. alibi_slopes_,
  335. out,
  336. softmax_lse,
  337. dout_padded,
  338. softmax_d,
  339. dq,
  340. dk_expanded,
  341. dv_expanded,
  342. softmax_scale,
  343. p_dropout,
  344. drop_seed,
  345. drop_offset);
  346. fmha_bwd(traits, args, stream_config);
  347. } else {
  348. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  349. dk_expanded.zero_();
  350. dv_expanded.zero_();
  351. softmax_d.zero_();
  352. }
  353. // For MQA/GQA we need to sum dK and dV across the groups
  354. if (num_heads_k != num_heads) {
  355. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
  356. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size_8x}), {2});
  357. }
  358. if (head_size_og % 8 != 0) {
  359. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  360. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  361. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  362. }
  363. return { dq, dk, dv, softmax_d };
  364. }