flash_api.cpp 74 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *seqused_k,
  34. void *p_d,
  35. void *softmax_lse_d,
  36. float p_dropout,
  37. float softmax_scale,
  38. int window_size_left,
  39. int window_size_right,
  40. const float softcap,
  41. bool seqlenq_ngroups_swapped=false,
  42. const bool unpadded_lse=false) {
  43. // Reset the parameters
  44. params = {};
  45. params.is_bf16 = q.dtype() == torch::kBFloat16;
  46. // Set the pointers and strides.
  47. params.q_ptr = q.data_ptr();
  48. params.k_ptr = k.data_ptr();
  49. params.v_ptr = v.data_ptr();
  50. // All stride are in elements, not bytes.
  51. params.q_row_stride = q.stride(-3);
  52. params.k_row_stride = k.stride(-3);
  53. params.v_row_stride = v.stride(-3);
  54. params.q_head_stride = q.stride(-2);
  55. params.k_head_stride = k.stride(-2);
  56. params.v_head_stride = v.stride(-2);
  57. params.o_ptr = out.data_ptr();
  58. params.o_row_stride = out.stride(-3);
  59. params.o_head_stride = out.stride(-2);
  60. if (cu_seqlens_q_d == nullptr) {
  61. params.q_batch_stride = q.stride(0);
  62. params.k_batch_stride = k.stride(0);
  63. params.v_batch_stride = v.stride(0);
  64. params.o_batch_stride = out.stride(0);
  65. if (seqlenq_ngroups_swapped) {
  66. params.q_batch_stride *= seqlen_q;
  67. params.o_batch_stride *= seqlen_q;
  68. }
  69. }
  70. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  71. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  72. params.seqused_k = static_cast<int *>(seqused_k);
  73. // P = softmax(QK^T)
  74. params.p_ptr = p_d;
  75. // Softmax sum
  76. params.softmax_lse_ptr = softmax_lse_d;
  77. // Set the dimensions.
  78. params.b = b;
  79. params.h = h;
  80. params.h_k = h_k;
  81. params.h_h_k_ratio = h / h_k;
  82. params.seqlen_q = seqlen_q;
  83. params.seqlen_k = seqlen_k;
  84. params.seqlen_q_rounded = seqlen_q_rounded;
  85. params.seqlen_k_rounded = seqlen_k_rounded;
  86. params.d = d;
  87. params.d_rounded = d_rounded;
  88. // Set the different scale values.
  89. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  90. TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
  91. #endif
  92. if (softcap > 0.0) {
  93. params.softcap = softmax_scale / softcap;
  94. params.scale_softmax = softcap;
  95. params.scale_softmax_log2 = softcap * M_LOG2E;
  96. } else{
  97. // Remove potential NaN
  98. params.softcap = 0.0;
  99. params.scale_softmax = softmax_scale;
  100. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  101. }
  102. // Set this to probability of keeping an element to simplify things.
  103. params.p_dropout = 1.f - p_dropout;
  104. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  105. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  106. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  107. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  108. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  109. params.rp_dropout = 1.f / params.p_dropout;
  110. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  111. TORCH_CHECK(p_dropout < 1.f);
  112. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  113. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  114. #endif
  115. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  116. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  117. params.is_causal = window_size_left < 0 && window_size_right == 0;
  118. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
  119. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
  120. params.window_size_left = window_size_left;
  121. params.window_size_right = window_size_right;
  122. #ifdef FLASHATTENTION_DISABLE_LOCAL
  123. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  124. "This flash attention build does not support local attention.");
  125. #endif
  126. params.is_seqlens_k_cumulative = true;
  127. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  128. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  129. #endif
  130. params.unpadded_lse = unpadded_lse;
  131. params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
  132. }
  133. void set_params_dgrad(Flash_bwd_params &params,
  134. // sizes
  135. const size_t b,
  136. const size_t seqlen_q,
  137. const size_t seqlen_k,
  138. const size_t seqlen_q_rounded,
  139. const size_t seqlen_k_rounded,
  140. const size_t h,
  141. const size_t h_k,
  142. const size_t d,
  143. const size_t d_rounded,
  144. // device pointers
  145. const at::Tensor q,
  146. const at::Tensor k,
  147. const at::Tensor v,
  148. const at::Tensor out,
  149. const at::Tensor dout,
  150. at::Tensor dq,
  151. at::Tensor dk,
  152. at::Tensor dv,
  153. void *cu_seqlens_q_d,
  154. void *cu_seqlens_k_d,
  155. void *dq_accum_d,
  156. void *dk_accum_d,
  157. void *dv_accum_d,
  158. void *softmax_lse_d,
  159. void *dsoftmax_sum_d,
  160. float p_dropout,
  161. float softmax_scale,
  162. int window_size_left,
  163. int window_size_right,
  164. const float softcap,
  165. bool deterministic,
  166. const bool unpadded_lse) {
  167. set_params_fprop(params,
  168. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  169. q, k, v, out,
  170. cu_seqlens_q_d,
  171. cu_seqlens_k_d,
  172. nullptr,
  173. nullptr,
  174. softmax_lse_d,
  175. p_dropout,
  176. softmax_scale,
  177. window_size_left,
  178. window_size_right,
  179. softcap,
  180. false, // seqlenq_ngroups_swapped
  181. unpadded_lse);
  182. // Set the pointers and strides.
  183. params.do_ptr = dout.data_ptr();
  184. params.do_row_stride = dout.stride(-3);
  185. params.do_head_stride = dout.stride(-2);
  186. params.dq_ptr = dq.data_ptr();
  187. params.dk_ptr = dk.data_ptr();
  188. params.dv_ptr = dv.data_ptr();
  189. params.dq_row_stride = dq.stride(-3);
  190. params.dk_row_stride = dk.stride(-3);
  191. params.dv_row_stride = dv.stride(-3);
  192. params.dq_head_stride = dq.stride(-2);
  193. params.dk_head_stride = dk.stride(-2);
  194. params.dv_head_stride = dv.stride(-2);
  195. if (cu_seqlens_q_d == nullptr) {
  196. params.do_batch_stride = dout.stride(0);
  197. params.dq_batch_stride = dq.stride(0);
  198. params.dk_batch_stride = dk.stride(0);
  199. params.dv_batch_stride = dv.stride(0);
  200. }
  201. params.dq_accum_ptr = dq_accum_d;
  202. params.dk_accum_ptr = dk_accum_d;
  203. params.dv_accum_ptr = dv_accum_d;
  204. // Softmax sum
  205. params.dsoftmax_sum = dsoftmax_sum_d;
  206. params.deterministic = deterministic;
  207. }
  208. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  209. FP16_SWITCH(!params.is_bf16, [&] {
  210. HEADDIM_SWITCH(params.d, [&] {
  211. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  212. if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
  213. run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
  214. } else {
  215. run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
  216. }
  217. });
  218. });
  219. });
  220. }
  221. // Find the number of splits that maximizes the occupancy. For example, if we have
  222. // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
  223. // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
  224. // splits as that would incur more HBM reads/writes.
  225. // So we find the best efficiency, then find the smallest number of splits that gets 85%
  226. // of the best efficiency.
  227. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
  228. // If we have enough to almost fill the SMs, then just use 1 split
  229. if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
  230. max_splits = std::min({max_splits, num_SMs, num_n_blocks});
  231. float max_efficiency = 0.f;
  232. std::vector<float> efficiency;
  233. efficiency.reserve(max_splits);
  234. auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
  235. // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
  236. // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
  237. // (i.e. it's 11 splits anyway).
  238. // So we check if the number of blocks per split is the same as the previous num_splits.
  239. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
  240. return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
  241. };
  242. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  243. if (!is_split_eligible(num_splits)) {
  244. efficiency.push_back(0.f);
  245. } else {
  246. float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
  247. float eff = n_waves / ceil(n_waves);
  248. // printf("num_splits = %d, eff = %f\n", num_splits, eff);
  249. if (eff > max_efficiency) { max_efficiency = eff; }
  250. efficiency.push_back(eff);
  251. }
  252. }
  253. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  254. if (!is_split_eligible(num_splits)) { continue; }
  255. if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
  256. // printf("num_splits chosen = %d\n", num_splits);
  257. return num_splits;
  258. }
  259. }
  260. return 1;
  261. }
  262. void set_params_splitkv(Flash_fwd_params &params, const int batch_size,
  263. const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
  264. const int head_size_rounded, const float p_dropout,
  265. const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
  266. // This needs to match with run_mha_fwd_splitkv_dispatch
  267. const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
  268. const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
  269. // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
  270. // In any case we don't expect seqlen_q to be larger than 64 for inference.
  271. const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
  272. params.num_splits = num_splits;
  273. if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
  274. if (num_splits < 1) {
  275. // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
  276. params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
  277. }
  278. if (params.num_splits > 1) {
  279. at::Tensor softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
  280. at::Tensor out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
  281. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  282. params.oaccum_ptr = out_accum.data_ptr();
  283. }
  284. TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
  285. }
  286. }
  287. void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
  288. #ifdef FLASHATTENTION_DISABLE_ALIBI
  289. TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
  290. params.alibi_slopes_ptr = nullptr;
  291. #else
  292. if (alibi_slopes_.has_value()) {
  293. auto alibi_slopes = alibi_slopes_.value();
  294. TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
  295. CHECK_DEVICE(alibi_slopes);
  296. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  297. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
  298. params.alibi_slopes_ptr = alibi_slopes.data_ptr();
  299. params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  300. } else {
  301. params.alibi_slopes_ptr = nullptr;
  302. }
  303. #endif
  304. }
  305. std::vector<at::Tensor>
  306. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  307. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  308. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  309. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  310. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  311. const float p_dropout,
  312. const float softmax_scale,
  313. bool is_causal,
  314. int window_size_left,
  315. int window_size_right,
  316. const float softcap,
  317. const bool return_softmax,
  318. c10::optional<at::Generator> gen_) {
  319. auto dprops = at::cuda::getCurrentDeviceProperties();
  320. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  321. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  322. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  323. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  324. // We will support Turing in the near future
  325. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  326. auto q_dtype = q.dtype();
  327. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  328. "FlashAttention only support fp16 and bf16 data type");
  329. if (q_dtype == torch::kBFloat16) {
  330. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  331. }
  332. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  333. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  334. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  335. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  336. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  337. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  338. const auto sizes = q.sizes();
  339. const int batch_size = sizes[0];
  340. int seqlen_q = sizes[1];
  341. int num_heads = sizes[2];
  342. const int head_size_og = sizes[3];
  343. const int seqlen_k = k.size(1);
  344. const int num_heads_k = k.size(2);
  345. TORCH_CHECK(batch_size > 0, "batch size must be postive");
  346. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  347. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  348. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  349. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  350. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  351. // causal=true is the same as causal=false in this case
  352. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  353. if (is_causal) { window_size_right = 0; }
  354. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  355. // H/t Daniel Haziza
  356. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
  357. const int ngroups = num_heads / num_heads_k;
  358. if (seqlenq_ngroups_swapped) {
  359. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  360. seqlen_q = ngroups;
  361. num_heads = num_heads_k;
  362. }
  363. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  364. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  365. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  366. at::Tensor q_padded, k_padded, v_padded;
  367. if (head_size_og % 8 != 0) {
  368. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  369. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  370. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  371. } else {
  372. q_padded = q;
  373. k_padded = k;
  374. v_padded = v;
  375. }
  376. at::Tensor out;
  377. if (out_.has_value()) {
  378. out = out_.value();
  379. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  380. CHECK_DEVICE(out);
  381. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  382. CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
  383. if (seqlenq_ngroups_swapped) {
  384. out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  385. }
  386. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  387. } else {
  388. out = torch::empty_like(q_padded);
  389. }
  390. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  391. const int head_size = round_multiple(head_size_og, 8);
  392. const int head_size_rounded = round_multiple(head_size, 32);
  393. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  394. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  395. // Otherwise the kernel will be launched from cuda:0 device
  396. // Cast to char to avoid compiler warning about narrowing
  397. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  398. auto opts = q.options();
  399. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  400. at::Tensor p;
  401. // Only return softmax if there's dropout to reduce compilation time
  402. if (return_softmax) {
  403. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  404. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  405. }
  406. Flash_fwd_params params;
  407. set_params_fprop(params,
  408. batch_size,
  409. seqlen_q, seqlen_k,
  410. seqlen_q_rounded, seqlen_k_rounded,
  411. num_heads, num_heads_k,
  412. head_size, head_size_rounded,
  413. q_padded, k_padded, v_padded, out,
  414. /*cu_seqlens_q_d=*/nullptr,
  415. /*cu_seqlens_k_d=*/nullptr,
  416. /*seqused_k=*/nullptr,
  417. return_softmax ? p.data_ptr() : nullptr,
  418. softmax_lse.data_ptr(),
  419. p_dropout,
  420. softmax_scale,
  421. window_size_left,
  422. window_size_right,
  423. softcap
  424. );
  425. set_params_splitkv(params, batch_size, num_heads,
  426. head_size, seqlen_k, seqlen_q,
  427. head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
  428. // number of times random will be generated per thread, to offset philox counter in thc random
  429. // state
  430. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  431. int64_t counter_offset = params.b * params.h * 32;
  432. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  433. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  434. // Forward kernel will populate memory with the seed and offset.
  435. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  436. if (p_dropout > 0.0) {
  437. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  438. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  439. // See Note [Acquire lock when using random generators]
  440. std::lock_guard<std::mutex> lock(gen->mutex_);
  441. params.philox_args = gen->philox_cuda_state(counter_offset);
  442. }
  443. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  444. if (seqlen_k > 0) {
  445. auto stream = at::cuda::getCurrentCUDAStream().stream();
  446. run_mha_fwd(params, stream);
  447. } else {
  448. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  449. out.zero_();
  450. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  451. }
  452. at::Tensor out_padded = out;
  453. if (head_size_og % 8 != 0) {
  454. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  455. if (out_.has_value()) { out_.value().copy_(out); }
  456. }
  457. if (seqlenq_ngroups_swapped) {
  458. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  459. out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  460. q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  461. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  462. }
  463. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
  464. }
  465. std::vector<at::Tensor>
  466. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  467. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  468. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  469. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  470. const at::Tensor &cu_seqlens_q, // b+1
  471. const at::Tensor &cu_seqlens_k, // b+1
  472. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  473. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  474. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  475. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  476. int max_seqlen_q,
  477. const int max_seqlen_k,
  478. const float p_dropout,
  479. const float softmax_scale,
  480. const bool zero_tensors,
  481. bool is_causal,
  482. int window_size_left,
  483. int window_size_right,
  484. const float softcap,
  485. const bool return_softmax,
  486. c10::optional<at::Generator> gen_) {
  487. auto dprops = at::cuda::getCurrentDeviceProperties();
  488. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  489. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  490. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  491. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  492. // We will support Turing in the near future
  493. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  494. auto q_dtype = q.dtype();
  495. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  496. "FlashAttention only support fp16 and bf16 data type");
  497. if (q_dtype == torch::kBFloat16) {
  498. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  499. }
  500. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  501. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  502. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  503. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  504. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  505. CHECK_DEVICE(cu_seqlens_q);
  506. CHECK_DEVICE(cu_seqlens_k);
  507. at::Tensor block_table;
  508. const bool paged_KV = block_table_.has_value();
  509. if (paged_KV) {
  510. block_table = block_table_.value();
  511. CHECK_DEVICE(block_table);
  512. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  513. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  514. }
  515. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  516. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  517. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  518. CHECK_CONTIGUOUS(cu_seqlens_q);
  519. CHECK_CONTIGUOUS(cu_seqlens_k);
  520. const auto sizes = q.sizes();
  521. const int batch_size = cu_seqlens_q.numel() - 1;
  522. int num_heads = sizes[1];
  523. const int head_size_og = sizes[2];
  524. const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
  525. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  526. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  527. const int num_blocks = !paged_KV ? 0 : k.size(0);
  528. const int page_block_size = !paged_KV ? 1 : k.size(1);
  529. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  530. if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
  531. if (is_causal) { window_size_right = 0; }
  532. void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  533. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  534. // H/t Daniel Haziza
  535. const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
  536. const int ngroups = num_heads / num_heads_k;
  537. if (seqlenq_ngroups_swapped) {
  538. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
  539. max_seqlen_q = ngroups;
  540. num_heads = num_heads_k;
  541. cu_seqlens_q_d = nullptr;
  542. }
  543. const int total_q = q.sizes()[0];
  544. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  545. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  546. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  547. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  548. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  549. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  550. if (!paged_KV) {
  551. const int total_k = k.size(0);
  552. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  553. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  554. } else {
  555. CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
  556. CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
  557. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  558. }
  559. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  560. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  561. if (seqused_k.has_value()){
  562. auto seqused_k_ = seqused_k.value();
  563. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  564. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  565. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  566. CHECK_SHAPE(seqused_k_, batch_size);
  567. }
  568. at::Tensor q_padded, k_padded, v_padded;
  569. if (head_size_og % 8 != 0) {
  570. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  571. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  572. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  573. } else {
  574. q_padded = q;
  575. k_padded = k;
  576. v_padded = v;
  577. }
  578. at::Tensor out;
  579. if (out_.has_value()) {
  580. out = out_.value();
  581. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  582. CHECK_DEVICE(out);
  583. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  584. CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
  585. if (seqlenq_ngroups_swapped) {
  586. out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
  587. }
  588. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  589. } else {
  590. out = torch::empty_like(q_padded);
  591. }
  592. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  593. const int head_size = round_multiple(head_size_og, 8);
  594. const int head_size_rounded = round_multiple(head_size, 32);
  595. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  596. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  597. // Otherwise the kernel will be launched from cuda:0 device
  598. // Cast to char to avoid compiler warning about narrowing
  599. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  600. auto opts = q.options();
  601. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  602. at::Tensor p;
  603. // Only return softmax if there's dropout to reduce compilation time
  604. if (return_softmax) {
  605. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  606. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  607. }
  608. if (zero_tensors) {
  609. out.zero_();
  610. softmax_lse.fill_(-std::numeric_limits<float>::infinity());
  611. if (return_softmax) {p.zero_();}
  612. }
  613. Flash_fwd_params params;
  614. set_params_fprop(params,
  615. batch_size,
  616. max_seqlen_q, max_seqlen_k,
  617. seqlen_q_rounded, seqlen_k_rounded,
  618. num_heads, num_heads_k,
  619. head_size, head_size_rounded,
  620. q_padded, k_padded, v_padded, out,
  621. cu_seqlens_q_d,
  622. cu_seqlens_k.data_ptr(),
  623. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  624. return_softmax ? p.data_ptr() : nullptr,
  625. softmax_lse.data_ptr(),
  626. p_dropout,
  627. softmax_scale,
  628. window_size_left,
  629. window_size_right,
  630. softcap,
  631. seqlenq_ngroups_swapped,
  632. /*unpadded_lse*/true);
  633. params.total_q = total_q;
  634. if (paged_KV) {
  635. params.block_table = block_table.data_ptr<int>();
  636. params.block_table_batch_stride = block_table.stride(0);
  637. params.k_batch_stride = k_padded.stride(0);
  638. params.v_batch_stride = v_padded.stride(0);
  639. }
  640. params.page_block_size = page_block_size;
  641. if (seqlenq_ngroups_swapped) {
  642. // Only apply split-k for decoding
  643. set_params_splitkv(params, batch_size, num_heads,
  644. head_size, max_seqlen_k, max_seqlen_q,
  645. head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
  646. }
  647. if (leftpad_k_.has_value()) {
  648. auto leftpad_k = leftpad_k_.value();
  649. TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
  650. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  651. CHECK_DEVICE(leftpad_k);
  652. CHECK_CONTIGUOUS(leftpad_k);
  653. CHECK_SHAPE(leftpad_k, batch_size);
  654. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  655. }
  656. // number of times random will be generated per thread, to offset philox counter in thc random
  657. // state
  658. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  659. int64_t counter_offset = params.b * params.h * 32;
  660. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  661. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  662. // Forward kernel will populate memory with the seed and offset.
  663. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  664. if (p_dropout > 0.0) {
  665. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  666. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  667. // See Note [Acquire lock when using random generators]
  668. std::lock_guard<std::mutex> lock(gen->mutex_);
  669. params.philox_args = gen->philox_cuda_state(counter_offset);
  670. }
  671. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  672. if (max_seqlen_k > 0) {
  673. auto stream = at::cuda::getCurrentCUDAStream().stream();
  674. run_mha_fwd(params, stream, paged_KV);
  675. } else {
  676. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  677. out.zero_();
  678. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  679. }
  680. at::Tensor out_padded = out;
  681. if (head_size_og % 8 != 0) {
  682. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  683. if (out_.has_value()) { out_.value().copy_(out); }
  684. }
  685. if (seqlenq_ngroups_swapped) {
  686. int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
  687. int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
  688. out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
  689. out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
  690. q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
  691. softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
  692. }
  693. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
  694. }
  695. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  696. FP16_SWITCH(!params.is_bf16, [&] {
  697. HEADDIM_SWITCH(params.d, [&] {
  698. run_mha_bwd_<elem_type, kHeadDim>(params, stream);
  699. });
  700. });
  701. }
  702. std::vector<at::Tensor>
  703. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  704. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  705. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  706. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  707. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  708. const at::Tensor &softmax_lse, // b x h x seqlen_q
  709. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  710. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  711. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  712. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  713. const float p_dropout, // probability to drop
  714. const float softmax_scale,
  715. const bool is_causal,
  716. int window_size_left,
  717. int window_size_right,
  718. const float softcap,
  719. const bool deterministic,
  720. c10::optional<at::Generator> gen_,
  721. c10::optional<at::Tensor> &rng_state) {
  722. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  723. TORCH_CHECK(false, "This flash attention build does not support backward.");
  724. #endif
  725. if (is_causal) { window_size_right = 0; }
  726. auto dprops = at::cuda::getCurrentDeviceProperties();
  727. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  728. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  729. bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
  730. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  731. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  732. // We will support Turing in the near future
  733. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  734. bool is_dropout = p_dropout > 0.0;
  735. auto stream = at::cuda::getCurrentCUDAStream().stream();
  736. auto q_dtype = q.dtype();
  737. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  738. "FlashAttention only support fp16 and bf16 data type");
  739. if (q_dtype == torch::kBFloat16) {
  740. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  741. }
  742. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  743. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  744. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  745. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  746. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  747. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  748. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  749. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  750. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  751. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  752. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  753. const auto sizes = q.sizes();
  754. const int batch_size = sizes[0];
  755. const int seqlen_q = sizes[1];
  756. const int num_heads = sizes[2];
  757. const int head_size_og = dout.size(3);
  758. const int head_size = sizes[3];
  759. const int seqlen_k = k.size(1);
  760. const int num_heads_k = k.size(2);
  761. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  762. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  763. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  764. if (head_size > 192 && (head_size <= 224 || is_dropout)) {
  765. TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
  766. }
  767. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  768. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  769. const int head_size_rounded = round_multiple(head_size, 32);
  770. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  771. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  772. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  773. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  774. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  775. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  776. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  777. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  778. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  779. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  780. at::Tensor dq, dk, dv;
  781. if (dq_.has_value()) {
  782. dq = dq_.value();
  783. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  784. CHECK_DEVICE(dq);
  785. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  786. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  787. } else {
  788. dq = torch::empty_like(q);
  789. }
  790. if (dk_.has_value()) {
  791. dk = dk_.value();
  792. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  793. CHECK_DEVICE(dk);
  794. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  795. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  796. } else {
  797. dk = torch::empty_like(k);
  798. }
  799. if (dv_.has_value()) {
  800. dv = dv_.value();
  801. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  802. CHECK_DEVICE(dv);
  803. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  804. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  805. } else {
  806. dv = torch::empty_like(v);
  807. }
  808. at::Tensor dout_padded;
  809. if (head_size_og % 8 != 0) {
  810. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  811. } else {
  812. dout_padded = dout;
  813. }
  814. // bool loop = seqlen_k > blocksize_c;
  815. // TODO: change later, for now set to true for simplicity
  816. bool loop = true;
  817. // Otherwise the kernel will be launched from cuda:0 device
  818. // Cast to char to avoid compiler warning about narrowing
  819. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  820. auto opts = q.options();
  821. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  822. at::Tensor dq_accum;
  823. at::Tensor dk_accum, dv_accum;
  824. if (loop) {
  825. if (!deterministic) {
  826. dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  827. } else {
  828. const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
  829. dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  830. }
  831. // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  832. // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  833. }
  834. at::Tensor dk_expanded, dv_expanded;
  835. if (num_heads_k != num_heads) { // MQA / GQA
  836. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  837. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  838. } else {
  839. dk_expanded = dk;
  840. dv_expanded = dv;
  841. }
  842. Flash_bwd_params params;
  843. set_params_dgrad(params,
  844. batch_size,
  845. seqlen_q, seqlen_k,
  846. seqlen_q_rounded, seqlen_k_rounded,
  847. num_heads, num_heads_k,
  848. head_size, head_size_rounded,
  849. q, k, v, out,
  850. dout_padded, dq, dk_expanded, dv_expanded,
  851. nullptr,
  852. nullptr,
  853. loop ? dq_accum.data_ptr() : nullptr,
  854. // loop ? dk_accum.data_ptr() : nullptr,
  855. // loop ? dv_accum.data_ptr() : nullptr,
  856. nullptr,
  857. nullptr,
  858. softmax_lse.data_ptr(),
  859. softmax_d.data_ptr(),
  860. p_dropout,
  861. softmax_scale,
  862. window_size_left,
  863. window_size_right,
  864. softcap,
  865. deterministic,
  866. /*unpadded_lse*/false);
  867. params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
  868. auto launch = &run_mha_bwd;
  869. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  870. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  871. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  872. int64_t counter_offset = params.b * params.h * 32;
  873. if ( rng_state.has_value() ) {
  874. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  875. } else if( is_dropout ) {
  876. // See Note [Acquire lock when using random generators]
  877. std::lock_guard<std::mutex> lock(gen->mutex_);
  878. params.philox_args = gen->philox_cuda_state(counter_offset);
  879. auto seeds = at::cuda::philox::unpack(params.philox_args);
  880. params.rng_state[0] = std::get<0>(seeds);
  881. params.rng_state[1] = std::get<1>(seeds);
  882. }
  883. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  884. if (seqlen_q > 0) {
  885. launch(params, stream);
  886. } else {
  887. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  888. dk_expanded.zero_();
  889. dv_expanded.zero_();
  890. softmax_d.zero_();
  891. }
  892. // For MQA/GQA we need to sum dK and dV across the groups
  893. if (num_heads_k != num_heads) {
  894. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  895. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  896. }
  897. if (head_size_og % 8 != 0) {
  898. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  899. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  900. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  901. }
  902. return { dq, dk, dv, softmax_d };
  903. }
  904. std::vector<at::Tensor>
  905. mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
  906. const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  907. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  908. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  909. const at::Tensor &out, // total_q x num_heads x head_size
  910. const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
  911. c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  912. c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  913. c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  914. const at::Tensor &cu_seqlens_q, // b+1
  915. const at::Tensor &cu_seqlens_k, // b+1
  916. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  917. const int max_seqlen_q,
  918. const int max_seqlen_k, // max sequence length to choose the kernel
  919. const float p_dropout, // probability to drop
  920. const float softmax_scale,
  921. const bool zero_tensors,
  922. const bool is_causal,
  923. int window_size_left,
  924. int window_size_right,
  925. const float softcap,
  926. const bool deterministic,
  927. c10::optional<at::Generator> gen_,
  928. c10::optional<at::Tensor> &rng_state) {
  929. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  930. TORCH_CHECK(false, "This flash attention build does not support backward.");
  931. #endif
  932. if (is_causal) { window_size_right = 0; }
  933. auto dprops = at::cuda::getCurrentDeviceProperties();
  934. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  935. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  936. bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
  937. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  938. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  939. // We will support Turing in the near future
  940. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  941. bool is_dropout = p_dropout > 0.0;
  942. auto stream = at::cuda::getCurrentCUDAStream().stream();
  943. auto q_dtype = q.dtype();
  944. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  945. "FlashAttention only support fp16 and bf16 data type");
  946. if (q_dtype == torch::kBFloat16) {
  947. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  948. }
  949. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  950. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  951. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  952. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  953. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  954. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  955. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  956. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  957. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  958. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  959. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  960. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  961. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  962. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  963. CHECK_CONTIGUOUS(cu_seqlens_q);
  964. CHECK_CONTIGUOUS(cu_seqlens_k);
  965. const auto sizes = q.sizes();
  966. const int total_q = sizes[0];
  967. const int batch_size = cu_seqlens_q.numel() - 1;
  968. const int num_heads = sizes[1];
  969. const int head_size_og = dout.size(2);
  970. const int head_size = sizes[2];
  971. const int total_k = k.size(0);
  972. const int num_heads_k = k.size(1);
  973. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  974. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  975. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  976. if (head_size > 192 && (head_size <= 224 || is_dropout)) {
  977. TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
  978. }
  979. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  980. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  981. const int head_size_rounded = round_multiple(head_size, 32);
  982. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  983. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  984. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  985. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  986. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  987. CHECK_SHAPE(q, total_q, num_heads, head_size);
  988. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  989. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  990. CHECK_SHAPE(out, total_q, num_heads, head_size);
  991. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  992. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  993. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  994. at::Tensor dq, dk, dv;
  995. if (dq_.has_value()) {
  996. dq = dq_.value();
  997. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  998. CHECK_DEVICE(dq);
  999. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  1000. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  1001. } else {
  1002. dq = torch::empty_like(q);
  1003. }
  1004. if (dk_.has_value()) {
  1005. dk = dk_.value();
  1006. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  1007. CHECK_DEVICE(dk);
  1008. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  1009. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  1010. } else {
  1011. dk = torch::empty_like(k);
  1012. }
  1013. if (dv_.has_value()) {
  1014. dv = dv_.value();
  1015. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  1016. CHECK_DEVICE(dv);
  1017. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  1018. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  1019. } else {
  1020. dv = torch::empty_like(v);
  1021. }
  1022. at::Tensor dout_padded;
  1023. if (head_size_og % 8 != 0) {
  1024. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1025. } else {
  1026. dout_padded = dout;
  1027. }
  1028. // bool loop = max_seqlen_k > blocksize_c;
  1029. // TODO: change later, for now set to true for simplicity
  1030. bool loop = true;
  1031. // Otherwise the kernel will be launched from cuda:0 device
  1032. // Cast to char to avoid compiler warning about narrowing
  1033. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1034. auto opts = q.options();
  1035. auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
  1036. at::Tensor dq_accum;
  1037. if (loop) {
  1038. // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
  1039. // because that would be too large if there is a very long sequence and the rest of the sequences are short.
  1040. // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
  1041. // Note that 128 is the max block size on the seqlen_q dimension.
  1042. // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
  1043. // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
  1044. // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
  1045. // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
  1046. // Same holds for softmax_d, since LSE is stored in unpadded format.
  1047. if (!deterministic) {
  1048. dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  1049. } else {
  1050. const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
  1051. dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  1052. }
  1053. }
  1054. at::Tensor dk_expanded, dv_expanded;
  1055. if (num_heads_k != num_heads) { // MQA / GQA
  1056. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  1057. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  1058. } else {
  1059. dk_expanded = dk;
  1060. dv_expanded = dv;
  1061. }
  1062. if( zero_tensors ) {
  1063. dq.zero_();
  1064. dk_expanded.zero_();
  1065. dv_expanded.zero_();
  1066. softmax_d.zero_();
  1067. }
  1068. Flash_bwd_params params;
  1069. set_params_dgrad(params,
  1070. batch_size,
  1071. max_seqlen_q, max_seqlen_k,
  1072. seqlen_q_rounded, seqlen_k_rounded,
  1073. num_heads, num_heads_k,
  1074. head_size, head_size_rounded,
  1075. q, k, v, out,
  1076. dout_padded, dq, dk_expanded, dv_expanded,
  1077. cu_seqlens_q.data_ptr(),
  1078. cu_seqlens_k.data_ptr(),
  1079. loop ? dq_accum.data_ptr() : nullptr,
  1080. nullptr,
  1081. nullptr,
  1082. softmax_lse.data_ptr(),
  1083. softmax_d.data_ptr(),
  1084. p_dropout,
  1085. softmax_scale,
  1086. window_size_left,
  1087. window_size_right,
  1088. softcap,
  1089. deterministic,
  1090. /*unpadded_lse*/true);
  1091. params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
  1092. params.total_q = total_q;
  1093. auto launch = &run_mha_bwd;
  1094. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  1095. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  1096. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  1097. int64_t counter_offset = params.b * params.h * 32;
  1098. if ( rng_state.has_value() ) {
  1099. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  1100. } else if( is_dropout ) {
  1101. // See Note [Acquire lock when using random generators]
  1102. std::lock_guard<std::mutex> lock(gen->mutex_);
  1103. params.philox_args = gen->philox_cuda_state(counter_offset);
  1104. auto seeds = at::cuda::philox::unpack(params.philox_args);
  1105. params.rng_state[0] = std::get<0>(seeds);
  1106. params.rng_state[1] = std::get<1>(seeds);
  1107. }
  1108. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1109. if (max_seqlen_q > 0) {
  1110. launch(params, stream);
  1111. } else {
  1112. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  1113. dk_expanded.zero_();
  1114. dv_expanded.zero_();
  1115. softmax_d.zero_();
  1116. }
  1117. // For MQA/GQA we need to sum dK and dV across the groups
  1118. if (num_heads_k != num_heads) {
  1119. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1120. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1121. }
  1122. if (head_size_og % 8 != 0) {
  1123. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1124. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1125. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1126. }
  1127. return { dq, dk, dv, softmax_d };
  1128. }
  1129. std::vector<at::Tensor>
  1130. mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  1131. const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1132. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1133. c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
  1134. c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
  1135. c10::optional<const at::Tensor> &seqlens_k_, // batch_size
  1136. c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  1137. c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  1138. c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
  1139. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  1140. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  1141. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  1142. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  1143. const float softmax_scale,
  1144. bool is_causal,
  1145. int window_size_left,
  1146. int window_size_right,
  1147. const float softcap,
  1148. bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  1149. int num_splits
  1150. ) {
  1151. auto dprops = at::cuda::getCurrentDeviceProperties();
  1152. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  1153. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  1154. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  1155. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  1156. // We will support Turing in the near future
  1157. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  1158. auto q_dtype = q.dtype();
  1159. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  1160. "FlashAttention only support fp16 and bf16 data type");
  1161. if (q_dtype == torch::kBFloat16) {
  1162. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  1163. }
  1164. TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
  1165. TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
  1166. CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
  1167. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1168. TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1169. TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1170. at::Tensor block_table;
  1171. const bool paged_KV = block_table_.has_value();
  1172. if (paged_KV) {
  1173. TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
  1174. block_table = block_table_.value();
  1175. CHECK_DEVICE(block_table);
  1176. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  1177. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  1178. }
  1179. const auto sizes = q.sizes();
  1180. const int batch_size = sizes[0];
  1181. int seqlen_q = sizes[1];
  1182. int num_heads = sizes[2];
  1183. const int head_size_og = sizes[3];
  1184. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  1185. const int num_blocks = !paged_KV ? 0 : kcache.size(0);
  1186. const int page_block_size = !paged_KV ? 1 : kcache.size(1);
  1187. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  1188. const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
  1189. const int num_heads_k = kcache.size(2);
  1190. const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
  1191. TORCH_CHECK(batch_size > 0, "batch size must be postive");
  1192. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  1193. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  1194. // causal=true is the same as causal=false in this case
  1195. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  1196. if (is_causal) { window_size_right = 0; }
  1197. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  1198. // H/t Daniel Haziza
  1199. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
  1200. if (seqlenq_ngroups_swapped) {
  1201. const int ngroups = num_heads / num_heads_k;
  1202. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  1203. seqlen_q = ngroups;
  1204. num_heads = num_heads_k;
  1205. }
  1206. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  1207. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  1208. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  1209. if (!paged_KV) {
  1210. CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1211. CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1212. } else {
  1213. CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1214. CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1215. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  1216. }
  1217. at::Tensor q_padded, kcache_padded, vcache_padded;
  1218. if (head_size_og % 8 != 0) {
  1219. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1220. kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1221. vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1222. } else {
  1223. q_padded = q;
  1224. kcache_padded = kcache;
  1225. vcache_padded = vcache;
  1226. }
  1227. at::Tensor out;
  1228. if (out_.has_value()) {
  1229. out = out_.value();
  1230. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  1231. CHECK_DEVICE(out);
  1232. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1233. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  1234. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  1235. } else {
  1236. out = torch::empty_like(q_padded);
  1237. }
  1238. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1239. const int head_size = round_multiple(head_size_og, 8);
  1240. const int head_size_rounded = round_multiple(head_size, 32);
  1241. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  1242. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  1243. // Otherwise the kernel will be launched from cuda:0 device
  1244. // Cast to char to avoid compiler warning about narrowing
  1245. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1246. auto opts = q.options();
  1247. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  1248. Flash_fwd_params params;
  1249. set_params_fprop(params,
  1250. batch_size,
  1251. seqlen_q, seqlen_k,
  1252. seqlen_q_rounded, seqlen_k_rounded,
  1253. num_heads, num_heads_k,
  1254. head_size, head_size_rounded,
  1255. q_padded, kcache_padded, vcache_padded, out,
  1256. /*cu_seqlens_q_d=*/nullptr,
  1257. /*cu_seqlens_k_d=*/nullptr,
  1258. /*seqused_k=*/nullptr,
  1259. /*p_ptr=*/nullptr,
  1260. softmax_lse.data_ptr(),
  1261. /*p_dropout=*/0.f,
  1262. softmax_scale,
  1263. window_size_left,
  1264. window_size_right,
  1265. softcap
  1266. );
  1267. at::Tensor k, v, k_padded, v_padded;
  1268. if (k_.has_value()) {
  1269. TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
  1270. TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
  1271. TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
  1272. k = k_.value();
  1273. v = v_.value();
  1274. TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
  1275. TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
  1276. CHECK_DEVICE(k); CHECK_DEVICE(v);
  1277. TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
  1278. TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
  1279. int seqlen_knew = k.size(1);
  1280. CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1281. CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1282. if (head_size_og % 8 != 0) {
  1283. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1284. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1285. } else {
  1286. k_padded = k;
  1287. v_padded = v;
  1288. }
  1289. params.seqlen_knew = seqlen_knew;
  1290. params.knew_ptr = k_padded.data_ptr();
  1291. params.vnew_ptr = v_padded.data_ptr();
  1292. // All stride are in elements, not bytes.
  1293. params.knew_batch_stride = k_padded.stride(0);
  1294. params.vnew_batch_stride = v_padded.stride(0);
  1295. params.knew_row_stride = k_padded.stride(-3);
  1296. params.vnew_row_stride = v_padded.stride(-3);
  1297. params.knew_head_stride = k_padded.stride(-2);
  1298. params.vnew_head_stride = v_padded.stride(-2);
  1299. }
  1300. if (seqlens_k_.has_value()) {
  1301. auto seqlens_k = seqlens_k_.value();
  1302. TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
  1303. CHECK_DEVICE(seqlens_k);
  1304. CHECK_CONTIGUOUS(seqlens_k);
  1305. CHECK_SHAPE(seqlens_k, batch_size);
  1306. params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
  1307. }
  1308. params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
  1309. if (leftpad_k_.has_value()) {
  1310. TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
  1311. auto leftpad_k = leftpad_k_.value();
  1312. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  1313. CHECK_DEVICE(leftpad_k);
  1314. CHECK_CONTIGUOUS(leftpad_k);
  1315. CHECK_SHAPE(leftpad_k, batch_size);
  1316. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  1317. }
  1318. if (rotary_cos_.has_value()) {
  1319. TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  1320. auto rotary_cos = rotary_cos_.value();
  1321. CHECK_DEVICE(rotary_cos);
  1322. params.rotary_dim = rotary_cos.size(1) * 2;
  1323. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  1324. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  1325. const int seqlen_ro = rotary_cos.size(0);
  1326. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  1327. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  1328. CHECK_CONTIGUOUS(rotary_cos);
  1329. TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1330. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  1331. auto rotary_sin = rotary_sin_.value();
  1332. CHECK_DEVICE(rotary_sin);
  1333. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  1334. CHECK_CONTIGUOUS(rotary_sin);
  1335. TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1336. params.rotary_cos_ptr = rotary_cos.data_ptr();
  1337. params.rotary_sin_ptr = rotary_sin.data_ptr();
  1338. params.is_rotary_interleaved = is_rotary_interleaved;
  1339. } else {
  1340. params.rotary_dim = 0;
  1341. }
  1342. if (cache_batch_idx_.has_value()) {
  1343. auto cache_batch_idx = cache_batch_idx_.value();
  1344. CHECK_DEVICE(cache_batch_idx);
  1345. CHECK_CONTIGUOUS(cache_batch_idx);
  1346. TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
  1347. params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
  1348. }
  1349. set_params_splitkv(params, batch_size, num_heads,
  1350. head_size, seqlen_k, seqlen_q,
  1351. head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
  1352. if (paged_KV) {
  1353. params.block_table = block_table.data_ptr<int>();
  1354. params.block_table_batch_stride = block_table.stride(0);
  1355. }
  1356. params.page_block_size = page_block_size;
  1357. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1358. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1359. // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
  1360. // or paged KV cache
  1361. run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
  1362. if (head_size_og % 8 != 0) {
  1363. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1364. if (out_.has_value()) { out_.value().copy_(out); }
  1365. if (k_.has_value()) {
  1366. // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
  1367. // but we don't expect to get this case in practice. This is just so that the code works for that case.
  1368. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1369. vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1370. }
  1371. }
  1372. if (seqlenq_ngroups_swapped) {
  1373. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  1374. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  1375. }
  1376. return {out, softmax_lse};
  1377. }
  1378. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1379. m.doc() = "FlashAttention";
  1380. m.def("fwd", &mha_fwd, "Forward pass");
  1381. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  1382. m.def("bwd", &mha_bwd, "Backward pass");
  1383. m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
  1384. m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
  1385. }