flash_api.cpp 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *seqused_q,
  34. void *seqused_k,
  35. void *p_d,
  36. void *softmax_lse_d,
  37. float p_dropout,
  38. float softmax_scale,
  39. int window_size_left,
  40. int window_size_right,
  41. bool seqlenq_ngroups_swapped=false,
  42. bool unpadded_lse=false,
  43. bool optimize_for_doc_masking=false) {
  44. // Reset the parameters
  45. params = {};
  46. params.is_bf16 = q.dtype() == torch::kBFloat16;
  47. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  48. // Set the pointers and strides.
  49. params.q_ptr = q.data_ptr();
  50. params.k_ptr = k.data_ptr();
  51. params.v_ptr = v.data_ptr();
  52. // All stride are in elements, not bytes.
  53. params.q_row_stride = q.stride(-3);
  54. params.k_row_stride = k.stride(-3);
  55. params.v_row_stride = v.stride(-3);
  56. params.q_head_stride = q.stride(-2);
  57. params.k_head_stride = k.stride(-2);
  58. params.v_head_stride = v.stride(-2);
  59. params.o_ptr = out.data_ptr();
  60. params.o_row_stride = out.stride(-3);
  61. params.o_head_stride = out.stride(-2);
  62. if (cu_seqlens_q_d == nullptr) {
  63. params.q_batch_stride = q.stride(0);
  64. params.k_batch_stride = k.stride(0);
  65. params.v_batch_stride = v.stride(0);
  66. params.o_batch_stride = out.stride(0);
  67. if (seqlenq_ngroups_swapped) {
  68. params.q_batch_stride *= seqlen_q;
  69. params.o_batch_stride *= seqlen_q;
  70. }
  71. }
  72. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  73. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  74. params.seqused_q = static_cast<int *>(seqused_q);
  75. params.seqused_k = static_cast<int *>(seqused_k);
  76. TORCH_CHECK(
  77. bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k),
  78. "cu_seqlens_q and cu_seqlens_k must be both null or non-null"
  79. );
  80. // P = softmax(QK^T)
  81. params.p_ptr = p_d;
  82. // Softmax sum
  83. params.softmax_lse_ptr = softmax_lse_d;
  84. // Set the dimensions.
  85. params.b = b;
  86. params.h = h;
  87. params.h_k = h_k;
  88. params.h_h_k_ratio = h / h_k;
  89. params.seqlen_q = seqlen_q;
  90. params.seqlen_k = seqlen_k;
  91. params.seqlen_q_rounded = seqlen_q_rounded;
  92. params.seqlen_k_rounded = seqlen_k_rounded;
  93. params.d = d;
  94. params.d_rounded = d_rounded;
  95. // Set the different scale values.
  96. params.scale_softmax = softmax_scale;
  97. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  98. __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
  99. __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half);
  100. params.scale_softmax_log2_half2 = reinterpret_cast<uint32_t&>(scale_softmax_log2_half2);
  101. // Set this to probability of keeping an element to simplify things.
  102. params.p_dropout = 1.f - p_dropout;
  103. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  104. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  105. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  106. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  107. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  108. params.rp_dropout = 1.f / params.p_dropout;
  109. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  110. TORCH_CHECK(p_dropout < 1.f);
  111. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  112. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  113. #endif
  114. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  115. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  116. window_size_left = std::min(int(seqlen_k), window_size_left);
  117. window_size_right = std::min(int(seqlen_k), window_size_right);
  118. if (window_size_left < 0) { window_size_left = seqlen_k; }
  119. if (window_size_right < 0) { window_size_right = seqlen_k; }
  120. params.window_size_left = window_size_left;
  121. params.window_size_right = window_size_right;
  122. params.is_causal = window_size_left == seqlen_k && window_size_right == 0;
  123. if ((window_size_left < seqlen_k || window_size_right < seqlen_k) && !params.is_causal) {
  124. params.is_local = true;
  125. }
  126. #ifdef FLASHATTENTION_DISABLE_LOCAL
  127. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  128. "This flash attention build does not support local attention.");
  129. #endif
  130. params.is_seqlens_k_cumulative = true;
  131. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  132. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  133. #endif
  134. params.unpadded_lse = unpadded_lse;
  135. params.optimize_for_doc_masking = optimize_for_doc_masking;
  136. }
  137. void set_params_dgrad(Flash_bwd_params &params,
  138. // sizes
  139. const size_t b,
  140. const size_t seqlen_q,
  141. const size_t seqlen_k,
  142. const size_t seqlen_q_rounded,
  143. const size_t seqlen_k_rounded,
  144. const size_t h,
  145. const size_t h_k,
  146. const size_t d,
  147. const size_t d_rounded,
  148. // device pointers
  149. const at::Tensor q,
  150. const at::Tensor k,
  151. const at::Tensor v,
  152. const at::Tensor out,
  153. const at::Tensor dout,
  154. at::Tensor dq,
  155. at::Tensor dk,
  156. at::Tensor dv,
  157. void *cu_seqlens_q_d,
  158. void *cu_seqlens_k_d,
  159. void *seqused_q,
  160. void *seqused_k,
  161. void *dq_accum_d,
  162. void *dk_accum_d,
  163. void *dv_accum_d,
  164. void *softmax_lse_d,
  165. void *dsoftmax_sum_d,
  166. float p_dropout,
  167. float softmax_scale,
  168. int window_size_left,
  169. int window_size_right,
  170. bool deterministic,
  171. bool seqlenq_ngroups_swapped=false,
  172. bool unpadded_lse=false,
  173. bool optimize_for_doc_masking=false) {
  174. set_params_fprop(params,
  175. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  176. q, k, v, out,
  177. cu_seqlens_q_d,
  178. cu_seqlens_k_d,
  179. seqused_q,
  180. seqused_k,
  181. nullptr,
  182. softmax_lse_d,
  183. p_dropout,
  184. softmax_scale,
  185. window_size_left,
  186. window_size_right,
  187. seqlenq_ngroups_swapped,
  188. unpadded_lse,
  189. optimize_for_doc_masking);
  190. // Set the pointers and strides.
  191. params.do_ptr = dout.data_ptr();
  192. params.do_row_stride = dout.stride(-3);
  193. params.do_head_stride = dout.stride(-2);
  194. params.dq_ptr = dq.data_ptr();
  195. params.dk_ptr = dk.data_ptr();
  196. params.dv_ptr = dv.data_ptr();
  197. params.dq_row_stride = dq.stride(-3);
  198. params.dk_row_stride = dk.stride(-3);
  199. params.dv_row_stride = dv.stride(-3);
  200. params.dq_head_stride = dq.stride(-2);
  201. params.dk_head_stride = dk.stride(-2);
  202. params.dv_head_stride = dv.stride(-2);
  203. if (cu_seqlens_q_d == nullptr) {
  204. params.do_batch_stride = dout.stride(0);
  205. params.dq_batch_stride = dq.stride(0);
  206. params.dk_batch_stride = dk.stride(0);
  207. params.dv_batch_stride = dv.stride(0);
  208. }
  209. params.dq_accum_ptr = dq_accum_d;
  210. params.dk_accum_ptr = dk_accum_d;
  211. params.dv_accum_ptr = dv_accum_d;
  212. // Softmax sum
  213. params.dsoftmax_sum = dsoftmax_sum_d;
  214. params.deterministic = deterministic;
  215. }
  216. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  217. // HEADDIM_SWITCH(params.d, [&] {
  218. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  219. // });
  220. if (!params.is_e4m3) {
  221. if (params.is_bf16) {
  222. if (params.d == 64) {
  223. run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
  224. } else if (params.d == 128) {
  225. run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
  226. } else {
  227. run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
  228. }
  229. } else {
  230. if (params.d == 64) {
  231. run_mha_fwd_<cutlass::half_t, 64>(params, stream);
  232. } else if (params.d == 128) {
  233. run_mha_fwd_<cutlass::half_t, 128>(params, stream);
  234. } else {
  235. run_mha_fwd_<cutlass::half_t, 256>(params, stream);
  236. }
  237. }
  238. } else {
  239. if (params.d == 64) {
  240. run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
  241. } else if (params.d == 128) {
  242. run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
  243. } else if (params.d == 256) {
  244. run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
  245. }
  246. }
  247. }
  248. std::vector<at::Tensor>
  249. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  250. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  251. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  252. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  253. const float softmax_scale,
  254. c10::optional<at::Tensor> &descale_q_, // 1
  255. c10::optional<at::Tensor> &descale_k_, // 1
  256. c10::optional<at::Tensor> &descale_v_, // 1
  257. bool is_causal,
  258. int window_size_left,
  259. int window_size_right) {
  260. auto dprops = at::cuda::getCurrentDeviceProperties();
  261. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  262. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  263. auto q_dtype = q.dtype();
  264. // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  265. // "FlashAttention only support fp16 and bf16 data type for now");
  266. // TODO: will add e4m3 later
  267. // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
  268. // "FlashAttention only support fp16 and bf16 data type");
  269. // "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
  270. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  271. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  272. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  273. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  274. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  275. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  276. const auto sizes = q.sizes();
  277. const int batch_size = sizes[0];
  278. int seqlen_q = sizes[1];
  279. int num_heads = sizes[2];
  280. const int head_size_og = sizes[3];
  281. const int seqlen_k = k.size(1);
  282. const int num_heads_k = k.size(2);
  283. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  284. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  285. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  286. TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
  287. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  288. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  289. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  290. at::Tensor q_padded, k_padded, v_padded;
  291. if (head_size_og % 8 != 0) {
  292. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  293. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  294. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  295. } else {
  296. q_padded = q;
  297. k_padded = k;
  298. v_padded = v;
  299. }
  300. at::Tensor out;
  301. if (out_.has_value()) {
  302. out = out_.value();
  303. // TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  304. TORCH_CHECK(q_dtype == at::ScalarType::Float8_e4m3fn
  305. ? (out.dtype() == at::kHalf)
  306. : (out.dtype() == q_dtype),
  307. "Output must have the same dtype as input dtype if dtype is "
  308. "not fp8, or fp16 for fp8 input.");
  309. CHECK_DEVICE(out);
  310. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  311. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  312. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  313. } else {
  314. if (q_dtype == at::ScalarType::Float8_e4m3fn)
  315. out = torch::empty_like(q_padded, at::kHalf);
  316. else
  317. out = torch::empty_like(q_padded);
  318. }
  319. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  320. const int head_size = round_multiple(head_size_og, 8);
  321. const int head_size_rounded = round_multiple(head_size, 32);
  322. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  323. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  324. if (is_causal) { window_size_right = 0; }
  325. // Otherwise the kernel will be launched from cuda:0 device
  326. // Cast to char to avoid compiler warning about narrowing
  327. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  328. auto opts = q.options();
  329. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  330. at::Tensor p;
  331. Flash_fwd_params params;
  332. set_params_fprop(params,
  333. batch_size,
  334. seqlen_q, seqlen_k,
  335. seqlen_q_rounded, seqlen_k_rounded,
  336. num_heads, num_heads_k,
  337. head_size, head_size_rounded,
  338. q_padded, k_padded, v_padded, out,
  339. /*cu_seqlens_q_d=*/nullptr,
  340. /*cu_seqlens_k_d=*/nullptr,
  341. /*seqused_q=*/nullptr,
  342. /*seqused_k=*/nullptr,
  343. nullptr,
  344. softmax_lse.data_ptr(),
  345. /*p_dropout=*/0.f,
  346. softmax_scale,
  347. /*window_size_left=*/window_size_left,
  348. /*window_size_right=*/window_size_right);
  349. auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  350. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  351. if(q_dtype == at::ScalarType::Float8_e4m3fn) {
  352. at::Tensor descale_q, descale_k, descale_v;
  353. if (descale_q_.has_value() && descale_k_.has_value() && descale_k_.has_value()) {
  354. descale_q = descale_q_.value();
  355. descale_k = descale_k_.value();
  356. descale_v = descale_v_.value();
  357. CHECK_DEVICE(descale_q);
  358. CHECK_DEVICE(descale_k);
  359. CHECK_DEVICE(descale_v);
  360. CHECK_SHAPE(descale_q, 1);
  361. CHECK_SHAPE(descale_k, 1);
  362. CHECK_SHAPE(descale_v, 1);
  363. } else {
  364. descale_q = torch::ones({1}, opts.dtype(at::kFloat));
  365. descale_k = torch::ones({1}, opts.dtype(at::kFloat));
  366. descale_v = torch::ones({1}, opts.dtype(at::kFloat));
  367. }
  368. params.descale_q_ptr = descale_q.data_ptr<float>();
  369. params.descale_k_ptr = descale_k.data_ptr<float>();
  370. params.descale_v_ptr = descale_v.data_ptr<float>();
  371. } else {
  372. params.descale_q_ptr = nullptr;
  373. params.descale_k_ptr = nullptr;
  374. params.descale_v_ptr = nullptr;
  375. }
  376. if (seqlen_k > 0) {
  377. auto stream = at::cuda::getCurrentCUDAStream().stream();
  378. run_mha_fwd(params, stream);
  379. } else {
  380. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  381. out.zero_();
  382. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  383. }
  384. at::Tensor out_padded = out;
  385. if (head_size_og % 8 != 0) {
  386. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  387. if (out_.has_value()) { out_.value().copy_(out); }
  388. }
  389. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
  390. }
  391. std::vector<at::Tensor>
  392. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  393. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  394. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  395. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  396. const at::Tensor &cu_seqlens_q, // b+1
  397. const at::Tensor &cu_seqlens_k, // b+1
  398. c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
  399. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  400. int max_seqlen_q,
  401. const int max_seqlen_k,
  402. const float softmax_scale,
  403. bool is_causal,
  404. int window_size_left,
  405. int window_size_right,
  406. bool optimize_for_doc_masking) {
  407. auto dprops = at::cuda::getCurrentDeviceProperties();
  408. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  409. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  410. auto q_dtype = q.dtype();
  411. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  412. "FlashAttention only support fp16 and bf16 data type");
  413. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  414. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  415. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  416. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  417. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  418. CHECK_DEVICE(cu_seqlens_q);
  419. CHECK_DEVICE(cu_seqlens_k);
  420. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  421. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  422. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  423. CHECK_CONTIGUOUS(cu_seqlens_q);
  424. CHECK_CONTIGUOUS(cu_seqlens_k);
  425. const auto sizes = q.sizes();
  426. const int batch_size = cu_seqlens_q.numel() - 1;
  427. int num_heads = sizes[1];
  428. const int head_size_og = sizes[2];
  429. const int num_heads_k = k.size(1);
  430. void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  431. const int total_q = q.sizes()[0];
  432. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  433. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  434. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  435. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  436. const int total_k = k.size(0);
  437. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  438. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  439. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  440. if (seqused_q.has_value()){
  441. auto seqused_q_ = seqused_q.value();
  442. TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  443. TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
  444. TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
  445. CHECK_SHAPE(seqused_q_, batch_size);
  446. }
  447. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  448. if (seqused_k.has_value()){
  449. auto seqused_k_ = seqused_k.value();
  450. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  451. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  452. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  453. CHECK_SHAPE(seqused_k_, batch_size);
  454. }
  455. at::Tensor q_padded, k_padded, v_padded;
  456. if (head_size_og % 8 != 0) {
  457. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  458. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  459. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  460. } else {
  461. q_padded = q;
  462. k_padded = k;
  463. v_padded = v;
  464. }
  465. at::Tensor out;
  466. if (out_.has_value()) {
  467. out = out_.value();
  468. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  469. CHECK_DEVICE(out);
  470. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  471. CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
  472. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  473. } else {
  474. out = torch::empty_like(q_padded);
  475. }
  476. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  477. const int head_size = round_multiple(head_size_og, 8);
  478. const int head_size_rounded = round_multiple(head_size, 32);
  479. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  480. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  481. if (is_causal) { window_size_right = 0; }
  482. // Otherwise the kernel will be launched from cuda:0 device
  483. // Cast to char to avoid compiler warning about narrowing
  484. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  485. auto opts = q.options();
  486. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  487. Flash_fwd_params params;
  488. set_params_fprop(params,
  489. batch_size,
  490. max_seqlen_q, max_seqlen_k,
  491. seqlen_q_rounded, seqlen_k_rounded,
  492. num_heads, num_heads_k,
  493. head_size, head_size_rounded,
  494. q_padded, k_padded, v_padded, out,
  495. cu_seqlens_q_d,
  496. cu_seqlens_k.data_ptr(),
  497. seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
  498. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  499. /*p_d=*/nullptr,
  500. softmax_lse.data_ptr(),
  501. /*p_dropout=*/0.f,
  502. softmax_scale,
  503. window_size_left,
  504. window_size_right,
  505. /*seqlenq_ngroups_swapped=*/false,
  506. /*unpadded_lse=*/true,
  507. /*optimize_for_doc_masking=*/optimize_for_doc_masking);
  508. params.total_q = total_q;
  509. params.total_k = total_k;
  510. if (max_seqlen_k > 0) {
  511. auto stream = at::cuda::getCurrentCUDAStream().stream();
  512. run_mha_fwd(params, stream);
  513. } else {
  514. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  515. out.zero_();
  516. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  517. }
  518. at::Tensor out_padded = out;
  519. if (head_size_og % 8 != 0) {
  520. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  521. if (out_.has_value()) { out_.value().copy_(out); }
  522. }
  523. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
  524. }
  525. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  526. // FP16_SWITCH(!params.is_bf16, [&] {
  527. // HEADDIM_SWITCH(params.d, [&] {
  528. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  529. // });
  530. // });
  531. if (!params.is_bf16) {
  532. if (params.d <= 64) {
  533. run_mha_bwd_<cutlass::half_t, 64>(params, stream);
  534. } else if (params.d <= 96) {
  535. run_mha_bwd_<cutlass::half_t, 96>(params, stream);
  536. } else {
  537. run_mha_bwd_<cutlass::half_t, 128>(params, stream);
  538. }
  539. } else {
  540. if (params.d <= 64) {
  541. run_mha_bwd_<cutlass::bfloat16_t, 64>(params, stream);
  542. } else if (params.d <= 96) {
  543. run_mha_bwd_<cutlass::bfloat16_t, 96>(params, stream);
  544. } else {
  545. run_mha_bwd_<cutlass::bfloat16_t, 128>(params, stream);
  546. }
  547. }
  548. }
  549. std::vector<at::Tensor>
  550. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  551. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  552. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  553. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  554. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  555. const at::Tensor &softmax_lse, // b x h x seqlen_q
  556. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  557. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  558. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  559. const float softmax_scale,
  560. const bool is_causal,
  561. int window_size_left,
  562. int window_size_right,
  563. const bool deterministic) {
  564. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  565. TORCH_CHECK(false, "This flash attention build does not support backward.");
  566. #endif
  567. auto dprops = at::cuda::getCurrentDeviceProperties();
  568. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  569. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  570. auto stream = at::cuda::getCurrentCUDAStream().stream();
  571. auto q_dtype = q.dtype();
  572. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  573. "FlashAttention only support fp16 and bf16 data type");
  574. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  575. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  576. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  577. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  578. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  579. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  580. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  581. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  582. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  583. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  584. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  585. const auto sizes = q.sizes();
  586. const int batch_size = sizes[0];
  587. const int seqlen_q = sizes[1];
  588. const int num_heads = sizes[2];
  589. const int head_size_og = dout.size(3);
  590. const int head_size = sizes[3];
  591. const int seqlen_k = k.size(1);
  592. const int num_heads_k = k.size(2);
  593. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  594. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  595. TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
  596. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  597. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  598. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  599. // This should match the kernel configs
  600. const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
  601. const int seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
  602. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  603. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  604. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  605. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  606. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  607. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  608. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  609. at::Tensor dq, dk, dv;
  610. if (dq_.has_value()) {
  611. dq = dq_.value();
  612. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  613. CHECK_DEVICE(dq);
  614. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  615. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  616. } else {
  617. dq = torch::empty_like(q);
  618. }
  619. if (dk_.has_value()) {
  620. dk = dk_.value();
  621. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  622. CHECK_DEVICE(dk);
  623. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  624. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  625. } else {
  626. dk = torch::empty_like(k);
  627. }
  628. if (dv_.has_value()) {
  629. dv = dv_.value();
  630. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  631. CHECK_DEVICE(dv);
  632. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  633. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  634. } else {
  635. dv = torch::empty_like(v);
  636. }
  637. at::Tensor dout_padded;
  638. if (head_size_og % 8 != 0) {
  639. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  640. } else {
  641. dout_padded = dout;
  642. }
  643. // Otherwise the kernel will be launched from cuda:0 device
  644. // Cast to char to avoid compiler warning about narrowing
  645. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  646. auto opts = q.options();
  647. // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  648. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  649. auto softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  650. at::Tensor dq_accum;
  651. at::Tensor dk_accum, dv_accum;
  652. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  653. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  654. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  655. at::Tensor dk_expanded, dv_expanded;
  656. if (num_heads_k != num_heads) { // MQA / GQA
  657. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  658. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  659. } else {
  660. dk_expanded = dk;
  661. dv_expanded = dv;
  662. }
  663. if (is_causal) { window_size_right = 0; }
  664. Flash_bwd_params params;
  665. set_params_dgrad(params,
  666. batch_size,
  667. seqlen_q, seqlen_k,
  668. seqlen_q_rounded, seqlen_k_rounded,
  669. num_heads, num_heads_k,
  670. head_size, head_size_rounded,
  671. q, k, v, out,
  672. dout_padded, dq, dk_expanded, dv_expanded,
  673. /*cu_seqlens_q_d=*/nullptr,
  674. /*cu_seqlens_k_d=*/nullptr,
  675. /*seqused_q=*/nullptr,
  676. /*seqused_k=*/nullptr,
  677. dq_accum.data_ptr(),
  678. // loop ? dk_accum.data_ptr() : nullptr,
  679. // loop ? dv_accum.data_ptr() : nullptr,
  680. nullptr,
  681. nullptr,
  682. softmax_lse.data_ptr(),
  683. softmax_d.data_ptr(),
  684. /*p_dropout=*/0.f,
  685. softmax_scale,
  686. /*window_size_left=*/window_size_left,
  687. /*window_size_right=*/window_size_right,
  688. deterministic);
  689. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  690. // Will be zero'ed out in the backward preprocess kernel
  691. at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  692. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  693. // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads);
  694. if (seqlen_q > 0) {
  695. run_mha_bwd(params, stream);
  696. } else {
  697. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  698. dk_expanded.zero_();
  699. dv_expanded.zero_();
  700. softmax_d.zero_();
  701. }
  702. // For MQA/GQA we need to sum dK and dV across the groups
  703. if (num_heads_k != num_heads) {
  704. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  705. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  706. }
  707. if (head_size_og % 8 != 0) {
  708. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  709. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  710. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  711. }
  712. return { dq, dk, dv, softmax_d, dq_accum};
  713. }
  714. std::vector<at::Tensor>
  715. mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  716. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  717. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  718. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  719. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  720. const at::Tensor &softmax_lse, // b x h x seqlen_q
  721. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  722. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  723. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  724. const at::Tensor &cu_seqlens_q, // b+1
  725. const at::Tensor &cu_seqlens_k, // b+1
  726. c10::optional<at::Tensor> &seqused_q, // b. If given, only this many elements of each batch element's queries and outputs are used.
  727. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  728. const int max_seqlen_q,
  729. const int max_seqlen_k, // max sequence length to choose the kernel
  730. const float softmax_scale,
  731. const bool is_causal,
  732. int window_size_left,
  733. int window_size_right,
  734. const bool deterministic,
  735. const bool optimize_for_doc_masking) {
  736. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  737. TORCH_CHECK(false, "This flash attention build does not support backward.");
  738. #endif
  739. auto dprops = at::cuda::getCurrentDeviceProperties();
  740. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  741. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  742. auto stream = at::cuda::getCurrentCUDAStream().stream();
  743. auto q_dtype = q.dtype();
  744. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  745. "FlashAttention only support fp16 and bf16 data type");
  746. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  747. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  748. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  749. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  750. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  751. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  752. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  753. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  754. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  755. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  756. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  757. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  758. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  759. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  760. CHECK_CONTIGUOUS(cu_seqlens_q);
  761. CHECK_CONTIGUOUS(cu_seqlens_k);
  762. const auto sizes = q.sizes();
  763. const int total_q = sizes[0];
  764. const int batch_size = cu_seqlens_q.numel() - 1;
  765. const int num_heads = sizes[1];
  766. const int head_size_og = dout.size(2);
  767. const int head_size = sizes[2];
  768. const int total_k = k.size(0);
  769. const int num_heads_k = k.size(1);
  770. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  771. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  772. TORCH_CHECK(head_size <= 128, "FlashAttention backward only supports head dimension at most 128");
  773. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  774. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  775. const int head_size_rounded = head_size <= 64 ? 64 : round_multiple(head_size, 32);
  776. // This should match the kernel configs
  777. const int kBlockM = head_size <= 64 ? 128 : (head_size < 256 ? 64 : 32);
  778. const int seqlen_q_rounded = round_multiple(max_seqlen_q, kBlockM);
  779. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  780. int const total_q_padded_rounded = round_multiple(total_q + batch_size * 128, 128);
  781. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  782. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  783. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  784. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  785. CHECK_SHAPE(out, total_q, num_heads, head_size);
  786. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  787. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  788. if (seqused_q.has_value()){
  789. auto seqused_q_ = seqused_q.value();
  790. TORCH_CHECK(seqused_q_.dtype() == torch::kInt32, "seqused_q must have dtype int32");
  791. TORCH_CHECK(seqused_q_.is_cuda(), "seqused_q must be on CUDA device");
  792. TORCH_CHECK(seqused_q_.is_contiguous(), "seqused_q must be contiguous");
  793. CHECK_SHAPE(seqused_q_, batch_size);
  794. }
  795. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  796. if (seqused_k.has_value()){
  797. auto seqused_k_ = seqused_k.value();
  798. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  799. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  800. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  801. CHECK_SHAPE(seqused_k_, batch_size);
  802. }
  803. at::Tensor dq, dk, dv;
  804. if (dq_.has_value()) {
  805. dq = dq_.value();
  806. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  807. CHECK_DEVICE(dq);
  808. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  809. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  810. } else {
  811. dq = torch::empty_like(q);
  812. }
  813. if (dk_.has_value()) {
  814. dk = dk_.value();
  815. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  816. CHECK_DEVICE(dk);
  817. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  818. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  819. } else {
  820. dk = torch::empty_like(k);
  821. }
  822. if (dv_.has_value()) {
  823. dv = dv_.value();
  824. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  825. CHECK_DEVICE(dv);
  826. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  827. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  828. } else {
  829. dv = torch::empty_like(v);
  830. }
  831. at::Tensor dout_padded;
  832. if (head_size_og % 8 != 0) {
  833. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  834. } else {
  835. dout_padded = dout;
  836. }
  837. if (is_causal) { window_size_right = 0; }
  838. // Otherwise the kernel will be launched from cuda:0 device
  839. // Cast to char to avoid compiler warning about narrowing
  840. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  841. auto opts = q.options();
  842. // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
  843. auto softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  844. auto softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
  845. at::Tensor dq_accum;
  846. at::Tensor dk_accum, dv_accum;
  847. dq_accum = torch::empty({num_heads, total_q_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  848. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  849. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  850. at::Tensor dk_expanded, dv_expanded;
  851. if (num_heads_k != num_heads) { // MQA / GQA
  852. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  853. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  854. } else {
  855. dk_expanded = dk;
  856. dv_expanded = dv;
  857. }
  858. Flash_bwd_params params;
  859. set_params_dgrad(params,
  860. batch_size,
  861. max_seqlen_q, max_seqlen_k,
  862. seqlen_q_rounded, seqlen_k_rounded,
  863. num_heads, num_heads_k,
  864. head_size, head_size_rounded,
  865. q, k, v, out,
  866. dout_padded, dq, dk_expanded, dv_expanded,
  867. cu_seqlens_q.data_ptr(),
  868. cu_seqlens_k.data_ptr(),
  869. seqused_q.has_value() ? seqused_q.value().data_ptr() : nullptr,
  870. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  871. dq_accum.data_ptr(),
  872. // loop ? dk_accum.data_ptr() : nullptr,
  873. // loop ? dv_accum.data_ptr() : nullptr,
  874. nullptr,
  875. nullptr,
  876. softmax_lse.data_ptr(),
  877. softmax_d.data_ptr(),
  878. /*p_dropout=*/0.f,
  879. softmax_scale,
  880. /*window_size_left=*/window_size_left,
  881. /*window_size_right=*/window_size_right,
  882. deterministic,
  883. /*seqlenq_ngroups_swapped=*/false,
  884. /*unpadded_lse=*/true,
  885. optimize_for_doc_masking);
  886. params.total_q = total_q;
  887. params.total_k = total_k;
  888. params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
  889. // Will be zero'ed out in the backward preprocess kernel
  890. at::Tensor dq_semaphore = torch::empty({(max_seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
  891. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  892. if (max_seqlen_q > 0) {
  893. run_mha_bwd(params, stream);
  894. } else {
  895. // If max_seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  896. dk_expanded.zero_();
  897. dv_expanded.zero_();
  898. softmax_d.zero_();
  899. }
  900. // For MQA/GQA we need to sum dK and dV across the groups
  901. if (num_heads_k != num_heads) {
  902. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  903. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  904. }
  905. if (head_size_og % 8 != 0) {
  906. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  907. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  908. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  909. }
  910. return { dq, dk, dv, softmax_d, dq_accum, softmax_lse_log2 };
  911. }
  912. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  913. m.doc() = "FlashAttention";
  914. m.def("fwd", &mha_fwd, "Forward pass");
  915. m.def("bwd", &mha_bwd, "Backward pass");
  916. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  917. m.def("varlen_bwd", &mha_varlen_bwd, "Varlen backward pass");
  918. }