flash_api.cpp 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *seqused_k,
  34. void *p_d,
  35. void *softmax_lse_d,
  36. float p_dropout,
  37. float softmax_scale,
  38. int window_size_left,
  39. int window_size_right,
  40. bool seqlenq_ngroups_swapped=false,
  41. bool unpadded_lse=false) {
  42. // Reset the parameters
  43. params = {};
  44. params.is_bf16 = q.dtype() == torch::kBFloat16;
  45. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  46. // Set the pointers and strides.
  47. params.q_ptr = q.data_ptr();
  48. params.k_ptr = k.data_ptr();
  49. params.v_ptr = v.data_ptr();
  50. // All stride are in elements, not bytes.
  51. params.q_row_stride = q.stride(-3);
  52. params.k_row_stride = k.stride(-3);
  53. params.v_row_stride = v.stride(-3);
  54. params.q_head_stride = q.stride(-2);
  55. params.k_head_stride = k.stride(-2);
  56. params.v_head_stride = v.stride(-2);
  57. params.o_ptr = out.data_ptr();
  58. params.o_row_stride = out.stride(-3);
  59. params.o_head_stride = out.stride(-2);
  60. if (cu_seqlens_q_d == nullptr) {
  61. params.q_batch_stride = q.stride(0);
  62. params.k_batch_stride = k.stride(0);
  63. params.v_batch_stride = v.stride(0);
  64. params.o_batch_stride = out.stride(0);
  65. if (seqlenq_ngroups_swapped) {
  66. params.q_batch_stride *= seqlen_q;
  67. params.o_batch_stride *= seqlen_q;
  68. }
  69. }
  70. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  71. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  72. params.seqused_k = static_cast<int *>(seqused_k);
  73. TORCH_CHECK(
  74. bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k),
  75. "cu_seqlens_q and cu_seqlens_k must be both null or non-null"
  76. );
  77. // P = softmax(QK^T)
  78. params.p_ptr = p_d;
  79. // Softmax sum
  80. params.softmax_lse_ptr = softmax_lse_d;
  81. // Set the dimensions.
  82. params.b = b;
  83. params.h = h;
  84. params.h_k = h_k;
  85. params.h_h_k_ratio = h / h_k;
  86. params.seqlen_q = seqlen_q;
  87. params.seqlen_k = seqlen_k;
  88. params.seqlen_q_rounded = seqlen_q_rounded;
  89. params.seqlen_k_rounded = seqlen_k_rounded;
  90. params.d = d;
  91. params.d_rounded = d_rounded;
  92. // Set the different scale values.
  93. params.scale_softmax = softmax_scale;
  94. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  95. __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
  96. __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half);
  97. params.scale_softmax_log2_half2 = reinterpret_cast<uint32_t&>(scale_softmax_log2_half2);
  98. // Set this to probability of keeping an element to simplify things.
  99. params.p_dropout = 1.f - p_dropout;
  100. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  101. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  102. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  103. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  104. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  105. params.rp_dropout = 1.f / params.p_dropout;
  106. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  107. TORCH_CHECK(p_dropout < 1.f);
  108. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  109. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  110. #endif
  111. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  112. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  113. params.is_causal = window_size_left < 0 && window_size_right == 0;
  114. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
  115. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
  116. params.window_size_left = window_size_left;
  117. params.window_size_right = window_size_right;
  118. #ifdef FLASHATTENTION_DISABLE_LOCAL
  119. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  120. "This flash attention build does not support local attention.");
  121. #endif
  122. params.is_seqlens_k_cumulative = true;
  123. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  124. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  125. #endif
  126. params.unpadded_lse = unpadded_lse;
  127. }
  128. void set_params_dgrad(Flash_bwd_params &params,
  129. // sizes
  130. const size_t b,
  131. const size_t seqlen_q,
  132. const size_t seqlen_k,
  133. const size_t seqlen_q_rounded,
  134. const size_t seqlen_k_rounded,
  135. const size_t h,
  136. const size_t h_k,
  137. const size_t d,
  138. const size_t d_rounded,
  139. // device pointers
  140. const at::Tensor q,
  141. const at::Tensor k,
  142. const at::Tensor v,
  143. const at::Tensor out,
  144. const at::Tensor dout,
  145. at::Tensor dq,
  146. at::Tensor dk,
  147. at::Tensor dv,
  148. void *cu_seqlens_q_d,
  149. void *cu_seqlens_k_d,
  150. void *dq_accum_d,
  151. void *dk_accum_d,
  152. void *dv_accum_d,
  153. void *softmax_lse_d,
  154. void *dsoftmax_sum_d,
  155. float p_dropout,
  156. float softmax_scale,
  157. int window_size_left,
  158. int window_size_right,
  159. bool deterministic) {
  160. set_params_fprop(params,
  161. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  162. q, k, v, out,
  163. cu_seqlens_q_d,
  164. cu_seqlens_k_d,
  165. nullptr,
  166. nullptr,
  167. softmax_lse_d,
  168. p_dropout,
  169. softmax_scale,
  170. window_size_left,
  171. window_size_right);
  172. // Set the pointers and strides.
  173. params.do_ptr = dout.data_ptr();
  174. params.do_row_stride = dout.stride(-3);
  175. params.do_head_stride = dout.stride(-2);
  176. params.dq_ptr = dq.data_ptr();
  177. params.dk_ptr = dk.data_ptr();
  178. params.dv_ptr = dv.data_ptr();
  179. params.dq_row_stride = dq.stride(-3);
  180. params.dk_row_stride = dk.stride(-3);
  181. params.dv_row_stride = dv.stride(-3);
  182. params.dq_head_stride = dq.stride(-2);
  183. params.dk_head_stride = dk.stride(-2);
  184. params.dv_head_stride = dv.stride(-2);
  185. if (cu_seqlens_q_d == nullptr) {
  186. params.do_batch_stride = dout.stride(0);
  187. params.dq_batch_stride = dq.stride(0);
  188. params.dk_batch_stride = dk.stride(0);
  189. params.dv_batch_stride = dv.stride(0);
  190. }
  191. params.dq_accum_ptr = dq_accum_d;
  192. params.dk_accum_ptr = dk_accum_d;
  193. params.dv_accum_ptr = dv_accum_d;
  194. // Softmax sum
  195. params.dsoftmax_sum = dsoftmax_sum_d;
  196. params.deterministic = deterministic;
  197. }
  198. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  199. // HEADDIM_SWITCH(params.d, [&] {
  200. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  201. // });
  202. if (!params.is_e4m3) {
  203. if (params.is_bf16) {
  204. if (params.d == 64) {
  205. run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
  206. } else if (params.d == 128) {
  207. run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
  208. } else {
  209. run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
  210. }
  211. } else {
  212. if (params.d == 64) {
  213. run_mha_fwd_<cutlass::half_t, 64>(params, stream);
  214. } else if (params.d == 128) {
  215. run_mha_fwd_<cutlass::half_t, 128>(params, stream);
  216. } else {
  217. run_mha_fwd_<cutlass::half_t, 256>(params, stream);
  218. }
  219. }
  220. } else {
  221. if (params.d == 64) {
  222. run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
  223. } else if (params.d == 128) {
  224. run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
  225. } else if (params.d == 256) {
  226. run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
  227. }
  228. }
  229. }
  230. std::vector<at::Tensor>
  231. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  232. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  233. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  234. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  235. const float softmax_scale,
  236. c10::optional<at::Tensor> &descale_q_, // 1
  237. c10::optional<at::Tensor> &descale_k_, // 1
  238. c10::optional<at::Tensor> &descale_v_, // 1
  239. bool is_causal) {
  240. auto dprops = at::cuda::getCurrentDeviceProperties();
  241. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  242. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  243. auto q_dtype = q.dtype();
  244. // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  245. // "FlashAttention only support fp16 and bf16 data type for now");
  246. // TODO: will add e4m3 later
  247. // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
  248. // "FlashAttention only support fp16 and bf16 data type");
  249. // "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
  250. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  251. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  252. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  253. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  254. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  255. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  256. const auto sizes = q.sizes();
  257. const int batch_size = sizes[0];
  258. int seqlen_q = sizes[1];
  259. int num_heads = sizes[2];
  260. const int head_size_og = sizes[3];
  261. const int seqlen_k = k.size(1);
  262. const int num_heads_k = k.size(2);
  263. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  264. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  265. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  266. TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
  267. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  268. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  269. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  270. at::Tensor q_padded, k_padded, v_padded;
  271. if (head_size_og % 8 != 0) {
  272. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  273. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  274. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  275. } else {
  276. q_padded = q;
  277. k_padded = k;
  278. v_padded = v;
  279. }
  280. at::Tensor out;
  281. if (out_.has_value()) {
  282. out = out_.value();
  283. // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  284. TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn
  285. ? (out.dtype() == at::kHalf)
  286. : (out.dtype() == q_dtype),
  287. "Output must have the same dtype as input dtype if dtype is "
  288. "not fp8, or fp16 for fp8 input.");
  289. CHECK_DEVICE(out);
  290. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  291. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  292. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  293. } else {
  294. if (q_dtype == at::ScalarType::Float8_e4m3fn)
  295. out = torch::empty_like(q_padded, at::kHalf);
  296. else
  297. out = torch::empty_like(q_padded);
  298. }
  299. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  300. const int head_size = round_multiple(head_size_og, 8);
  301. const int head_size_rounded = round_multiple(head_size, 32);
  302. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  303. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  304. // Otherwise the kernel will be launched from cuda:0 device
  305. // Cast to char to avoid compiler warning about narrowing
  306. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  307. auto opts = q.options();
  308. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  309. at::Tensor p;
  310. Flash_fwd_params params;
  311. set_params_fprop(params,
  312. batch_size,
  313. seqlen_q, seqlen_k,
  314. seqlen_q_rounded, seqlen_k_rounded,
  315. num_heads, num_heads_k,
  316. head_size, head_size_rounded,
  317. q_padded, k_padded, v_padded, out,
  318. /*cu_seqlens_q_d=*/nullptr,
  319. /*cu_seqlens_k_d=*/nullptr,
  320. /*seqused_k=*/nullptr,
  321. nullptr,
  322. softmax_lse.data_ptr(),
  323. /*p_dropout=*/0.f,
  324. softmax_scale,
  325. /*window_size_left=*/-1,
  326. /*window_size_right=*/is_causal ? 0 : -1);
  327. auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  328. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  329. if(q_dtype == at::ScalarType::Float8_e4m3fn) {
  330. at::Tensor descale_q, descale_k, descale_v;
  331. if (descale_q_.has_value() && descale_k_.has_value() && descale_k_.has_value()) {
  332. descale_q = descale_q_.value();
  333. descale_k = descale_k_.value();
  334. descale_v = descale_v_.value();
  335. CHECK_DEVICE(descale_q);
  336. CHECK_DEVICE(descale_k);
  337. CHECK_DEVICE(descale_v);
  338. CHECK_SHAPE(descale_q, 1);
  339. CHECK_SHAPE(descale_k, 1);
  340. CHECK_SHAPE(descale_v, 1);
  341. } else {
  342. descale_q = torch::ones({1}, opts.dtype(at::kFloat));
  343. descale_k = torch::ones({1}, opts.dtype(at::kFloat));
  344. descale_v = torch::ones({1}, opts.dtype(at::kFloat));
  345. }
  346. params.descale_q_ptr = descale_q.data_ptr<float>();
  347. params.descale_k_ptr = descale_k.data_ptr<float>();
  348. params.descale_v_ptr = descale_v.data_ptr<float>();
  349. } else {
  350. params.descale_q_ptr = nullptr;
  351. params.descale_k_ptr = nullptr;
  352. params.descale_v_ptr = nullptr;
  353. }
  354. if (seqlen_k > 0) {
  355. auto stream = at::cuda::getCurrentCUDAStream().stream();
  356. run_mha_fwd(params, stream);
  357. } else {
  358. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  359. out.zero_();
  360. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  361. }
  362. at::Tensor out_padded = out;
  363. if (head_size_og % 8 != 0) {
  364. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  365. if (out_.has_value()) { out_.value().copy_(out); }
  366. }
  367. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
  368. }
  369. std::vector<at::Tensor>
  370. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  371. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  372. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  373. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  374. const at::Tensor &cu_seqlens_q, // b+1
  375. const at::Tensor &cu_seqlens_k, // b+1
  376. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  377. int max_seqlen_q,
  378. const int max_seqlen_k,
  379. const float softmax_scale,
  380. bool is_causal) {
  381. auto dprops = at::cuda::getCurrentDeviceProperties();
  382. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  383. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  384. auto q_dtype = q.dtype();
  385. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  386. "FlashAttention only support fp16 and bf16 data type");
  387. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  388. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  389. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  390. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  391. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  392. CHECK_DEVICE(cu_seqlens_q);
  393. CHECK_DEVICE(cu_seqlens_k);
  394. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  395. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  396. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  397. CHECK_CONTIGUOUS(cu_seqlens_q);
  398. CHECK_CONTIGUOUS(cu_seqlens_k);
  399. const auto sizes = q.sizes();
  400. const int batch_size = cu_seqlens_q.numel() - 1;
  401. int num_heads = sizes[1];
  402. const int head_size_og = sizes[2];
  403. const int num_heads_k = k.size(1);
  404. int window_size_left = -1;
  405. int window_size_right = -1;
  406. if (is_causal) { window_size_right = 0; }
  407. void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  408. const int total_q = q.sizes()[0];
  409. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  410. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  411. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  412. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  413. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  414. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  415. const int total_k = k.size(0);
  416. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  417. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  418. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  419. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  420. if (seqused_k.has_value()){
  421. auto seqused_k_ = seqused_k.value();
  422. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  423. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  424. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  425. CHECK_SHAPE(seqused_k_, batch_size);
  426. }
  427. at::Tensor q_padded, k_padded, v_padded;
  428. if (head_size_og % 8 != 0) {
  429. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  430. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  431. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  432. } else {
  433. q_padded = q;
  434. k_padded = k;
  435. v_padded = v;
  436. }
  437. at::Tensor out;
  438. if (out_.has_value()) {
  439. out = out_.value();
  440. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  441. CHECK_DEVICE(out);
  442. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  443. CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
  444. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  445. } else {
  446. out = torch::empty_like(q_padded);
  447. }
  448. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  449. const int head_size = round_multiple(head_size_og, 8);
  450. const int head_size_rounded = round_multiple(head_size, 32);
  451. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  452. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  453. // Otherwise the kernel will be launched from cuda:0 device
  454. // Cast to char to avoid compiler warning about narrowing
  455. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  456. auto opts = q.options();
  457. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  458. Flash_fwd_params params;
  459. set_params_fprop(params,
  460. batch_size,
  461. max_seqlen_q, max_seqlen_k,
  462. seqlen_q_rounded, seqlen_k_rounded,
  463. num_heads, num_heads_k,
  464. head_size, head_size_rounded,
  465. q_padded, k_padded, v_padded, out,
  466. cu_seqlens_q_d,
  467. cu_seqlens_k.data_ptr(),
  468. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  469. /*p_d=*/nullptr,
  470. softmax_lse.data_ptr(),
  471. /*p_dropout=*/0.f,
  472. softmax_scale,
  473. window_size_left,
  474. window_size_right,
  475. /*seqlenq_ngroups_swapped=*/false,
  476. /*unpadded_lse=*/true);
  477. params.total_q = total_q;
  478. params.total_k = total_k;
  479. if (max_seqlen_k > 0) {
  480. auto stream = at::cuda::getCurrentCUDAStream().stream();
  481. run_mha_fwd(params, stream);
  482. } else {
  483. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  484. out.zero_();
  485. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  486. }
  487. at::Tensor out_padded = out;
  488. if (head_size_og % 8 != 0) {
  489. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  490. if (out_.has_value()) { out_.value().copy_(out); }
  491. }
  492. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
  493. }
  494. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  495. // FP16_SWITCH(!params.is_bf16, [&] {
  496. // HEADDIM_SWITCH(params.d, [&] {
  497. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  498. // });
  499. // });
  500. if (!params.is_bf16) {
  501. if (params.d <= 64) {
  502. run_mha_bwd_<cutlass::half_t, 64>(params, stream);
  503. } else if (params.d <= 96) {
  504. run_mha_bwd_<cutlass::half_t, 96>(params, stream);
  505. } else {
  506. run_mha_bwd_<cutlass::half_t, 128>(params, stream);
  507. }
  508. } else {
  509. if (params.d <= 64) {
  510. run_mha_bwd_<cutlass::bfloat16_t, 64>(params, stream);
  511. } else if (params.d <= 96) {
  512. run_mha_bwd_<cutlass::bfloat16_t, 96>(params, stream);
  513. } else {
  514. run_mha_bwd_<cutlass::bfloat16_t, 128>(params, stream);
  515. }
  516. }
  517. }
  518. std::vector<at::Tensor>
  519. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  520. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  521. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  522. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  523. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  524. const at::Tensor &softmax_lse, // b x h x seqlen_q
  525. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  526. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  527. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  528. const float softmax_scale,
  529. const bool is_causal,
  530. const bool deterministic) {
  531. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  532. TORCH_CHECK(false, "This flash attention build does not support backward.");
  533. #endif
  534. auto dprops = at::cuda::getCurrentDeviceProperties();
  535. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  536. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  537. auto stream = at::cuda::getCurrentCUDAStream().stream();
  538. auto q_dtype = q.dtype();
  539. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  540. "FlashAttention only support fp16 and bf16 data type");
  541. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  542. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  543. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  544. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  545. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  546. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  547. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  548. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  549. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  550. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  551. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  552. const auto sizes = q.sizes();
  553. const int batch_size = sizes[0];
  554. const int seqlen_q = sizes[1];
  555. const int num_heads = sizes[2];
  556. const int head_size_og = dout.size(3);
  557. const int head_size = sizes[3];
  558. const int seqlen_k = k.size(1);
  559. const int num_heads_k = k.size(2);
  560. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  561. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  562. TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
  563. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  564. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  565. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  566. // This should match the kernel configs
  567. const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
  568. const int seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
  569. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  570. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  571. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  572. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  573. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  574. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  575. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  576. at::Tensor dq, dk, dv;
  577. if (dq_.has_value()) {
  578. dq = dq_.value();
  579. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  580. CHECK_DEVICE(dq);
  581. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  582. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  583. } else {
  584. dq = torch::empty_like(q);
  585. }
  586. if (dk_.has_value()) {
  587. dk = dk_.value();
  588. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  589. CHECK_DEVICE(dk);
  590. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  591. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  592. } else {
  593. dk = torch::empty_like(k);
  594. }
  595. if (dv_.has_value()) {
  596. dv = dv_.value();
  597. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  598. CHECK_DEVICE(dv);
  599. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  600. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  601. } else {
  602. dv = torch::empty_like(v);
  603. }
  604. at::Tensor dout_padded;
  605. if (head_size_og % 8 != 0) {
  606. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  607. } else {
  608. dout_padded = dout;
  609. }
  610. // Otherwise the kernel will be launched from cuda:0 device
  611. // Cast to char to avoid compiler warning about narrowing
  612. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  613. auto opts = q.options();
  614. // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  615. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  616. auto softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  617. at::Tensor dq_accum;
  618. at::Tensor dk_accum, dv_accum;
  619. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  620. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  621. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  622. at::Tensor dk_expanded, dv_expanded;
  623. if (num_heads_k != num_heads) { // MQA / GQA
  624. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  625. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  626. } else {
  627. dk_expanded = dk;
  628. dv_expanded = dv;
  629. }
  630. Flash_bwd_params params;
  631. set_params_dgrad(params,
  632. batch_size,
  633. seqlen_q, seqlen_k,
  634. seqlen_q_rounded, seqlen_k_rounded,
  635. num_heads, num_heads_k,
  636. head_size, head_size_rounded,
  637. q, k, v, out,
  638. dout_padded, dq, dk_expanded, dv_expanded,
  639. nullptr,
  640. nullptr,
  641. dq_accum.data_ptr(),
  642. // loop ? dk_accum.data_ptr() : nullptr,
  643. // loop ? dv_accum.data_ptr() : nullptr,
  644. nullptr,
  645. nullptr,
  646. softmax_lse.data_ptr(),
  647. softmax_d.data_ptr(),
  648. /*p_dropout=*/0.f,
  649. softmax_scale,
  650. /*window_size_left=*/-1,
  651. /*window_size_right=*/is_causal ? 0 : -1,
  652. deterministic);
  653. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  654. // Will be zero'ed out in the backward preprocess kernel
  655. at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  656. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  657. // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads);
  658. if (seqlen_q > 0) {
  659. run_mha_bwd(params, stream);
  660. } else {
  661. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  662. dk_expanded.zero_();
  663. dv_expanded.zero_();
  664. softmax_d.zero_();
  665. }
  666. // For MQA/GQA we need to sum dK and dV across the groups
  667. if (num_heads_k != num_heads) {
  668. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  669. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  670. }
  671. if (head_size_og % 8 != 0) {
  672. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  673. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  674. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  675. }
  676. return { dq, dk, dv, softmax_d, dq_accum};
  677. }
  678. std::vector<at::Tensor>
  679. mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  680. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  681. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  682. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  683. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  684. const at::Tensor &softmax_lse, // b x h x seqlen_q
  685. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  686. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  687. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  688. const at::Tensor &cu_seqlens_q, // b+1
  689. const at::Tensor &cu_seqlens_k, // b+1
  690. const int max_seqlen_q,
  691. const int max_seqlen_k, // max sequence length to choose the kernel
  692. const float softmax_scale,
  693. const bool is_causal,
  694. const bool deterministic) {
  695. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  696. TORCH_CHECK(false, "This flash attention build does not support backward.");
  697. #endif
  698. auto dprops = at::cuda::getCurrentDeviceProperties();
  699. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  700. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  701. auto stream = at::cuda::getCurrentCUDAStream().stream();
  702. auto q_dtype = q.dtype();
  703. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  704. "FlashAttention only support fp16 and bf16 data type");
  705. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  706. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  707. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  708. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  709. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  710. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  711. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  712. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  713. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  714. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  715. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  716. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  717. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  718. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  719. CHECK_CONTIGUOUS(cu_seqlens_q);
  720. CHECK_CONTIGUOUS(cu_seqlens_k);
  721. const auto sizes = q.sizes();
  722. const int total_q = sizes[0];
  723. const int batch_size = cu_seqlens_q.numel() - 1;
  724. const int num_heads = sizes[1];
  725. const int head_size_og = dout.size(2);
  726. const int head_size = sizes[2];
  727. const int total_k = k.size(0);
  728. const int num_heads_k = k.size(1);
  729. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  730. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  731. TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
  732. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  733. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  734. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  735. // This should match the kernel configs
  736. const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
  737. const int seqlen_q_rounded = round_multiple(max_seqlen_q, kBlockM);
  738. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  739. int const total_q_padded_rounded = round_multiple(total_q + batch_size * 128, 128);
  740. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  741. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  742. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  743. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  744. CHECK_SHAPE(out, total_q, num_heads, head_size);
  745. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  746. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  747. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  748. at::Tensor dq, dk, dv;
  749. if (dq_.has_value()) {
  750. dq = dq_.value();
  751. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  752. CHECK_DEVICE(dq);
  753. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  754. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  755. } else {
  756. dq = torch::empty_like(q);
  757. }
  758. if (dk_.has_value()) {
  759. dk = dk_.value();
  760. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  761. CHECK_DEVICE(dk);
  762. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  763. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  764. } else {
  765. dk = torch::empty_like(k);
  766. }
  767. if (dv_.has_value()) {
  768. dv = dv_.value();
  769. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  770. CHECK_DEVICE(dv);
  771. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  772. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  773. } else {
  774. dv = torch::empty_like(v);
  775. }
  776. at::Tensor dout_padded;
  777. if (head_size_og % 8 != 0) {
  778. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  779. } else {
  780. dout_padded = dout;
  781. }
  782. // Otherwise the kernel will be launched from cuda:0 device
  783. // Cast to char to avoid compiler warning about narrowing
  784. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  785. auto opts = q.options();
  786. // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  787. auto softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  788. auto softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  789. at::Tensor dq_accum;
  790. at::Tensor dk_accum, dv_accum;
  791. dq_accum = torch::empty({num_heads, total_q_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  792. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  793. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  794. at::Tensor dk_expanded, dv_expanded;
  795. if (num_heads_k != num_heads) { // MQA / GQA
  796. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  797. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  798. } else {
  799. dk_expanded = dk;
  800. dv_expanded = dv;
  801. }
  802. Flash_bwd_params params;
  803. set_params_dgrad(params,
  804. batch_size,
  805. max_seqlen_q, max_seqlen_k,
  806. seqlen_q_rounded, seqlen_k_rounded,
  807. num_heads, num_heads_k,
  808. head_size, head_size_rounded,
  809. q, k, v, out,
  810. dout_padded, dq, dk_expanded, dv_expanded,
  811. cu_seqlens_q.data_ptr(),
  812. cu_seqlens_k.data_ptr(),
  813. dq_accum.data_ptr(),
  814. // loop ? dk_accum.data_ptr() : nullptr,
  815. // loop ? dv_accum.data_ptr() : nullptr,
  816. nullptr,
  817. nullptr,
  818. softmax_lse.data_ptr(),
  819. softmax_d.data_ptr(),
  820. /*p_dropout=*/0.f,
  821. softmax_scale,
  822. /*window_size_left=*/-1,
  823. /*window_size_right=*/is_causal ? 0 : -1,
  824. deterministic);
  825. params.total_q = total_q;
  826. params.total_k = total_k;
  827. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  828. // Will be zero'ed out in the backward preprocess kernel
  829. at::Tensor dq_semaphore = torch::empty({(max_seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  830. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  831. if (max_seqlen_q > 0) {
  832. run_mha_bwd(params, stream);
  833. } else {
  834. // If max_seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  835. dk_expanded.zero_();
  836. dv_expanded.zero_();
  837. softmax_d.zero_();
  838. }
  839. // For MQA/GQA we need to sum dK and dV across the groups
  840. if (num_heads_k != num_heads) {
  841. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  842. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  843. }
  844. if (head_size_og % 8 != 0) {
  845. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  846. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  847. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  848. }
  849. return { dq, dk, dv, softmax_d, dq_accum, softmax_lse_log2 };
  850. }
  851. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  852. m.doc() = "FlashAttention";
  853. m.def("fwd", &mha_fwd, "Forward pass");
  854. m.def("bwd", &mha_bwd, "Backward pass");
  855. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  856. m.def("varlen_bwd", &mha_varlen_bwd, "Varlen backward pass");
  857. }