decoder_masked_multihead_attention_template.hpp 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605
  1. // Downloaded from from FasterTransformer v5.2.1
  2. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. /*
  4. * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #pragma once
  19. #include "decoder_masked_multihead_attention.h"
  20. #include "decoder_masked_multihead_attention_utils.h"
  21. #include "cuda_bf16_wrapper.h"
  22. #include "cuda_bf16_fallbacks.cuh"
  23. #include <assert.h>
  24. #include <float.h>
  25. #include <type_traits>
  26. // #define MMHA_USE_HMMA_FOR_REDUCTION
  27. // Below are knobs to extend FP32 accumulation for higher FP16 accuracy
  28. // Does not seem to affect the accuracy that much
  29. // #define MMHA_USE_FP32_ACUM_FOR_FMA
  30. // Seems to slightly improve the accuracy
  31. #define MMHA_USE_FP32_ACUM_FOR_OUT
  32. #if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
  33. // Does not seem to improve the accuracy
  34. //#define MMHA_USE_FP32_ACUM_FOR_LOGITS
  35. #endif
  36. namespace mmha {
  37. ////////////////////////////////////////////////////////////////////////////////////////////////////
  38. //
  39. // We use the following terminology to describe the different dimensions.
  40. //
  41. // B: Batch size (number of sequences),
  42. // L: Sequence length,
  43. // D: Hidden dimension,
  44. // H: Number of heads,
  45. // Dh: Hidden dimension per head - Dh = D / H.
  46. //
  47. // The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
  48. // 64, 128 and 256 threads per block.
  49. //
  50. // Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
  51. // compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
  52. // cache buffer helps with memory accesses and contains keys with bias.
  53. //
  54. // The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
  55. // x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
  56. // values for x are chosen to create chunks of 16 bytes.
  57. //
  58. // The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
  59. // depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
  60. // the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
  61. // HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
  62. //
  63. // After that loop, a parallel softmax is computed across the different Q * K^T values stored in
  64. // shared memory.
  65. //
  66. // The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
  67. // timesteps are computed by loop iteration. As with the keys, the values are read from a cache
  68. // except for the current timestep. The layout of the cache buffer for the values is much simpler
  69. // as it is [B, H, L, Dh].
  70. //
  71. ////////////////////////////////////////////////////////////////////////////////////////////////////
  72. template<typename T, int Dh>
  73. struct Qk_vec_ {
  74. };
  75. template<>
  76. struct Qk_vec_<float, 32> {
  77. using Type = float;
  78. };
  79. template<>
  80. struct Qk_vec_<float, 64> {
  81. using Type = float2;
  82. };
  83. template<>
  84. struct Qk_vec_<float, 128> {
  85. using Type = float4;
  86. };
  87. template<>
  88. struct Qk_vec_<float, 256> {
  89. using Type = float4;
  90. };
  91. template<>
  92. struct Qk_vec_<uint16_t, 32> {
  93. using Type = uint32_t;
  94. };
  95. template<>
  96. struct Qk_vec_<uint16_t, 64> {
  97. using Type = uint32_t;
  98. };
  99. template<>
  100. struct Qk_vec_<uint16_t, 128> {
  101. using Type = uint2;
  102. };
  103. template<>
  104. struct Qk_vec_<uint16_t, 256> {
  105. using Type = uint4;
  106. };
  107. #ifdef ENABLE_BF16
  108. template<>
  109. struct Qk_vec_<__nv_bfloat16, 32> {
  110. using Type = __nv_bfloat162;
  111. };
  112. template<>
  113. struct Qk_vec_<__nv_bfloat16, 64> {
  114. using Type = __nv_bfloat162;
  115. };
  116. template<>
  117. struct Qk_vec_<__nv_bfloat16, 128> {
  118. using Type = bf16_4_t;
  119. };
  120. template<>
  121. struct Qk_vec_<__nv_bfloat16, 256> {
  122. using Type = bf16_8_t;
  123. };
  124. #endif // ENABLE_BF16
  125. ////////////////////////////////////////////////////////////////////////////////////////////////////
  126. template<typename T, int THREADS_PER_KEY>
  127. struct K_vec_ {
  128. };
  129. template<>
  130. struct K_vec_<float, 4> {
  131. using Type = float;
  132. };
  133. template<>
  134. struct K_vec_<float, 2> {
  135. using Type = float2;
  136. };
  137. template<>
  138. struct K_vec_<float, 1> {
  139. using Type = float4;
  140. };
  141. template<>
  142. struct K_vec_<uint16_t, 4> {
  143. using Type = uint32_t;
  144. };
  145. template<>
  146. struct K_vec_<uint16_t, 2> {
  147. using Type = uint2;
  148. };
  149. template<>
  150. struct K_vec_<uint16_t, 1> {
  151. using Type = uint4;
  152. };
  153. #ifdef ENABLE_BF16
  154. template<>
  155. struct K_vec_<__nv_bfloat16, 4> {
  156. using Type = __nv_bfloat162;
  157. };
  158. template<>
  159. struct K_vec_<__nv_bfloat16, 2> {
  160. using Type = bf16_4_t;
  161. };
  162. template<>
  163. struct K_vec_<__nv_bfloat16, 1> {
  164. using Type = bf16_8_t;
  165. };
  166. #endif // ENABLE_BF16
  167. ////////////////////////////////////////////////////////////////////////////////////////////////////
  168. template<typename T, int V_VEC_SIZE>
  169. struct V_vec_ {
  170. };
  171. template<>
  172. struct V_vec_<float, 1> {
  173. using Type = float;
  174. };
  175. template<>
  176. struct V_vec_<float, 2> {
  177. using Type = float2;
  178. };
  179. template<>
  180. struct V_vec_<float, 4> {
  181. using Type = float4;
  182. };
  183. template<>
  184. struct V_vec_<uint16_t, 2> {
  185. using Type = uint32_t;
  186. };
  187. template<>
  188. struct V_vec_<uint16_t, 4> {
  189. using Type = uint2;
  190. };
  191. template<>
  192. struct V_vec_<uint16_t, 8> {
  193. using Type = uint4;
  194. };
  195. #ifdef ENABLE_BF16
  196. template<>
  197. struct V_vec_<__nv_bfloat16, 2> {
  198. using Type = __nv_bfloat162;
  199. };
  200. template<>
  201. struct V_vec_<__nv_bfloat16, 4> {
  202. using Type = bf16_4_t;
  203. };
  204. template<>
  205. struct V_vec_<__nv_bfloat16, 8> {
  206. using Type = bf16_8_t;
  207. };
  208. #endif // ENABLE_BF16
  209. ////////////////////////////////////////////////////////////////////////////////////////////////////
  210. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  211. template<typename T>
  212. struct Qk_vec_acum_fp32_ {
  213. };
  214. template<>
  215. struct Qk_vec_acum_fp32_<float> {
  216. using Type = float;
  217. };
  218. template<>
  219. struct Qk_vec_acum_fp32_<float2> {
  220. using Type = float2;
  221. };
  222. template<>
  223. struct Qk_vec_acum_fp32_<float4> {
  224. using Type = float4;
  225. };
  226. // template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
  227. template<>
  228. struct Qk_vec_acum_fp32_<uint32_t> {
  229. using Type = float2;
  230. };
  231. template<>
  232. struct Qk_vec_acum_fp32_<uint2> {
  233. using Type = Float4_;
  234. };
  235. template<>
  236. struct Qk_vec_acum_fp32_<uint4> {
  237. using Type = Float8_;
  238. };
  239. template<>
  240. struct Qk_vec_acum_fp32_<__nv_bfloat16> {
  241. using Type = float;
  242. };
  243. template<>
  244. struct Qk_vec_acum_fp32_<__nv_bfloat162> {
  245. using Type = float2;
  246. };
  247. template<>
  248. struct Qk_vec_acum_fp32_<bf16_4_t> {
  249. using Type = Float4_;
  250. };
  251. template<>
  252. struct Qk_vec_acum_fp32_<bf16_8_t> {
  253. using Type = Float8_;
  254. };
  255. template<>
  256. struct Qk_vec_acum_fp32_<uint4> {
  257. using Type = Float8_;
  258. };
  259. template<>
  260. struct Qk_vec_acum_fp32_<__nv_bfloat16> {
  261. using Type = float;
  262. };
  263. template<>
  264. struct Qk_vec_acum_fp32_<__nv_bfloat162> {
  265. using Type = float2;
  266. };
  267. template<>
  268. struct Qk_vec_acum_fp32_<bf16_4_t> {
  269. using Type = Float4_;
  270. };
  271. template<>
  272. struct Qk_vec_acum_fp32_<bf16_8_t> {
  273. using Type = Float8_;
  274. };
  275. ////////////////////////////////////////////////////////////////////////////////////////////////////
  276. template<typename T>
  277. struct K_vec_acum_fp32_ {
  278. };
  279. template<>
  280. struct K_vec_acum_fp32_<float> {
  281. using Type = float;
  282. };
  283. template<>
  284. struct K_vec_acum_fp32_<float2> {
  285. using Type = float2;
  286. };
  287. template<>
  288. struct K_vec_acum_fp32_<float4> {
  289. using Type = float4;
  290. };
  291. template<>
  292. struct K_vec_acum_fp32_<uint32_t> {
  293. using Type = float2;
  294. };
  295. template<>
  296. struct K_vec_acum_fp32_<uint2> {
  297. using Type = Float4_;
  298. };
  299. template<>
  300. struct K_vec_acum_fp32_<uint4> {
  301. using Type = Float8_;
  302. };
  303. template<>
  304. struct K_vec_acum_fp32_<__nv_bfloat16> {
  305. using Type = float;
  306. };
  307. template<>
  308. struct K_vec_acum_fp32_<__nv_bfloat162> {
  309. using Type = float2;
  310. };
  311. template<>
  312. struct K_vec_acum_fp32_<bf16_4_t> {
  313. using Type = Float4_;
  314. };
  315. template<>
  316. struct K_vec_acum_fp32_<bf16_8_t> {
  317. using Type = Float8_;
  318. };
  319. #endif
  320. ////////////////////////////////////////////////////////////////////////////////////////////////////
  321. #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
  322. template<typename T>
  323. struct V_vec_acum_fp32_ {
  324. };
  325. template<>
  326. struct V_vec_acum_fp32_<float> {
  327. using Type = float;
  328. };
  329. template<>
  330. struct V_vec_acum_fp32_<float2> {
  331. using Type = float2;
  332. };
  333. template<>
  334. struct V_vec_acum_fp32_<float4> {
  335. using Type = float4;
  336. };
  337. template<>
  338. struct V_vec_acum_fp32_<uint32_t> {
  339. using Type = float2;
  340. };
  341. template<>
  342. struct V_vec_acum_fp32_<uint2> {
  343. using Type = Float4_;
  344. };
  345. template<>
  346. struct V_vec_acum_fp32_<uint4> {
  347. using Type = Float8_;
  348. };
  349. #ifdef ENABLE_BF16
  350. template<>
  351. struct V_vec_acum_fp32_<__nv_bfloat162> {
  352. using Type = float2;
  353. };
  354. template<>
  355. struct V_vec_acum_fp32_<bf16_4_t> {
  356. using Type = Float4_;
  357. };
  358. template<>
  359. struct V_vec_acum_fp32_<bf16_8_t> {
  360. using Type = Float8_;
  361. };
  362. #endif // ENABLE_BF16
  363. #endif
  364. ////////////////////////////////////////////////////////////////////////////////////////////////////
  365. template<int THREADS_PER_KEY, typename K_vec, int N>
  366. inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
  367. {
  368. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  369. using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
  370. #else
  371. using K_vec_acum = K_vec;
  372. #endif
  373. // Compute the parallel products for Q*K^T (treat vector lanes separately).
  374. K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
  375. #pragma unroll
  376. for (int ii = 1; ii < N; ++ii) {
  377. qk_vec = fma(q[ii], k[ii], qk_vec);
  378. }
  379. // Finalize the reduction across lanes.
  380. float qk = sum(qk_vec);
  381. #pragma unroll
  382. for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
  383. qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
  384. }
  385. return qk;
  386. }
  387. ////////////////////////////////////////////////////////////////////////////////////////////////////
  388. template<typename T, int THREADS_PER_KEY>
  389. struct Qk_dot {
  390. template<typename K_vec, int N>
  391. static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
  392. {
  393. return qk_dot_<THREADS_PER_KEY>(q, k);
  394. }
  395. };
  396. ////////////////////////////////////////////////////////////////////////////////////////////////////
  397. inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
  398. {
  399. float4 c;
  400. float zero = 0.f;
  401. asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
  402. " {%0, %1, %2, %3}, \n"
  403. " {%4, %5}, \n"
  404. " {%6}, \n"
  405. " {%7, %7, %7, %7}; \n"
  406. : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
  407. : "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
  408. return c;
  409. }
  410. ////////////////////////////////////////////////////////////////////////////////////////////////////
  411. template<int N>
  412. inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
  413. {
  414. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
  415. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  416. using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
  417. #else
  418. using K_vec_acum = uint32_t;
  419. #endif
  420. K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
  421. #pragma unroll
  422. for (int ii = 1; ii < N; ++ii) {
  423. qk_vec = fma(q[ii], k[ii], qk_vec);
  424. }
  425. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  426. uint32_t qk_vec_ = float2_to_half2(qk_vec);
  427. return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
  428. #else
  429. return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
  430. #endif
  431. #else
  432. return 0.f;
  433. #endif
  434. }
  435. ////////////////////////////////////////////////////////////////////////////////////////////////////
  436. template<>
  437. struct Qk_dot<uint16_t, 4> {
  438. template<int N>
  439. static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
  440. {
  441. #if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
  442. return qk_hmma_dot_(q, k);
  443. #else
  444. return qk_dot_<4>(q, k);
  445. #endif // defined MMHA_USE_HMMA_FOR_REDUCTION
  446. }
  447. };
  448. ////////////////////////////////////////////////////////////////////////////////////////////////////
  449. template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
  450. inline __device__ float block_sum(float* red_smem, float sum)
  451. {
  452. // Decompose the thread index into warp / lane.
  453. int warp = threadIdx.x / WARP_SIZE;
  454. int lane = threadIdx.x % WARP_SIZE;
  455. // Compute the sum per warp.
  456. #pragma unroll
  457. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  458. sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
  459. }
  460. // Warp leaders store the data to shared memory.
  461. if (lane == 0) {
  462. red_smem[warp] = sum;
  463. }
  464. // Make sure the data is in shared memory.
  465. __syncthreads();
  466. // The warps compute the final sums.
  467. if (lane < WARPS_PER_BLOCK) {
  468. sum = red_smem[lane];
  469. }
  470. // Parallel reduction inside the warp.
  471. #pragma unroll
  472. for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
  473. sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
  474. }
  475. // Broadcast to other threads.
  476. return __shfl_sync(uint32_t(-1), sum, 0);
  477. }
  478. ////////////////////////////////////////////////////////////////////////////////////////////////////
  479. inline __device__ void convert_from_float(float& dst, float src)
  480. {
  481. dst = src;
  482. }
  483. ////////////////////////////////////////////////////////////////////////////////////////////////////
  484. inline __device__ void convert_from_float(uint16_t& dst, float src)
  485. {
  486. dst = float_to_half(src);
  487. }
  488. ////////////////////////////////////////////////////////////////////////////////////////////////////
  489. inline __device__ void convert_from_float(uint32_t& dst, float2 src)
  490. {
  491. dst = float2_to_half2(src);
  492. }
  493. ////////////////////////////////////////////////////////////////////////////////////////////////////
  494. #ifdef ENABLE_BF16
  495. inline __device__ void convert_from_float(__nv_bfloat16& dst, float src)
  496. {
  497. dst = __float2bfloat16(src);
  498. }
  499. ////////////////////////////////////////////////////////////////////////////////////////////////////
  500. inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src)
  501. {
  502. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  503. dst = __float22bfloat162_rn(src);
  504. #else
  505. dst = __floats2bfloat162_rn(src.x, src.y);
  506. #endif
  507. }
  508. #endif // ENABLE_BF16
  509. ////////////////////////////////////////////////////////////////////////////////////////////////////
  510. inline __device__ void convert_from_float(uint2& dst, Float4_ src)
  511. {
  512. dst.x = float2_to_half2(src.x);
  513. dst.y = float2_to_half2(src.y);
  514. }
  515. ////////////////////////////////////////////////////////////////////////////////////////////////////
  516. inline __device__ void convert_from_float(uint2& dst, float4 src)
  517. {
  518. convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
  519. }
  520. ////////////////////////////////////////////////////////////////////////////////////////////////////
  521. inline __device__ void convert_from_float(uint4& dst, Float8_ src)
  522. {
  523. dst.x = float2_to_half2(src.x);
  524. dst.y = float2_to_half2(src.y);
  525. dst.z = float2_to_half2(src.z);
  526. dst.w = float2_to_half2(src.w);
  527. }
  528. ////////////////////////////////////////////////////////////////////////////////////////////////////
  529. #ifdef ENABLE_BF16
  530. inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src)
  531. {
  532. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  533. dst.x = __float22bfloat162_rn(src.x);
  534. dst.y = __float22bfloat162_rn(src.y);
  535. #else
  536. dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
  537. dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
  538. #endif
  539. }
  540. ////////////////////////////////////////////////////////////////////////////////////////////////////
  541. inline __device__ void convert_from_float(bf16_4_t& dst, float4 src)
  542. {
  543. convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
  544. }
  545. ////////////////////////////////////////////////////////////////////////////////////////////////////
  546. inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
  547. {
  548. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  549. dst.x = __float22bfloat162_rn(src.x);
  550. dst.y = __float22bfloat162_rn(src.y);
  551. dst.z = __float22bfloat162_rn(src.z);
  552. dst.w = __float22bfloat162_rn(src.w);
  553. #else
  554. dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
  555. dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
  556. dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
  557. dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
  558. #endif
  559. }
  560. #endif // ENABLE_BF16
  561. ////////////////////////////////////////////////////////////////////////////////////////////////////
  562. inline __device__ void convert_from_float(float2& dst, float2 src)
  563. {
  564. dst = src;
  565. }
  566. ////////////////////////////////////////////////////////////////////////////////////////////////////
  567. inline __device__ void convert_from_float(float4& dst, float4 src)
  568. {
  569. dst = src;
  570. }
  571. ////////////////////////////////////////////////////////////////////////////////////////////////////
  572. inline __device__ float convert_to_float(float4 u)
  573. {
  574. return u.x;
  575. }
  576. ////////////////////////////////////////////////////////////////////////////////////////////////////
  577. inline __device__ float convert_to_float(uint4 u)
  578. {
  579. float2 tmp = half2_to_float2(u.x);
  580. return tmp.x;
  581. }
  582. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
  583. ////////////////////////////////////////////////////////////////////////////////////////////////////
  584. inline __device__ float cast_to_float(float u)
  585. {
  586. return u;
  587. }
  588. ////////////////////////////////////////////////////////////////////////////////////////////////////
  589. inline __device__ float2 cast_to_float(float2 u)
  590. {
  591. return u;
  592. }
  593. ////////////////////////////////////////////////////////////////////////////////////////////////////
  594. inline __device__ float4 cast_to_float(float4 u)
  595. {
  596. return u;
  597. }
  598. ////////////////////////////////////////////////////////////////////////////////////////////////////
  599. inline __device__ Float4_ cast_to_float(Float4_ u)
  600. {
  601. return u;
  602. }
  603. ////////////////////////////////////////////////////////////////////////////////////////////////////
  604. inline __device__ Float8_ cast_to_float(Float8_ u)
  605. {
  606. return u;
  607. }
  608. ////////////////////////////////////////////////////////////////////////////////////////////////////
  609. inline __device__ float2 cast_to_float(uint32_t u)
  610. {
  611. return half2_to_float2(u);
  612. }
  613. ////////////////////////////////////////////////////////////////////////////////////////////////////
  614. inline __device__ Float4_ cast_to_float(uint2 u)
  615. {
  616. Float4_ tmp;
  617. tmp.x = half2_to_float2(u.x);
  618. tmp.y = half2_to_float2(u.y);
  619. return tmp;
  620. }
  621. ////////////////////////////////////////////////////////////////////////////////////////////////////
  622. inline __device__ Float8_ cast_to_float(uint4 u)
  623. {
  624. Float8_ tmp;
  625. tmp.x = half2_to_float2(u.x);
  626. tmp.y = half2_to_float2(u.y);
  627. tmp.z = half2_to_float2(u.z);
  628. tmp.w = half2_to_float2(u.w);
  629. return tmp;
  630. }
  631. #endif
  632. ////////////////////////////////////////////////////////////////////////////////////////////////////
  633. inline __device__ float float_from_int8(int8_t u)
  634. {
  635. return u;
  636. }
  637. ////////////////////////////////////////////////////////////////////////////////////////////////////
  638. inline __device__ float2 float_from_int8(int16_t u)
  639. {
  640. union {
  641. int16_t int16;
  642. int8_t int8[2];
  643. };
  644. int16 = u;
  645. return make_float2(int8[0], int8[1]);
  646. }
  647. ////////////////////////////////////////////////////////////////////////////////////////////////////
  648. inline __device__ float4 float_from_int8(int32_t u)
  649. {
  650. union {
  651. int32_t int32;
  652. int8_t int8[4];
  653. };
  654. int32 = u;
  655. return make_float4(int8[0], int8[1], int8[2], int8[3]);
  656. }
  657. ////////////////////////////////////////////////////////////////////////////////////////////////////
  658. // clang-format off
  659. inline __device__ Float8_ float_from_int8(int64_t u)
  660. {
  661. union {
  662. int64_t int64;
  663. int16_t int16[4];
  664. };
  665. int64 = u;
  666. return Float8_ {float_from_int8(int16[0]),
  667. float_from_int8(int16[1]),
  668. float_from_int8(int16[2]),
  669. float_from_int8(int16[3])};
  670. }
  671. // clang-format on
  672. ////////////////////////////////////////////////////////////////////////////////////////////////////
  673. inline __device__ int8_t cast_to_int8(float val)
  674. {
  675. union {
  676. int8_t int8[2];
  677. int16_t int16;
  678. };
  679. asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
  680. return int8[0];
  681. }
  682. ////////////////////////////////////////////////////////////////////////////////////////////////////
  683. inline __device__ int32_t cast_to_int8(float4 val)
  684. {
  685. union {
  686. int8_t int8[4];
  687. int32_t int32;
  688. };
  689. int8[0] = cast_to_int8(val.x);
  690. int8[1] = cast_to_int8(val.y);
  691. int8[2] = cast_to_int8(val.z);
  692. int8[3] = cast_to_int8(val.w);
  693. return int32;
  694. }
  695. ////////////////////////////////////////////////////////////////////////////////////////////////////
  696. inline __device__ int64_t cast_to_int8(Float8_ val)
  697. {
  698. union {
  699. int8_t int8[8];
  700. int64_t int64;
  701. };
  702. int8[0] = cast_to_int8(val.x.x);
  703. int8[1] = cast_to_int8(val.x.y);
  704. int8[2] = cast_to_int8(val.y.x);
  705. int8[3] = cast_to_int8(val.y.y);
  706. int8[4] = cast_to_int8(val.z.x);
  707. int8[5] = cast_to_int8(val.z.y);
  708. int8[6] = cast_to_int8(val.w.x);
  709. int8[7] = cast_to_int8(val.w.y);
  710. return int64;
  711. }
  712. ////////////////////////////////////////////////////////////////////////////////////////////////////
  713. template<typename T>
  714. inline __device__ __host__ T div_up(T m, T n)
  715. {
  716. return (m + n - 1) / n;
  717. }
  718. ////////////////////////////////////////////////////////////////////////////////////////////////////
  719. template<typename T, bool DO_CROSS_ATTENTION>
  720. inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
  721. int threads_per_value,
  722. int threads_per_block)
  723. {
  724. // The amount of shared memory needed to store the Q*K^T values in float.
  725. const int max_timesteps = min(params.timestep, params.memory_max_len);
  726. size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
  727. // The extra memory needed if we are not using floats for the final logits.
  728. size_t logits_sz = 0;
  729. #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
  730. if (sizeof(T) != 4) {
  731. // TDOD
  732. logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) :
  733. div_up(max_timesteps + 1, 4) * 4 * sizeof(T);
  734. }
  735. #endif
  736. // The total size needed during softmax.
  737. size_t softmax_sz = qk_sz + logits_sz;
  738. // The number of partial rows to reduce in the final reduction.
  739. int rows_per_red = threads_per_block / threads_per_value;
  740. // The amount of storage needed to finalize the outputs.
  741. size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2;
  742. size_t transpose_rotary_size = 0;
  743. if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
  744. transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T);
  745. }
  746. // The max.
  747. return max(max(softmax_sz, red_sz), transpose_rotary_size);
  748. }
  749. ////////////////////////////////////////////////////////////////////////////////////////////////////
  750. inline __device__ constexpr uint32_t shfl_mask(int threads)
  751. {
  752. return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
  753. }
  754. ////////////////////////////////////////////////////////////////////////////////////////////////////
  755. template<
  756. // The type of the inputs. Supported types: float and half.
  757. typename T,
  758. // The hidden dimension per head.
  759. int Dh,
  760. int Dh_MAX,
  761. // The number of threads per key.
  762. int THREADS_PER_KEY,
  763. // The number of threads per value.
  764. int THREADS_PER_VALUE,
  765. // The number of threads in a threadblock.
  766. int THREADS_PER_BLOCK,
  767. bool DO_CROSS_ATTENTION>
  768. __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, DO_CROSS_ATTENTION> params)
  769. {
  770. // Make sure the hidden dimension per head is a multiple of the number of threads per key.
  771. static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
  772. // Make sure the hidden dimension per head is a multiple of the number of threads per value.
  773. static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
  774. // The size of a warp.
  775. constexpr int WARP_SIZE = 32;
  776. // The number of warps in a threadblock.
  777. constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
  778. // Use smem_size_in_bytes (above) to determine the amount of shared memory.
  779. extern __shared__ char smem_[];
  780. // The shared memory for the Q*K^T values and partial logits in softmax.
  781. float* qk_smem = reinterpret_cast<float*>(smem_);
  782. // The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
  783. char* logits_smem_ = smem_;
  784. #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
  785. if (sizeof(T) != 4) {
  786. // TODO - change to tlength
  787. const int max_timesteps = min(params.timestep, params.memory_max_len);
  788. logits_smem_ +=
  789. (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
  790. }
  791. T* logits_smem = reinterpret_cast<T*>(logits_smem_);
  792. #else
  793. float* logits_smem = reinterpret_cast<float*>(logits_smem_);
  794. #endif
  795. // The shared memory to do the final reduction for the output values. Reuse qk_smem.
  796. T* out_smem = reinterpret_cast<T*>(smem_);
  797. // The shared memory buffers for the block-wide reductions. One for max, one for sum.
  798. __shared__ float red_smem[WARPS_PER_BLOCK * 2];
  799. // A vector of Q or K elements for the current timestep.
  800. using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
  801. // Use alignment for safely casting the shared buffers as Qk_vec.
  802. // Shared memory to store Q inputs.
  803. __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
  804. // This is one of the reasons we should have a separate kernel for cross attention
  805. __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
  806. // A vector of Q or K elements for the current timestep.
  807. using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
  808. // The number of elements per vector.
  809. constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
  810. // Make sure the hidden size per head is a multiple of the vector size.
  811. static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
  812. // We will use block wide reduction if needed
  813. // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
  814. // The number of vectors per warp.
  815. constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
  816. // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
  817. // owns x elements, we have to decompose the linear index into chunks of x values and the posi-
  818. // tion of the thread in that chunk.
  819. // The number of elements in a chunk of 16B (that's the x in the above formula).
  820. constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
  821. // The number of K vectors in 16B.
  822. constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
  823. // The batch/beam idx
  824. const int bi = blockIdx.y;
  825. if (params.finished != nullptr && params.finished[bi] == true) {
  826. return;
  827. }
  828. // The beam idx
  829. const int beami = bi % params.beam_width;
  830. // The "beam-aware" batch idx
  831. const int bbi = bi / params.beam_width;
  832. // The head.
  833. const int hi = blockIdx.x;
  834. // Combine the batch and the head indices.
  835. const int bhi = bi * params.num_heads + hi;
  836. // Combine the "beam-aware" batch idx and the head indices.
  837. const int bbhi = bbi * params.beam_width * params.num_heads + hi;
  838. // The thread in the block.
  839. const int tidx = threadIdx.x;
  840. const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);
  841. // While doing the product Q*K^T for the different keys we track the max.
  842. float qk_max = -FLT_MAX;
  843. float qk = 0.0F;
  844. int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
  845. const size_t bi_seq_len_offset = bi * params.memory_max_len;
  846. // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
  847. int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
  848. (params.length_per_sample == nullptr) ?
  849. params.timestep :
  850. params.length_per_sample[bi] + params.max_prefix_prompt_length;
  851. const int first_step = max(0, tlength + 1 - params.memory_max_len);
  852. const int tlength_circ = tlength % params.memory_max_len;
  853. // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
  854. const bool is_masked = tidx >= QK_VECS_PER_WARP;
  855. // The offset in the Q and K buffer also accounts for the batch.
  856. int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
  857. // The offset in the bias buffer.
  858. int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
  859. const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
  860. const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
  861. // Trigger the loads from the Q and K buffers.
  862. Qk_vec q;
  863. zero(q);
  864. if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
  865. if (params.int8_mode == 2) {
  866. using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
  867. using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
  868. const auto q_scaling = params.qkv_scale_out[0];
  869. const auto q_quant =
  870. *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[qk_offset]);
  871. convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
  872. }
  873. else {
  874. q = *reinterpret_cast<const Qk_vec*>(&params.q[qk_offset]);
  875. }
  876. }
  877. Qk_vec k;
  878. zero(k);
  879. if (DO_CROSS_ATTENTION) {
  880. // The 16B chunk written by the thread.
  881. int co = tidx / QK_VECS_IN_16B;
  882. // The position of the thread in that 16B chunk.
  883. int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
  884. // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
  885. int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
  886. // params.timestep*QK_ELTS_IN_16B +
  887. tlength * QK_ELTS_IN_16B + ci;
  888. k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
  889. *reinterpret_cast<const Qk_vec*>(&params.k_cache[offset]) :
  890. k;
  891. }
  892. else {
  893. if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
  894. if (params.int8_mode == 2) {
  895. using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
  896. using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
  897. const auto k_scaling = params.qkv_scale_out[1];
  898. const auto k_quant =
  899. *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);
  900. convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
  901. }
  902. else {
  903. k = *reinterpret_cast<const Qk_vec*>(&params.k[qk_offset]);
  904. }
  905. }
  906. }
  907. // Trigger the loads from the Q and K bias buffers.
  908. Qk_vec q_bias;
  909. zero(q_bias);
  910. q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
  911. *reinterpret_cast<const Qk_vec*>(&params.q_bias[qk_bias_offset]) :
  912. q_bias;
  913. Qk_vec k_bias;
  914. zero(k_bias);
  915. if (handle_kv) {
  916. k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
  917. *reinterpret_cast<const Qk_vec*>(&params.k_bias[qk_bias_offset]) :
  918. k_bias;
  919. }
  920. // Computes the Q/K values with bias.
  921. q = add(q, q_bias);
  922. if (handle_kv) {
  923. k = add(k, k_bias);
  924. }
  925. if (do_ia3 && !is_masked) {
  926. k = mul<Qk_vec, Qk_vec, Qk_vec>(
  927. k,
  928. *reinterpret_cast<const Qk_vec*>(
  929. &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]));
  930. }
  931. // Padded len
  932. const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
  933. if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
  934. if (handle_kv) {
  935. apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
  936. }
  937. else {
  938. apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len);
  939. }
  940. }
  941. else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
  942. const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
  943. T* q_smem = reinterpret_cast<T*>(smem_);
  944. T* k_smem = q_smem + params.rotary_embedding_dim;
  945. const int half_rotary_dim = params.rotary_embedding_dim / 2;
  946. const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim;
  947. const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim;
  948. const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
  949. assert(half_rotary_dim % QK_VEC_SIZE == 0);
  950. if (do_rotary) {
  951. *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
  952. if (handle_kv) {
  953. *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
  954. }
  955. }
  956. __syncthreads();
  957. const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
  958. constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
  959. if (do_rotary) {
  960. mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
  961. if (handle_kv) {
  962. mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
  963. mmha::apply_rotary_embedding(
  964. q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len);
  965. mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
  966. }
  967. else {
  968. mmha::apply_rotary_embedding(
  969. q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep);
  970. }
  971. mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
  972. }
  973. __syncthreads();
  974. if (do_rotary) {
  975. q = *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx);
  976. if (handle_kv) {
  977. k = *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx);
  978. }
  979. }
  980. __syncthreads();
  981. }
  982. if (!is_masked) {
  983. // Store the Q values to shared memory.
  984. *reinterpret_cast<Qk_vec*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
  985. // Store Dh values of k_bias into smem, since will need to add later
  986. // if params.timestep == 0
  987. if (DO_CROSS_ATTENTION && params.timestep == 0) {
  988. *reinterpret_cast<Qk_vec*>(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
  989. }
  990. // Write the K values to the global memory cache.
  991. //
  992. // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
  993. // system. We designed it this way as it allows much better memory loads (and there are many
  994. // more loads) + the stores are really "write and forget" since we won't need the ack before
  995. // the end of the kernel. There's plenty of time for the transactions to complete.
  996. // The 16B chunk written by the thread.
  997. int co = tidx / QK_VECS_IN_16B;
  998. // The position of the thread in that 16B chunk.
  999. int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
  1000. // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
  1001. int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
  1002. // params.timestep*QK_ELTS_IN_16B +
  1003. tlength_circ * QK_ELTS_IN_16B + ci;
  1004. if (handle_kv) {
  1005. // Trigger the stores to global memory.
  1006. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
  1007. *reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k;
  1008. }
  1009. }
  1010. // Compute \sum_i Q[i] * K^T[i] for the current timestep.
  1011. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  1012. using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec>::Type;
  1013. #else
  1014. using Qk_vec_acum = Qk_vec;
  1015. #endif
  1016. qk = dot<Qk_vec_acum, Qk_vec>(q, k);
  1017. if (QK_VECS_PER_WARP <= WARP_SIZE) {
  1018. #pragma unroll
  1019. for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
  1020. qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
  1021. }
  1022. }
  1023. }
  1024. if (QK_VECS_PER_WARP > WARP_SIZE) {
  1025. constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
  1026. qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
  1027. }
  1028. // Store that value in shared memory. Keep the Q*K^T value in register for softmax.
  1029. if (tidx == 0) {
  1030. // Normalize qk.
  1031. qk *= params.inv_sqrt_dh;
  1032. if (params.relative_attention_bias != nullptr) {
  1033. qk = add(qk,
  1034. params.relative_attention_bias[hi * params.relative_attention_bias_stride
  1035. * params.relative_attention_bias_stride
  1036. + (tlength - padd_len) * params.relative_attention_bias_stride
  1037. + (tlength - padd_len)]);
  1038. }
  1039. // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
  1040. qk_max = qk;
  1041. qk_smem[tlength - first_step] = qk;
  1042. // qk_smem[params.timestep] = qk;
  1043. }
  1044. // Make sure the data is in shared memory.
  1045. __syncthreads();
  1046. // The type of queries and keys for the math in the Q*K^T product.
  1047. using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
  1048. // The number of elements per vector.
  1049. constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
  1050. // Make sure the hidden size per head is a multiple of the vector size.
  1051. static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
  1052. // The number of elements per thread.
  1053. constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
  1054. // The number of vectors per thread.
  1055. constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
  1056. // The position the first key loaded by each thread from the cache buffer (for this B * H).
  1057. int ko = tidx / THREADS_PER_KEY;
  1058. // The position of the thread in the chunk of keys.
  1059. int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
  1060. static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
  1061. // Load the Q values from shared memory. The values are reused during the loop on K.
  1062. K_vec q_vec[K_VECS_PER_THREAD];
  1063. #pragma unroll
  1064. for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
  1065. q_vec[ii] = *reinterpret_cast<const K_vec*>(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
  1066. }
  1067. K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1];
  1068. if (DO_CROSS_ATTENTION && params.timestep == 0) {
  1069. #pragma unroll
  1070. for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
  1071. k_bias_vec[ii] = *reinterpret_cast<const K_vec*>(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
  1072. }
  1073. }
  1074. // The number of timesteps loaded per iteration.
  1075. constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
  1076. // The number of keys per warp.
  1077. constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
  1078. // The base pointer for the key in the cache buffer.
  1079. T* k_cache = &params.k_cache[bhi * params.memory_max_len * Dh + ki];
  1080. // Base pointer for the beam's batch, before offsetting with indirection buffer
  1081. T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
  1082. // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
  1083. // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
  1084. int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
  1085. // prefix prompt length if has
  1086. const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
  1087. // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
  1088. const bool has_beams = params.cache_indir != nullptr;
  1089. const int* beam_indices = has_beams ? &params.cache_indir[bi_seq_len_offset] : nullptr;
  1090. for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
  1091. const int ti_circ = ti % params.memory_max_len;
  1092. // The keys loaded from the key cache.
  1093. K_vec k[K_VECS_PER_THREAD];
  1094. K_vec k_vec_zero;
  1095. zero(k_vec_zero);
  1096. #pragma unroll
  1097. for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
  1098. int jj = ii * params.memory_max_len + ti_circ;
  1099. // if( ti < params.timestep ) {
  1100. const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
  1101. if (ti < tlength) {
  1102. if (!within_bounds) {
  1103. k[ii] = k_vec_zero;
  1104. }
  1105. else {
  1106. if (has_beams) {
  1107. const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
  1108. k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]);
  1109. }
  1110. else {
  1111. k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[jj * QK_ELTS_IN_16B]);
  1112. }
  1113. }
  1114. // add bias and update k_cache
  1115. if (DO_CROSS_ATTENTION && params.timestep == 0) {
  1116. k[ii] = add(k[ii], k_bias_vec[ii]);
  1117. if (do_ia3) {
  1118. k[ii] = mul<K_vec, K_vec, K_vec>(
  1119. k[ii],
  1120. *reinterpret_cast<const K_vec*>(
  1121. &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki
  1122. + ii * THREADS_PER_KEY * K_VEC_SIZE]));
  1123. }
  1124. if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) {
  1125. *reinterpret_cast<K_vec*>(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii];
  1126. }
  1127. }
  1128. }
  1129. }
  1130. // Perform the dot product and normalize qk.
  1131. //
  1132. // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
  1133. float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k) * params.inv_sqrt_dh;
  1134. bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
  1135. // Store the product to shared memory. There's one qk value per timestep. Update the max.
  1136. // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
  1137. if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
  1138. if (params.relative_attention_bias != nullptr) {
  1139. qk = add(qk,
  1140. params.relative_attention_bias[hi * params.relative_attention_bias_stride
  1141. * params.relative_attention_bias_stride
  1142. + tlength * params.relative_attention_bias_stride + ti]);
  1143. }
  1144. if (params.linear_bias_slopes != nullptr) {
  1145. // Apply the linear position bias: (ki - qi) * slope[hi].
  1146. // The padding token locates between the input context and the generated tokens.
  1147. // We need to remove the number of padding tokens in the distance computation.
  1148. // ti : 0 1 2 3 4 5 6 7 8 9(tlength)
  1149. // token: i i i i p p p o o o where i=input, p=pad, o=output.
  1150. // e.g. ti = 2, dist = (9 - 3) - 2 = 4.
  1151. int max_context_length = params.max_prefix_prompt_length + params.max_input_length;
  1152. float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength;
  1153. qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
  1154. }
  1155. qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
  1156. qk_smem[ti - first_step] = qk;
  1157. }
  1158. }
  1159. // Perform the final reduction to compute the max inside each warp.
  1160. //
  1161. // NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
  1162. // group so it's not needed to run the reduction inside the group (again).
  1163. #pragma unroll
  1164. for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
  1165. qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
  1166. }
  1167. // Decompose the thread index into warp and lane.
  1168. const int warp = tidx / WARP_SIZE;
  1169. const int lane = tidx % WARP_SIZE;
  1170. // The warp leader writes the max to shared memory.
  1171. if (lane == 0) {
  1172. red_smem[warp] = qk_max;
  1173. }
  1174. // Make sure the products are in shared memory.
  1175. __syncthreads();
  1176. // The warps finalize the reduction.
  1177. qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
  1178. #pragma unroll
  1179. for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
  1180. qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
  1181. }
  1182. // Broadcast to all the threads in the warp.
  1183. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
  1184. // Compute the logits and start the sum.
  1185. float sum = 0.f;
  1186. // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
  1187. for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
  1188. bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
  1189. float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
  1190. sum += logit;
  1191. qk_smem[ti - first_step] = logit;
  1192. }
  1193. // Compute the sum.
  1194. sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
  1195. // Normalize the logits.
  1196. float inv_sum = __fdividef(1.f, sum + 1.e-6f);
  1197. // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
  1198. const size_t cross_attention_out_offset =
  1199. params.is_return_cross_attentions ?
  1200. bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len :
  1201. 0;
  1202. for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
  1203. float logit = qk_smem[ti - first_step] * inv_sum;
  1204. if (params.is_return_cross_attentions) {
  1205. params.cross_attention_out[cross_attention_out_offset + ti] = logit;
  1206. }
  1207. convert_from_float(logits_smem[ti - first_step], logit);
  1208. }
  1209. // Put Values part below so we leverage __syncthreads
  1210. // from the previous step
  1211. // The number of elements per vector.
  1212. constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
  1213. // A vector of V elements for the current timestep.
  1214. using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
  1215. // The value computed by this thread.
  1216. int vo = tidx / THREADS_PER_VALUE;
  1217. // The hidden dimensions computed by this particular thread.
  1218. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
  1219. // The base pointer for the value in the cache buffer.
  1220. T* v_cache = &params.v_cache[bhi * params.memory_max_len * Dh + vi];
  1221. // Base pointer for the beam's batch, before offsetting with indirection buffer
  1222. T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
  1223. // The number of values processed per iteration of the loop.
  1224. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
  1225. // One group of threads computes the product(s) for the current timestep.
  1226. V_vec v_bias;
  1227. zero(v_bias);
  1228. // if( vo == params.timestep % V_PER_ITER ) {
  1229. if (Dh == Dh_MAX || vi < Dh) {
  1230. if (handle_kv) {
  1231. if (vo == tlength % V_PER_ITER) {
  1232. // Trigger the loads from the V bias buffer.
  1233. if (params.v_bias != nullptr) {
  1234. v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi * Dh + vi]);
  1235. }
  1236. if (DO_CROSS_ATTENTION) {
  1237. *reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
  1238. }
  1239. }
  1240. }
  1241. }
  1242. // From previous, before values, step
  1243. // Also make sure the logits are in shared memory.
  1244. __syncthreads();
  1245. // Values continued
  1246. #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
  1247. using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
  1248. #else
  1249. using V_vec_acum = V_vec;
  1250. #endif
  1251. // The partial outputs computed by each thread.
  1252. V_vec_acum out;
  1253. zero(out);
  1254. // Loop over the timesteps to compute the partial outputs.
  1255. // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
  1256. if (Dh == Dh_MAX || vi < Dh) {
  1257. for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
  1258. const int ti_circ = ti % params.memory_max_len;
  1259. // Fetch offset based on cache_indir when beam sampling
  1260. const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
  1261. const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh;
  1262. // Load the values from the cache.
  1263. V_vec v = *reinterpret_cast<const V_vec*>(&v_cache_batch[beam_offset + ti_circ * Dh]);
  1264. if (DO_CROSS_ATTENTION && params.timestep == 0) {
  1265. v = add(v, *reinterpret_cast<V_vec*>(&bias_smem[vi]));
  1266. if (do_ia3) {
  1267. v = mul<V_vec, V_vec, V_vec>(
  1268. v,
  1269. *reinterpret_cast<const V_vec*>(
  1270. &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
  1271. }
  1272. *reinterpret_cast<V_vec*>(&v_cache[ti * Dh]) = v;
  1273. }
  1274. // Load the logits from shared memory.
  1275. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
  1276. float logit = logits_smem[ti - first_step];
  1277. out = fma(logit, cast_to_float(v), out);
  1278. #else
  1279. T logit = logits_smem[ti - first_step];
  1280. // Update the partial sums.
  1281. out = fma(logit, v, out);
  1282. #endif
  1283. }
  1284. }
  1285. // One group of threads computes the product(s) for the current timestep.
  1286. // if( vo == params.timestep % V_PER_ITER ) {
  1287. if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
  1288. V_vec v;
  1289. if (DO_CROSS_ATTENTION) {
  1290. v = *reinterpret_cast<const V_vec*>(&v_cache[tlength * Dh]);
  1291. }
  1292. else {
  1293. // Trigger the loads from the V buffer.
  1294. const auto v_offset = qkv_base_offset + vi;
  1295. if (params.int8_mode == 2) {
  1296. using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
  1297. using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
  1298. const auto v_scaling = params.qkv_scale_out[2];
  1299. const auto v_quant =
  1300. *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
  1301. convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
  1302. }
  1303. else {
  1304. v = *reinterpret_cast<const V_vec*>(&params.v[v_offset]);
  1305. }
  1306. // Trigger the loads from the V bias buffer.
  1307. // V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
  1308. }
  1309. // Compute the V values with bias.
  1310. if (handle_kv) {
  1311. v = add(v, v_bias);
  1312. if (do_ia3) {
  1313. v = mul<V_vec, V_vec, V_vec>(
  1314. v,
  1315. *reinterpret_cast<const V_vec*>(
  1316. &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
  1317. }
  1318. // Store the values with bias back to global memory in the cache for V.
  1319. //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
  1320. *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
  1321. }
  1322. // Initialize the output value with the current timestep.
  1323. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
  1324. // out = fma(logits_smem[params.timestep], cast_to_float(v), out);
  1325. out = fma(logits_smem[tlength - first_step], cast_to_float(v), out);
  1326. #else
  1327. // out = fma(logits_smem[params.timestep], v, out);
  1328. out = fma(logits_smem[tlength - first_step], v, out);
  1329. #endif
  1330. }
  1331. // Make sure we can start writing to shared memory.
  1332. __syncthreads();
  1333. // Run the final reduction amongst the different groups computing different partial outputs.
  1334. if (Dh == Dh_MAX || vi < Dh) {
  1335. #pragma unroll
  1336. for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
  1337. // The midpoint in the number of active groups.
  1338. int midpoint = active_groups / 2;
  1339. // The upper part of active threads store to shared memory.
  1340. if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
  1341. #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
  1342. convert_from_float(*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
  1343. #else
  1344. *reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
  1345. #endif
  1346. }
  1347. __syncthreads();
  1348. // The bottom warps update their values.
  1349. if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
  1350. out = add(*reinterpret_cast<const V_vec*>(&out_smem[vo * Dh + vi]), out);
  1351. }
  1352. __syncthreads();
  1353. }
  1354. }
  1355. // Output the final values.
  1356. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
  1357. #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
  1358. if (params.int8_mode == 2) {
  1359. using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
  1360. out = mul<V_vec_acum, float>(*params.attention_out_scale, out);
  1361. *reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
  1362. cast_to_int8(out);
  1363. }
  1364. else {
  1365. convert_from_float(*reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]), out);
  1366. }
  1367. #else
  1368. // TODO: support int8_mode?
  1369. *reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]) = out;
  1370. #endif
  1371. }
  1372. }
  1373. ////////////////////////////////////////////////////////////////////////////////////////////////////
  1374. } // namespace mmha
  1375. ////////////////////////////////////////////////////////////////////////////////////////////////////
  1376. template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
  1377. void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream);