flash_api.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. /******************************************************************************
  2. * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *seqused_k,
  34. void *p_d,
  35. void *softmax_lse_d,
  36. float p_dropout,
  37. float softmax_scale,
  38. int window_size_left,
  39. int window_size_right,
  40. bool seqlenq_ngroups_swapped=false,
  41. bool unpadded_lse=false) {
  42. // Reset the parameters
  43. params = {};
  44. params.is_bf16 = q.dtype() == torch::kBFloat16;
  45. params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
  46. // Set the pointers and strides.
  47. params.q_ptr = q.data_ptr();
  48. params.k_ptr = k.data_ptr();
  49. params.v_ptr = v.data_ptr();
  50. // All stride are in elements, not bytes.
  51. params.q_row_stride = q.stride(-3);
  52. params.k_row_stride = k.stride(-3);
  53. params.v_row_stride = v.stride(-3);
  54. params.q_head_stride = q.stride(-2);
  55. params.k_head_stride = k.stride(-2);
  56. params.v_head_stride = v.stride(-2);
  57. params.o_ptr = out.data_ptr();
  58. params.o_row_stride = out.stride(-3);
  59. params.o_head_stride = out.stride(-2);
  60. if (cu_seqlens_q_d == nullptr) {
  61. params.q_batch_stride = q.stride(0);
  62. params.k_batch_stride = k.stride(0);
  63. params.v_batch_stride = v.stride(0);
  64. params.o_batch_stride = out.stride(0);
  65. if (seqlenq_ngroups_swapped) {
  66. params.q_batch_stride *= seqlen_q;
  67. params.o_batch_stride *= seqlen_q;
  68. }
  69. }
  70. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  71. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  72. params.seqused_k = static_cast<int *>(seqused_k);
  73. TORCH_CHECK(
  74. bool(params.cu_seqlens_q) == bool(params.cu_seqlens_k),
  75. "cu_seqlens_q and cu_seqlens_k must be both null or non-null"
  76. );
  77. // P = softmax(QK^T)
  78. params.p_ptr = p_d;
  79. // Softmax sum
  80. params.softmax_lse_ptr = softmax_lse_d;
  81. // Set the dimensions.
  82. params.b = b;
  83. params.h = h;
  84. params.h_k = h_k;
  85. params.h_h_k_ratio = h / h_k;
  86. params.seqlen_q = seqlen_q;
  87. params.seqlen_k = seqlen_k;
  88. params.seqlen_q_rounded = seqlen_q_rounded;
  89. params.seqlen_k_rounded = seqlen_k_rounded;
  90. params.d = d;
  91. params.d_rounded = d_rounded;
  92. // Set the different scale values.
  93. params.scale_softmax = softmax_scale;
  94. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  95. __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2);
  96. __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half);
  97. params.scale_softmax_log2_half2 = reinterpret_cast<uint32_t&>(scale_softmax_log2_half2);
  98. // Set this to probability of keeping an element to simplify things.
  99. params.p_dropout = 1.f - p_dropout;
  100. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  101. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  102. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  103. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  104. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  105. params.rp_dropout = 1.f / params.p_dropout;
  106. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  107. TORCH_CHECK(p_dropout < 1.f);
  108. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  109. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  110. #endif
  111. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  112. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  113. params.is_causal = window_size_left < 0 && window_size_right == 0;
  114. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
  115. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
  116. params.window_size_left = window_size_left;
  117. params.window_size_right = window_size_right;
  118. #ifdef FLASHATTENTION_DISABLE_LOCAL
  119. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  120. "This flash attention build does not support local attention.");
  121. #endif
  122. params.is_seqlens_k_cumulative = true;
  123. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  124. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  125. #endif
  126. params.unpadded_lse = unpadded_lse;
  127. }
  128. void set_params_dgrad(Flash_bwd_params &params,
  129. // sizes
  130. const size_t b,
  131. const size_t seqlen_q,
  132. const size_t seqlen_k,
  133. const size_t seqlen_q_rounded,
  134. const size_t seqlen_k_rounded,
  135. const size_t h,
  136. const size_t h_k,
  137. const size_t d,
  138. const size_t d_rounded,
  139. // device pointers
  140. const at::Tensor q,
  141. const at::Tensor k,
  142. const at::Tensor v,
  143. const at::Tensor out,
  144. const at::Tensor dout,
  145. at::Tensor dq,
  146. at::Tensor dk,
  147. at::Tensor dv,
  148. void *cu_seqlens_q_d,
  149. void *cu_seqlens_k_d,
  150. void *dq_accum_d,
  151. void *dk_accum_d,
  152. void *dv_accum_d,
  153. void *softmax_lse_d,
  154. void *dsoftmax_sum_d,
  155. float p_dropout,
  156. float softmax_scale,
  157. int window_size_left,
  158. int window_size_right,
  159. bool deterministic) {
  160. set_params_fprop(params,
  161. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  162. q, k, v, out,
  163. cu_seqlens_q_d,
  164. cu_seqlens_k_d,
  165. nullptr,
  166. nullptr,
  167. softmax_lse_d,
  168. p_dropout,
  169. softmax_scale,
  170. window_size_left,
  171. window_size_right);
  172. // Set the pointers and strides.
  173. params.do_ptr = dout.data_ptr();
  174. params.do_row_stride = dout.stride(-3);
  175. params.do_head_stride = dout.stride(-2);
  176. params.dq_ptr = dq.data_ptr();
  177. params.dk_ptr = dk.data_ptr();
  178. params.dv_ptr = dv.data_ptr();
  179. params.dq_row_stride = dq.stride(-3);
  180. params.dk_row_stride = dk.stride(-3);
  181. params.dv_row_stride = dv.stride(-3);
  182. params.dq_head_stride = dq.stride(-2);
  183. params.dk_head_stride = dk.stride(-2);
  184. params.dv_head_stride = dv.stride(-2);
  185. if (cu_seqlens_q_d == nullptr) {
  186. params.do_batch_stride = dout.stride(0);
  187. params.dq_batch_stride = dq.stride(0);
  188. params.dk_batch_stride = dk.stride(0);
  189. params.dv_batch_stride = dv.stride(0);
  190. }
  191. params.dq_accum_ptr = dq_accum_d;
  192. params.dk_accum_ptr = dk_accum_d;
  193. params.dv_accum_ptr = dv_accum_d;
  194. // Softmax sum
  195. params.dsoftmax_sum = dsoftmax_sum_d;
  196. params.deterministic = deterministic;
  197. }
  198. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  199. // HEADDIM_SWITCH(params.d, [&] {
  200. // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
  201. // });
  202. if (!params.is_e4m3) {
  203. if (params.is_bf16) {
  204. if (params.d == 64) {
  205. run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
  206. } else if (params.d == 128) {
  207. run_mha_fwd_<cutlass::bfloat16_t, 128>(params, stream);
  208. } else {
  209. run_mha_fwd_<cutlass::bfloat16_t, 256>(params, stream);
  210. }
  211. } else {
  212. if (params.d == 64) {
  213. run_mha_fwd_<cutlass::half_t, 64>(params, stream);
  214. } else if (params.d == 128) {
  215. run_mha_fwd_<cutlass::half_t, 128>(params, stream);
  216. } else {
  217. run_mha_fwd_<cutlass::half_t, 256>(params, stream);
  218. }
  219. }
  220. } else {
  221. if (params.d == 64) {
  222. run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
  223. } else if (params.d == 128) {
  224. run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
  225. } else {
  226. run_mha_fwd_<cutlass::float_e4m3_t, 256>(params, stream);
  227. }
  228. }
  229. }
  230. std::vector<at::Tensor>
  231. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  232. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  233. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  234. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  235. const float softmax_scale,
  236. bool is_causal) {
  237. auto dprops = at::cuda::getCurrentDeviceProperties();
  238. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  239. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  240. auto q_dtype = q.dtype();
  241. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16 || q_dtype == torch::kFloat8_e4m3fn,
  242. "FlashAttention only support fp16, bf16 and fp8 (e4m3) data type for now");
  243. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  244. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  245. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  246. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  247. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  248. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  249. TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous");
  250. TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous");
  251. TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous");
  252. const auto sizes = q.sizes();
  253. const int batch_size = sizes[0];
  254. int seqlen_q = sizes[1];
  255. int num_heads = sizes[2];
  256. const int head_size_og = sizes[3];
  257. const int seqlen_k = k.size(1);
  258. const int num_heads_k = k.size(2);
  259. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  260. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  261. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  262. TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now");
  263. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  264. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  265. if (q_dtype == torch::kFloat8_e4m3fn) {
  266. CHECK_SHAPE(v, batch_size, head_size_og, num_heads_k, seqlen_k);
  267. } else {
  268. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  269. }
  270. at::Tensor q_padded, k_padded, v_padded;
  271. if (q_dtype == torch::kFloat8_e4m3fn)
  272. {
  273. if (head_size_og % 16 != 0) {
  274. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 16 - head_size_og % 16}));
  275. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 16 - head_size_og % 16}));
  276. } else {
  277. q_padded = q;
  278. k_padded = k;
  279. }
  280. if (seqlen_k % 16 != 0) {
  281. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 16 - seqlen_k % 16}));
  282. } else {
  283. v_padded = v;
  284. }
  285. }
  286. else {
  287. if (head_size_og % 8 != 0) {
  288. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  289. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  290. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  291. } else {
  292. q_padded = q;
  293. k_padded = k;
  294. v_padded = v;
  295. }
  296. }
  297. at::Tensor out;
  298. if (out_.has_value()) {
  299. out = out_.value();
  300. //TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  301. CHECK_DEVICE(out);
  302. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  303. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  304. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  305. } else {
  306. out = q_dtype == torch::kFloat8_e4m3fn ? torch::empty_like(q_padded, at::kHalf) : torch::empty_like(q_padded);
  307. }
  308. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  309. const int head_size = round_multiple(head_size_og, 8);
  310. const int head_size_rounded = round_multiple(head_size, 32);
  311. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  312. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  313. // Otherwise the kernel will be launched from cuda:0 device
  314. // Cast to char to avoid compiler warning about narrowing
  315. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  316. auto opts = q.options();
  317. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  318. at::Tensor p;
  319. Flash_fwd_params params;
  320. set_params_fprop(params,
  321. batch_size,
  322. seqlen_q, seqlen_k,
  323. seqlen_q_rounded, seqlen_k_rounded,
  324. num_heads, num_heads_k,
  325. head_size, head_size_rounded,
  326. q_padded, k_padded, v_padded, out,
  327. /*cu_seqlens_q_d=*/nullptr,
  328. /*cu_seqlens_k_d=*/nullptr,
  329. /*seqused_k=*/nullptr,
  330. nullptr,
  331. softmax_lse.data_ptr(),
  332. /*p_dropout=*/0.f,
  333. softmax_scale,
  334. /*window_size_left=*/-1,
  335. /*window_size_right=*/is_causal ? 0 : -1);
  336. auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
  337. params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
  338. if (seqlen_k > 0) {
  339. auto stream = at::cuda::getCurrentCUDAStream().stream();
  340. run_mha_fwd(params, stream);
  341. } else {
  342. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  343. out.zero_();
  344. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  345. }
  346. at::Tensor out_padded = out;
  347. if (head_size_og % 8 != 0) {
  348. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  349. if (out_.has_value()) { out_.value().copy_(out); }
  350. }
  351. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
  352. }
  353. std::vector<at::Tensor>
  354. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  355. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  356. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  357. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  358. const at::Tensor &cu_seqlens_q, // b+1
  359. const at::Tensor &cu_seqlens_k, // b+1
  360. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  361. int max_seqlen_q,
  362. const int max_seqlen_k,
  363. const float softmax_scale,
  364. bool is_causal) {
  365. auto dprops = at::cuda::getCurrentDeviceProperties();
  366. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  367. TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer.");
  368. auto q_dtype = q.dtype();
  369. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  370. "FlashAttention only support fp16 and bf16 data type");
  371. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  372. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  373. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  374. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  375. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  376. CHECK_DEVICE(cu_seqlens_q);
  377. CHECK_DEVICE(cu_seqlens_k);
  378. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  379. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  380. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  381. CHECK_CONTIGUOUS(cu_seqlens_q);
  382. CHECK_CONTIGUOUS(cu_seqlens_k);
  383. const auto sizes = q.sizes();
  384. const int batch_size = cu_seqlens_q.numel() - 1;
  385. int num_heads = sizes[1];
  386. const int head_size_og = sizes[2];
  387. const int num_heads_k = k.size(1);
  388. int window_size_left = -1;
  389. int window_size_right = -1;
  390. if (is_causal) { window_size_right = 0; }
  391. void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  392. const int total_q = q.sizes()[0];
  393. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  394. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  395. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  396. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  397. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  398. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  399. const int total_k = k.size(0);
  400. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  401. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  402. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  403. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  404. if (seqused_k.has_value()){
  405. auto seqused_k_ = seqused_k.value();
  406. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  407. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  408. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  409. CHECK_SHAPE(seqused_k_, batch_size);
  410. }
  411. at::Tensor q_padded, k_padded, v_padded;
  412. if (head_size_og % 8 != 0) {
  413. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  414. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  415. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  416. } else {
  417. q_padded = q;
  418. k_padded = k;
  419. v_padded = v;
  420. }
  421. at::Tensor out;
  422. if (out_.has_value()) {
  423. out = out_.value();
  424. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  425. CHECK_DEVICE(out);
  426. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  427. CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
  428. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  429. } else {
  430. out = torch::empty_like(q_padded);
  431. }
  432. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  433. const int head_size = round_multiple(head_size_og, 8);
  434. const int head_size_rounded = round_multiple(head_size, 32);
  435. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  436. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  437. // Otherwise the kernel will be launched from cuda:0 device
  438. // Cast to char to avoid compiler warning about narrowing
  439. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  440. auto opts = q.options();
  441. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  442. Flash_fwd_params params;
  443. set_params_fprop(params,
  444. batch_size,
  445. max_seqlen_q, max_seqlen_k,
  446. seqlen_q_rounded, seqlen_k_rounded,
  447. num_heads, num_heads_k,
  448. head_size, head_size_rounded,
  449. q_padded, k_padded, v_padded, out,
  450. cu_seqlens_q_d,
  451. cu_seqlens_k.data_ptr(),
  452. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  453. /*p_d=*/nullptr,
  454. softmax_lse.data_ptr(),
  455. /*p_dropout=*/0.f,
  456. softmax_scale,
  457. window_size_left,
  458. window_size_right,
  459. /*seqlenq_ngroups_swapped=*/false,
  460. /*unpadded_lse=*/true);
  461. params.total_q = total_q;
  462. params.total_k = total_k;
  463. if (max_seqlen_k > 0) {
  464. auto stream = at::cuda::getCurrentCUDAStream().stream();
  465. run_mha_fwd(params, stream);
  466. } else {
  467. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  468. out.zero_();
  469. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  470. }
  471. at::Tensor out_padded = out;
  472. if (head_size_og % 8 != 0) {
  473. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  474. if (out_.has_value()) { out_.value().copy_(out); }
  475. }
  476. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse};
  477. }
  478. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  479. // FP16_SWITCH(!params.is_bf16, [&] {
  480. // HEADDIM_SWITCH(params.d, [&] {
  481. // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  482. // });
  483. // });
  484. if (params.d == 64) {
  485. run_mha_bwd_<cutlass::half_t, 64>(params, stream);
  486. } else if (params.d == 128) {
  487. run_mha_bwd_<cutlass::half_t, 128>(params, stream);
  488. } else {
  489. run_mha_bwd_<cutlass::half_t, 256>(params, stream);
  490. }
  491. }
  492. std::vector<at::Tensor>
  493. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  494. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  495. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  496. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  497. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  498. const at::Tensor &softmax_lse, // b x h x seqlen_q
  499. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  500. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  501. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  502. const float softmax_scale,
  503. const bool is_causal) {
  504. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  505. TORCH_CHECK(false, "This flash attention build does not support backward.");
  506. #endif
  507. auto dprops = at::cuda::getCurrentDeviceProperties();
  508. bool is_sm9x = dprops->major == 9 && dprops->minor >= 0;
  509. TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer.");
  510. auto stream = at::cuda::getCurrentCUDAStream().stream();
  511. auto q_dtype = q.dtype();
  512. TORCH_CHECK(q_dtype == torch::kFloat16,
  513. // "FlashAttention only support fp16 and bf16 data type");
  514. "FlashAttention only support fp16 data type for now");
  515. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  516. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  517. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  518. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  519. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  520. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  521. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  522. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  523. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  524. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  525. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  526. TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous");
  527. TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous");
  528. TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous");
  529. const auto sizes = q.sizes();
  530. const int batch_size = sizes[0];
  531. const int seqlen_q = sizes[1];
  532. const int num_heads = sizes[2];
  533. const int head_size_og = dout.size(3);
  534. const int head_size = sizes[3];
  535. const int seqlen_k = k.size(1);
  536. const int num_heads_k = k.size(2);
  537. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  538. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  539. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  540. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  541. TORCH_CHECK(head_size_og == 64 || head_size_og == 128, "Only support head size 64 and 128 for now");
  542. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  543. const int head_size_rounded = round_multiple(head_size, 32);
  544. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  545. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  546. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  547. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  548. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  549. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  550. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  551. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  552. at::Tensor dq, dk, dv;
  553. if (dq_.has_value()) {
  554. dq = dq_.value();
  555. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  556. CHECK_DEVICE(dq);
  557. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  558. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  559. } else {
  560. dq = torch::empty_like(q);
  561. }
  562. if (dk_.has_value()) {
  563. dk = dk_.value();
  564. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  565. CHECK_DEVICE(dk);
  566. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  567. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  568. } else {
  569. dk = torch::empty_like(k);
  570. }
  571. if (dv_.has_value()) {
  572. dv = dv_.value();
  573. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  574. CHECK_DEVICE(dv);
  575. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  576. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  577. } else {
  578. dv = torch::empty_like(v);
  579. }
  580. at::Tensor dout_padded;
  581. if (head_size_og % 8 != 0) {
  582. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  583. } else {
  584. dout_padded = dout;
  585. }
  586. // bool loop = seqlen_k > blocksize_c;
  587. // TODO: change later, for now set to true for simplicity
  588. bool loop = true;
  589. // Otherwise the kernel will be launched from cuda:0 device
  590. // Cast to char to avoid compiler warning about narrowing
  591. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  592. auto opts = q.options();
  593. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  594. at::Tensor dq_accum;
  595. at::Tensor dk_accum, dv_accum;
  596. if (loop) {
  597. dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  598. // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  599. // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat));
  600. }
  601. at::Tensor dk_expanded, dv_expanded;
  602. if (num_heads_k != num_heads) { // MQA / GQA
  603. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  604. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  605. } else {
  606. dk_expanded = dk;
  607. dv_expanded = dv;
  608. }
  609. Flash_bwd_params params;
  610. set_params_dgrad(params,
  611. batch_size,
  612. seqlen_q, seqlen_k,
  613. seqlen_q_rounded, seqlen_k_rounded,
  614. num_heads, num_heads_k,
  615. head_size, head_size_rounded,
  616. q, k, v, out,
  617. dout_padded, dq, dk_expanded, dv_expanded,
  618. nullptr,
  619. nullptr,
  620. loop ? dq_accum.data_ptr() : nullptr,
  621. // loop ? dk_accum.data_ptr() : nullptr,
  622. // loop ? dv_accum.data_ptr() : nullptr,
  623. nullptr,
  624. nullptr,
  625. softmax_lse.data_ptr(),
  626. softmax_d.data_ptr(),
  627. /*p_dropout=*/0.f,
  628. softmax_scale,
  629. /*window_size_left=*/-1,
  630. /*window_size_right=*/-1,
  631. /*deterministic=*/false);
  632. at::Tensor dq_semaphore = torch::zeros({(seqlen_q + 64 - 1) / 64, batch_size, num_heads}, opts.dtype(torch::kInt32));
  633. params.dq_semaphore = dq_semaphore.data_ptr<int>();
  634. // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads);
  635. auto launch = &run_mha_bwd;
  636. if (seqlen_q > 0) {
  637. launch(params, stream);
  638. } else {
  639. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  640. dk_expanded.zero_();
  641. dv_expanded.zero_();
  642. softmax_d.zero_();
  643. }
  644. if (head_size_og % 8 != 0) {
  645. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  646. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  647. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  648. }
  649. return { dq, dk, dv, softmax_d };
  650. }
  651. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  652. m.doc() = "FlashAttention";
  653. m.def("fwd", &mha_fwd, "Forward pass");
  654. m.def("bwd", &mha_bwd, "Backward pass");
  655. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  656. }