1
0

flash_api.cpp 74 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. // Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
  5. #include <torch/python.h>
  6. #include <torch/nn/functional.h>
  7. #include <ATen/cuda/CUDAContext.h>
  8. #include <c10/cuda/CUDAGuard.h>
  9. #include <cutlass/numeric_types.h>
  10. #include "flash.h"
  11. #include "static_switch.h"
  12. #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
  13. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  14. #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  15. void set_params_fprop(Flash_fwd_params &params,
  16. // sizes
  17. const size_t b,
  18. const size_t seqlen_q,
  19. const size_t seqlen_k,
  20. const size_t seqlen_q_rounded,
  21. const size_t seqlen_k_rounded,
  22. const size_t h,
  23. const size_t h_k,
  24. const size_t d,
  25. const size_t d_rounded,
  26. // device pointers
  27. const at::Tensor q,
  28. const at::Tensor k,
  29. const at::Tensor v,
  30. at::Tensor out,
  31. void *cu_seqlens_q_d,
  32. void *cu_seqlens_k_d,
  33. void *seqused_k,
  34. void *p_d,
  35. void *softmax_lse_d,
  36. float p_dropout,
  37. float softmax_scale,
  38. int window_size_left,
  39. int window_size_right,
  40. const float softcap,
  41. bool seqlenq_ngroups_swapped=false,
  42. const bool unpadded_lse=false) {
  43. // Reset the parameters
  44. params = {};
  45. params.is_bf16 = q.dtype() == torch::kBFloat16;
  46. // Set the pointers and strides.
  47. params.q_ptr = q.data_ptr();
  48. params.k_ptr = k.data_ptr();
  49. params.v_ptr = v.data_ptr();
  50. // All stride are in elements, not bytes.
  51. params.q_row_stride = q.stride(-3);
  52. params.k_row_stride = k.stride(-3);
  53. params.v_row_stride = v.stride(-3);
  54. params.q_head_stride = q.stride(-2);
  55. params.k_head_stride = k.stride(-2);
  56. params.v_head_stride = v.stride(-2);
  57. params.o_ptr = out.data_ptr();
  58. params.o_row_stride = out.stride(-3);
  59. params.o_head_stride = out.stride(-2);
  60. if (cu_seqlens_q_d == nullptr) {
  61. params.q_batch_stride = q.stride(0);
  62. params.k_batch_stride = k.stride(0);
  63. params.v_batch_stride = v.stride(0);
  64. params.o_batch_stride = out.stride(0);
  65. if (seqlenq_ngroups_swapped) {
  66. params.q_batch_stride *= seqlen_q;
  67. params.o_batch_stride *= seqlen_q;
  68. }
  69. }
  70. params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
  71. params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
  72. params.seqused_k = static_cast<int *>(seqused_k);
  73. // P = softmax(QK^T)
  74. params.p_ptr = p_d;
  75. // Softmax sum
  76. params.softmax_lse_ptr = softmax_lse_d;
  77. // Set the dimensions.
  78. params.b = b;
  79. params.h = h;
  80. params.h_k = h_k;
  81. params.h_h_k_ratio = h / h_k;
  82. params.seqlen_q = seqlen_q;
  83. params.seqlen_k = seqlen_k;
  84. params.seqlen_q_rounded = seqlen_q_rounded;
  85. params.seqlen_k_rounded = seqlen_k_rounded;
  86. params.d = d;
  87. params.d_rounded = d_rounded;
  88. // Set the different scale values.
  89. #ifdef FLASHATTENTION_DISABLE_SOFTCAP
  90. TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
  91. #endif
  92. if (softcap > 0.0) {
  93. params.softcap = softmax_scale / softcap;
  94. params.scale_softmax = softcap;
  95. params.scale_softmax_log2 = softcap * M_LOG2E;
  96. } else{
  97. // Remove potential NaN
  98. params.softcap = 0.0;
  99. params.scale_softmax = softmax_scale;
  100. params.scale_softmax_log2 = softmax_scale * M_LOG2E;
  101. }
  102. // Set this to probability of keeping an element to simplify things.
  103. params.p_dropout = 1.f - p_dropout;
  104. // Convert p from float to int so we don't have to convert the random uint to float to compare.
  105. // [Minor] We want to round down since when we do the comparison we use <= instead of <
  106. // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
  107. // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
  108. params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
  109. params.rp_dropout = 1.f / params.p_dropout;
  110. params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
  111. TORCH_CHECK(p_dropout < 1.f);
  112. #ifdef FLASHATTENTION_DISABLE_DROPOUT
  113. TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
  114. #endif
  115. // Causal is the special case where window_size_right == 0 and window_size_left < 0.
  116. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
  117. params.is_causal = window_size_left < 0 && window_size_right == 0;
  118. if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
  119. if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
  120. params.window_size_left = window_size_left;
  121. params.window_size_right = window_size_right;
  122. #ifdef FLASHATTENTION_DISABLE_LOCAL
  123. TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
  124. "This flash attention build does not support local attention.");
  125. #endif
  126. params.is_seqlens_k_cumulative = true;
  127. #ifdef FLASHATTENTION_DISABLE_UNEVEN_K
  128. TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
  129. #endif
  130. params.unpadded_lse = unpadded_lse;
  131. params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
  132. }
  133. void set_params_dgrad(Flash_bwd_params &params,
  134. // sizes
  135. const size_t b,
  136. const size_t seqlen_q,
  137. const size_t seqlen_k,
  138. const size_t seqlen_q_rounded,
  139. const size_t seqlen_k_rounded,
  140. const size_t h,
  141. const size_t h_k,
  142. const size_t d,
  143. const size_t d_rounded,
  144. // device pointers
  145. const at::Tensor q,
  146. const at::Tensor k,
  147. const at::Tensor v,
  148. const at::Tensor out,
  149. const at::Tensor dout,
  150. at::Tensor dq,
  151. at::Tensor dk,
  152. at::Tensor dv,
  153. void *cu_seqlens_q_d,
  154. void *cu_seqlens_k_d,
  155. void *dq_accum_d,
  156. void *dk_accum_d,
  157. void *dv_accum_d,
  158. void *softmax_lse_d,
  159. void *dsoftmax_sum_d,
  160. float p_dropout,
  161. float softmax_scale,
  162. int window_size_left,
  163. int window_size_right,
  164. const float softcap,
  165. bool deterministic,
  166. const bool unpadded_lse) {
  167. set_params_fprop(params,
  168. b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
  169. q, k, v, out,
  170. cu_seqlens_q_d,
  171. cu_seqlens_k_d,
  172. nullptr,
  173. nullptr,
  174. softmax_lse_d,
  175. p_dropout,
  176. softmax_scale,
  177. window_size_left,
  178. window_size_right,
  179. softcap,
  180. false, // seqlenq_ngroups_swapped
  181. unpadded_lse);
  182. // Set the pointers and strides.
  183. params.do_ptr = dout.data_ptr();
  184. params.do_row_stride = dout.stride(-3);
  185. params.do_head_stride = dout.stride(-2);
  186. params.dq_ptr = dq.data_ptr();
  187. params.dk_ptr = dk.data_ptr();
  188. params.dv_ptr = dv.data_ptr();
  189. params.dq_row_stride = dq.stride(-3);
  190. params.dk_row_stride = dk.stride(-3);
  191. params.dv_row_stride = dv.stride(-3);
  192. params.dq_head_stride = dq.stride(-2);
  193. params.dk_head_stride = dk.stride(-2);
  194. params.dv_head_stride = dv.stride(-2);
  195. if (cu_seqlens_q_d == nullptr) {
  196. params.do_batch_stride = dout.stride(0);
  197. params.dq_batch_stride = dq.stride(0);
  198. params.dk_batch_stride = dk.stride(0);
  199. params.dv_batch_stride = dv.stride(0);
  200. }
  201. params.dq_accum_ptr = dq_accum_d;
  202. params.dk_accum_ptr = dk_accum_d;
  203. params.dv_accum_ptr = dv_accum_d;
  204. // Softmax sum
  205. params.dsoftmax_sum = dsoftmax_sum_d;
  206. params.deterministic = deterministic;
  207. }
  208. void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
  209. FP16_SWITCH(!params.is_bf16, [&] {
  210. HEADDIM_SWITCH(params.d, [&] {
  211. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  212. if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
  213. run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
  214. } else {
  215. run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
  216. }
  217. });
  218. });
  219. });
  220. }
  221. // Find the number of splits that maximizes the occupancy. For example, if we have
  222. // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
  223. // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
  224. // splits as that would incur more HBM reads/writes.
  225. // So we find the best efficiency, then find the smallest number of splits that gets 85%
  226. // of the best efficiency.
  227. inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
  228. // If we have enough to almost fill the SMs, then just use 1 split
  229. if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
  230. max_splits = std::min({max_splits, num_SMs, num_n_blocks});
  231. float max_efficiency = 0.f;
  232. std::vector<float> efficiency;
  233. efficiency.reserve(max_splits);
  234. auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
  235. // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
  236. // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
  237. // (i.e. it's 11 splits anyway).
  238. // So we check if the number of blocks per split is the same as the previous num_splits.
  239. auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
  240. return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
  241. };
  242. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  243. if (!is_split_eligible(num_splits)) {
  244. efficiency.push_back(0.f);
  245. } else {
  246. float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
  247. float eff = n_waves / ceil(n_waves);
  248. // printf("num_splits = %d, eff = %f\n", num_splits, eff);
  249. if (eff > max_efficiency) { max_efficiency = eff; }
  250. efficiency.push_back(eff);
  251. }
  252. }
  253. for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
  254. if (!is_split_eligible(num_splits)) { continue; }
  255. if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
  256. // printf("num_splits chosen = %d\n", num_splits);
  257. return num_splits;
  258. }
  259. }
  260. return 1;
  261. }
  262. std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params, const int batch_size,
  263. const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
  264. const int head_size_rounded, const float p_dropout,
  265. const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
  266. // This needs to match with run_mha_fwd_splitkv_dispatch
  267. const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
  268. const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
  269. // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
  270. // In any case we don't expect seqlen_q to be larger than 64 for inference.
  271. const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
  272. params.num_splits = num_splits;
  273. at::Tensor softmax_lse_accum;
  274. at::Tensor out_accum;
  275. if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
  276. if (num_splits < 1) {
  277. // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
  278. params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount * 2, num_n_blocks, 128);
  279. }
  280. if (params.num_splits > 1) {
  281. softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
  282. out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
  283. params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
  284. params.oaccum_ptr = out_accum.data_ptr();
  285. }
  286. TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
  287. }
  288. return std::make_tuple(softmax_lse_accum, out_accum);
  289. }
  290. void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
  291. #ifdef FLASHATTENTION_DISABLE_ALIBI
  292. TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
  293. params.alibi_slopes_ptr = nullptr;
  294. #else
  295. if (alibi_slopes_.has_value()) {
  296. auto alibi_slopes = alibi_slopes_.value();
  297. TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32, "ALiBi slopes must have dtype fp32");
  298. CHECK_DEVICE(alibi_slopes);
  299. TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
  300. TORCH_CHECK(alibi_slopes.sizes() == torch::IntArrayRef({num_heads}) || alibi_slopes.sizes() == torch::IntArrayRef({batch_size, num_heads}));
  301. params.alibi_slopes_ptr = alibi_slopes.data_ptr();
  302. params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
  303. } else {
  304. params.alibi_slopes_ptr = nullptr;
  305. }
  306. #endif
  307. }
  308. std::vector<at::Tensor>
  309. mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  310. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  311. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  312. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  313. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  314. const float p_dropout,
  315. const float softmax_scale,
  316. bool is_causal,
  317. int window_size_left,
  318. int window_size_right,
  319. const float softcap,
  320. const bool return_softmax,
  321. c10::optional<at::Generator> gen_) {
  322. auto dprops = at::cuda::getCurrentDeviceProperties();
  323. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  324. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  325. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  326. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  327. // We will support Turing in the near future
  328. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  329. auto q_dtype = q.dtype();
  330. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  331. "FlashAttention only support fp16 and bf16 data type");
  332. if (q_dtype == torch::kBFloat16) {
  333. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  334. }
  335. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  336. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  337. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  338. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  339. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  340. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  341. const auto sizes = q.sizes();
  342. const int batch_size = sizes[0];
  343. int seqlen_q = sizes[1];
  344. int num_heads = sizes[2];
  345. const int head_size_og = sizes[3];
  346. const int seqlen_k = k.size(1);
  347. const int num_heads_k = k.size(2);
  348. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  349. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  350. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  351. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  352. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  353. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  354. // causal=true is the same as causal=false in this case
  355. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  356. if (is_causal) { window_size_right = 0; }
  357. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  358. // H/t Daniel Haziza
  359. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
  360. const int ngroups = num_heads / num_heads_k;
  361. if (seqlenq_ngroups_swapped) {
  362. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  363. seqlen_q = ngroups;
  364. num_heads = num_heads_k;
  365. }
  366. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  367. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
  368. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
  369. at::Tensor q_padded, k_padded, v_padded;
  370. if (head_size_og % 8 != 0) {
  371. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  372. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  373. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  374. } else {
  375. q_padded = q;
  376. k_padded = k;
  377. v_padded = v;
  378. }
  379. at::Tensor out;
  380. if (out_.has_value()) {
  381. out = out_.value();
  382. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  383. CHECK_DEVICE(out);
  384. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  385. CHECK_SHAPE(out, batch_size, sizes[1], sizes[2], head_size_og);
  386. if (seqlenq_ngroups_swapped) {
  387. out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  388. }
  389. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  390. } else {
  391. out = torch::empty_like(q_padded);
  392. }
  393. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  394. const int head_size = round_multiple(head_size_og, 8);
  395. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  396. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  397. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  398. // Otherwise the kernel will be launched from cuda:0 device
  399. // Cast to char to avoid compiler warning about narrowing
  400. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  401. auto opts = q.options();
  402. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  403. at::Tensor p;
  404. // Only return softmax if there's dropout to reduce compilation time
  405. if (return_softmax) {
  406. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  407. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  408. }
  409. Flash_fwd_params params;
  410. set_params_fprop(params,
  411. batch_size,
  412. seqlen_q, seqlen_k,
  413. seqlen_q_rounded, seqlen_k_rounded,
  414. num_heads, num_heads_k,
  415. head_size, head_size_rounded,
  416. q_padded, k_padded, v_padded, out,
  417. /*cu_seqlens_q_d=*/nullptr,
  418. /*cu_seqlens_k_d=*/nullptr,
  419. /*seqused_k=*/nullptr,
  420. return_softmax ? p.data_ptr() : nullptr,
  421. softmax_lse.data_ptr(),
  422. p_dropout,
  423. softmax_scale,
  424. window_size_left,
  425. window_size_right,
  426. softcap
  427. );
  428. // Keep references to these tensors to extend their lifetime
  429. at::Tensor softmax_lse_accum, out_accum;
  430. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  431. params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
  432. head_size_rounded, p_dropout, /*num_splits*/ 0, dprops, opts);
  433. // number of times random will be generated per thread, to offset philox counter in thc random
  434. // state
  435. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  436. int64_t counter_offset = params.b * params.h * 32;
  437. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  438. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  439. // Forward kernel will populate memory with the seed and offset.
  440. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  441. if (p_dropout > 0.0) {
  442. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  443. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  444. // See Note [Acquire lock when using random generators]
  445. std::lock_guard<std::mutex> lock(gen->mutex_);
  446. params.philox_args = gen->philox_cuda_state(counter_offset);
  447. }
  448. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  449. if (seqlen_k > 0) {
  450. auto stream = at::cuda::getCurrentCUDAStream().stream();
  451. run_mha_fwd(params, stream);
  452. } else {
  453. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  454. out.zero_();
  455. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  456. }
  457. at::Tensor out_padded = out;
  458. if (head_size_og % 8 != 0) {
  459. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  460. if (out_.has_value()) { out_.value().copy_(out); }
  461. }
  462. if (seqlenq_ngroups_swapped) {
  463. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  464. out_padded = out_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  465. q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  466. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  467. }
  468. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
  469. }
  470. std::vector<at::Tensor>
  471. mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  472. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  473. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  474. c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
  475. const at::Tensor &cu_seqlens_q, // b+1
  476. const at::Tensor &cu_seqlens_k, // b+1
  477. c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
  478. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  479. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  480. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  481. int max_seqlen_q,
  482. const int max_seqlen_k,
  483. const float p_dropout,
  484. const float softmax_scale,
  485. const bool zero_tensors,
  486. bool is_causal,
  487. int window_size_left,
  488. int window_size_right,
  489. const float softcap,
  490. const bool return_softmax,
  491. c10::optional<at::Generator> gen_) {
  492. auto dprops = at::cuda::getCurrentDeviceProperties();
  493. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  494. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  495. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  496. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  497. // We will support Turing in the near future
  498. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  499. auto q_dtype = q.dtype();
  500. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  501. "FlashAttention only support fp16 and bf16 data type");
  502. if (q_dtype == torch::kBFloat16) {
  503. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  504. }
  505. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  506. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  507. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  508. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  509. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  510. CHECK_DEVICE(cu_seqlens_q);
  511. CHECK_DEVICE(cu_seqlens_k);
  512. at::Tensor block_table;
  513. const bool paged_KV = block_table_.has_value();
  514. if (paged_KV) {
  515. block_table = block_table_.value();
  516. CHECK_DEVICE(block_table);
  517. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  518. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  519. }
  520. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  521. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  522. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  523. CHECK_CONTIGUOUS(cu_seqlens_q);
  524. CHECK_CONTIGUOUS(cu_seqlens_k);
  525. const auto sizes = q.sizes();
  526. const int batch_size = cu_seqlens_q.numel() - 1;
  527. int num_heads = sizes[1];
  528. const int head_size_og = sizes[2];
  529. const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
  530. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  531. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  532. const int num_blocks = !paged_KV ? 0 : k.size(0);
  533. const int page_block_size = !paged_KV ? 1 : k.size(1);
  534. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  535. if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
  536. if (is_causal) { window_size_right = 0; }
  537. void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
  538. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  539. // H/t Daniel Haziza
  540. const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
  541. const int ngroups = num_heads / num_heads_k;
  542. if (seqlenq_ngroups_swapped) {
  543. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
  544. max_seqlen_q = ngroups;
  545. num_heads = num_heads_k;
  546. cu_seqlens_q_d = nullptr;
  547. }
  548. const int total_q = q.sizes()[0];
  549. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  550. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  551. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  552. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  553. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  554. CHECK_SHAPE(q, total_q, num_heads, head_size_og);
  555. if (!paged_KV) {
  556. const int total_k = k.size(0);
  557. CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
  558. CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
  559. } else {
  560. CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
  561. CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
  562. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  563. }
  564. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  565. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  566. if (seqused_k.has_value()){
  567. auto seqused_k_ = seqused_k.value();
  568. TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
  569. TORCH_CHECK(seqused_k_.is_cuda(), "seqused_k must be on CUDA device");
  570. TORCH_CHECK(seqused_k_.is_contiguous(), "seqused_k must be contiguous");
  571. CHECK_SHAPE(seqused_k_, batch_size);
  572. }
  573. at::Tensor q_padded, k_padded, v_padded;
  574. if (head_size_og % 8 != 0) {
  575. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  576. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  577. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  578. } else {
  579. q_padded = q;
  580. k_padded = k;
  581. v_padded = v;
  582. }
  583. at::Tensor out;
  584. if (out_.has_value()) {
  585. out = out_.value();
  586. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  587. CHECK_DEVICE(out);
  588. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  589. CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
  590. if (seqlenq_ngroups_swapped) {
  591. out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
  592. }
  593. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  594. } else {
  595. out = torch::empty_like(q_padded);
  596. }
  597. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  598. const int head_size = round_multiple(head_size_og, 8);
  599. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  600. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  601. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  602. // Otherwise the kernel will be launched from cuda:0 device
  603. // Cast to char to avoid compiler warning about narrowing
  604. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  605. auto opts = q.options();
  606. auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
  607. at::Tensor p;
  608. // Only return softmax if there's dropout to reduce compilation time
  609. if (return_softmax) {
  610. TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
  611. p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
  612. }
  613. if (zero_tensors) {
  614. out.zero_();
  615. softmax_lse.fill_(-std::numeric_limits<float>::infinity());
  616. if (return_softmax) {p.zero_();}
  617. }
  618. Flash_fwd_params params;
  619. set_params_fprop(params,
  620. batch_size,
  621. max_seqlen_q, max_seqlen_k,
  622. seqlen_q_rounded, seqlen_k_rounded,
  623. num_heads, num_heads_k,
  624. head_size, head_size_rounded,
  625. q_padded, k_padded, v_padded, out,
  626. cu_seqlens_q_d,
  627. cu_seqlens_k.data_ptr(),
  628. seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
  629. return_softmax ? p.data_ptr() : nullptr,
  630. softmax_lse.data_ptr(),
  631. p_dropout,
  632. softmax_scale,
  633. window_size_left,
  634. window_size_right,
  635. softcap,
  636. seqlenq_ngroups_swapped,
  637. /*unpadded_lse*/true);
  638. params.total_q = total_q;
  639. if (paged_KV) {
  640. params.block_table = block_table.data_ptr<int>();
  641. params.block_table_batch_stride = block_table.stride(0);
  642. params.k_batch_stride = k_padded.stride(0);
  643. params.v_batch_stride = v_padded.stride(0);
  644. }
  645. params.page_block_size = page_block_size;
  646. // Keep references to these tensors to extend their lifetime
  647. at::Tensor softmax_lse_accum, out_accum;
  648. if (seqlenq_ngroups_swapped) {
  649. // Only apply split-k for decoding
  650. std::tie(softmax_lse_accum, out_accum) =
  651. set_params_splitkv(params, batch_size, num_heads, head_size,
  652. max_seqlen_k, max_seqlen_q, head_size_rounded,
  653. p_dropout, /*num_splits*/ 0, dprops, opts);
  654. }
  655. if (leftpad_k_.has_value()) {
  656. auto leftpad_k = leftpad_k_.value();
  657. TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
  658. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  659. CHECK_DEVICE(leftpad_k);
  660. CHECK_CONTIGUOUS(leftpad_k);
  661. CHECK_SHAPE(leftpad_k, batch_size);
  662. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  663. }
  664. // number of times random will be generated per thread, to offset philox counter in thc random
  665. // state
  666. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  667. int64_t counter_offset = params.b * params.h * 32;
  668. auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
  669. auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
  670. // Forward kernel will populate memory with the seed and offset.
  671. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
  672. if (p_dropout > 0.0) {
  673. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  674. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  675. // See Note [Acquire lock when using random generators]
  676. std::lock_guard<std::mutex> lock(gen->mutex_);
  677. params.philox_args = gen->philox_cuda_state(counter_offset);
  678. }
  679. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  680. if (max_seqlen_k > 0) {
  681. auto stream = at::cuda::getCurrentCUDAStream().stream();
  682. run_mha_fwd(params, stream, paged_KV);
  683. } else {
  684. // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
  685. out.zero_();
  686. softmax_lse.fill_(std::numeric_limits<float>::infinity());
  687. }
  688. at::Tensor out_padded = out;
  689. if (head_size_og % 8 != 0) {
  690. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  691. if (out_.has_value()) { out_.value().copy_(out); }
  692. }
  693. if (seqlenq_ngroups_swapped) {
  694. int64_t size_before[] = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
  695. int64_t size_after[] = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
  696. out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
  697. out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
  698. q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
  699. softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
  700. }
  701. return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state};
  702. }
  703. void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
  704. FP16_SWITCH(!params.is_bf16, [&] {
  705. HEADDIM_SWITCH(params.d, [&] {
  706. BOOL_SWITCH(params.is_causal, Is_causal, [&] {
  707. run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
  708. });
  709. });
  710. });
  711. }
  712. std::vector<at::Tensor>
  713. mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
  714. const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  715. const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
  716. const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
  717. const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
  718. const at::Tensor &softmax_lse, // b x h x seqlen_q
  719. c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
  720. c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
  721. c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
  722. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  723. const float p_dropout, // probability to drop
  724. const float softmax_scale,
  725. const bool is_causal,
  726. int window_size_left,
  727. int window_size_right,
  728. const float softcap,
  729. const bool deterministic,
  730. c10::optional<at::Generator> gen_,
  731. c10::optional<at::Tensor> &rng_state) {
  732. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  733. TORCH_CHECK(false, "This flash attention build does not support backward.");
  734. #endif
  735. if (is_causal) { window_size_right = 0; }
  736. auto dprops = at::cuda::getCurrentDeviceProperties();
  737. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  738. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  739. bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
  740. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  741. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  742. // We will support Turing in the near future
  743. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  744. bool is_dropout = p_dropout > 0.0;
  745. auto stream = at::cuda::getCurrentCUDAStream().stream();
  746. auto q_dtype = q.dtype();
  747. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  748. "FlashAttention only support fp16 and bf16 data type");
  749. if (q_dtype == torch::kBFloat16) {
  750. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  751. }
  752. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  753. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  754. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  755. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  756. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  757. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  758. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  759. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  760. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  761. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  762. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  763. const auto sizes = q.sizes();
  764. const int batch_size = sizes[0];
  765. const int seqlen_q = sizes[1];
  766. const int num_heads = sizes[2];
  767. const int head_size_og = dout.size(3);
  768. const int head_size = sizes[3];
  769. const int seqlen_k = k.size(1);
  770. const int num_heads_k = k.size(2);
  771. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  772. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  773. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  774. if (head_size > 192 && is_dropout) {
  775. TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
  776. }
  777. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  778. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  779. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  780. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  781. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  782. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  783. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  784. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  785. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  786. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
  787. CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
  788. CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
  789. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
  790. CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
  791. at::Tensor dq, dk, dv;
  792. if (dq_.has_value()) {
  793. dq = dq_.value();
  794. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  795. CHECK_DEVICE(dq);
  796. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  797. CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
  798. } else {
  799. dq = torch::empty_like(q);
  800. }
  801. if (dk_.has_value()) {
  802. dk = dk_.value();
  803. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  804. CHECK_DEVICE(dk);
  805. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  806. CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
  807. } else {
  808. dk = torch::empty_like(k);
  809. }
  810. if (dv_.has_value()) {
  811. dv = dv_.value();
  812. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  813. CHECK_DEVICE(dv);
  814. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  815. CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
  816. } else {
  817. dv = torch::empty_like(v);
  818. }
  819. at::Tensor dout_padded;
  820. if (head_size_og % 8 != 0) {
  821. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  822. } else {
  823. dout_padded = dout;
  824. }
  825. // bool loop = seqlen_k > blocksize_c;
  826. // TODO: change later, for now set to true for simplicity
  827. bool loop = true;
  828. // Otherwise the kernel will be launched from cuda:0 device
  829. // Cast to char to avoid compiler warning about narrowing
  830. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  831. auto opts = q.options();
  832. auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
  833. at::Tensor dq_accum;
  834. at::Tensor dk_accum, dv_accum;
  835. if (loop) {
  836. if (!deterministic) {
  837. dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  838. } else {
  839. const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
  840. dq_accum = torch::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  841. }
  842. // dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  843. // dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
  844. }
  845. at::Tensor dk_expanded, dv_expanded;
  846. if (num_heads_k != num_heads) { // MQA / GQA
  847. dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  848. dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
  849. } else {
  850. dk_expanded = dk;
  851. dv_expanded = dv;
  852. }
  853. Flash_bwd_params params;
  854. set_params_dgrad(params,
  855. batch_size,
  856. seqlen_q, seqlen_k,
  857. seqlen_q_rounded, seqlen_k_rounded,
  858. num_heads, num_heads_k,
  859. head_size, head_size_rounded,
  860. q, k, v, out,
  861. dout_padded, dq, dk_expanded, dv_expanded,
  862. nullptr,
  863. nullptr,
  864. loop ? dq_accum.data_ptr() : nullptr,
  865. // loop ? dk_accum.data_ptr() : nullptr,
  866. // loop ? dv_accum.data_ptr() : nullptr,
  867. nullptr,
  868. nullptr,
  869. softmax_lse.data_ptr(),
  870. softmax_d.data_ptr(),
  871. p_dropout,
  872. softmax_scale,
  873. window_size_left,
  874. window_size_right,
  875. softcap,
  876. deterministic,
  877. /*unpadded_lse*/false);
  878. params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
  879. auto launch = &run_mha_bwd;
  880. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  881. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  882. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  883. int64_t counter_offset = params.b * params.h * 32;
  884. if ( rng_state.has_value() ) {
  885. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  886. } else if( is_dropout ) {
  887. // See Note [Acquire lock when using random generators]
  888. std::lock_guard<std::mutex> lock(gen->mutex_);
  889. params.philox_args = gen->philox_cuda_state(counter_offset);
  890. auto seeds = at::cuda::philox::unpack(params.philox_args);
  891. params.rng_state[0] = std::get<0>(seeds);
  892. params.rng_state[1] = std::get<1>(seeds);
  893. }
  894. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  895. if (seqlen_q > 0) {
  896. launch(params, stream);
  897. } else {
  898. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  899. dk_expanded.zero_();
  900. dv_expanded.zero_();
  901. softmax_d.zero_();
  902. }
  903. // For MQA/GQA we need to sum dK and dV across the groups
  904. if (num_heads_k != num_heads) {
  905. at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  906. at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
  907. }
  908. if (head_size_og % 8 != 0) {
  909. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  910. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  911. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  912. }
  913. return { dq, dk, dv, softmax_d };
  914. }
  915. std::vector<at::Tensor>
  916. mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
  917. const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  918. const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  919. const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  920. const at::Tensor &out, // total_q x num_heads x head_size
  921. const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
  922. c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
  923. c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  924. c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
  925. const at::Tensor &cu_seqlens_q, // b+1
  926. const at::Tensor &cu_seqlens_k, // b+1
  927. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
  928. const int max_seqlen_q,
  929. const int max_seqlen_k, // max sequence length to choose the kernel
  930. const float p_dropout, // probability to drop
  931. const float softmax_scale,
  932. const bool zero_tensors,
  933. const bool is_causal,
  934. int window_size_left,
  935. int window_size_right,
  936. const float softcap,
  937. const bool deterministic,
  938. c10::optional<at::Generator> gen_,
  939. c10::optional<at::Tensor> &rng_state) {
  940. #ifdef FLASHATTENTION_DISABLE_BACKWARD
  941. TORCH_CHECK(false, "This flash attention build does not support backward.");
  942. #endif
  943. if (is_causal) { window_size_right = 0; }
  944. auto dprops = at::cuda::getCurrentDeviceProperties();
  945. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  946. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  947. bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
  948. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  949. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  950. // We will support Turing in the near future
  951. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  952. bool is_dropout = p_dropout > 0.0;
  953. auto stream = at::cuda::getCurrentCUDAStream().stream();
  954. auto q_dtype = q.dtype();
  955. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  956. "FlashAttention only support fp16 and bf16 data type");
  957. if (q_dtype == torch::kBFloat16) {
  958. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  959. }
  960. TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
  961. TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
  962. TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
  963. TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
  964. TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype int32");
  965. TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype int32");
  966. CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
  967. CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
  968. CHECK_DEVICE(cu_seqlens_q); CHECK_DEVICE(cu_seqlens_k);
  969. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  970. TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  971. TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  972. TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
  973. TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
  974. CHECK_CONTIGUOUS(cu_seqlens_q);
  975. CHECK_CONTIGUOUS(cu_seqlens_k);
  976. const auto sizes = q.sizes();
  977. const int total_q = sizes[0];
  978. const int batch_size = cu_seqlens_q.numel() - 1;
  979. const int num_heads = sizes[1];
  980. const int head_size_og = dout.size(2);
  981. const int head_size = sizes[2];
  982. const int total_k = k.size(0);
  983. const int num_heads_k = k.size(1);
  984. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  985. TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
  986. TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
  987. if (head_size > 192 && is_dropout) {
  988. TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 with dropout requires A100/A800 or H100/H800");
  989. }
  990. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  991. if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
  992. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  993. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  994. const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
  995. const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
  996. TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
  997. if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
  998. if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
  999. CHECK_SHAPE(q, total_q, num_heads, head_size);
  1000. CHECK_SHAPE(k, total_k, num_heads_k, head_size);
  1001. CHECK_SHAPE(v, total_k, num_heads_k, head_size);
  1002. CHECK_SHAPE(out, total_q, num_heads, head_size);
  1003. CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
  1004. CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
  1005. CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
  1006. at::Tensor dq, dk, dv;
  1007. if (dq_.has_value()) {
  1008. dq = dq_.value();
  1009. TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
  1010. CHECK_DEVICE(dq);
  1011. TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
  1012. CHECK_SHAPE(dq, total_q, num_heads, head_size);
  1013. } else {
  1014. dq = torch::empty_like(q);
  1015. }
  1016. if (dk_.has_value()) {
  1017. dk = dk_.value();
  1018. TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
  1019. CHECK_DEVICE(dk);
  1020. TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
  1021. CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
  1022. } else {
  1023. dk = torch::empty_like(k);
  1024. }
  1025. if (dv_.has_value()) {
  1026. dv = dv_.value();
  1027. TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
  1028. CHECK_DEVICE(dv);
  1029. TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
  1030. CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
  1031. } else {
  1032. dv = torch::empty_like(v);
  1033. }
  1034. at::Tensor dout_padded;
  1035. if (head_size_og % 8 != 0) {
  1036. dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1037. } else {
  1038. dout_padded = dout;
  1039. }
  1040. // bool loop = max_seqlen_k > blocksize_c;
  1041. // TODO: change later, for now set to true for simplicity
  1042. bool loop = true;
  1043. // Otherwise the kernel will be launched from cuda:0 device
  1044. // Cast to char to avoid compiler warning about narrowing
  1045. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1046. auto opts = q.options();
  1047. auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
  1048. at::Tensor dq_accum;
  1049. if (loop) {
  1050. // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
  1051. // because that would be too large if there is a very long sequence and the rest of the sequences are short.
  1052. // Instead, we allocate dq_accum of size (total_q + 128 * batch, num_heads, head_size_rounded).
  1053. // Note that 128 is the max block size on the seqlen_q dimension.
  1054. // For dQ, the i-th sequence is stored in indices from cu_seqlens[i] + 128 * i to
  1055. // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
  1056. // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
  1057. // allowed to do. So we won't have to do any bound checking, and performance should stay the same.
  1058. // Same holds for softmax_d, since LSE is stored in unpadded format.
  1059. if (!deterministic) {
  1060. dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  1061. } else {
  1062. const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
  1063. dq_accum = torch::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
  1064. }
  1065. }
  1066. at::Tensor dk_expanded, dv_expanded;
  1067. if (num_heads_k != num_heads) { // MQA / GQA
  1068. dk_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  1069. dv_expanded = torch::empty({total_k, num_heads, head_size}, opts);
  1070. } else {
  1071. dk_expanded = dk;
  1072. dv_expanded = dv;
  1073. }
  1074. if( zero_tensors ) {
  1075. dq.zero_();
  1076. dk_expanded.zero_();
  1077. dv_expanded.zero_();
  1078. softmax_d.zero_();
  1079. }
  1080. Flash_bwd_params params;
  1081. set_params_dgrad(params,
  1082. batch_size,
  1083. max_seqlen_q, max_seqlen_k,
  1084. seqlen_q_rounded, seqlen_k_rounded,
  1085. num_heads, num_heads_k,
  1086. head_size, head_size_rounded,
  1087. q, k, v, out,
  1088. dout_padded, dq, dk_expanded, dv_expanded,
  1089. cu_seqlens_q.data_ptr(),
  1090. cu_seqlens_k.data_ptr(),
  1091. loop ? dq_accum.data_ptr() : nullptr,
  1092. nullptr,
  1093. nullptr,
  1094. softmax_lse.data_ptr(),
  1095. softmax_d.data_ptr(),
  1096. p_dropout,
  1097. softmax_scale,
  1098. window_size_left,
  1099. window_size_right,
  1100. softcap,
  1101. deterministic,
  1102. /*unpadded_lse*/true);
  1103. params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
  1104. params.total_q = total_q;
  1105. auto launch = &run_mha_bwd;
  1106. auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
  1107. gen_, at::cuda::detail::getDefaultCUDAGenerator());
  1108. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
  1109. int64_t counter_offset = params.b * params.h * 32;
  1110. if ( rng_state.has_value() ) {
  1111. params.rng_state = reinterpret_cast<uint64_t*>(rng_state.value().data_ptr());
  1112. } else if( is_dropout ) {
  1113. // See Note [Acquire lock when using random generators]
  1114. std::lock_guard<std::mutex> lock(gen->mutex_);
  1115. params.philox_args = gen->philox_cuda_state(counter_offset);
  1116. auto seeds = at::cuda::philox::unpack(params.philox_args);
  1117. params.rng_state[0] = std::get<0>(seeds);
  1118. params.rng_state[1] = std::get<1>(seeds);
  1119. }
  1120. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1121. if (max_seqlen_q > 0) {
  1122. launch(params, stream);
  1123. } else {
  1124. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
  1125. dk_expanded.zero_();
  1126. dv_expanded.zero_();
  1127. softmax_d.zero_();
  1128. }
  1129. // For MQA/GQA we need to sum dK and dV across the groups
  1130. if (num_heads_k != num_heads) {
  1131. at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1132. at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
  1133. }
  1134. if (head_size_og % 8 != 0) {
  1135. dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1136. dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1137. dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1138. }
  1139. return { dq, dk, dv, softmax_d };
  1140. }
  1141. std::vector<at::Tensor>
  1142. mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
  1143. const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1144. const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
  1145. c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
  1146. c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
  1147. c10::optional<const at::Tensor> &seqlens_k_, // batch_size
  1148. c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
  1149. c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
  1150. c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
  1151. c10::optional<const at::Tensor> &leftpad_k_, // batch_size
  1152. c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
  1153. c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
  1154. c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
  1155. const float softmax_scale,
  1156. bool is_causal,
  1157. int window_size_left,
  1158. int window_size_right,
  1159. const float softcap,
  1160. bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
  1161. int num_splits
  1162. ) {
  1163. auto dprops = at::cuda::getCurrentDeviceProperties();
  1164. // bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
  1165. bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
  1166. bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
  1167. TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
  1168. // We will support Turing in the near future
  1169. // TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
  1170. auto q_dtype = q.dtype();
  1171. TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
  1172. "FlashAttention only support fp16 and bf16 data type");
  1173. if (q_dtype == torch::kBFloat16) {
  1174. TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
  1175. }
  1176. TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
  1177. TORCH_CHECK(vcache.dtype() == q_dtype, "query and value must have the same dtype");
  1178. CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache);
  1179. TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1180. TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1181. TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
  1182. at::Tensor block_table;
  1183. const bool paged_KV = block_table_.has_value();
  1184. if (paged_KV) {
  1185. TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
  1186. block_table = block_table_.value();
  1187. CHECK_DEVICE(block_table);
  1188. TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
  1189. TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
  1190. }
  1191. const auto sizes = q.sizes();
  1192. const int batch_size = sizes[0];
  1193. int seqlen_q = sizes[1];
  1194. int num_heads = sizes[2];
  1195. const int head_size_og = sizes[3];
  1196. const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
  1197. const int num_blocks = !paged_KV ? 0 : kcache.size(0);
  1198. const int page_block_size = !paged_KV ? 1 : kcache.size(1);
  1199. TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
  1200. const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
  1201. const int num_heads_k = kcache.size(2);
  1202. const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
  1203. TORCH_CHECK(batch_size > 0, "batch size must be positive");
  1204. TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
  1205. TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
  1206. // causal=true is the same as causal=false in this case
  1207. if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
  1208. if (is_causal) { window_size_right = 0; }
  1209. // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
  1210. // H/t Daniel Haziza
  1211. const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
  1212. if (seqlenq_ngroups_swapped) {
  1213. const int ngroups = num_heads / num_heads_k;
  1214. q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
  1215. seqlen_q = ngroups;
  1216. num_heads = num_heads_k;
  1217. }
  1218. if (window_size_left >= seqlen_k) { window_size_left = -1; }
  1219. if (window_size_right >= seqlen_k) { window_size_right = -1; }
  1220. CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
  1221. if (!paged_KV) {
  1222. CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1223. CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
  1224. } else {
  1225. CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1226. CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
  1227. CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
  1228. }
  1229. at::Tensor q_padded, kcache_padded, vcache_padded;
  1230. if (head_size_og % 8 != 0) {
  1231. q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1232. kcache_padded = torch::nn::functional::pad(kcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1233. vcache_padded = torch::nn::functional::pad(vcache, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1234. } else {
  1235. q_padded = q;
  1236. kcache_padded = kcache;
  1237. vcache_padded = vcache;
  1238. }
  1239. at::Tensor out;
  1240. if (out_.has_value()) {
  1241. out = out_.value();
  1242. TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
  1243. CHECK_DEVICE(out);
  1244. TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
  1245. CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
  1246. if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); }
  1247. } else {
  1248. out = torch::empty_like(q_padded);
  1249. }
  1250. auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
  1251. const int head_size = round_multiple(head_size_og, 8);
  1252. const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
  1253. const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
  1254. const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
  1255. // Otherwise the kernel will be launched from cuda:0 device
  1256. // Cast to char to avoid compiler warning about narrowing
  1257. at::cuda::CUDAGuard device_guard{(char)q.get_device()};
  1258. auto opts = q.options();
  1259. auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
  1260. Flash_fwd_params params;
  1261. set_params_fprop(params,
  1262. batch_size,
  1263. seqlen_q, seqlen_k,
  1264. seqlen_q_rounded, seqlen_k_rounded,
  1265. num_heads, num_heads_k,
  1266. head_size, head_size_rounded,
  1267. q_padded, kcache_padded, vcache_padded, out,
  1268. /*cu_seqlens_q_d=*/nullptr,
  1269. /*cu_seqlens_k_d=*/nullptr,
  1270. /*seqused_k=*/nullptr,
  1271. /*p_ptr=*/nullptr,
  1272. softmax_lse.data_ptr(),
  1273. /*p_dropout=*/0.f,
  1274. softmax_scale,
  1275. window_size_left,
  1276. window_size_right,
  1277. softcap
  1278. );
  1279. at::Tensor k, v, k_padded, v_padded;
  1280. if (k_.has_value()) {
  1281. TORCH_CHECK(v_.has_value(), "If key is supplied, value must also be passed in");
  1282. TORCH_CHECK(seqlens_k_.has_value(), "If key is supplied, seqlens_k must also be passed in");
  1283. TORCH_CHECK(seqlen_q <= seqlen_k, "If key is supplied, it must have seqlen <= the seqlen of the KV cache");
  1284. k = k_.value();
  1285. v = v_.value();
  1286. TORCH_CHECK(k.dtype() == q_dtype, "Key must have the same dtype as query");
  1287. TORCH_CHECK(v.dtype() == q_dtype, "Value must have the same dtype as query");
  1288. CHECK_DEVICE(k); CHECK_DEVICE(v);
  1289. TORCH_CHECK(k.stride(-1) == 1, "Key tensor must have contiguous last dimension");
  1290. TORCH_CHECK(v.stride(-1) == 1, "Value tensor must have contiguous last dimension");
  1291. int seqlen_knew = k.size(1);
  1292. CHECK_SHAPE(k, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1293. CHECK_SHAPE(v, batch_size, seqlen_knew, num_heads_k, head_size_og);
  1294. if (head_size_og % 8 != 0) {
  1295. k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1296. v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
  1297. } else {
  1298. k_padded = k;
  1299. v_padded = v;
  1300. }
  1301. params.seqlen_knew = seqlen_knew;
  1302. params.knew_ptr = k_padded.data_ptr();
  1303. params.vnew_ptr = v_padded.data_ptr();
  1304. // All stride are in elements, not bytes.
  1305. params.knew_batch_stride = k_padded.stride(0);
  1306. params.vnew_batch_stride = v_padded.stride(0);
  1307. params.knew_row_stride = k_padded.stride(-3);
  1308. params.vnew_row_stride = v_padded.stride(-3);
  1309. params.knew_head_stride = k_padded.stride(-2);
  1310. params.vnew_head_stride = v_padded.stride(-2);
  1311. }
  1312. if (seqlens_k_.has_value()) {
  1313. auto seqlens_k = seqlens_k_.value();
  1314. TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
  1315. CHECK_DEVICE(seqlens_k);
  1316. CHECK_CONTIGUOUS(seqlens_k);
  1317. CHECK_SHAPE(seqlens_k, batch_size);
  1318. params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
  1319. }
  1320. params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());
  1321. if (leftpad_k_.has_value()) {
  1322. TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet");
  1323. auto leftpad_k = leftpad_k_.value();
  1324. TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
  1325. CHECK_DEVICE(leftpad_k);
  1326. CHECK_CONTIGUOUS(leftpad_k);
  1327. CHECK_SHAPE(leftpad_k, batch_size);
  1328. params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
  1329. }
  1330. if (rotary_cos_.has_value()) {
  1331. TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
  1332. auto rotary_cos = rotary_cos_.value();
  1333. CHECK_DEVICE(rotary_cos);
  1334. params.rotary_dim = rotary_cos.size(1) * 2;
  1335. TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
  1336. TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
  1337. const int seqlen_ro = rotary_cos.size(0);
  1338. TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
  1339. CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
  1340. CHECK_CONTIGUOUS(rotary_cos);
  1341. TORCH_CHECK(rotary_cos.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1342. TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
  1343. auto rotary_sin = rotary_sin_.value();
  1344. CHECK_DEVICE(rotary_sin);
  1345. CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
  1346. CHECK_CONTIGUOUS(rotary_sin);
  1347. TORCH_CHECK(rotary_sin.scalar_type() == q_dtype, "rotary_cos must have the same dtype as query");
  1348. params.rotary_cos_ptr = rotary_cos.data_ptr();
  1349. params.rotary_sin_ptr = rotary_sin.data_ptr();
  1350. params.is_rotary_interleaved = is_rotary_interleaved;
  1351. } else {
  1352. params.rotary_dim = 0;
  1353. }
  1354. if (cache_batch_idx_.has_value()) {
  1355. auto cache_batch_idx = cache_batch_idx_.value();
  1356. CHECK_DEVICE(cache_batch_idx);
  1357. CHECK_CONTIGUOUS(cache_batch_idx);
  1358. TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
  1359. params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
  1360. }
  1361. // Keep references to these tensors to extend their lifetime
  1362. at::Tensor softmax_lse_accum, out_accum;
  1363. std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
  1364. params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
  1365. head_size_rounded, /*dropout*/ 0.f, num_splits, dprops, opts);
  1366. if (paged_KV) {
  1367. params.block_table = block_table.data_ptr<int>();
  1368. params.block_table_batch_stride = block_table.stride(0);
  1369. }
  1370. params.page_block_size = page_block_size;
  1371. set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
  1372. auto stream = at::cuda::getCurrentCUDAStream().stream();
  1373. // Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
  1374. // or paged KV cache
  1375. run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
  1376. if (head_size_og % 8 != 0) {
  1377. out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
  1378. if (out_.has_value()) { out_.value().copy_(out); }
  1379. if (k_.has_value()) {
  1380. // It's expensive to copy the KV cache here for the case where head size not divisible by 8,
  1381. // but we don't expect to get this case in practice. This is just so that the code works for that case.
  1382. kcache.copy_(kcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1383. vcache.copy_(vcache_padded.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}));
  1384. }
  1385. }
  1386. if (seqlenq_ngroups_swapped) {
  1387. out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
  1388. softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
  1389. }
  1390. return {out, softmax_lse};
  1391. }
  1392. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  1393. m.doc() = "FlashAttention";
  1394. m.def("fwd", &mha_fwd, "Forward pass");
  1395. m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
  1396. m.def("bwd", &mha_bwd, "Backward pass");
  1397. m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
  1398. m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
  1399. }