flash_api.cpp 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *seqused_q,
  34. void *seqused_k,
  35. void *p_d,
  36. void *softmax_lse_d,
  37. float p_dropout,
  38. float softmax_scale,
  39. int window_size_left,
  40. int window_size_right,
  41. bool seqlenq_ngroups_swapped=false,
  42. bool unpadded_lse=false) {
  43. // Reset the parameters
  44. params = {};
  45. params.is_bf16 = q.dtype() == torch::kBFloat16;
  46. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  47. // Set the pointers and strides.
  48. params.q_ptr = q.data_ptr();
  49. params.k_ptr = k.data_ptr();
  50. params.v_ptr = v.data_ptr();
  51. // All stride are in elements, not bytes.
  52. params.q_row_stride = q.stride(-3);
  53. params.k_row_stride = k.stride(-3);
  54. params.v_row_stride = v.stride(-3);
  55. params.q_head_stride = q.stride(-2);
  56. params.k_head_stride = k.stride(-2);
  57. params.v_head_stride = v.stride(-2);
  58. params.o_ptr = out.data_ptr();
  59. params.o_row_stride = out.stride(-3);
  60. params.o_head_stride = out.stride(-2);
  61. if (cu_seqlens_q_d == nullptr) {
  62. params.q_batch_stride = q.stride(0);
  63. params.k_batch_stride = k.stride(0);
  64. params.v_batch_stride = v.stride(0);
  65. params.o_batch_stride = out.stride(0);
  66. if (seqlenq_ngroups_swapped) {
  67. params.q_batch_stride *= seqlen_q;
  68. params.o_batch_stride *= seqlen_q;
  69. }
  70. }
  71. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  72. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  73. params.seqused_q = static_cast<int *>(seqused_q);
  74. params.seqused_k = static_cast<int *>(seqused_k);
  75. TORCH_CHECK(
  76. bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k),
  77. "cu_seqlens_q and cu_seqlens_k must be both null or non-null"
  78. );
  79. // P = softmax(QK^T)
  80. params.p_ptr = p_d;
  81. // Softmax sum
  82. params.softmax_lse_ptr = softmax_lse_d;
  83. // Set the dimensions.
  84. params.b = b;
  85. params.h = h;
  86. params.h_k = h_k;
  87. params.h_h_k_ratio = h / h_k;
  88. params.seqlen_q = seqlen_q;
  89. params.seqlen_k = seqlen_k;
  90. params.seqlen_q_rounded = seqlen_q_rounded;
  91. params.seqlen_k_rounded = seqlen_k_rounded;
  92. params.d = d;
  93. params.d_rounded = d_rounded;
  94. // Set the different scale values.
  95. params.scale_softmax = softmax_scale;
  96. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  97. __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
  98. __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half);
  99. params.scale_softmax_log2_half2 = reinterpret_cast<uint32_t&>(scale_softmax_log2_half2);
  100. // Set this to probability of keeping an element to simplify things.
  101. params.p_dropout = 1.f - p_dropout;
  102. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  103. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  104. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  105. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  106. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  107. params.rp_dropout = 1.f / params.p_dropout;
  108. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  109. TORCH_CHECK(p_dropout < 1.f);
  110. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  111. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  112. #endif
  113. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  114. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  115. params.is_causal = window_size_left < 0 && window_size_right == 0;
  116. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
  117. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
  118. params.window_size_left = window_size_left;
  119. params.window_size_right = window_size_right;
  120. #ifdef FLASHATTENTION_DISABLE_LOCAL
  121. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  122. "This flash attention build does not support local attention.");
  123. #endif
  124. params.is_seqlens_k_cumulative = true;
  125. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  126. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  127. #endif
  128. params.unpadded_lse = unpadded_lse;
  129. }
  130. void set_params_dgrad(Flash_bwd_params &params,
  131. // sizes
  132. const size_t b,
  133. const size_t seqlen_q,
  134. const size_t seqlen_k,
  135. const size_t seqlen_q_rounded,
  136. const size_t seqlen_k_rounded,
  137. const size_t h,
  138. const size_t h_k,
  139. const size_t d,
  140. const size_t d_rounded,
  141. // device pointers
  142. const at::Tensor q,
  143. const at::Tensor k,
  144. const at::Tensor v,
  145. const at::Tensor out,
  146. const at::Tensor dout,
  147. at::Tensor dq,
  148. at::Tensor dk,
  149. at::Tensor dv,
  150. void *cu_seqlens_q_d,
  151. void *cu_seqlens_k_d,
  152. void *seqused_q,
  153. void *seqused_k,
  154. void *dq_accum_d,
  155. void *dk_accum_d,
  156. void *dv_accum_d,
  157. void *softmax_lse_d,
  158. void *dsoftmax_sum_d,
  159. float p_dropout,
  160. float softmax_scale,
  161. int window_size_left,
  162. int window_size_right,
  163. bool deterministic) {
  164. set_params_fprop(params,
  165. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  166. q, k, v, out,
  167. cu_seqlens_q_d,
  168. cu_seqlens_k_d,
  169. seqused_q,
  170. seqused_k,
  171. nullptr,
  172. softmax_lse_d,
  173. p_dropout,
  174. softmax_scale,
  175. window_size_left,
  176. window_size_right);
  177. // Set the pointers and strides.
  178. params.do_ptr = dout.data_ptr();
  179. params.do_row_stride = dout.stride(-3);
  180. params.do_head_stride = dout.stride(-2);
  181. params.dq_ptr = dq.data_ptr();
  182. params.dk_ptr = dk.data_ptr();
  183. params.dv_ptr = dv.data_ptr();
  184. params.dq_row_stride = dq.stride(-3);
  185. params.dk_row_stride = dk.stride(-3);
  186. params.dv_row_stride = dv.stride(-3);
  187. params.dq_head_stride = dq.stride(-2);
  188. params.dk_head_stride = dk.stride(-2);
  189. params.dv_head_stride = dv.stride(-2);
  190. if (cu_seqlens_q_d == nullptr) {
  191. params.do_batch_stride = dout.stride(0);
  192. params.dq_batch_stride = dq.stride(0);
  193. params.dk_batch_stride = dk.stride(0);
  194. params.dv_batch_stride = dv.stride(0);
  195. }
  196. params.dq_accum_ptr = dq_accum_d;
  197. params.dk_accum_ptr = dk_accum_d;
  198. params.dv_accum_ptr = dv_accum_d;
  199. // Softmax sum
  200. params.dsoftmax_sum = dsoftmax_sum_d;
  201. params.deterministic = deterministic;
  202. }
  203. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  204. // HEADDIM_SWITCH(params.d, [&] {
  205. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  206. // });
  207. if (!params.is_e4m3) {
  208. if (params.is_bf16) {
  209. if (params.d == 64) {
  210. run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
  211. } else if (params.d == 128) {
  212. run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
  213. } else {
  214. run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
  215. }
  216. } else {
  217. if (params.d == 64) {
  218. run_mha_fwd_<cutlass::half_t, 64>(params, stream);
  219. } else if (params.d == 128) {
  220. run_mha_fwd_<cutlass::half_t, 128>(params, stream);
  221. } else {
  222. run_mha_fwd_<cutlass::half_t, 256>(params, stream);
  223. }
  224. }
  225. } else {
  226. if (params.d == 64) {
  227. run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
  228. } else if (params.d == 128) {
  229. run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
  230. } else if (params.d == 256) {
  231. run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
  232. }
  233. }
  234. }
  235. std::vector<at::Tensor>
  236. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  237. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  238. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  239. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  240. const float softmax_scale,
  241. c10::optional<at::Tensor> &descale_q_, // 1
  242. c10::optional<at::Tensor> &descale_k_, // 1
  243. c10::optional<at::Tensor> &descale_v_, // 1
  244. bool is_causal) {
  245. auto dprops = at::cuda::getCurrentDeviceProperties();
  246. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  247. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  248. auto q_dtype = q.dtype();
  249. // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  250. // "FlashAttention only support fp16 and bf16 data type for now");
  251. // TODO: will add e4m3 later
  252. // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
  253. // "FlashAttention only support fp16 and bf16 data type");
  254. // "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
  255. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  256. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  257. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  258. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  259. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  260. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  261. const auto sizes = q.sizes();
  262. const int batch_size = sizes[0];
  263. int seqlen_q = sizes[1];
  264. int num_heads = sizes[2];
  265. const int head_size_og = sizes[3];
  266. const int seqlen_k = k.size(1);
  267. const int num_heads_k = k.size(2);
  268. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  269. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  270. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  271. TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
  272. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  273. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  274. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  275. at::Tensor q_padded, k_padded, v_padded;
  276. if (head_size_og % 8 != 0) {
  277. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  278. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  279. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  280. } else {
  281. q_padded = q;
  282. k_padded = k;
  283. v_padded = v;
  284. }
  285. at::Tensor out;
  286. if (out_.has_value()) {
  287. out = out_.value();
  288. // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  289. TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn
  290. ? (out.dtype() == at::kHalf)
  291. : (out.dtype() == q_dtype),
  292. "Output must have the same dtype as input dtype if dtype is "
  293. "not fp8, or fp16 for fp8 input.");
  294. CHECK_DEVICE(out);
  295. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  296. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  297. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  298. } else {
  299. if (q_dtype == at::ScalarType::Float8_e4m3fn)
  300. out = torch::empty_like(q_padded, at::kHalf);
  301. else
  302. out = torch::empty_like(q_padded);
  303. }
  304. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  305. const int head_size = round_multiple(head_size_og, 8);
  306. const int head_size_rounded = round_multiple(head_size, 32);
  307. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  308. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  309. // Otherwise the kernel will be launched from cuda:0 device
  310. // Cast to char to avoid compiler warning about narrowing
  311. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  312. auto opts = q.options();
  313. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  314. at::Tensor p;
  315. Flash_fwd_params params;
  316. set_params_fprop(params,
  317. batch_size,
  318. seqlen_q, seqlen_k,
  319. seqlen_q_rounded, seqlen_k_rounded,
  320. num_heads, num_heads_k,
  321. head_size, head_size_rounded,
  322. q_padded, k_padded, v_padded, out,
  323. /*cu_seqlens_q_d=*/nullptr,
  324. /*cu_seqlens_k_d=*/nullptr,
  325. /*seqused_q=*/nullptr,
  326. /*seqused_k=*/nullptr,
  327. nullptr,
  328. softmax_lse.data_ptr(),
  329. /*p_dropout=*/0.f,
  330. softmax_scale,
  331. /*window_size_left=*/-1,
  332. /*window_size_right=*/is_causal ? 0 : -1);
  333. auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  334. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  335. if(q_dtype == at::ScalarType::Float8_e4m3fn) {
  336. at::Tensor descale_q, descale_k, descale_v;
  337. if (descale_q_.has_value() && descale_k_.has_value() && descale_k_.has_value()) {
  338. descale_q = descale_q_.value();
  339. descale_k = descale_k_.value();
  340. descale_v = descale_v_.value();
  341. CHECK_DEVICE(descale_q);
  342. CHECK_DEVICE(descale_k);
  343. CHECK_DEVICE(descale_v);
  344. CHECK_SHAPE(descale_q, 1);
  345. CHECK_SHAPE(descale_k, 1);
  346. CHECK_SHAPE(descale_v, 1);
  347. } else {
  348. descale_q = torch::ones({1}, opts.dtype(at::kFloat));
  349. descale_k = torch::ones({1}, opts.dtype(at::kFloat));
  350. descale_v = torch::ones({1}, opts.dtype(at::kFloat));
  351. }
  352. params.descale_q_ptr = descale_q.data_ptr<float>();
  353. params.descale_k_ptr = descale_k.data_ptr<float>();
  354. params.descale_v_ptr = descale_v.data_ptr<float>();
  355. } else {
  356. params.descale_q_ptr = nullptr;
  357. params.descale_k_ptr = nullptr;
  358. params.descale_v_ptr = nullptr;
  359. }
  360. if (seqlen_k > 0) {
  361. auto stream = at::cuda::getCurrentCUDAStream().stream();
  362. run_mha_fwd(params, stream);
  363. } else {
  364. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  365. out.zero_();
  366. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  367. }
  368. at::Tensor out_padded = out;
  369. if (head_size_og % 8 != 0) {
  370. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  371. if (out_.has_value()) { out_.value().copy_(out); }
  372. }
  373. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
  374. }
  375. std::vector<at::Tensor>
  376. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  377. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  378. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  379. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  380. const at::Tensor &cu_seqlens_q, // b+1
  381. const at::Tensor &cu_seqlens_k, // b+1
  382. c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
  383. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  384. int max_seqlen_q,
  385. const int max_seqlen_k,
  386. const float softmax_scale,
  387. bool is_causal) {
  388. auto dprops = at::cuda::getCurrentDeviceProperties();
  389. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  390. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  391. auto q_dtype = q.dtype();
  392. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  393. "FlashAttention only support fp16 and bf16 data type");
  394. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  395. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  396. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  397. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  398. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  399. CHECK_DEVICE(cu_seqlens_q);
  400. CHECK_DEVICE(cu_seqlens_k);
  401. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  402. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  403. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  404. CHECK_CONTIGUOUS(cu_seqlens_q);
  405. CHECK_CONTIGUOUS(cu_seqlens_k);
  406. const auto sizes = q.sizes();
  407. const int batch_size = cu_seqlens_q.numel() - 1;
  408. int num_heads = sizes[1];
  409. const int head_size_og = sizes[2];
  410. const int num_heads_k = k.size(1);
  411. int window_size_left = -1;
  412. int window_size_right = -1;
  413. if (is_causal) { window_size_right = 0; }
  414. void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  415. const int total_q = q.sizes()[0];
  416. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  417. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  418. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  419. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  420. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  421. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  422. const int total_k = k.size(0);
  423. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  424. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  425. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  426. if (seqused_q.has_value()){
  427. auto seqused_q_ = seqused_q.value();
  428. TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  429. TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
  430. TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
  431. CHECK_SHAPE(seqused_q_, batch_size);
  432. }
  433. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  434. if (seqused_k.has_value()){
  435. auto seqused_k_ = seqused_k.value();
  436. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  437. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  438. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  439. CHECK_SHAPE(seqused_k_, batch_size);
  440. }
  441. at::Tensor q_padded, k_padded, v_padded;
  442. if (head_size_og % 8 != 0) {
  443. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  444. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  445. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  446. } else {
  447. q_padded = q;
  448. k_padded = k;
  449. v_padded = v;
  450. }
  451. at::Tensor out;
  452. if (out_.has_value()) {
  453. out = out_.value();
  454. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  455. CHECK_DEVICE(out);
  456. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  457. CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
  458. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  459. } else {
  460. out = torch::empty_like(q_padded);
  461. }
  462. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  463. const int head_size = round_multiple(head_size_og, 8);
  464. const int head_size_rounded = round_multiple(head_size, 32);
  465. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  466. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  467. // Otherwise the kernel will be launched from cuda:0 device
  468. // Cast to char to avoid compiler warning about narrowing
  469. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  470. auto opts = q.options();
  471. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  472. Flash_fwd_params params;
  473. set_params_fprop(params,
  474. batch_size,
  475. max_seqlen_q, max_seqlen_k,
  476. seqlen_q_rounded, seqlen_k_rounded,
  477. num_heads, num_heads_k,
  478. head_size, head_size_rounded,
  479. q_padded, k_padded, v_padded, out,
  480. cu_seqlens_q_d,
  481. cu_seqlens_k.data_ptr(),
  482. seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
  483. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  484. /*p_d=*/nullptr,
  485. softmax_lse.data_ptr(),
  486. /*p_dropout=*/0.f,
  487. softmax_scale,
  488. window_size_left,
  489. window_size_right,
  490. /*seqlenq_ngroups_swapped=*/false,
  491. /*unpadded_lse=*/true);
  492. params.total_q = total_q;
  493. params.total_k = total_k;
  494. if (max_seqlen_k > 0) {
  495. auto stream = at::cuda::getCurrentCUDAStream().stream();
  496. run_mha_fwd(params, stream);
  497. } else {
  498. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  499. out.zero_();
  500. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  501. }
  502. at::Tensor out_padded = out;
  503. if (head_size_og % 8 != 0) {
  504. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  505. if (out_.has_value()) { out_.value().copy_(out); }
  506. }
  507. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
  508. }
  509. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  510. // FP16_SWITCH(!params.is_bf16, [&] {
  511. // HEADDIM_SWITCH(params.d, [&] {
  512. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  513. // });
  514. // });
  515. if (!params.is_bf16) {
  516. if (params.d <= 64) {
  517. run_mha_bwd_<cutlass::half_t, 64>(params, stream);
  518. } else if (params.d <= 96) {
  519. run_mha_bwd_<cutlass::half_t, 96>(params, stream);
  520. } else {
  521. run_mha_bwd_<cutlass::half_t, 128>(params, stream);
  522. }
  523. } else {
  524. if (params.d <= 64) {
  525. run_mha_bwd_<cutlass::bfloat16_t, 64>(params, stream);
  526. } else if (params.d <= 96) {
  527. run_mha_bwd_<cutlass::bfloat16_t, 96>(params, stream);
  528. } else {
  529. run_mha_bwd_<cutlass::bfloat16_t, 128>(params, stream);
  530. }
  531. }
  532. }
  533. std::vector<at::Tensor>
  534. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  535. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  536. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  537. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  538. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  539. const at::Tensor &softmax_lse, // b x h x seqlen_q
  540. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  541. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  542. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  543. const float softmax_scale,
  544. const bool is_causal,
  545. const bool deterministic) {
  546. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  547. TORCH_CHECK(false, "This flash attention build does not support backward.");
  548. #endif
  549. auto dprops = at::cuda::getCurrentDeviceProperties();
  550. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  551. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  552. auto stream = at::cuda::getCurrentCUDAStream().stream();
  553. auto q_dtype = q.dtype();
  554. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  555. "FlashAttention only support fp16 and bf16 data type");
  556. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  557. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  558. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  559. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  560. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  561. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  562. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  563. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  564. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  565. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  566. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  567. const auto sizes = q.sizes();
  568. const int batch_size = sizes[0];
  569. const int seqlen_q = sizes[1];
  570. const int num_heads = sizes[2];
  571. const int head_size_og = dout.size(3);
  572. const int head_size = sizes[3];
  573. const int seqlen_k = k.size(1);
  574. const int num_heads_k = k.size(2);
  575. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  576. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  577. TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
  578. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  579. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  580. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  581. // This should match the kernel configs
  582. const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
  583. const int seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
  584. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  585. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  586. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  587. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  588. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  589. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  590. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  591. at::Tensor dq, dk, dv;
  592. if (dq_.has_value()) {
  593. dq = dq_.value();
  594. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  595. CHECK_DEVICE(dq);
  596. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  597. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  598. } else {
  599. dq = torch::empty_like(q);
  600. }
  601. if (dk_.has_value()) {
  602. dk = dk_.value();
  603. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  604. CHECK_DEVICE(dk);
  605. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  606. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  607. } else {
  608. dk = torch::empty_like(k);
  609. }
  610. if (dv_.has_value()) {
  611. dv = dv_.value();
  612. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  613. CHECK_DEVICE(dv);
  614. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  615. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  616. } else {
  617. dv = torch::empty_like(v);
  618. }
  619. at::Tensor dout_padded;
  620. if (head_size_og % 8 != 0) {
  621. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  622. } else {
  623. dout_padded = dout;
  624. }
  625. // Otherwise the kernel will be launched from cuda:0 device
  626. // Cast to char to avoid compiler warning about narrowing
  627. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  628. auto opts = q.options();
  629. // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  630. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  631. auto softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  632. at::Tensor dq_accum;
  633. at::Tensor dk_accum, dv_accum;
  634. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  635. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  636. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  637. at::Tensor dk_expanded, dv_expanded;
  638. if (num_heads_k != num_heads) { // MQA / GQA
  639. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  640. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  641. } else {
  642. dk_expanded = dk;
  643. dv_expanded = dv;
  644. }
  645. Flash_bwd_params params;
  646. set_params_dgrad(params,
  647. batch_size,
  648. seqlen_q, seqlen_k,
  649. seqlen_q_rounded, seqlen_k_rounded,
  650. num_heads, num_heads_k,
  651. head_size, head_size_rounded,
  652. q, k, v, out,
  653. dout_padded, dq, dk_expanded, dv_expanded,
  654. /*cu_seqlens_q_d=*/nullptr,
  655. /*cu_seqlens_k_d=*/nullptr,
  656. /*seqused_q=*/nullptr,
  657. /*seqused_k=*/nullptr,
  658. dq_accum.data_ptr(),
  659. // loop ? dk_accum.data_ptr() : nullptr,
  660. // loop ? dv_accum.data_ptr() : nullptr,
  661. nullptr,
  662. nullptr,
  663. softmax_lse.data_ptr(),
  664. softmax_d.data_ptr(),
  665. /*p_dropout=*/0.f,
  666. softmax_scale,
  667. /*window_size_left=*/-1,
  668. /*window_size_right=*/is_causal ? 0 : -1,
  669. deterministic);
  670. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  671. // Will be zero'ed out in the backward preprocess kernel
  672. at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  673. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  674. // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads);
  675. if (seqlen_q > 0) {
  676. run_mha_bwd(params, stream);
  677. } else {
  678. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  679. dk_expanded.zero_();
  680. dv_expanded.zero_();
  681. softmax_d.zero_();
  682. }
  683. // For MQA/GQA we need to sum dK and dV across the groups
  684. if (num_heads_k != num_heads) {
  685. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  686. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  687. }
  688. if (head_size_og % 8 != 0) {
  689. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  690. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  691. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  692. }
  693. return { dq, dk, dv, softmax_d, dq_accum};
  694. }
  695. std::vector<at::Tensor>
  696. mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  697. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  698. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  699. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  700. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  701. const at::Tensor &softmax_lse, // b x h x seqlen_q
  702. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  703. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  704. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  705. const at::Tensor &cu_seqlens_q, // b+1
  706. const at::Tensor &cu_seqlens_k, // b+1
  707. c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
  708. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  709. const int max_seqlen_q,
  710. const int max_seqlen_k, // max sequence length to choose the kernel
  711. const float softmax_scale,
  712. const bool is_causal,
  713. const bool deterministic) {
  714. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  715. TORCH_CHECK(false, "This flash attention build does not support backward.");
  716. #endif
  717. auto dprops = at::cuda::getCurrentDeviceProperties();
  718. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  719. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  720. auto stream = at::cuda::getCurrentCUDAStream().stream();
  721. auto q_dtype = q.dtype();
  722. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  723. "FlashAttention only support fp16 and bf16 data type");
  724. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  725. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  726. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  727. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  728. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  729. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  730. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  731. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  732. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  733. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  734. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  735. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  736. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  737. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  738. CHECK_CONTIGUOUS(cu_seqlens_q);
  739. CHECK_CONTIGUOUS(cu_seqlens_k);
  740. const auto sizes = q.sizes();
  741. const int total_q = sizes[0];
  742. const int batch_size = cu_seqlens_q.numel() - 1;
  743. const int num_heads = sizes[1];
  744. const int head_size_og = dout.size(2);
  745. const int head_size = sizes[2];
  746. const int total_k = k.size(0);
  747. const int num_heads_k = k.size(1);
  748. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  749. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  750. TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
  751. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  752. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  753. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  754. // This should match the kernel configs
  755. const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
  756. const int seqlen_q_rounded = round_multiple(max_seqlen_q, kBlockM);
  757. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  758. int const total_q_padded_rounded = round_multiple(total_q + batch_size * 128, 128);
  759. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  760. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  761. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  762. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  763. CHECK_SHAPE(out, total_q, num_heads, head_size);
  764. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  765. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  766. if (seqused_q.has_value()){
  767. auto seqused_q_ = seqused_q.value();
  768. TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  769. TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
  770. TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
  771. CHECK_SHAPE(seqused_q_, batch_size);
  772. }
  773. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  774. if (seqused_k.has_value()){
  775. auto seqused_k_ = seqused_k.value();
  776. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  777. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  778. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  779. CHECK_SHAPE(seqused_k_, batch_size);
  780. }
  781. at::Tensor dq, dk, dv;
  782. if (dq_.has_value()) {
  783. dq = dq_.value();
  784. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  785. CHECK_DEVICE(dq);
  786. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  787. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  788. } else {
  789. dq = torch::empty_like(q);
  790. }
  791. if (dk_.has_value()) {
  792. dk = dk_.value();
  793. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  794. CHECK_DEVICE(dk);
  795. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  796. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  797. } else {
  798. dk = torch::empty_like(k);
  799. }
  800. if (dv_.has_value()) {
  801. dv = dv_.value();
  802. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  803. CHECK_DEVICE(dv);
  804. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  805. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  806. } else {
  807. dv = torch::empty_like(v);
  808. }
  809. at::Tensor dout_padded;
  810. if (head_size_og % 8 != 0) {
  811. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  812. } else {
  813. dout_padded = dout;
  814. }
  815. // Otherwise the kernel will be launched from cuda:0 device
  816. // Cast to char to avoid compiler warning about narrowing
  817. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  818. auto opts = q.options();
  819. // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  820. auto softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  821. auto softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  822. at::Tensor dq_accum;
  823. at::Tensor dk_accum, dv_accum;
  824. dq_accum = torch::empty({num_heads, total_q_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  825. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  826. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  827. at::Tensor dk_expanded, dv_expanded;
  828. if (num_heads_k != num_heads) { // MQA / GQA
  829. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  830. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  831. } else {
  832. dk_expanded = dk;
  833. dv_expanded = dv;
  834. }
  835. Flash_bwd_params params;
  836. set_params_dgrad(params,
  837. batch_size,
  838. max_seqlen_q, max_seqlen_k,
  839. seqlen_q_rounded, seqlen_k_rounded,
  840. num_heads, num_heads_k,
  841. head_size, head_size_rounded,
  842. q, k, v, out,
  843. dout_padded, dq, dk_expanded, dv_expanded,
  844. cu_seqlens_q.data_ptr(),
  845. cu_seqlens_k.data_ptr(),
  846. seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
  847. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  848. dq_accum.data_ptr(),
  849. // loop ? dk_accum.data_ptr() : nullptr,
  850. // loop ? dv_accum.data_ptr() : nullptr,
  851. nullptr,
  852. nullptr,
  853. softmax_lse.data_ptr(),
  854. softmax_d.data_ptr(),
  855. /*p_dropout=*/0.f,
  856. softmax_scale,
  857. /*window_size_left=*/-1,
  858. /*window_size_right=*/is_causal ? 0 : -1,
  859. deterministic);
  860. params.total_q = total_q;
  861. params.total_k = total_k;
  862. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  863. // Will be zero'ed out in the backward preprocess kernel
  864. at::Tensor dq_semaphore = torch::empty({(max_seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  865. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  866. if (max_seqlen_q > 0) {
  867. run_mha_bwd(params, stream);
  868. } else {
  869. // If max_seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  870. dk_expanded.zero_();
  871. dv_expanded.zero_();
  872. softmax_d.zero_();
  873. }
  874. // For MQA/GQA we need to sum dK and dV across the groups
  875. if (num_heads_k != num_heads) {
  876. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  877. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  878. }
  879. if (head_size_og % 8 != 0) {
  880. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  881. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  882. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  883. }
  884. return { dq, dk, dv, softmax_d, dq_accum, softmax_lse_log2 };
  885. }
  886. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  887. m.doc() = "FlashAttention";
  888. m.def("fwd", &mha_fwd, "Forward pass");
  889. m.def("bwd", &mha_bwd, "Backward pass");
  890. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  891. m.def("varlen_bwd", &mha_varlen_bwd, "Varlen backward pass");
  892. }