1
0

flash_api.cpp 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *seqused_q,
  34. void *seqused_k,
  35. void *p_d,
  36. void *softmax_lse_d,
  37. float p_dropout,
  38. float softmax_scale,
  39. int window_size_left,
  40. int window_size_right,
  41. const float softcap=0.f,
  42. bool seqlenq_ngroups_swapped=false) {
  43. // Reset the parameters
  44. params = {};
  45. params.is_bf16 = q.dtype() == torch::kBFloat16;
  46. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  47. // Set the pointers and strides.
  48. params.q_ptr = q.data_ptr();
  49. params.k_ptr = k.data_ptr();
  50. params.v_ptr = v.data_ptr();
  51. // All stride are in elements, not bytes.
  52. params.q_row_stride = q.stride(-3);
  53. params.k_row_stride = k.stride(-3);
  54. params.v_row_stride = v.stride(-3);
  55. params.q_head_stride = q.stride(-2);
  56. params.k_head_stride = k.stride(-2);
  57. params.v_head_stride = v.stride(-2);
  58. params.v_dim_stride = v.stride(-1);
  59. params.o_ptr = out.data_ptr();
  60. params.o_row_stride = out.stride(-3);
  61. params.o_head_stride = out.stride(-2);
  62. if (cu_seqlens_q_d == nullptr) {
  63. params.q_batch_stride = q.stride(0);
  64. params.k_batch_stride = k.stride(0);
  65. params.v_batch_stride = v.stride(0);
  66. params.o_batch_stride = out.stride(0);
  67. if (seqlenq_ngroups_swapped) {
  68. params.q_batch_stride *= seqlen_q;
  69. params.o_batch_stride *= seqlen_q;
  70. }
  71. }
  72. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  73. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  74. params.seqused_q = static_cast<int *>(seqused_q);
  75. params.seqused_k = static_cast<int *>(seqused_k);
  76. // P = softmax(QK^T)
  77. params.p_ptr = p_d;
  78. // Softmax sum
  79. params.softmax_lse_ptr = softmax_lse_d;
  80. // Set the dimensions.
  81. params.b = b;
  82. params.h = h;
  83. params.h_k = h_k;
  84. params.h_h_k_ratio = h / h_k;
  85. params.seqlen_q = seqlen_q;
  86. params.seqlen_k = seqlen_k;
  87. params.seqlen_q_rounded = seqlen_q_rounded;
  88. params.seqlen_k_rounded = seqlen_k_rounded;
  89. params.d = d;
  90. params.d_rounded = d_rounded;
  91. // Set the different scale values.
  92. params.scale_softmax = softmax_scale;
  93. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  94. __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
  95. __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half);
  96. params.scale_softmax_log2_half2 = reinterpret_cast<uint32_t&>(scale_softmax_log2_half2);
  97. params.softcap = softcap;
  98. // Set this to probability of keeping an element to simplify things.
  99. params.p_dropout = 1.f - p_dropout;
  100. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  101. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  102. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  103. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  104. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  105. params.rp_dropout = 1.f / params.p_dropout;
  106. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  107. TORCH_CHECK(p_dropout < 1.f);
  108. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  109. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  110. #endif
  111. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  112. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  113. params.is_causal = window_size_left < 0 && window_size_right == 0;
  114. params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
  115. // TODO: check this
  116. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k - 1; }
  117. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_q - 1; }
  118. params.window_size_left = window_size_left;
  119. params.window_size_right = window_size_right;
  120. #ifdef FLASHATTENTION_DISABLE_LOCAL
  121. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  122. "This flash attention build does not support local attention.");
  123. #endif
  124. params.is_seqlens_k_cumulative = true;
  125. }
  126. void set_params_dgrad(Flash_bwd_params &params,
  127. // sizes
  128. const size_t b,
  129. const size_t seqlen_q,
  130. const size_t seqlen_k,
  131. const size_t seqlen_q_rounded,
  132. const size_t seqlen_k_rounded,
  133. const size_t h,
  134. const size_t h_k,
  135. const size_t d,
  136. const size_t d_rounded,
  137. // device pointers
  138. const at::Tensor q,
  139. const at::Tensor k,
  140. const at::Tensor v,
  141. const at::Tensor out,
  142. const at::Tensor dout,
  143. at::Tensor dq,
  144. at::Tensor dk,
  145. at::Tensor dv,
  146. void *cu_seqlens_q_d,
  147. void *cu_seqlens_k_d,
  148. void *seqused_q,
  149. void *seqused_k,
  150. void *dq_accum_d,
  151. void *dk_accum_d,
  152. void *dv_accum_d,
  153. void *softmax_lse_d,
  154. void *dsoftmax_sum_d,
  155. float p_dropout,
  156. float softmax_scale,
  157. int window_size_left,
  158. int window_size_right,
  159. const float softcap=0.f,
  160. bool deterministic=false) {
  161. set_params_fprop(params,
  162. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  163. q, k, v, out,
  164. cu_seqlens_q_d,
  165. cu_seqlens_k_d,
  166. seqused_q,
  167. seqused_k,
  168. nullptr,
  169. softmax_lse_d,
  170. p_dropout,
  171. softmax_scale,
  172. window_size_left,
  173. window_size_right,
  174. softcap);
  175. // Set the pointers and strides.
  176. params.do_ptr = dout.data_ptr();
  177. params.do_row_stride = dout.stride(-3);
  178. params.do_head_stride = dout.stride(-2);
  179. params.dq_ptr = dq.data_ptr();
  180. params.dk_ptr = dk.data_ptr();
  181. params.dv_ptr = dv.data_ptr();
  182. params.dq_row_stride = dq.stride(-3);
  183. params.dk_row_stride = dk.stride(-3);
  184. params.dv_row_stride = dv.stride(-3);
  185. params.dq_head_stride = dq.stride(-2);
  186. params.dk_head_stride = dk.stride(-2);
  187. params.dv_head_stride = dv.stride(-2);
  188. if (cu_seqlens_q_d == nullptr) {
  189. params.do_batch_stride = dout.stride(0);
  190. params.dq_batch_stride = dq.stride(0);
  191. params.dk_batch_stride = dk.stride(0);
  192. params.dv_batch_stride = dv.stride(0);
  193. }
  194. params.dq_accum_ptr = dq_accum_d;
  195. params.dk_accum_ptr = dk_accum_d;
  196. params.dv_accum_ptr = dv_accum_d;
  197. // Softmax sum
  198. params.dsoftmax_sum = dsoftmax_sum_d;
  199. params.deterministic = deterministic;
  200. }
  201. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  202. // HEADDIM_SWITCH(params.d, [&] {
  203. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  204. // });
  205. if (!params.is_e4m3) {
  206. if (params.is_bf16) {
  207. if (params.d <= 64) {
  208. run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
  209. } else if (params.d <= 96) {
  210. run_mha_fwd_<cutlass::bfloat16_t, 96>(params, stream);
  211. } else if (params.d <= 128) {
  212. run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
  213. } else if (params.d <= 192) {
  214. run_mha_fwd_<cutlass::bfloat16_t, 192>(params, stream);
  215. } else {
  216. run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
  217. }
  218. } else {
  219. if (params.d <= 64) {
  220. run_mha_fwd_<cutlass::half_t, 64>(params, stream);
  221. } else if (params.d <= 96) {
  222. run_mha_fwd_<cutlass::half_t, 96>(params, stream);
  223. } else if (params.d <= 128) {
  224. run_mha_fwd_<cutlass::half_t, 128>(params, stream);
  225. } else if (params.d <= 192) {
  226. run_mha_fwd_<cutlass::half_t, 192>(params, stream);
  227. } else {
  228. run_mha_fwd_<cutlass::half_t, 256>(params, stream);
  229. }
  230. }
  231. } else {
  232. if (params.d <= 64) {
  233. run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
  234. } else if (params.d <= 96) {
  235. run_mha_fwd_<cutlass::float_e4m3_t, 96>(params, stream);
  236. } else if (params.d <= 128) {
  237. run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
  238. } else if (params.d <= 192) {
  239. run_mha_fwd_<cutlass::float_e4m3_t, 192>(params, stream);
  240. } else {
  241. run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
  242. }
  243. }
  244. }
  245. std::vector<at::Tensor>
  246. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  247. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  248. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  249. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  250. const float softmax_scale,
  251. bool is_causal,
  252. c10::optional<at::Tensor> &q_scale_, // 1
  253. c10::optional<at::Tensor> &k_scale_, // 1
  254. c10::optional<at::Tensor> &v_scale_, // 1
  255. int window_size_left,
  256. int window_size_right,
  257. const float softcap
  258. ) {
  259. auto dprops = at::cuda::getCurrentDeviceProperties();
  260. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  261. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  262. auto q_type = q.scalar_type();
  263. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
  264. "FlashAttention only support fp16, bf16, and fp8_e4m3 data type");
  265. TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
  266. TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
  267. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  268. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  269. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  270. TORCH_CHECK(v.stride(-1) == 1 || v.stride(-3) == 1, "Input tensor V must have contiguous last dimension or contiguous seqlen dimension");
  271. if (v.stride(-1) != 1) {
  272. TORCH_CHECK(q_type == at::ScalarType::Float8_e4m3fn, "Only fp8_e4m3 data type supports input tensor V having contiguous seqlen dimension")
  273. }
  274. const auto sizes = q.sizes();
  275. const int batch_size = sizes[0];
  276. int seqlen_q = sizes[1];
  277. int num_heads = sizes[2];
  278. const int head_size_og = sizes[3];
  279. const int seqlen_k = k.size(1);
  280. const int num_heads_k = k.size(2);
  281. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  282. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  283. if (softcap > 0.0) { TORCH_CHECK(q_type != at::ScalarType::Float8_e4m3fn, "Softcap is not yet supported for fp8_e4m3 data type"); }
  284. // TODO: check this
  285. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  286. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  287. if (is_causal) {
  288. window_size_left = -1;
  289. window_size_right = 0;
  290. }
  291. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  292. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  293. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  294. int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
  295. at::Tensor q_padded, k_padded, v_padded;
  296. auto pad = [](at::Tensor x, int alignment) {
  297. return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
  298. };
  299. q_padded = pad(q, alignment);
  300. k_padded = pad(k, alignment);
  301. v_padded = pad(v, alignment);
  302. if (v_padded.stride(-1) != 1) {
  303. TORCH_CHECK(v_padded.stride(-1) % 16 == 0 && v_padded.stride(-2) % 16 == 0 && v_padded.stride(-4) % 16 == 0,
  304. "If input tensor V has contiguous seqlen dimension, the others dimension must have stride divisible by 16");
  305. }
  306. auto opts = q.options();
  307. auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
  308. at::Tensor out;
  309. if (out_.has_value()) {
  310. out = out_.value();
  311. TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
  312. CHECK_DEVICE(out);
  313. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  314. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  315. if (head_size_og % alignment != 0) { out = torch::empty_like(q_padded, opts.dtype(out_type)); }
  316. } else {
  317. out = torch::empty_like(q_padded, opts.dtype(out_type));
  318. }
  319. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  320. const int head_size = round_multiple(head_size_og, alignment);
  321. const int head_size_rounded = head_size <= 64 ? 64 : (head_size <= 128 ? round_multiple(head_size, 32) : round_multiple(head_size, 64));
  322. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  323. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  324. // Otherwise the kernel will be launched from cuda:0 device
  325. // Cast to char to avoid compiler warning about narrowing
  326. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  327. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  328. Flash_fwd_params params;
  329. set_params_fprop(params,
  330. batch_size,
  331. seqlen_q, seqlen_k,
  332. seqlen_q_rounded, seqlen_k_rounded,
  333. num_heads, num_heads_k,
  334. head_size, head_size_rounded,
  335. q_padded, k_padded, v_padded, out,
  336. /*cu_seqlens_q_d=*/nullptr,
  337. /*cu_seqlens_k_d=*/nullptr,
  338. /*seqused_q_=*/nullptr,
  339. /*seqused_k=*/nullptr,
  340. nullptr,
  341. softmax_lse.data_ptr(),
  342. /*p_dropout=*/0.f,
  343. softmax_scale,
  344. window_size_left,
  345. window_size_right,
  346. softcap);
  347. auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  348. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  349. if (q_type == at::ScalarType::Float8_e4m3fn) {
  350. if (q_scale_.has_value()) {
  351. auto q_scale = q_scale_.value();
  352. CHECK_DEVICE(q_scale);
  353. CHECK_SHAPE(q_scale, 1);
  354. params.q_scale_ptr = q_scale.data_ptr<float>();
  355. } else {
  356. params.q_scale_ptr = nullptr;
  357. }
  358. if (k_scale_.has_value()) {
  359. auto k_scale = k_scale_.value();
  360. CHECK_DEVICE(k_scale);
  361. CHECK_SHAPE(k_scale, 1);
  362. params.k_scale_ptr = k_scale.data_ptr<float>();
  363. } else {
  364. params.k_scale_ptr = nullptr;
  365. }
  366. if (v_scale_.has_value()) {
  367. auto v_scale = v_scale_.value();
  368. CHECK_DEVICE(v_scale);
  369. CHECK_SHAPE(v_scale, 1);
  370. params.v_scale_ptr = v_scale.data_ptr<float>();
  371. } else {
  372. params.v_scale_ptr = nullptr;
  373. }
  374. }
  375. if (seqlen_k > 0 && batch_size > 0) {
  376. auto stream = at::cuda::getCurrentCUDAStream().stream();
  377. run_mha_fwd(params, stream);
  378. } else if (batch_size > 0) {
  379. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  380. out.zero_();
  381. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  382. }
  383. at::Tensor out_padded = out;
  384. if (head_size_og % alignment != 0) {
  385. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  386. if (out_.has_value()) { out_.value().copy_(out); }
  387. }
  388. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
  389. }
  390. std::vector<at::Tensor>
  391. mha_varlen_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  392. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  393. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  394. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  395. const at::Tensor &cu_seqlens_q, // b+1
  396. const at::Tensor &cu_seqlens_k, // b+1
  397. c10::optional<at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  398. c10::optional<at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  399. int const max_seqlen_q,
  400. int const max_seqlen_k,
  401. const float softmax_scale,
  402. bool is_causal,
  403. c10::optional<at::Tensor> &q_scale_, // 1
  404. c10::optional<at::Tensor> &k_scale_, // 1
  405. c10::optional<at::Tensor> &v_scale_, // 1
  406. int window_size_left,
  407. int window_size_right,
  408. const float softcap
  409. ) {
  410. auto dprops = at::cuda::getCurrentDeviceProperties();
  411. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  412. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  413. auto q_type = q.scalar_type();
  414. TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
  415. "FlashAttention only support fp16, bf16, and fp8_e4m3 data type");
  416. TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
  417. TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
  418. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  419. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  420. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  421. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  422. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  423. TORCH_CHECK(cu_seqlens_q.stride(-1) == 1, "cu_seqlens_q must have contiguous last dimension");
  424. TORCH_CHECK(cu_seqlens_k.stride(-1) == 1, "cu_seqlens_q must have contiguous last dimension");
  425. const auto sizes = q.sizes();
  426. const int batch_size = cu_seqlens_q.numel() - 1;
  427. int num_heads = sizes[1];
  428. const int head_size_og = sizes[2];
  429. const int num_heads_k = k.size(1);
  430. const int total_q = q.sizes()[0];
  431. const int total_k = k.sizes()[0];
  432. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  433. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  434. if (softcap > 0.0) { TORCH_CHECK(q_type != at::ScalarType::Float8_e4m3fn, "Softcap is not yet supported for fp8_e4m3 data type"); }
  435. if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
  436. if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
  437. if (is_causal) {
  438. window_size_left = -1;
  439. window_size_right = 0;
  440. }
  441. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  442. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  443. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  444. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  445. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  446. if (seqused_q_.has_value()){
  447. auto seqused_q = seqused_q_.value();
  448. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  449. TORCH_CHECK(seqused_q.is_cuda(), "seqused_q must be on CUDA device");
  450. TORCH_CHECK(seqused_q.is_contiguous(), "seqused_q must be contiguous");
  451. CHECK_SHAPE(seqused_q, batch_size);
  452. }
  453. if (seqused_k_.has_value()){
  454. auto seqused_k = seqused_k_.value();
  455. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  456. TORCH_CHECK(seqused_k.is_cuda(), "seqused_k must be on CUDA device");
  457. TORCH_CHECK(seqused_k.is_contiguous(), "seqused_k must be contiguous");
  458. CHECK_SHAPE(seqused_k, batch_size);
  459. }
  460. int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
  461. at::Tensor q_padded, k_padded, v_padded;
  462. auto pad = [](at::Tensor x, int alignment) {
  463. return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
  464. };
  465. q_padded = pad(q, alignment);
  466. k_padded = pad(k, alignment);
  467. v_padded = pad(v, alignment);
  468. auto opts = q.options();
  469. auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
  470. at::Tensor out;
  471. if (out_.has_value()) {
  472. out = out_.value();
  473. TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
  474. CHECK_DEVICE(out);
  475. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  476. CHECK_SHAPE(out, total_q, num_heads, head_size_og);
  477. if (head_size_og % alignment != 0) { out = torch::empty_like(q_padded, opts.dtype(out_type)); }
  478. } else {
  479. out = torch::empty_like(q_padded, opts.dtype(out_type));
  480. }
  481. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  482. const int head_size = round_multiple(head_size_og, alignment);
  483. const int head_size_rounded = head_size <= 64 ? 64 : (head_size <= 128 ? round_multiple(head_size, 32) : round_multiple(head_size, 64));
  484. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  485. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  486. // Otherwise the kernel will be launched from cuda:0 device
  487. // Cast to char to avoid compiler warning about narrowing
  488. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  489. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  490. Flash_fwd_params params;
  491. set_params_fprop(params,
  492. batch_size,
  493. max_seqlen_q, max_seqlen_k,
  494. seqlen_q_rounded, seqlen_k_rounded,
  495. num_heads, num_heads_k,
  496. head_size, head_size_rounded,
  497. q_padded, k_padded, v_padded, out,
  498. cu_seqlens_q.data_ptr(),
  499. cu_seqlens_k.data_ptr(),
  500. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  501. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  502. nullptr,
  503. softmax_lse.data_ptr(),
  504. /*p_dropout=*/0.f,
  505. softmax_scale,
  506. window_size_left,
  507. window_size_right,
  508. softcap);
  509. params.total_q = total_q;
  510. params.total_k = total_k;
  511. auto tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32));
  512. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  513. if (q_type == at::ScalarType::Float8_e4m3fn) {
  514. if (q_scale_.has_value()) {
  515. auto q_scale = q_scale_.value();
  516. CHECK_DEVICE(q_scale);
  517. CHECK_SHAPE(q_scale, 1);
  518. params.q_scale_ptr = q_scale.data_ptr<float>();
  519. } else {
  520. params.q_scale_ptr = nullptr;
  521. }
  522. if (k_scale_.has_value()) {
  523. auto k_scale = k_scale_.value();
  524. CHECK_DEVICE(k_scale);
  525. CHECK_SHAPE(k_scale, 1);
  526. params.k_scale_ptr = k_scale.data_ptr<float>();
  527. } else {
  528. params.k_scale_ptr = nullptr;
  529. }
  530. if (v_scale_.has_value()) {
  531. auto v_scale = v_scale_.value();
  532. CHECK_DEVICE(v_scale);
  533. CHECK_SHAPE(v_scale, 1);
  534. params.v_scale_ptr = v_scale.data_ptr<float>();
  535. } else {
  536. params.v_scale_ptr = nullptr;
  537. }
  538. }
  539. if (max_seqlen_k > 0 && batch_size > 0) {
  540. auto stream = at::cuda::getCurrentCUDAStream().stream();
  541. run_mha_fwd(params, stream);
  542. } else if (batch_size > 0) {
  543. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  544. out.zero_();
  545. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  546. }
  547. at::Tensor out_padded = out;
  548. if (head_size_og % 8 != 0) {
  549. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  550. if (out_.has_value()) { out_.value().copy_(out); }
  551. }
  552. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
  553. }
  554. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  555. // FP16_SWITCH(!params.is_bf16, [&] {
  556. // HEADDIM_SWITCH(params.d, [&] {
  557. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  558. // });
  559. // });
  560. if (!params.is_bf16) {
  561. if (params.d <= 64) {
  562. run_mha_bwd_<cutlass::half_t, 64>(params, stream);
  563. } else if (params.d <= 96) {
  564. run_mha_bwd_<cutlass::half_t, 96>(params, stream);
  565. } else if (params.d <= 128) {
  566. run_mha_bwd_<cutlass::half_t, 128>(params, stream);
  567. } else if (params.d <= 192) {
  568. run_mha_bwd_<cutlass::half_t, 192>(params, stream);
  569. } else {
  570. run_mha_bwd_<cutlass::half_t, 256>(params, stream);
  571. }
  572. } else {
  573. if (params.d <= 64) {
  574. run_mha_bwd_<cutlass::bfloat16_t, 64>(params, stream);
  575. } else if (params.d <= 96) {
  576. run_mha_bwd_<cutlass::bfloat16_t, 96>(params, stream);
  577. } else if (params.d <= 128) {
  578. run_mha_bwd_<cutlass::bfloat16_t, 128>(params, stream);
  579. } else if (params.d <= 192) {
  580. run_mha_bwd_<cutlass::bfloat16_t, 192>(params, stream);
  581. } else {
  582. run_mha_bwd_<cutlass::bfloat16_t, 256>(params, stream);
  583. }
  584. }
  585. }
  586. std::vector<at::Tensor>
  587. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  588. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  589. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  590. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  591. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  592. const at::Tensor &softmax_lse, // b x h x seqlen_q
  593. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  594. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  595. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  596. const float softmax_scale,
  597. const bool is_causal,
  598. int window_size_left,
  599. int window_size_right,
  600. const float softcap,
  601. const bool deterministic) {
  602. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  603. TORCH_CHECK(false, "This flash attention build does not support backward.");
  604. #endif
  605. auto dprops = at::cuda::getCurrentDeviceProperties();
  606. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  607. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  608. auto stream = at::cuda::getCurrentCUDAStream().stream();
  609. auto q_type = q.dtype();
  610. TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
  611. "FlashAttention only support fp16 and bf16 data type");
  612. TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
  613. TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
  614. TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
  615. TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
  616. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  617. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  618. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  619. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  620. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  621. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  622. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  623. const auto sizes = q.sizes();
  624. const int batch_size = sizes[0];
  625. const int seqlen_q = sizes[1];
  626. const int num_heads = sizes[2];
  627. const int head_size_og = dout.size(3);
  628. const int head_size = sizes[3];
  629. const int seqlen_k = k.size(1);
  630. const int num_heads_k = k.size(2);
  631. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  632. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  633. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  634. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  635. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  636. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  637. // This needs to match the kernel configs
  638. const int kBlockM = head_size <= 64 ? (softcap == 0.0 ? 128 : 96) : 64;
  639. const int kBlockN = head_size <= 128 ? 128 : (head_size <= 192 ? 96 : 80);
  640. const int seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
  641. const int seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
  642. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  643. if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
  644. if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
  645. if (is_causal) {
  646. window_size_left = -1;
  647. window_size_right = 0;
  648. }
  649. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  650. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  651. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  652. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  653. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  654. at::Tensor dq, dk, dv;
  655. if (dq_.has_value()) {
  656. dq = dq_.value();
  657. TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
  658. CHECK_DEVICE(dq);
  659. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  660. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  661. } else {
  662. dq = torch::empty_like(q);
  663. }
  664. if (dk_.has_value()) {
  665. dk = dk_.value();
  666. TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
  667. CHECK_DEVICE(dk);
  668. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  669. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  670. } else {
  671. dk = torch::empty_like(k);
  672. }
  673. if (dv_.has_value()) {
  674. dv = dv_.value();
  675. TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
  676. CHECK_DEVICE(dv);
  677. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  678. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  679. } else {
  680. dv = torch::empty_like(v);
  681. }
  682. at::Tensor dout_padded;
  683. if (head_size_og % 8 != 0) {
  684. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  685. } else {
  686. dout_padded = dout;
  687. }
  688. // Otherwise the kernel will be launched from cuda:0 device
  689. // Cast to char to avoid compiler warning about narrowing
  690. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  691. auto opts = q.options();
  692. // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  693. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  694. auto softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  695. at::Tensor dq_accum;
  696. at::Tensor dk_accum, dv_accum;
  697. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  698. if (num_heads_k != num_heads) { // MQA / GQA
  699. dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  700. dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  701. }
  702. Flash_bwd_params params;
  703. set_params_dgrad(params,
  704. batch_size,
  705. seqlen_q, seqlen_k,
  706. seqlen_q_rounded, seqlen_k_rounded,
  707. num_heads, num_heads_k,
  708. head_size, head_size_rounded,
  709. q, k, v, out,
  710. dout_padded, dq, dk, dv,
  711. nullptr /*cu_seqlens_q*/,
  712. nullptr /*cu_seqlens_k*/,
  713. nullptr /*seqused_q_*/,
  714. nullptr /*seqused_k*/,
  715. dq_accum.data_ptr(),
  716. num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
  717. num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
  718. // nullptr,
  719. // nullptr,
  720. softmax_lse.data_ptr(),
  721. softmax_d.data_ptr(),
  722. /*p_dropout=*/0.f,
  723. softmax_scale,
  724. window_size_left,
  725. window_size_right,
  726. softcap,
  727. deterministic);
  728. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  729. // Will be zero'ed out in the backward preprocess kernel
  730. at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  731. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  732. // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads);
  733. if (num_heads_k != num_heads) {
  734. at::Tensor dk_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  735. at::Tensor dv_semaphore = torch::zeros({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  736. params.dk_semaphore = dk_semaphore.data_ptr<int>();
  737. params.dv_semaphore = dv_semaphore.data_ptr<int>();
  738. }
  739. if (seqlen_q > 0) {
  740. run_mha_bwd(params, stream);
  741. } else {
  742. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  743. dk.zero_();
  744. dv.zero_();
  745. softmax_d.zero_();
  746. }
  747. if (head_size_og % 8 != 0) {
  748. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  749. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  750. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  751. }
  752. return { dq, dk, dv, softmax_d, dq_accum, dk_accum, dv_accum};
  753. }
  754. std::vector<at::Tensor>
  755. mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  756. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  757. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  758. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  759. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  760. const at::Tensor &softmax_lse, // b x h x seqlen_q
  761. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  762. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  763. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  764. const at::Tensor &cu_seqlens_q, // b+1
  765. const at::Tensor &cu_seqlens_k, // b+1
  766. c10::optional<at::Tensor> &seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
  767. c10::optional<at::Tensor> &seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
  768. const int max_seqlen_q,
  769. const int max_seqlen_k, // max sequence length to choose the kernel
  770. const float softmax_scale,
  771. const bool is_causal,
  772. int window_size_left,
  773. int window_size_right,
  774. const float softcap,
  775. const bool deterministic) {
  776. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  777. TORCH_CHECK(false, "This flash attention build does not support backward.");
  778. #endif
  779. auto dprops = at::cuda::getCurrentDeviceProperties();
  780. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  781. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  782. auto stream = at::cuda::getCurrentCUDAStream().stream();
  783. auto q_type = q.dtype();
  784. TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
  785. "FlashAttention only support fp16 and bf16 data type");
  786. TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
  787. TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
  788. TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
  789. TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
  790. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  791. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  792. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  793. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  794. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  795. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  796. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  797. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  798. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  799. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  800. CHECK_CONTIGUOUS(cu_seqlens_q);
  801. CHECK_CONTIGUOUS(cu_seqlens_k);
  802. const auto sizes = q.sizes();
  803. const int total_q = sizes[0];
  804. const int batch_size = cu_seqlens_q.numel() - 1;
  805. const int num_heads = sizes[1];
  806. const int head_size_og = dout.size(2);
  807. const int head_size = sizes[2];
  808. const int total_k = k.size(0);
  809. const int num_heads_k = k.size(1);
  810. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  811. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  812. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  813. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  814. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  815. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  816. // This needs to match the kernel configs
  817. const int kBlockM = head_size <= 64 ? (softcap == 0.0 ? 128 : 96) : 64;
  818. const int kBlockN = head_size <= 128 ? 128 : (head_size <= 192 ? 96 : 80);
  819. const int seqlen_q_rounded = round_multiple(max_seqlen_q, kBlockM);
  820. const int seqlen_k_rounded = round_multiple(max_seqlen_k, kBlockN);
  821. int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
  822. int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
  823. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  824. if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
  825. if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
  826. if (is_causal) {
  827. window_size_left = -1;
  828. window_size_right = 0;
  829. }
  830. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  831. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  832. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  833. CHECK_SHAPE(out, total_q, num_heads, head_size);
  834. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  835. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  836. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  837. if (seqused_q_.has_value()){
  838. auto seqused_q = seqused_q_.value();
  839. TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  840. TORCH_CHECK(seqused_q.is_cuda(), "seqused_q must be on CUDA device");
  841. TORCH_CHECK(seqused_q.is_contiguous(), "seqused_q must be contiguous");
  842. CHECK_SHAPE(seqused_q, batch_size);
  843. }
  844. if (seqused_k_.has_value()){
  845. auto seqused_k = seqused_k_.value();
  846. TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  847. TORCH_CHECK(seqused_k.is_cuda(), "seqused_k must be on CUDA device");
  848. TORCH_CHECK(seqused_k.is_contiguous(), "seqused_k must be contiguous");
  849. CHECK_SHAPE(seqused_k, batch_size);
  850. }
  851. at::Tensor dq, dk, dv;
  852. if (dq_.has_value()) {
  853. dq = dq_.value();
  854. TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
  855. CHECK_DEVICE(dq);
  856. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  857. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  858. } else {
  859. dq = torch::empty_like(q);
  860. }
  861. if (dk_.has_value()) {
  862. dk = dk_.value();
  863. TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
  864. CHECK_DEVICE(dk);
  865. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  866. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  867. } else {
  868. dk = torch::empty_like(k);
  869. }
  870. if (dv_.has_value()) {
  871. dv = dv_.value();
  872. TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
  873. CHECK_DEVICE(dv);
  874. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  875. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  876. } else {
  877. dv = torch::empty_like(v);
  878. }
  879. at::Tensor dout_padded;
  880. if (head_size_og % 8 != 0) {
  881. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  882. } else {
  883. dout_padded = dout;
  884. }
  885. // Otherwise the kernel will be launched from cuda:0 device
  886. // Cast to char to avoid compiler warning about narrowing
  887. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  888. auto opts = q.options();
  889. // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  890. auto softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  891. auto softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  892. at::Tensor dq_accum;
  893. at::Tensor dk_accum, dv_accum;
  894. dq_accum = torch::empty({num_heads, total_q_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  895. if (num_heads_k != num_heads) { // MQA / GQA
  896. dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  897. dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  898. }
  899. Flash_bwd_params params;
  900. set_params_dgrad(params,
  901. batch_size,
  902. max_seqlen_q, max_seqlen_k,
  903. seqlen_q_rounded, seqlen_k_rounded,
  904. num_heads, num_heads_k,
  905. head_size, head_size_rounded,
  906. q, k, v, out,
  907. dout_padded, dq, dk, dv,
  908. cu_seqlens_q.data_ptr(),
  909. cu_seqlens_k.data_ptr(),
  910. seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
  911. seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
  912. dq_accum.data_ptr(),
  913. num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
  914. num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
  915. // nullptr,
  916. // nullptr,
  917. softmax_lse.data_ptr(),
  918. softmax_d.data_ptr(),
  919. /*p_dropout=*/0.f,
  920. softmax_scale,
  921. window_size_left,
  922. window_size_right,
  923. softcap,
  924. deterministic);
  925. params.total_q = total_q;
  926. params.total_k = total_k;
  927. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  928. // Will be zero'ed out in the backward preprocess kernel
  929. at::Tensor dq_semaphore = torch::empty({(max_seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  930. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  931. if (num_heads_k != num_heads) {
  932. // TODO: do we need to zero them out?
  933. at::Tensor dk_semaphore = torch::empty({(max_seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  934. at::Tensor dv_semaphore = torch::empty({(max_seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
  935. params.dk_semaphore = dk_semaphore.data_ptr<int>();
  936. params.dv_semaphore = dv_semaphore.data_ptr<int>();
  937. }
  938. if (max_seqlen_q > 0) {
  939. run_mha_bwd(params, stream);
  940. } else {
  941. // If max_seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  942. dk.zero_();
  943. dv.zero_();
  944. softmax_d.zero_();
  945. }
  946. if (head_size_og % 8 != 0) {
  947. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  948. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  949. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  950. }
  951. return { dq, dk, dv, softmax_d, dq_accum, softmax_lse_log2 };
  952. }
  953. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  954. m.doc() = "FlashAttention";
  955. m.def("fwd", &mha_fwd, "Forward pass");
  956. m.def("fwd_varlen", &mha_varlen_fwd, "Varlen forward pass");
  957. m.def("bwd", &mha_bwd, "Backward pass");
  958. m.def("bwd_varlen", &mha_varlen_bwd, "Varlen backward pass");
  959. }