flash_api.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *seqused_k,
  34. void *p_d,
  35. void *softmax_lse_d,
  36. float p_dropout,
  37. float softmax_scale,
  38. int window_size_left,
  39. int window_size_right,
  40. bool seqlenq_ngroups_swapped=false) {
  41. // Reset the parameters
  42. params = {};
  43. params.is_bf16 = q.dtype() == torch::kBFloat16;
  44. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  45. // Set the pointers and strides.
  46. params.q_ptr = q.data_ptr();
  47. params.k_ptr = k.data_ptr();
  48. params.v_ptr = v.data_ptr();
  49. // All stride are in elements, not bytes.
  50. params.q_row_stride = q.stride(-3);
  51. params.k_row_stride = k.stride(-3);
  52. params.v_row_stride = v.stride(-3);
  53. params.q_head_stride = q.stride(-2);
  54. params.k_head_stride = k.stride(-2);
  55. params.v_head_stride = v.stride(-2);
  56. params.o_ptr = out.data_ptr();
  57. params.o_row_stride = out.stride(-3);
  58. params.o_head_stride = out.stride(-2);
  59. if (cu_seqlens_q_d == nullptr) {
  60. params.q_batch_stride = q.stride(0);
  61. params.k_batch_stride = k.stride(0);
  62. params.v_batch_stride = v.stride(0);
  63. params.o_batch_stride = out.stride(0);
  64. if (seqlenq_ngroups_swapped) {
  65. params.q_batch_stride *= seqlen_q;
  66. params.o_batch_stride *= seqlen_q;
  67. }
  68. }
  69. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  70. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  71. params.seqused_k = static_cast<int *>(seqused_k);
  72. // P = softmax(QK^T)
  73. params.p_ptr = p_d;
  74. // Softmax sum
  75. params.softmax_lse_ptr = softmax_lse_d;
  76. // Set the dimensions.
  77. params.b = b;
  78. params.h = h;
  79. params.h_k = h_k;
  80. params.h_h_k_ratio = h / h_k;
  81. params.seqlen_q = seqlen_q;
  82. params.seqlen_k = seqlen_k;
  83. params.seqlen_q_rounded = seqlen_q_rounded;
  84. params.seqlen_k_rounded = seqlen_k_rounded;
  85. params.d = d;
  86. params.d_rounded = d_rounded;
  87. // Set the different scale values.
  88. params.scale_softmax = softmax_scale;
  89. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  90. __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
  91. __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half);
  92. params.scale_softmax_log2_half2 = reinterpret_cast<uint32_t&>(scale_softmax_log2_half2);
  93. // Set this to probability of keeping an element to simplify things.
  94. params.p_dropout = 1.f - p_dropout;
  95. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  96. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  97. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  98. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  99. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  100. params.rp_dropout = 1.f / params.p_dropout;
  101. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  102. TORCH_CHECK(p_dropout < 1.f);
  103. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  104. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  105. #endif
  106. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  107. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  108. params.is_causal = window_size_left < 0 && window_size_right == 0;
  109. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
  110. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
  111. params.window_size_left = window_size_left;
  112. params.window_size_right = window_size_right;
  113. #ifdef FLASHATTENTION_DISABLE_LOCAL
  114. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  115. "This flash attention build does not support local attention.");
  116. #endif
  117. params.is_seqlens_k_cumulative = true;
  118. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  119. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  120. #endif
  121. }
  122. void set_params_dgrad(Flash_bwd_params &params,
  123. // sizes
  124. const size_t b,
  125. const size_t seqlen_q,
  126. const size_t seqlen_k,
  127. const size_t seqlen_q_rounded,
  128. const size_t seqlen_k_rounded,
  129. const size_t h,
  130. const size_t h_k,
  131. const size_t d,
  132. const size_t d_rounded,
  133. // device pointers
  134. const at::Tensor q,
  135. const at::Tensor k,
  136. const at::Tensor v,
  137. const at::Tensor out,
  138. const at::Tensor dout,
  139. at::Tensor dq,
  140. at::Tensor dk,
  141. at::Tensor dv,
  142. void *cu_seqlens_q_d,
  143. void *cu_seqlens_k_d,
  144. void *dq_accum_d,
  145. void *dk_accum_d,
  146. void *dv_accum_d,
  147. void *softmax_lse_d,
  148. void *dsoftmax_sum_d,
  149. float p_dropout,
  150. float softmax_scale,
  151. int window_size_left,
  152. int window_size_right,
  153. bool deterministic) {
  154. set_params_fprop(params,
  155. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  156. q, k, v, out,
  157. cu_seqlens_q_d,
  158. cu_seqlens_k_d,
  159. nullptr,
  160. nullptr,
  161. softmax_lse_d,
  162. p_dropout,
  163. softmax_scale,
  164. window_size_left,
  165. window_size_right);
  166. // Set the pointers and strides.
  167. params.do_ptr = dout.data_ptr();
  168. params.do_row_stride = dout.stride(-3);
  169. params.do_head_stride = dout.stride(-2);
  170. params.dq_ptr = dq.data_ptr();
  171. params.dk_ptr = dk.data_ptr();
  172. params.dv_ptr = dv.data_ptr();
  173. params.dq_row_stride = dq.stride(-3);
  174. params.dk_row_stride = dk.stride(-3);
  175. params.dv_row_stride = dv.stride(-3);
  176. params.dq_head_stride = dq.stride(-2);
  177. params.dk_head_stride = dk.stride(-2);
  178. params.dv_head_stride = dv.stride(-2);
  179. if (cu_seqlens_q_d == nullptr) {
  180. params.do_batch_stride = dout.stride(0);
  181. params.dq_batch_stride = dq.stride(0);
  182. params.dk_batch_stride = dk.stride(0);
  183. params.dv_batch_stride = dv.stride(0);
  184. }
  185. params.dq_accum_ptr = dq_accum_d;
  186. params.dk_accum_ptr = dk_accum_d;
  187. params.dv_accum_ptr = dv_accum_d;
  188. // Softmax sum
  189. params.dsoftmax_sum = dsoftmax_sum_d;
  190. params.deterministic = deterministic;
  191. }
  192. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  193. // HEADDIM_SWITCH(params.d, [&] {
  194. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  195. // });
  196. if (!params.is_e4m3) {
  197. if (params.is_bf16) {
  198. if (params.d == 64) {
  199. run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
  200. } else if (params.d == 128) {
  201. run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
  202. } else {
  203. run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
  204. }
  205. } else {
  206. if (params.d == 64) {
  207. run_mha_fwd_<cutlass::half_t, 64>(params, stream);
  208. } else if (params.d == 128) {
  209. run_mha_fwd_<cutlass::half_t, 128>(params, stream);
  210. } else {
  211. run_mha_fwd_<cutlass::half_t, 256>(params, stream);
  212. }
  213. }
  214. } else {
  215. // run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
  216. }
  217. }
  218. std::vector<at::Tensor>
  219. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  220. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  221. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  222. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  223. const float softmax_scale,
  224. bool is_causal) {
  225. auto dprops = at::cuda::getCurrentDeviceProperties();
  226. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  227. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  228. auto q_dtype = q.dtype();
  229. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  230. "FlashAttention only support fp16 and bf16 data type for now");
  231. // TODO: will add e4m3 later
  232. // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
  233. // "FlashAttention only support fp16 and bf16 data type");
  234. // "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
  235. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  236. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  237. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  238. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  239. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  240. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  241. TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous");
  242. TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous");
  243. TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous");
  244. const auto sizes = q.sizes();
  245. const int batch_size = sizes[0];
  246. int seqlen_q = sizes[1];
  247. int num_heads = sizes[2];
  248. const int head_size_og = sizes[3];
  249. const int seqlen_k = k.size(1);
  250. const int num_heads_k = k.size(2);
  251. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  252. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  253. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  254. TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
  255. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  256. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  257. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  258. at::Tensor q_padded, k_padded, v_padded;
  259. if (head_size_og % 8 != 0) {
  260. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  261. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  262. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  263. } else {
  264. q_padded = q;
  265. k_padded = k;
  266. v_padded = v;
  267. }
  268. at::Tensor out;
  269. if (out_.has_value()) {
  270. out = out_.value();
  271. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  272. CHECK_DEVICE(out);
  273. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  274. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  275. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  276. } else {
  277. out = torch::empty_like(q_padded);
  278. }
  279. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  280. const int head_size = round_multiple(head_size_og, 8);
  281. const int head_size_rounded = round_multiple(head_size, 32);
  282. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  283. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  284. // Otherwise the kernel will be launched from cuda:0 device
  285. // Cast to char to avoid compiler warning about narrowing
  286. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  287. auto opts = q.options();
  288. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  289. at::Tensor p;
  290. Flash_fwd_params params;
  291. set_params_fprop(params,
  292. batch_size,
  293. seqlen_q, seqlen_k,
  294. seqlen_q_rounded, seqlen_k_rounded,
  295. num_heads, num_heads_k,
  296. head_size, head_size_rounded,
  297. q_padded, k_padded, v_padded, out,
  298. /*cu_seqlens_q_d=*/nullptr,
  299. /*cu_seqlens_k_d=*/nullptr,
  300. /*seqused_k=*/nullptr,
  301. nullptr,
  302. softmax_lse.data_ptr(),
  303. /*p_dropout=*/0.f,
  304. softmax_scale,
  305. /*window_size_left=*/-1,
  306. /*window_size_right=*/is_causal ? 0 : -1);
  307. auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  308. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  309. if (seqlen_k > 0) {
  310. auto stream = at::cuda::getCurrentCUDAStream().stream();
  311. run_mha_fwd(params, stream);
  312. } else {
  313. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  314. out.zero_();
  315. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  316. }
  317. at::Tensor out_padded = out;
  318. if (head_size_og % 8 != 0) {
  319. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  320. if (out_.has_value()) { out_.value().copy_(out); }
  321. }
  322. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
  323. }
  324. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  325. // FP16_SWITCH(!params.is_bf16, [&] {
  326. // HEADDIM_SWITCH(params.d, [&] {
  327. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  328. // });
  329. // });
  330. if (params.d == 64) {
  331. run_mha_bwd_<cutlass::half_t, 64>(params, stream);
  332. } else if (params.d == 128) {
  333. run_mha_bwd_<cutlass::half_t, 128>(params, stream);
  334. } else {
  335. run_mha_bwd_<cutlass::half_t, 256>(params, stream);
  336. }
  337. }
  338. std::vector<at::Tensor>
  339. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  340. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  341. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  342. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  343. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  344. const at::Tensor &softmax_lse, // b x h x seqlen_q
  345. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  346. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  347. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  348. const float softmax_scale,
  349. const bool is_causal) {
  350. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  351. TORCH_CHECK(false, "This flash attention build does not support backward.");
  352. #endif
  353. auto dprops = at::cuda::getCurrentDeviceProperties();
  354. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  355. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  356. auto stream = at::cuda::getCurrentCUDAStream().stream();
  357. auto q_dtype = q.dtype();
  358. TORCH_CHECK(q_dtype == torch::kFloat16,
  359. // "FlashAttention only support fp16 and bf16 data type");
  360. "FlashAttention only support fp16 data type for now");
  361. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  362. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  363. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  364. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  365. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  366. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  367. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  368. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  369. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  370. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  371. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  372. TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous");
  373. TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous");
  374. TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous");
  375. const auto sizes = q.sizes();
  376. const int batch_size = sizes[0];
  377. const int seqlen_q = sizes[1];
  378. const int num_heads = sizes[2];
  379. const int head_size_og = dout.size(3);
  380. const int head_size = sizes[3];
  381. const int seqlen_k = k.size(1);
  382. const int num_heads_k = k.size(2);
  383. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  384. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  385. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  386. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  387. TORCH_CHECK(head_size_og == 64 || head_size_og == 128, "Only support head size 64 and 128 for now");
  388. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  389. const int head_size_rounded = round_multiple(head_size, 32);
  390. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  391. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  392. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  393. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  394. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  395. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  396. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  397. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  398. at::Tensor dq, dk, dv;
  399. if (dq_.has_value()) {
  400. dq = dq_.value();
  401. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  402. CHECK_DEVICE(dq);
  403. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  404. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  405. } else {
  406. dq = torch::empty_like(q);
  407. }
  408. if (dk_.has_value()) {
  409. dk = dk_.value();
  410. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  411. CHECK_DEVICE(dk);
  412. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  413. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  414. } else {
  415. dk = torch::empty_like(k);
  416. }
  417. if (dv_.has_value()) {
  418. dv = dv_.value();
  419. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  420. CHECK_DEVICE(dv);
  421. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  422. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  423. } else {
  424. dv = torch::empty_like(v);
  425. }
  426. at::Tensor dout_padded;
  427. if (head_size_og % 8 != 0) {
  428. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  429. } else {
  430. dout_padded = dout;
  431. }
  432. // bool loop = seqlen_k > blocksize_c;
  433. // TODO: change later, for now set to true for simplicity
  434. bool loop = true;
  435. // Otherwise the kernel will be launched from cuda:0 device
  436. // Cast to char to avoid compiler warning about narrowing
  437. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  438. auto opts = q.options();
  439. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  440. at::Tensor dq_accum;
  441. at::Tensor dk_accum, dv_accum;
  442. if (loop) {
  443. dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  444. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  445. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  446. }
  447. at::Tensor dk_expanded, dv_expanded;
  448. if (num_heads_k != num_heads) { // MQA / GQA
  449. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  450. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  451. } else {
  452. dk_expanded = dk;
  453. dv_expanded = dv;
  454. }
  455. Flash_bwd_params params;
  456. set_params_dgrad(params,
  457. batch_size,
  458. seqlen_q, seqlen_k,
  459. seqlen_q_rounded, seqlen_k_rounded,
  460. num_heads, num_heads_k,
  461. head_size, head_size_rounded,
  462. q, k, v, out,
  463. dout_padded, dq, dk_expanded, dv_expanded,
  464. nullptr,
  465. nullptr,
  466. loop ? dq_accum.data_ptr() : nullptr,
  467. // loop ? dk_accum.data_ptr() : nullptr,
  468. // loop ? dv_accum.data_ptr() : nullptr,
  469. nullptr,
  470. nullptr,
  471. softmax_lse.data_ptr(),
  472. softmax_d.data_ptr(),
  473. /*p_dropout=*/0.f,
  474. softmax_scale,
  475. /*window_size_left=*/-1,
  476. /*window_size_right=*/-1,
  477. /*deterministic=*/false);
  478. at::Tensor dq_semaphore = torch::zeros({(seqlen_q + 64 - 1) / 64, batch_size, num_heads}, opts.dtype(torch::kInt32));
  479. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  480. // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads);
  481. auto launch = &run_mha_bwd;
  482. if (seqlen_q > 0) {
  483. launch(params, stream);
  484. } else {
  485. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  486. dk_expanded.zero_();
  487. dv_expanded.zero_();
  488. softmax_d.zero_();
  489. }
  490. if (head_size_og % 8 != 0) {
  491. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  492. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  493. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  494. }
  495. return { dq, dk, dv, softmax_d };
  496. }
  497. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  498. m.doc() = "FlashAttention";
  499. m.def("fwd", &mha_fwd, "Forward pass");
  500. m.def("bwd", &mha_bwd, "Backward pass");
  501. }