flash_api.cpp 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *p_d,
  34. void *softmax_lse_d,
  35. float p_dropout,
  36. float softmax_scale,
  37. bool is_causal) {
  38. // Reset the parameters
  39. memset(&params, 0, sizeof(params));
  40. params.is_bf16 = q.dtype() == torch::kBFloat16;
  41. // Set the pointers and strides.
  42. params.q_ptr = q.data_ptr();
  43. params.k_ptr = k.data_ptr();
  44. params.v_ptr = v.data_ptr();
  45. // All stride are in elements, not bytes.
  46. params.q_row_stride = q.stride(-3);
  47. params.k_row_stride = k.stride(-3);
  48. params.v_row_stride = v.stride(-3);
  49. params.q_head_stride = q.stride(-2);
  50. params.k_head_stride = k.stride(-2);
  51. params.v_head_stride = v.stride(-2);
  52. params.o_ptr = out.data_ptr();
  53. params.o_row_stride = out.stride(-3);
  54. params.o_head_stride = out.stride(-2);
  55. if (cu_seqlens_q_d == nullptr) {
  56. params.q_batch_stride = q.stride(0);
  57. params.k_batch_stride = k.stride(0);
  58. params.v_batch_stride = v.stride(0);
  59. params.o_batch_stride = out.stride(0);
  60. }
  61. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  62. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  63. // P = softmax(QK^T)
  64. params.p_ptr = p_d;
  65. // Softmax sum
  66. params.softmax_lse_ptr = softmax_lse_d;
  67. // Set the dimensions.
  68. params.b = b;
  69. params.h = h;
  70. params.h_k = h_k;
  71. params.h_h_k_ratio = h / h_k;
  72. params.seqlen_q = seqlen_q;
  73. params.seqlen_k = seqlen_k;
  74. params.seqlen_q_rounded = seqlen_q_rounded;
  75. params.seqlen_k_rounded = seqlen_k_rounded;
  76. params.d = d;
  77. params.d_rounded = d_rounded;
  78. // Set the different scale values.
  79. params.scale_softmax = softmax_scale;
  80. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  81. // Set this to probability of keeping an element to simplify things.
  82. params.p_dropout = 1.f - p_dropout;
  83. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  84. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  85. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  86. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  87. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  88. params.rp_dropout = 1.f / params.p_dropout;
  89. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  90. TORCH_CHECK(p_dropout < 1.f);
  91. params.is_causal = is_causal;
  92. params.is_seqlens_k_cumulative = true;
  93. }
  94. void set_params_dgrad(Flash_bwd_params &params,
  95. // sizes
  96. const size_t b,
  97. const size_t seqlen_q,
  98. const size_t seqlen_k,
  99. const size_t seqlen_q_rounded,
  100. const size_t seqlen_k_rounded,
  101. const size_t h,
  102. const size_t h_k,
  103. const size_t d,
  104. const size_t d_rounded,
  105. // device pointers
  106. const at::Tensor q,
  107. const at::Tensor k,
  108. const at::Tensor v,
  109. const at::Tensor out,
  110. const at::Tensor dout,
  111. at::Tensor dq,
  112. at::Tensor dk,
  113. at::Tensor dv,
  114. void *cu_seqlens_q_d,
  115. void *cu_seqlens_k_d,
  116. void *dq_accum_d,
  117. void *dk_accum_d,
  118. void *dv_accum_d,
  119. void *softmax_lse_d,
  120. void *dsoftmax_sum_d,
  121. float p_dropout,
  122. float softmax_scale,
  123. bool is_causal) {
  124. set_params_fprop(params,
  125. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  126. q, k, v, out,
  127. cu_seqlens_q_d,
  128. cu_seqlens_k_d,
  129. nullptr,
  130. softmax_lse_d,
  131. p_dropout,
  132. softmax_scale,
  133. is_causal);
  134. // Set the pointers and strides.
  135. params.do_ptr = dout.data_ptr();
  136. params.do_row_stride = dout.stride(-3);
  137. params.do_head_stride = dout.stride(-2);
  138. params.dq_ptr = dq.data_ptr();
  139. params.dk_ptr = dk.data_ptr();
  140. params.dv_ptr = dv.data_ptr();
  141. params.dq_row_stride = dq.stride(-3);
  142. params.dk_row_stride = dk.stride(-3);
  143. params.dv_row_stride = dv.stride(-3);
  144. params.dq_head_stride = dq.stride(-2);
  145. params.dk_head_stride = dk.stride(-2);
  146. params.dv_head_stride = dv.stride(-2);
  147. if (cu_seqlens_q_d == nullptr) {
  148. params.do_batch_stride = dout.stride(0);
  149. params.dq_batch_stride = dq.stride(0);
  150. params.dk_batch_stride = dk.stride(0);
  151. params.dv_batch_stride = dv.stride(0);
  152. }
  153. params.dq_accum_ptr = dq_accum_d;
  154. params.dk_accum_ptr = dk_accum_d;
  155. params.dv_accum_ptr = dv_accum_d;
  156. // Softmax sum
  157. params.dsoftmax_sum = dsoftmax_sum_d;
  158. }
  159. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  160. FP16_SWITCH(!params.is_bf16, [&] {
  161. FWD_HEADDIM_SWITCH(params.d, [&] {
  162. if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
  163. run_mha_fwd_<elem_type, kHeadDim>(params, stream);
  164. } else {
  165. run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
  166. }
  167. });
  168. });
  169. }
  170. // Find the number of splits that maximizes the occupancy. For example, if we have
  171. // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
  172. // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
  173. // splits as that would incur more HBM reads/writes.
  174. // So we find the best efficiency, then find the smallest number of splits that gets 85%
  175. // of the best efficiency.
  176. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
  177. // If we have enough to almost fill the SMs, then just use 1 split
  178. if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
  179. max_splits = std::min({max_splits, num_SMs, num_n_blocks});
  180. float max_efficiency = 0.f;
  181. std::vector<float> efficiency;
  182. efficiency.reserve(max_splits);
  183. auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
  184. // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
  185. // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
  186. // (i.e. it's 11 splits anyway).
  187. // So we check if the number of blocks per split is the same as the previous num_splits.
  188. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
  189. return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
  190. };
  191. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  192. if (!is_split_eligible(num_splits)) {
  193. efficiency.push_back(0.f);
  194. } else {
  195. float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
  196. float eff = n_waves / ceil(n_waves);
  197. // printf("num_splits = %d, eff = %f\n", num_splits, eff);
  198. if (eff > max_efficiency) { max_efficiency = eff; }
  199. efficiency.push_back(eff);
  200. }
  201. }
  202. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  203. if (!is_split_eligible(num_splits)) { continue; }
  204. if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
  205. // printf("num_splits chosen = %d\n", num_splits);
  206. return num_splits;
  207. }
  208. }
  209. return 1;
  210. }
  211. std::vector<at::Tensor>
  212. mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  213. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  214. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  215. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  216. const float p_dropout,
  217. const float softmax_scale,
  218. const bool is_causal,
  219. const bool return_softmax,
  220. c10::optional<at::Generator> gen_) {
  221. auto dprops = at::cuda::getCurrentDeviceProperties();
  222. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  223. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  224. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  225. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  226. // We will support Turing in the near future
  227. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  228. auto q_dtype = q.dtype();
  229. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  230. "FlashAttention only support fp16 and bf16 data type");
  231. if (q_dtype == torch::kBFloat16) {
  232. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  233. }
  234. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  235. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  236. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  237. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  238. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  239. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  240. const auto sizes = q.sizes();
  241. const int batch_size = sizes[0];
  242. const int seqlen_q = sizes[1];
  243. const int num_heads = sizes[2];
  244. const int head_size_og = sizes[3];
  245. const int seqlen_k = k.size(1);
  246. const int num_heads_k = k.size(2);
  247. TORCH_CHECK(batch_size > 0, "batch size must be postive");
  248. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  249. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  250. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  251. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  252. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  253. at::Tensor q_padded, k_padded, v_padded;
  254. if (head_size_og % 8 != 0) {
  255. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  256. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  257. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  258. } else {
  259. q_padded = q;
  260. k_padded = k;
  261. v_padded = v;
  262. }
  263. at::Tensor out;
  264. if (out_.has_value()) {
  265. out = out_.value();
  266. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  267. CHECK_DEVICE(out);
  268. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  269. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  270. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  271. } else {
  272. out = torch::empty_like(q_padded);
  273. }
  274. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  275. const int head_size = round_multiple(head_size_og, 8);
  276. const int head_size_rounded = round_multiple(head_size, 32);
  277. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  278. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  279. // Otherwise the kernel will be launched from cuda:0 device
  280. // Cast to char to avoid compiler warning about narrowing
  281. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  282. auto opts = q.options();
  283. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  284. at::Tensor p;
  285. // Only return softmax if there's dropout to reduce compilation time
  286. if (return_softmax) {
  287. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  288. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  289. }
  290. Flash_fwd_params params;
  291. set_params_fprop(params,
  292. batch_size,
  293. seqlen_q, seqlen_k,
  294. seqlen_q_rounded, seqlen_k_rounded,
  295. num_heads, num_heads_k,
  296. head_size, head_size_rounded,
  297. q_padded, k_padded, v_padded, out,
  298. /*cu_seqlens_q_d=*/nullptr,
  299. /*cu_seqlens_k_d=*/nullptr,
  300. return_softmax ? p.data_ptr() : nullptr,
  301. softmax_lse.data_ptr(),
  302. p_dropout,
  303. softmax_scale,
  304. is_causal);
  305. // This needs to match with run_mha_fwd_splitkv_dispatch
  306. const int block_n = is_sm90 || is_sm8x
  307. ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))
  308. : (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64));
  309. const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
  310. // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
  311. // In any case we don't expect seqlen_q to be larger than 64 for inference.
  312. const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
  313. params.num_splits = 1;
  314. if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
  315. params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
  316. if (params.num_splits > 1) {
  317. at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  318. at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
  319. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  320. params.oaccum_ptr = out_accum.data_ptr();
  321. }
  322. }
  323. // number of times random will be generated per thread, to offset philox counter in thc random
  324. // state
  325. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  326. int64_t counter_offset = params.b * params.h * 32;
  327. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  328. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  329. // Forward kernel will populate memory with the seed and offset.
  330. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  331. if (p_dropout > 0.0) {
  332. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  333. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  334. // See Note [Acquire lock when using random generators]
  335. std::lock_guard<std::mutex> lock(gen->mutex_);
  336. params.philox_args = gen->philox_cuda_state(counter_offset);
  337. }
  338. auto stream = at::cuda::getCurrentCUDAStream().stream();
  339. run_mha_fwd(params, stream);
  340. at::Tensor out_padded = out;
  341. if (head_size_og % 8 != 0) {
  342. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  343. if (out_.has_value()) { out_.value().copy_(out); }
  344. }
  345. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
  346. }
  347. std::vector<at::Tensor>
  348. mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  349. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  350. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  351. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  352. const at::Tensor &cu_seqlens_q, // b+1
  353. const at::Tensor &cu_seqlens_k, // b+1
  354. const int max_seqlen_q,
  355. const int max_seqlen_k,
  356. const float p_dropout,
  357. const float softmax_scale,
  358. const bool zero_tensors,
  359. const bool is_causal,
  360. const bool return_softmax,
  361. c10::optional<at::Generator> gen_) {
  362. auto dprops = at::cuda::getCurrentDeviceProperties();
  363. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  364. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  365. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  366. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  367. // We will support Turing in the near future
  368. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  369. auto q_dtype = q.dtype();
  370. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  371. "FlashAttention only support fp16 and bf16 data type");
  372. if (q_dtype == torch::kBFloat16) {
  373. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  374. }
  375. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  376. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  377. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  378. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  379. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  380. CHECK_DEVICE(cu_seqlens_q);
  381. CHECK_DEVICE(cu_seqlens_k);
  382. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  383. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  384. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  385. CHECK_CONTIGUOUS(cu_seqlens_q);
  386. CHECK_CONTIGUOUS(cu_seqlens_k);
  387. const auto sizes = q.sizes();
  388. const int total_q = sizes[0];
  389. const int batch_size = cu_seqlens_q.numel() - 1;
  390. const int num_heads = sizes[1];
  391. const int head_size_og = sizes[2];
  392. const int total_k = k.size(0);
  393. const int num_heads_k = k.size(1);
  394. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  395. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  396. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  397. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  398. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  399. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  400. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  401. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  402. at::Tensor q_padded, k_padded, v_padded;
  403. if (head_size_og % 8 != 0) {
  404. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  405. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  406. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  407. } else {
  408. q_padded = q;
  409. k_padded = k;
  410. v_padded = v;
  411. }
  412. at::Tensor out;
  413. if (out_.has_value()) {
  414. out = out_.value();
  415. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  416. CHECK_DEVICE(out);
  417. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  418. CHECK_SHAPE(out, total_q, num_heads, head_size_og);
  419. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  420. } else {
  421. out = torch::empty_like(q_padded);
  422. }
  423. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  424. const int head_size = round_multiple(head_size_og, 8);
  425. const int head_size_rounded = round_multiple(head_size, 32);
  426. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  427. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  428. // Otherwise the kernel will be launched from cuda:0 device
  429. // Cast to char to avoid compiler warning about narrowing
  430. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  431. auto opts = q.options();
  432. auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
  433. at::Tensor p;
  434. // Only return softmax if there's dropout to reduce compilation time
  435. if (return_softmax) {
  436. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  437. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  438. }
  439. if (zero_tensors) {
  440. out.zero_();
  441. softmax_lse.fill_(-std::numeric_limits<float>::infinity());
  442. if (return_softmax) {p.zero_();}
  443. }
  444. Flash_fwd_params params;
  445. set_params_fprop(params,
  446. batch_size,
  447. max_seqlen_q, max_seqlen_k,
  448. seqlen_q_rounded, seqlen_k_rounded,
  449. num_heads, num_heads_k,
  450. head_size, head_size_rounded,
  451. q_padded, k_padded, v_padded, out,
  452. cu_seqlens_q.data_ptr(),
  453. cu_seqlens_k.data_ptr(),
  454. return_softmax ? p.data_ptr() : nullptr,
  455. softmax_lse.data_ptr(),
  456. p_dropout,
  457. softmax_scale,
  458. is_causal);
  459. // number of times random will be generated per thread, to offset philox counter in thc random
  460. // state
  461. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  462. int64_t counter_offset = params.b * params.h * 32;
  463. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  464. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  465. // Forward kernel will populate memory with the seed and offset.
  466. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  467. if (p_dropout > 0.0) {
  468. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  469. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  470. // See Note [Acquire lock when using random generators]
  471. std::lock_guard<std::mutex> lock(gen->mutex_);
  472. params.philox_args = gen->philox_cuda_state(counter_offset);
  473. }
  474. auto stream = at::cuda::getCurrentCUDAStream().stream();
  475. run_mha_fwd(params, stream);
  476. at::Tensor out_padded = out;
  477. if (head_size_og % 8 != 0) {
  478. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  479. if (out_.has_value()) { out_.value().copy_(out); }
  480. }
  481. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
  482. }
  483. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream, const bool configure) {
  484. FP16_SWITCH(!params.is_bf16, [&] {
  485. if (params.d <= 32) {
  486. run_mha_bwd_<elem_type, 32>(params, stream, configure);
  487. } else if (params.d <= 64) {
  488. run_mha_bwd_<elem_type, 64>(params, stream, configure);
  489. } else if (params.d <= 96) {
  490. run_mha_bwd_<elem_type, 96>(params, stream, configure);
  491. } else if (params.d <= 128) {
  492. run_mha_bwd_<elem_type, 128>(params, stream, configure);
  493. } else if (params.d <= 160) {
  494. run_mha_bwd_<elem_type, 160>(params, stream, configure);
  495. } else if (params.d <= 192) {
  496. run_mha_bwd_<elem_type, 192>(params, stream, configure);
  497. } else if (params.d <= 224) {
  498. run_mha_bwd_<elem_type, 224>(params, stream, configure);
  499. } else if (params.d <= 256) {
  500. run_mha_bwd_<elem_type, 256>(params, stream, configure);
  501. }
  502. });
  503. }
  504. std::vector<at::Tensor>
  505. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  506. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  507. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  508. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  509. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  510. const at::Tensor &softmax_lse, // b x h x seqlen_q
  511. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  512. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  513. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  514. const float p_dropout, // probability to drop
  515. const float softmax_scale,
  516. const bool is_causal,
  517. c10::optional<at::Generator> gen_,
  518. c10::optional<at::Tensor> &rng_state) {
  519. auto dprops = at::cuda::getCurrentDeviceProperties();
  520. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  521. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  522. bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
  523. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  524. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  525. // We will support Turing in the near future
  526. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  527. bool is_dropout = p_dropout > 0.0;
  528. auto stream = at::cuda::getCurrentCUDAStream().stream();
  529. auto q_dtype = q.dtype();
  530. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  531. "FlashAttention only support fp16 and bf16 data type");
  532. if (q_dtype == torch::kBFloat16) {
  533. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  534. }
  535. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  536. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  537. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  538. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  539. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  540. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  541. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  542. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  543. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  544. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  545. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  546. const auto sizes = q.sizes();
  547. const int batch_size = sizes[0];
  548. const int seqlen_q = sizes[1];
  549. const int num_heads = sizes[2];
  550. const int head_size_og = dout.size(3);
  551. const int head_size = sizes[3];
  552. const int seqlen_k = k.size(1);
  553. const int num_heads_k = k.size(2);
  554. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  555. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  556. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  557. if (head_size > 192) {
  558. TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
  559. }
  560. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  561. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  562. const int head_size_rounded = round_multiple(head_size, 32);
  563. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  564. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  565. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  566. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  567. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  568. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  569. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  570. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  571. at::Tensor dq, dk, dv;
  572. if (dq_.has_value()) {
  573. dq = dq_.value();
  574. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  575. CHECK_DEVICE(dq);
  576. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  577. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  578. } else {
  579. dq = torch::empty_like(q);
  580. }
  581. if (dk_.has_value()) {
  582. dk = dk_.value();
  583. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  584. CHECK_DEVICE(dk);
  585. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  586. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  587. } else {
  588. dk = torch::empty_like(k);
  589. }
  590. if (dv_.has_value()) {
  591. dv = dv_.value();
  592. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  593. CHECK_DEVICE(dv);
  594. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  595. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  596. } else {
  597. dv = torch::empty_like(k);
  598. }
  599. at::Tensor dout_padded;
  600. if (head_size_og % 8 != 0) {
  601. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  602. } else {
  603. dout_padded = dout;
  604. }
  605. // bool loop = seqlen_k > blocksize_c;
  606. // TODO: change later, for now set to true for simplicity
  607. bool loop = true;
  608. // Otherwise the kernel will be launched from cuda:0 device
  609. // Cast to char to avoid compiler warning about narrowing
  610. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  611. auto opts = q.options();
  612. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  613. at::Tensor dq_accum;
  614. at::Tensor dk_accum, dv_accum;
  615. if (loop) {
  616. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  617. // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  618. // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  619. }
  620. at::Tensor dk_expanded, dv_expanded;
  621. if (num_heads_k != num_heads) { // MQA / GQA
  622. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  623. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  624. } else {
  625. dk_expanded = dk;
  626. dv_expanded = dv;
  627. }
  628. Flash_bwd_params params;
  629. set_params_dgrad(params,
  630. batch_size,
  631. seqlen_q, seqlen_k,
  632. seqlen_q_rounded, seqlen_k_rounded,
  633. num_heads, num_heads_k,
  634. head_size, head_size_rounded,
  635. q, k, v, out,
  636. dout_padded, dq, dk_expanded, dv_expanded,
  637. nullptr,
  638. nullptr,
  639. loop ? dq_accum.data_ptr() : nullptr,
  640. // loop ? dk_accum.data_ptr() : nullptr,
  641. // loop ? dv_accum.data_ptr() : nullptr,
  642. nullptr,
  643. nullptr,
  644. softmax_lse.data_ptr(),
  645. softmax_d.data_ptr(),
  646. p_dropout,
  647. softmax_scale,
  648. is_causal);
  649. auto launch = &run_mha_bwd;
  650. // launch(params, stream, /*configure=*/true);
  651. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  652. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  653. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  654. int64_t counter_offset = params.b * params.h * 32;
  655. if ( rng_state.has_value() ) {
  656. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  657. } else if( is_dropout ) {
  658. // See Note [Acquire lock when using random generators]
  659. std::lock_guard<std::mutex> lock(gen->mutex_);
  660. params.philox_args = gen->philox_cuda_state(counter_offset);
  661. auto seeds = at::cuda::philox::unpack(params.philox_args);
  662. params.rng_state[0] = std::get<0>(seeds);
  663. params.rng_state[1] = std::get<1>(seeds);
  664. }
  665. launch(params, stream, /*configure=*/false);
  666. // For MQA/GQA we need to sum dK and dV across the groups
  667. if (num_heads_k != num_heads) {
  668. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  669. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  670. }
  671. if (head_size_og % 8 != 0) {
  672. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  673. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  674. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  675. }
  676. return { dq, dk, dv, softmax_d };
  677. }
  678. std::vector<at::Tensor>
  679. mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
  680. const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  681. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  682. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  683. const at::Tensor &out, // total_q x num_heads x head_size
  684. const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
  685. c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  686. c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  687. c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  688. const at::Tensor &cu_seqlens_q, // b+1
  689. const at::Tensor &cu_seqlens_k, // b+1
  690. const int max_seqlen_q,
  691. const int max_seqlen_k, // max sequence length to choose the kernel
  692. const float p_dropout, // probability to drop
  693. const float softmax_scale,
  694. const bool zero_tensors,
  695. const bool is_causal,
  696. c10::optional<at::Generator> gen_,
  697. c10::optional<at::Tensor> &rng_state
  698. ) {
  699. auto dprops = at::cuda::getCurrentDeviceProperties();
  700. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  701. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  702. bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
  703. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  704. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  705. // We will support Turing in the near future
  706. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  707. bool is_dropout = p_dropout > 0.0;
  708. auto stream = at::cuda::getCurrentCUDAStream().stream();
  709. auto q_dtype = q.dtype();
  710. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  711. "FlashAttention only support fp16 and bf16 data type");
  712. if (q_dtype == torch::kBFloat16) {
  713. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  714. }
  715. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  716. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  717. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  718. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  719. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  720. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  721. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  722. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  723. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  724. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  725. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  726. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  727. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  728. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  729. CHECK_CONTIGUOUS(cu_seqlens_q);
  730. CHECK_CONTIGUOUS(cu_seqlens_k);
  731. const auto sizes = q.sizes();
  732. const int total_q = sizes[0];
  733. const int batch_size = cu_seqlens_q.numel() - 1;
  734. const int num_heads = sizes[1];
  735. const int head_size_og = dout.size(2);
  736. const int head_size = sizes[2];
  737. const int total_k = k.size(0);
  738. const int num_heads_k = k.size(1);
  739. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  740. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  741. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  742. if (head_size > 192) {
  743. TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
  744. }
  745. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  746. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  747. const int head_size_rounded = round_multiple(head_size, 32);
  748. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  749. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  750. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  751. CHECK_SHAPE(q, total_q, num_heads, head_size);
  752. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  753. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  754. CHECK_SHAPE(out, total_q, num_heads, head_size);
  755. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  756. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  757. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  758. at::Tensor dq, dk, dv;
  759. if (dq_.has_value()) {
  760. dq = dq_.value();
  761. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  762. CHECK_DEVICE(dq);
  763. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  764. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  765. } else {
  766. dq = torch::empty_like(q);
  767. }
  768. if (dk_.has_value()) {
  769. dk = dk_.value();
  770. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  771. CHECK_DEVICE(dk);
  772. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  773. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  774. } else {
  775. dk = torch::empty_like(k);
  776. }
  777. if (dv_.has_value()) {
  778. dv = dv_.value();
  779. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  780. CHECK_DEVICE(dv);
  781. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  782. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  783. } else {
  784. dv = torch::empty_like(k);
  785. }
  786. at::Tensor dout_padded;
  787. if (head_size_og % 8 != 0) {
  788. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  789. } else {
  790. dout_padded = dout;
  791. }
  792. // bool loop = max_seqlen_k > blocksize_c;
  793. // TODO: change later, for now set to true for simplicity
  794. bool loop = true;
  795. // Otherwise the kernel will be launched from cuda:0 device
  796. // Cast to char to avoid compiler warning about narrowing
  797. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  798. auto opts = q.options();
  799. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  800. at::Tensor dq_accum;
  801. if (loop) {
  802. dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  803. }
  804. at::Tensor dk_expanded, dv_expanded;
  805. if (num_heads_k != num_heads) { // MQA / GQA
  806. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  807. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  808. } else {
  809. dk_expanded = dk;
  810. dv_expanded = dv;
  811. }
  812. if( zero_tensors ) {
  813. dq.zero_();
  814. dk_expanded.zero_();
  815. dv_expanded.zero_();
  816. softmax_d.zero_();
  817. }
  818. Flash_bwd_params params;
  819. set_params_dgrad(params,
  820. batch_size,
  821. max_seqlen_q, max_seqlen_k,
  822. seqlen_q_rounded, seqlen_k_rounded,
  823. num_heads, num_heads_k,
  824. head_size, head_size_rounded,
  825. q, k, v, out,
  826. dout_padded, dq, dk_expanded, dv_expanded,
  827. cu_seqlens_q.data_ptr(),
  828. cu_seqlens_k.data_ptr(),
  829. loop ? dq_accum.data_ptr() : nullptr,
  830. nullptr,
  831. nullptr,
  832. softmax_lse.data_ptr(),
  833. softmax_d.data_ptr(),
  834. p_dropout,
  835. softmax_scale,
  836. is_causal);
  837. auto launch = &run_mha_bwd;
  838. // launch(params, stream, /*configure=*/true);
  839. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  840. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  841. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  842. int64_t counter_offset = params.b * params.h * 32;
  843. if ( rng_state.has_value() ) {
  844. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  845. } else if( is_dropout ) {
  846. // See Note [Acquire lock when using random generators]
  847. std::lock_guard<std::mutex> lock(gen->mutex_);
  848. params.philox_args = gen->philox_cuda_state(counter_offset);
  849. auto seeds = at::cuda::philox::unpack(params.philox_args);
  850. params.rng_state[0] = std::get<0>(seeds);
  851. params.rng_state[1] = std::get<1>(seeds);
  852. }
  853. launch(params, stream, /*configure=*/false);
  854. // For MQA/GQA we need to sum dK and dV across the groups
  855. if (num_heads_k != num_heads) {
  856. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  857. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  858. }
  859. if (head_size_og % 8 != 0) {
  860. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  861. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  862. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  863. }
  864. return { dq, dk, dv, softmax_d };
  865. }
  866. std::vector<at::Tensor>
  867. mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  868. const at::Tensor &kcache, // batch_size x seqlen_k x num_heads_k x head_size
  869. const at::Tensor &vcache, // batch_size x seqlen_k x num_heads_k x head_size
  870. c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
  871. c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
  872. c10::optional<const at::Tensor> &seqlens_k_, // batch_size
  873. c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  874. c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  875. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  876. const float softmax_scale,
  877. bool is_causal,
  878. bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  879. int num_splits
  880. ) {
  881. auto dprops = at::cuda::getCurrentDeviceProperties();
  882. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  883. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  884. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  885. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  886. // We will support Turing in the near future
  887. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  888. auto q_dtype = q.dtype();
  889. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  890. "FlashAttention only support fp16 and bf16 data type");
  891. if (q_dtype == torch::kBFloat16) {
  892. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  893. }
  894. TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
  895. TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
  896. CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
  897. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  898. TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  899. TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  900. const auto sizes = q.sizes();
  901. const int batch_size = sizes[0];
  902. int seqlen_q = sizes[1];
  903. int num_heads = sizes[2];
  904. const int head_size_og = sizes[3];
  905. const int seqlen_k = kcache.size(1);
  906. const int num_heads_k = kcache.size(2);
  907. TORCH_CHECK(batch_size > 0, "batch size must be postive");
  908. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  909. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  910. if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
  911. // Faster to transpose q from (b, 1, h, d) to (b, h, 1, d) in this case
  912. const int seqlenq_nheads_swapped = seqlen_q == 1 && num_heads_k == 1 && num_heads > 1;
  913. if (seqlenq_nheads_swapped) {
  914. q = q.transpose(1, 2);
  915. std::swap(seqlen_q, num_heads);
  916. }
  917. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  918. CHECK_SHAPE(kcache, batch_size, seqlen_k, num_heads_k, head_size_og);
  919. CHECK_SHAPE(vcache, batch_size, seqlen_k, num_heads_k, head_size_og);
  920. at::Tensor q_padded, kcache_padded, vcache_padded;
  921. if (head_size_og % 8 != 0) {
  922. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  923. kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  924. vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  925. } else {
  926. q_padded = q;
  927. kcache_padded = kcache;
  928. vcache_padded = vcache;
  929. }
  930. at::Tensor out;
  931. if (out_.has_value()) {
  932. out = out_.value();
  933. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  934. CHECK_DEVICE(out);
  935. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  936. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  937. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  938. } else {
  939. out = torch::empty_like(q_padded);
  940. }
  941. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  942. const int head_size = round_multiple(head_size_og, 8);
  943. const int head_size_rounded = round_multiple(head_size, 32);
  944. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  945. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  946. // Otherwise the kernel will be launched from cuda:0 device
  947. // Cast to char to avoid compiler warning about narrowing
  948. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  949. auto opts = q.options();
  950. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  951. Flash_fwd_params params;
  952. set_params_fprop(params,
  953. batch_size,
  954. seqlen_q, seqlen_k,
  955. seqlen_q_rounded, seqlen_k_rounded,
  956. num_heads, num_heads_k,
  957. head_size, head_size_rounded,
  958. q_padded, kcache_padded, vcache_padded, out,
  959. /*cu_seqlens_q_d=*/nullptr,
  960. /*cu_seqlens_k_d=*/nullptr,
  961. /*p_ptr=*/nullptr,
  962. softmax_lse.data_ptr(),
  963. /*p_dropout=*/0.f,
  964. softmax_scale,
  965. is_causal);
  966. at::Tensor k, v, k_padded, v_padded;
  967. if (k_.has_value()) {
  968. TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
  969. TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
  970. TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
  971. k = k_.value();
  972. v = v_.value();
  973. TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
  974. TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
  975. CHECK_DEVICE(k); CHECK_DEVICE(v);
  976. TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
  977. TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
  978. int seqlen_knew = k.size(1);
  979. CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
  980. CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
  981. if (head_size_og % 8 != 0) {
  982. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  983. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  984. } else {
  985. k_padded = k;
  986. v_padded = v;
  987. }
  988. params.seqlen_knew = seqlen_knew;
  989. params.knew_ptr = k_padded.data_ptr();
  990. params.vnew_ptr = v_padded.data_ptr();
  991. // All stride are in elements, not bytes.
  992. params.knew_batch_stride = k_padded.stride(0);
  993. params.vnew_batch_stride = v_padded.stride(0);
  994. params.knew_row_stride = k_padded.stride(-3);
  995. params.vnew_row_stride = v_padded.stride(-3);
  996. params.knew_head_stride = k_padded.stride(-2);
  997. params.vnew_head_stride = v_padded.stride(-2);
  998. }
  999. if (seqlens_k_.has_value()) {
  1000. auto seqlens_k = seqlens_k_.value();
  1001. TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
  1002. CHECK_DEVICE(seqlens_k);
  1003. CHECK_CONTIGUOUS(seqlens_k);
  1004. CHECK_SHAPE(seqlens_k, batch_size);
  1005. params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
  1006. }
  1007. params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
  1008. if (rotary_cos_.has_value()) {
  1009. TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  1010. auto rotary_cos = rotary_cos_.value();
  1011. CHECK_DEVICE(rotary_cos);
  1012. params.rotary_dim = rotary_cos.size(1) * 2;
  1013. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  1014. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  1015. const int seqlen_ro = rotary_cos.size(0);
  1016. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  1017. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  1018. CHECK_CONTIGUOUS(rotary_cos);
  1019. TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1020. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  1021. auto rotary_sin = rotary_sin_.value();
  1022. CHECK_DEVICE(rotary_sin);
  1023. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  1024. CHECK_CONTIGUOUS(rotary_sin);
  1025. TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1026. params.rotary_cos_ptr = rotary_cos.data_ptr();
  1027. params.rotary_sin_ptr = rotary_sin.data_ptr();
  1028. params.is_rotary_interleaved = is_rotary_interleaved;
  1029. } else {
  1030. params.rotary_dim = 0;
  1031. }
  1032. // This needs to match with run_mha_fwd_splitkv_dispatch
  1033. const int block_n = is_sm90 || is_sm8x
  1034. ? (head_size <= 64 ? 256 : (head_size <= 160 ? 128 : 64))
  1035. : (head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64));
  1036. const int num_n_blocks = (seqlen_k + (params.knew_ptr == nullptr ? 0 : seqlen_q) + block_n - 1) / block_n;
  1037. // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
  1038. // In any case we don't expect seqlen_q to be larger than 64 for inference.
  1039. const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
  1040. params.num_splits = num_splits;
  1041. if (num_splits < 1) {
  1042. params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
  1043. }
  1044. if (params.num_splits > 1) {
  1045. at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  1046. at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
  1047. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  1048. params.oaccum_ptr = out_accum.data_ptr();
  1049. }
  1050. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1051. // Only split kernel supports appending to KV cache
  1052. run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value());
  1053. if (head_size_og % 8 != 0) {
  1054. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1055. if (out_.has_value()) { out_.value().copy_(out); }
  1056. if (k_.has_value()) {
  1057. // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
  1058. // but we don't expect to get this case in practice. This is just so that the code works for that case.
  1059. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1060. vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1061. }
  1062. }
  1063. if (seqlenq_nheads_swapped) {
  1064. out = out.transpose(1, 2);
  1065. softmax_lse = softmax_lse.transpose(1, 2);
  1066. }
  1067. return {out, softmax_lse};
  1068. }
  1069. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1070. m.doc() = "FlashAttention";
  1071. m.def("fwd", &mha_fwd, "Forward pass");
  1072. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  1073. m.def("bwd", &mha_bwd, "Backward pass");
  1074. m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
  1075. m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
  1076. }