decoder_masked_multihead_attention_template.hpp 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619
  1. // Downloaded from from FasterTransformer v5.2.1
  2. // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  3. /*
  4. * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #pragma once
  19. #include "decoder_masked_multihead_attention.h"
  20. #include "decoder_masked_multihead_attention_utils.h"
  21. #include "cuda_bf16_wrapper.h"
  22. #include "cuda_bf16_fallbacks.cuh"
  23. #include <assert.h>
  24. #include <float.h>
  25. #include <type_traits>
  26. // #define MMHA_USE_HMMA_FOR_REDUCTION
  27. // Below are knobs to extend FP32 accumulation for higher FP16 accuracy
  28. // Does not seem to affect the accuracy that much
  29. #define MMHA_USE_FP32_ACUM_FOR_FMA
  30. // Seems to slightly improve the accuracy
  31. #define MMHA_USE_FP32_ACUM_FOR_OUT
  32. #if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
  33. // Does not seem to improve the accuracy
  34. //#define MMHA_USE_FP32_ACUM_FOR_LOGITS
  35. #endif
  36. namespace mmha {
  37. ////////////////////////////////////////////////////////////////////////////////////////////////////
  38. //
  39. // We use the following terminology to describe the different dimensions.
  40. //
  41. // B: Batch size (number of sequences),
  42. // L: Sequence length,
  43. // D: Hidden dimension,
  44. // H: Number of heads,
  45. // Dh: Hidden dimension per head - Dh = D / H.
  46. //
  47. // The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
  48. // 64, 128 and 256 threads per block.
  49. //
  50. // Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
  51. // compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
  52. // cache buffer helps with memory accesses and contains keys with bias.
  53. //
  54. // The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
  55. // x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
  56. // values for x are chosen to create chunks of 16 bytes.
  57. //
  58. // The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
  59. // depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
  60. // the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
  61. // HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
  62. //
  63. // After that loop, a parallel softmax is computed across the different Q * K^T values stored in
  64. // shared memory.
  65. //
  66. // The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
  67. // timesteps are computed by loop iteration. As with the keys, the values are read from a cache
  68. // except for the current timestep. The layout of the cache buffer for the values is much simpler
  69. // as it is [B, H, L, Dh].
  70. //
  71. ////////////////////////////////////////////////////////////////////////////////////////////////////
  72. template<typename T, int Dh>
  73. struct Qk_vec_ {
  74. };
  75. template<>
  76. struct Qk_vec_<float, 32> {
  77. using Type = float;
  78. };
  79. template<>
  80. struct Qk_vec_<float, 64> {
  81. using Type = float2;
  82. };
  83. template<>
  84. struct Qk_vec_<float, 128> {
  85. using Type = float4;
  86. };
  87. template<>
  88. struct Qk_vec_<float, 256> {
  89. using Type = float4;
  90. };
  91. template<>
  92. struct Qk_vec_<uint16_t, 32> {
  93. using Type = uint32_t;
  94. };
  95. template<>
  96. struct Qk_vec_<uint16_t, 64> {
  97. using Type = uint32_t;
  98. };
  99. template<>
  100. struct Qk_vec_<uint16_t, 128> {
  101. using Type = uint2;
  102. };
  103. template<>
  104. struct Qk_vec_<uint16_t, 256> {
  105. using Type = uint4;
  106. };
  107. #ifdef ENABLE_BF16
  108. template<>
  109. struct Qk_vec_<__nv_bfloat16, 32> {
  110. using Type = __nv_bfloat162;
  111. };
  112. template<>
  113. struct Qk_vec_<__nv_bfloat16, 64> {
  114. using Type = __nv_bfloat162;
  115. };
  116. template<>
  117. struct Qk_vec_<__nv_bfloat16, 128> {
  118. using Type = bf16_4_t;
  119. };
  120. template<>
  121. struct Qk_vec_<__nv_bfloat16, 256> {
  122. using Type = bf16_8_t;
  123. };
  124. #endif // ENABLE_BF16
  125. ////////////////////////////////////////////////////////////////////////////////////////////////////
  126. template<typename T, int THREADS_PER_KEY>
  127. struct K_vec_ {
  128. };
  129. template<>
  130. struct K_vec_<float, 4> {
  131. using Type = float;
  132. };
  133. template<>
  134. struct K_vec_<float, 2> {
  135. using Type = float2;
  136. };
  137. template<>
  138. struct K_vec_<float, 1> {
  139. using Type = float4;
  140. };
  141. template<>
  142. struct K_vec_<uint16_t, 4> {
  143. using Type = uint32_t;
  144. };
  145. template<>
  146. struct K_vec_<uint16_t, 2> {
  147. using Type = uint2;
  148. };
  149. template<>
  150. struct K_vec_<uint16_t, 1> {
  151. using Type = uint4;
  152. };
  153. #ifdef ENABLE_BF16
  154. template<>
  155. struct K_vec_<__nv_bfloat16, 4> {
  156. using Type = __nv_bfloat162;
  157. };
  158. template<>
  159. struct K_vec_<__nv_bfloat16, 2> {
  160. using Type = bf16_4_t;
  161. };
  162. template<>
  163. struct K_vec_<__nv_bfloat16, 1> {
  164. using Type = bf16_8_t;
  165. };
  166. #endif // ENABLE_BF16
  167. ////////////////////////////////////////////////////////////////////////////////////////////////////
  168. template<typename T, int V_VEC_SIZE>
  169. struct V_vec_ {
  170. };
  171. template<>
  172. struct V_vec_<float, 1> {
  173. using Type = float;
  174. };
  175. template<>
  176. struct V_vec_<float, 2> {
  177. using Type = float2;
  178. };
  179. template<>
  180. struct V_vec_<float, 4> {
  181. using Type = float4;
  182. };
  183. template<>
  184. struct V_vec_<uint16_t, 2> {
  185. using Type = uint32_t;
  186. };
  187. template<>
  188. struct V_vec_<uint16_t, 4> {
  189. using Type = uint2;
  190. };
  191. template<>
  192. struct V_vec_<uint16_t, 8> {
  193. using Type = uint4;
  194. };
  195. #ifdef ENABLE_BF16
  196. template<>
  197. struct V_vec_<__nv_bfloat16, 2> {
  198. using Type = __nv_bfloat162;
  199. };
  200. template<>
  201. struct V_vec_<__nv_bfloat16, 4> {
  202. using Type = bf16_4_t;
  203. };
  204. template<>
  205. struct V_vec_<__nv_bfloat16, 8> {
  206. using Type = bf16_8_t;
  207. };
  208. #endif // ENABLE_BF16
  209. ////////////////////////////////////////////////////////////////////////////////////////////////////
  210. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  211. template<typename T>
  212. struct Qk_vec_acum_fp32_ {
  213. };
  214. template<>
  215. struct Qk_vec_acum_fp32_<float> {
  216. using Type = float;
  217. };
  218. template<>
  219. struct Qk_vec_acum_fp32_<float2> {
  220. using Type = float2;
  221. };
  222. template<>
  223. struct Qk_vec_acum_fp32_<float4> {
  224. using Type = float4;
  225. };
  226. // template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
  227. template<>
  228. struct Qk_vec_acum_fp32_<uint32_t> {
  229. using Type = float2;
  230. };
  231. template<>
  232. struct Qk_vec_acum_fp32_<uint2> {
  233. using Type = Float4_;
  234. };
  235. template<>
  236. struct Qk_vec_acum_fp32_<uint4> {
  237. using Type = Float8_;
  238. };
  239. template<>
  240. struct Qk_vec_acum_fp32_<__nv_bfloat16> {
  241. using Type = float;
  242. };
  243. template<>
  244. struct Qk_vec_acum_fp32_<__nv_bfloat162> {
  245. using Type = float2;
  246. };
  247. template<>
  248. struct Qk_vec_acum_fp32_<bf16_4_t> {
  249. using Type = Float4_;
  250. };
  251. template<>
  252. struct Qk_vec_acum_fp32_<bf16_8_t> {
  253. using Type = Float8_;
  254. };
  255. ////////////////////////////////////////////////////////////////////////////////////////////////////
  256. template<typename T>
  257. struct K_vec_acum_fp32_ {
  258. };
  259. template<>
  260. struct K_vec_acum_fp32_<float> {
  261. using Type = float;
  262. };
  263. template<>
  264. struct K_vec_acum_fp32_<float2> {
  265. using Type = float2;
  266. };
  267. template<>
  268. struct K_vec_acum_fp32_<float4> {
  269. using Type = float4;
  270. };
  271. template<>
  272. struct K_vec_acum_fp32_<uint32_t> {
  273. using Type = float2;
  274. };
  275. template<>
  276. struct K_vec_acum_fp32_<uint2> {
  277. using Type = Float4_;
  278. };
  279. template<>
  280. struct K_vec_acum_fp32_<uint4> {
  281. using Type = Float8_;
  282. };
  283. template<>
  284. struct K_vec_acum_fp32_<__nv_bfloat16> {
  285. using Type = float;
  286. };
  287. template<>
  288. struct K_vec_acum_fp32_<__nv_bfloat162> {
  289. using Type = float2;
  290. };
  291. template<>
  292. struct K_vec_acum_fp32_<bf16_4_t> {
  293. using Type = Float4_;
  294. };
  295. template<>
  296. struct K_vec_acum_fp32_<bf16_8_t> {
  297. using Type = Float8_;
  298. };
  299. #endif
  300. ////////////////////////////////////////////////////////////////////////////////////////////////////
  301. #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
  302. template<typename T>
  303. struct V_vec_acum_fp32_ {
  304. };
  305. template<>
  306. struct V_vec_acum_fp32_<float> {
  307. using Type = float;
  308. };
  309. template<>
  310. struct V_vec_acum_fp32_<float2> {
  311. using Type = float2;
  312. };
  313. template<>
  314. struct V_vec_acum_fp32_<float4> {
  315. using Type = float4;
  316. };
  317. template<>
  318. struct V_vec_acum_fp32_<uint32_t> {
  319. using Type = float2;
  320. };
  321. template<>
  322. struct V_vec_acum_fp32_<uint2> {
  323. using Type = Float4_;
  324. };
  325. template<>
  326. struct V_vec_acum_fp32_<uint4> {
  327. using Type = Float8_;
  328. };
  329. #ifdef ENABLE_BF16
  330. template<>
  331. struct V_vec_acum_fp32_<__nv_bfloat162> {
  332. using Type = float2;
  333. };
  334. template<>
  335. struct V_vec_acum_fp32_<bf16_4_t> {
  336. using Type = Float4_;
  337. };
  338. template<>
  339. struct V_vec_acum_fp32_<bf16_8_t> {
  340. using Type = Float8_;
  341. };
  342. #endif // ENABLE_BF16
  343. #endif
  344. ////////////////////////////////////////////////////////////////////////////////////////////////////
  345. template<int THREADS_PER_KEY, typename K_vec, int N>
  346. inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N])
  347. {
  348. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  349. using K_vec_acum = typename K_vec_acum_fp32_<K_vec>::Type;
  350. #else
  351. using K_vec_acum = K_vec;
  352. #endif
  353. // Compute the parallel products for Q*K^T (treat vector lanes separately).
  354. K_vec_acum qk_vec = mul<K_vec_acum, K_vec, K_vec>(q[0], k[0]);
  355. #pragma unroll
  356. for (int ii = 1; ii < N; ++ii) {
  357. qk_vec = fma(q[ii], k[ii], qk_vec);
  358. }
  359. // Finalize the reduction across lanes.
  360. float qk = sum(qk_vec);
  361. #pragma unroll
  362. for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) {
  363. qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
  364. }
  365. return qk;
  366. }
  367. ////////////////////////////////////////////////////////////////////////////////////////////////////
  368. template<typename T, int THREADS_PER_KEY>
  369. struct Qk_dot {
  370. template<typename K_vec, int N>
  371. static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N])
  372. {
  373. return qk_dot_<THREADS_PER_KEY>(q, k);
  374. }
  375. };
  376. ////////////////////////////////////////////////////////////////////////////////////////////////////
  377. inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
  378. {
  379. float4 c;
  380. float zero = 0.f;
  381. asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n"
  382. " {%0, %1, %2, %3}, \n"
  383. " {%4, %5}, \n"
  384. " {%6}, \n"
  385. " {%7, %7, %7, %7}; \n"
  386. : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
  387. : "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
  388. return c;
  389. }
  390. ////////////////////////////////////////////////////////////////////////////////////////////////////
  391. template<int N>
  392. inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N])
  393. {
  394. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
  395. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  396. using K_vec_acum = typename K_vec_acum_fp32_<uint32_t>::Type;
  397. #else
  398. using K_vec_acum = uint32_t;
  399. #endif
  400. K_vec_acum qk_vec = mul<K_vec_acum, uint32_t, uint32_t>(q[0], k[0]);
  401. #pragma unroll
  402. for (int ii = 1; ii < N; ++ii) {
  403. qk_vec = fma(q[ii], k[ii], qk_vec);
  404. }
  405. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  406. uint32_t qk_vec_ = float2_to_half2(qk_vec);
  407. return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x;
  408. #else
  409. return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x;
  410. #endif
  411. #else
  412. return 0.f;
  413. #endif
  414. }
  415. ////////////////////////////////////////////////////////////////////////////////////////////////////
  416. template<>
  417. struct Qk_dot<uint16_t, 4> {
  418. template<int N>
  419. static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N])
  420. {
  421. #if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
  422. return qk_hmma_dot_(q, k);
  423. #else
  424. return qk_dot_<4>(q, k);
  425. #endif // defined MMHA_USE_HMMA_FOR_REDUCTION
  426. }
  427. };
  428. ////////////////////////////////////////////////////////////////////////////////////////////////////
  429. template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
  430. inline __device__ float block_sum(float* red_smem, float sum)
  431. {
  432. // Decompose the thread index into warp / lane.
  433. int warp = threadIdx.x / WARP_SIZE;
  434. int lane = threadIdx.x % WARP_SIZE;
  435. // Compute the sum per warp.
  436. #pragma unroll
  437. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  438. sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
  439. }
  440. // Warp leaders store the data to shared memory.
  441. if (lane == 0) {
  442. red_smem[warp] = sum;
  443. }
  444. // Make sure the data is in shared memory.
  445. __syncthreads();
  446. // The warps compute the final sums.
  447. if (lane < WARPS_PER_BLOCK) {
  448. sum = red_smem[lane];
  449. }
  450. // Parallel reduction inside the warp.
  451. #pragma unroll
  452. for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
  453. sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
  454. }
  455. // Broadcast to other threads.
  456. return __shfl_sync(uint32_t(-1), sum, 0);
  457. }
  458. ////////////////////////////////////////////////////////////////////////////////////////////////////
  459. inline __device__ void convert_from_float(float& dst, float src)
  460. {
  461. dst = src;
  462. }
  463. ////////////////////////////////////////////////////////////////////////////////////////////////////
  464. inline __device__ void convert_from_float(uint16_t& dst, float src)
  465. {
  466. dst = float_to_half(src);
  467. }
  468. ////////////////////////////////////////////////////////////////////////////////////////////////////
  469. inline __device__ void convert_from_float(uint32_t& dst, float2 src)
  470. {
  471. dst = float2_to_half2(src);
  472. }
  473. ////////////////////////////////////////////////////////////////////////////////////////////////////
  474. #ifdef ENABLE_BF16
  475. inline __device__ void convert_from_float(__nv_bfloat16& dst, float src)
  476. {
  477. dst = __float2bfloat16(src);
  478. }
  479. ////////////////////////////////////////////////////////////////////////////////////////////////////
  480. inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src)
  481. {
  482. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  483. dst = __float22bfloat162_rn(src);
  484. #else
  485. dst = __floats2bfloat162_rn(src.x, src.y);
  486. #endif
  487. }
  488. #endif // ENABLE_BF16
  489. ////////////////////////////////////////////////////////////////////////////////////////////////////
  490. inline __device__ void convert_from_float(uint2& dst, Float4_ src)
  491. {
  492. dst.x = float2_to_half2(src.x);
  493. dst.y = float2_to_half2(src.y);
  494. }
  495. ////////////////////////////////////////////////////////////////////////////////////////////////////
  496. inline __device__ void convert_from_float(uint2& dst, float4 src)
  497. {
  498. convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
  499. }
  500. ////////////////////////////////////////////////////////////////////////////////////////////////////
  501. inline __device__ void convert_from_float(uint4& dst, Float8_ src)
  502. {
  503. dst.x = float2_to_half2(src.x);
  504. dst.y = float2_to_half2(src.y);
  505. dst.z = float2_to_half2(src.z);
  506. dst.w = float2_to_half2(src.w);
  507. }
  508. ////////////////////////////////////////////////////////////////////////////////////////////////////
  509. #ifdef ENABLE_BF16
  510. inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src)
  511. {
  512. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  513. dst.x = __float22bfloat162_rn(src.x);
  514. dst.y = __float22bfloat162_rn(src.y);
  515. #else
  516. dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
  517. dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
  518. #endif
  519. }
  520. ////////////////////////////////////////////////////////////////////////////////////////////////////
  521. inline __device__ void convert_from_float(bf16_4_t& dst, float4 src)
  522. {
  523. convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)});
  524. }
  525. ////////////////////////////////////////////////////////////////////////////////////////////////////
  526. inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src)
  527. {
  528. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  529. dst.x = __float22bfloat162_rn(src.x);
  530. dst.y = __float22bfloat162_rn(src.y);
  531. dst.z = __float22bfloat162_rn(src.z);
  532. dst.w = __float22bfloat162_rn(src.w);
  533. #else
  534. dst.x = __floats2bfloat162_rn(src.x.x, src.x.y);
  535. dst.y = __floats2bfloat162_rn(src.y.x, src.y.y);
  536. dst.z = __floats2bfloat162_rn(src.z.x, src.z.y);
  537. dst.w = __floats2bfloat162_rn(src.w.x, src.w.y);
  538. #endif
  539. }
  540. #endif // ENABLE_BF16
  541. ////////////////////////////////////////////////////////////////////////////////////////////////////
  542. inline __device__ void convert_from_float(float2& dst, float2 src)
  543. {
  544. dst = src;
  545. }
  546. ////////////////////////////////////////////////////////////////////////////////////////////////////
  547. inline __device__ void convert_from_float(float4& dst, float4 src)
  548. {
  549. dst = src;
  550. }
  551. ////////////////////////////////////////////////////////////////////////////////////////////////////
  552. inline __device__ float convert_to_float(float4 u)
  553. {
  554. return u.x;
  555. }
  556. ////////////////////////////////////////////////////////////////////////////////////////////////////
  557. inline __device__ float convert_to_float(uint4 u)
  558. {
  559. float2 tmp = half2_to_float2(u.x);
  560. return tmp.x;
  561. }
  562. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
  563. ////////////////////////////////////////////////////////////////////////////////////////////////////
  564. inline __device__ float cast_to_float(float u)
  565. {
  566. return u;
  567. }
  568. ////////////////////////////////////////////////////////////////////////////////////////////////////
  569. inline __device__ float2 cast_to_float(float2 u)
  570. {
  571. return u;
  572. }
  573. ////////////////////////////////////////////////////////////////////////////////////////////////////
  574. inline __device__ float4 cast_to_float(float4 u)
  575. {
  576. return u;
  577. }
  578. ////////////////////////////////////////////////////////////////////////////////////////////////////
  579. inline __device__ Float4_ cast_to_float(Float4_ u)
  580. {
  581. return u;
  582. }
  583. ////////////////////////////////////////////////////////////////////////////////////////////////////
  584. inline __device__ Float8_ cast_to_float(Float8_ u)
  585. {
  586. return u;
  587. }
  588. ////////////////////////////////////////////////////////////////////////////////////////////////////
  589. inline __device__ float2 cast_to_float(uint32_t u)
  590. {
  591. return half2_to_float2(u);
  592. }
  593. ////////////////////////////////////////////////////////////////////////////////////////////////////
  594. inline __device__ Float4_ cast_to_float(uint2 u)
  595. {
  596. Float4_ tmp;
  597. tmp.x = half2_to_float2(u.x);
  598. tmp.y = half2_to_float2(u.y);
  599. return tmp;
  600. }
  601. ////////////////////////////////////////////////////////////////////////////////////////////////////
  602. inline __device__ Float8_ cast_to_float(uint4 u)
  603. {
  604. Float8_ tmp;
  605. tmp.x = half2_to_float2(u.x);
  606. tmp.y = half2_to_float2(u.y);
  607. tmp.z = half2_to_float2(u.z);
  608. tmp.w = half2_to_float2(u.w);
  609. return tmp;
  610. }
  611. #endif
  612. ////////////////////////////////////////////////////////////////////////////////////////////////////
  613. inline __device__ float float_from_int8(int8_t u)
  614. {
  615. return u;
  616. }
  617. ////////////////////////////////////////////////////////////////////////////////////////////////////
  618. inline __device__ float2 float_from_int8(int16_t u)
  619. {
  620. union {
  621. int16_t int16;
  622. int8_t int8[2];
  623. };
  624. int16 = u;
  625. return make_float2(int8[0], int8[1]);
  626. }
  627. ////////////////////////////////////////////////////////////////////////////////////////////////////
  628. inline __device__ float4 float_from_int8(int32_t u)
  629. {
  630. union {
  631. int32_t int32;
  632. int8_t int8[4];
  633. };
  634. int32 = u;
  635. return make_float4(int8[0], int8[1], int8[2], int8[3]);
  636. }
  637. ////////////////////////////////////////////////////////////////////////////////////////////////////
  638. // clang-format off
  639. inline __device__ Float8_ float_from_int8(int64_t u)
  640. {
  641. union {
  642. int64_t int64;
  643. int16_t int16[4];
  644. };
  645. int64 = u;
  646. return Float8_ {float_from_int8(int16[0]),
  647. float_from_int8(int16[1]),
  648. float_from_int8(int16[2]),
  649. float_from_int8(int16[3])};
  650. }
  651. // clang-format on
  652. ////////////////////////////////////////////////////////////////////////////////////////////////////
  653. inline __device__ int8_t cast_to_int8(float val)
  654. {
  655. union {
  656. int8_t int8[2];
  657. int16_t int16;
  658. };
  659. asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val));
  660. return int8[0];
  661. }
  662. ////////////////////////////////////////////////////////////////////////////////////////////////////
  663. inline __device__ int32_t cast_to_int8(float4 val)
  664. {
  665. union {
  666. int8_t int8[4];
  667. int32_t int32;
  668. };
  669. int8[0] = cast_to_int8(val.x);
  670. int8[1] = cast_to_int8(val.y);
  671. int8[2] = cast_to_int8(val.z);
  672. int8[3] = cast_to_int8(val.w);
  673. return int32;
  674. }
  675. ////////////////////////////////////////////////////////////////////////////////////////////////////
  676. inline __device__ int64_t cast_to_int8(Float8_ val)
  677. {
  678. union {
  679. int8_t int8[8];
  680. int64_t int64;
  681. };
  682. int8[0] = cast_to_int8(val.x.x);
  683. int8[1] = cast_to_int8(val.x.y);
  684. int8[2] = cast_to_int8(val.y.x);
  685. int8[3] = cast_to_int8(val.y.y);
  686. int8[4] = cast_to_int8(val.z.x);
  687. int8[5] = cast_to_int8(val.z.y);
  688. int8[6] = cast_to_int8(val.w.x);
  689. int8[7] = cast_to_int8(val.w.y);
  690. return int64;
  691. }
  692. ////////////////////////////////////////////////////////////////////////////////////////////////////
  693. template<typename T>
  694. inline __device__ __host__ T div_up(T m, T n)
  695. {
  696. return (m + n - 1) / n;
  697. }
  698. ////////////////////////////////////////////////////////////////////////////////////////////////////
  699. template<typename T, bool DO_CROSS_ATTENTION>
  700. inline size_t smem_size_in_bytes(const Multihead_attention_params<T, DO_CROSS_ATTENTION>& params,
  701. int threads_per_value,
  702. int threads_per_block)
  703. {
  704. // The amount of shared memory needed to store the Q*K^T values in float.
  705. const int max_timesteps = min(params.timestep, params.memory_max_len);
  706. size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
  707. // The extra memory needed if we are not using floats for the final logits.
  708. size_t logits_sz = 0;
  709. #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
  710. if (sizeof(T) != 4) {
  711. // TDOD
  712. logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) :
  713. div_up(max_timesteps + 1, 4) * 4 * sizeof(T);
  714. }
  715. #endif
  716. // The total size needed during softmax.
  717. size_t softmax_sz = qk_sz + logits_sz;
  718. // The number of partial rows to reduce in the final reduction.
  719. int rows_per_red = threads_per_block / threads_per_value;
  720. // The amount of storage needed to finalize the outputs.
  721. size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2;
  722. size_t transpose_rotary_size = 0;
  723. if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
  724. transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T);
  725. }
  726. // The max.
  727. return max(max(softmax_sz, red_sz), transpose_rotary_size);
  728. }
  729. ////////////////////////////////////////////////////////////////////////////////////////////////////
  730. inline __device__ constexpr uint32_t shfl_mask(int threads)
  731. {
  732. return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u;
  733. }
  734. ////////////////////////////////////////////////////////////////////////////////////////////////////
  735. template<
  736. // The type of the inputs. Supported types: float and half.
  737. typename T,
  738. // The hidden dimension per head.
  739. int Dh,
  740. int Dh_MAX,
  741. // The number of threads per key.
  742. int THREADS_PER_KEY,
  743. // The number of threads per value.
  744. int THREADS_PER_VALUE,
  745. // The number of threads in a threadblock.
  746. int THREADS_PER_BLOCK,
  747. bool DO_CROSS_ATTENTION>
  748. __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, DO_CROSS_ATTENTION> params)
  749. {
  750. // Make sure the hidden dimension per head is a multiple of the number of threads per key.
  751. static_assert(Dh_MAX % THREADS_PER_KEY == 0, "");
  752. // Make sure the hidden dimension per head is a multiple of the number of threads per value.
  753. static_assert(Dh_MAX % THREADS_PER_VALUE == 0, "");
  754. // The size of a warp.
  755. constexpr int WARP_SIZE = 32;
  756. // The number of warps in a threadblock.
  757. constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE;
  758. // Use smem_size_in_bytes (above) to determine the amount of shared memory.
  759. extern __shared__ char smem_[];
  760. // The shared memory for the Q*K^T values and partial logits in softmax.
  761. float* qk_smem = reinterpret_cast<float*>(smem_);
  762. // The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
  763. char* logits_smem_ = smem_;
  764. #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
  765. if (sizeof(T) != 4) {
  766. // TODO - change to tlength
  767. const int max_timesteps = min(params.timestep, params.memory_max_len);
  768. logits_smem_ +=
  769. (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16;
  770. }
  771. T* logits_smem = reinterpret_cast<T*>(logits_smem_);
  772. #else
  773. float* logits_smem = reinterpret_cast<float*>(logits_smem_);
  774. #endif
  775. // The shared memory to do the final reduction for the output values. Reuse qk_smem.
  776. T* out_smem = reinterpret_cast<T*>(smem_);
  777. // The shared memory buffers for the block-wide reductions. One for max, one for sum.
  778. __shared__ float red_smem[WARPS_PER_BLOCK * 2];
  779. // A vector of Q or K elements for the current timestep.
  780. using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
  781. // Use alignment for safely casting the shared buffers as Qk_vec.
  782. // Shared memory to store Q inputs.
  783. __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX];
  784. // This is one of the reasons we should have a separate kernel for cross attention
  785. __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1];
  786. // A vector of Q or K elements for the current timestep.
  787. using Qk_vec = typename Qk_vec_<T, Dh_MAX>::Type;
  788. // The number of elements per vector.
  789. constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
  790. // Make sure the hidden size per head is a multiple of the vector size.
  791. static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
  792. // We will use block wide reduction if needed
  793. // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
  794. // The number of vectors per warp.
  795. constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE;
  796. // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
  797. // owns x elements, we have to decompose the linear index into chunks of x values and the posi-
  798. // tion of the thread in that chunk.
  799. // The number of elements in a chunk of 16B (that's the x in the above formula).
  800. constexpr int QK_ELTS_IN_16B = 16 / sizeof(T);
  801. // The number of K vectors in 16B.
  802. constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec);
  803. // The batch/beam idx
  804. const int bi = blockIdx.y;
  805. if (params.finished != nullptr && params.finished[bi] == true) {
  806. return;
  807. }
  808. // The beam idx
  809. const int beami = bi % params.beam_width;
  810. // The "beam-aware" batch idx
  811. const int bbi = bi / params.beam_width;
  812. // The head.
  813. // const int hi = blockIdx.x;
  814. const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x];
  815. const int hi_kv = hi / params.num_heads_q_kv_ratio;
  816. // Combine the batch and the head indices.
  817. const int bhi = bi * params.num_heads + hi;
  818. const int bhi_kv = bi * params.num_heads_kv + hi_kv;
  819. // Combine the "beam-aware" batch idx and the head indices.
  820. const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv;
  821. // The thread in the block.
  822. const int tidx = threadIdx.x;
  823. const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0);
  824. // While doing the product Q*K^T for the different keys we track the max.
  825. float qk_max = -FLT_MAX;
  826. float qk = 0.0F;
  827. int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh;
  828. int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh;
  829. int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh;
  830. const size_t bi_seq_len_offset = bi * params.memory_max_len;
  831. // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
  832. int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 :
  833. (params.length_per_sample == nullptr) ?
  834. params.timestep :
  835. params.length_per_sample[bi] + params.max_prefix_prompt_length;
  836. const int first_step = max(0, tlength + 1 - params.memory_max_len);
  837. const int tlength_circ = tlength % params.memory_max_len;
  838. // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
  839. const bool is_masked = tidx >= QK_VECS_PER_WARP;
  840. // The offset in the Q and K buffer also accounts for the batch.
  841. int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
  842. int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
  843. // The offset in the bias buffer.
  844. int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
  845. int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
  846. const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
  847. const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
  848. // Trigger the loads from the Q and K buffers.
  849. Qk_vec q;
  850. zero(q);
  851. if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
  852. if (params.int8_mode == 2) {
  853. using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
  854. using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
  855. const auto q_scaling = params.qkv_scale_out[0];
  856. const auto q_quant =
  857. *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
  858. convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
  859. }
  860. else {
  861. q = *reinterpret_cast<const Qk_vec*>(&params.q[q_offset]);
  862. }
  863. }
  864. Qk_vec k;
  865. zero(k);
  866. if (DO_CROSS_ATTENTION) {
  867. // The 16B chunk written by the thread.
  868. int co = tidx / QK_VECS_IN_16B;
  869. // The position of the thread in that 16B chunk.
  870. int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
  871. // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
  872. int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
  873. // params.timestep*QK_ELTS_IN_16B +
  874. tlength * QK_ELTS_IN_16B + ci;
  875. k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
  876. *reinterpret_cast<const Qk_vec*>(&params.k_cache[offset]) :
  877. k;
  878. }
  879. else {
  880. if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
  881. if (params.int8_mode == 2) {
  882. using Packed_Int8_t = typename packed_type<int8_t, num_elems<Qk_vec>::value>::type;
  883. using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
  884. const auto k_scaling = params.qkv_scale_out[1];
  885. const auto k_quant =
  886. *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
  887. convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
  888. }
  889. else {
  890. k = *reinterpret_cast<const Qk_vec*>(&params.k[k_offset]);
  891. }
  892. }
  893. }
  894. // Trigger the loads from the Q and K bias buffers.
  895. Qk_vec q_bias;
  896. zero(q_bias);
  897. q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
  898. *reinterpret_cast<const Qk_vec*>(&params.q_bias[q_bias_offset]) :
  899. q_bias;
  900. Qk_vec k_bias;
  901. zero(k_bias);
  902. if (handle_kv) {
  903. k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
  904. *reinterpret_cast<const Qk_vec*>(&params.k_bias[k_bias_offset]) :
  905. k_bias;
  906. }
  907. // Computes the Q/K values with bias.
  908. q = add(q, q_bias);
  909. if (handle_kv) {
  910. k = add(k, k_bias);
  911. }
  912. if (do_ia3 && !is_masked) {
  913. k = mul<Qk_vec, Qk_vec, Qk_vec>(
  914. k,
  915. *reinterpret_cast<const Qk_vec*>(
  916. &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]));
  917. }
  918. // Padded len
  919. const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi];
  920. if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) {
  921. if (handle_kv) {
  922. if (params.rotary_cos == nullptr) {
  923. apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
  924. } else {
  925. apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len,
  926. params.rotary_cos + bi * params.rotary_embedding_dim / 2,
  927. params.rotary_sin + bi * params.rotary_embedding_dim / 2);
  928. }
  929. }
  930. else {
  931. if (params.rotary_cos == nullptr) {
  932. apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
  933. } else {
  934. apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len,
  935. params.rotary_cos + bi * params.rotary_embedding_dim / 2,
  936. params.rotary_sin + bi * params.rotary_embedding_dim / 2);
  937. }
  938. }
  939. }
  940. else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) {
  941. const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
  942. T* q_smem = reinterpret_cast<T*>(smem_);
  943. T* k_smem = q_smem + params.rotary_embedding_dim;
  944. const int half_rotary_dim = params.rotary_embedding_dim / 2;
  945. const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim;
  946. const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim;
  947. const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
  948. assert(half_rotary_dim % QK_VEC_SIZE == 0);
  949. if (do_rotary) {
  950. *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx) = q;
  951. if (handle_kv) {
  952. *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx) = k;
  953. }
  954. }
  955. __syncthreads();
  956. const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2;
  957. constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1;
  958. if (do_rotary) {
  959. mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
  960. if (handle_kv) {
  961. mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
  962. if (params.rotary_cos == nullptr) {
  963. mmha::apply_rotary_embedding(
  964. q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base);
  965. } else {
  966. mmha::apply_rotary_embedding(
  967. q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len,
  968. params.rotary_cos + bi * params.rotary_embedding_dim / 2,
  969. params.rotary_sin + bi * params.rotary_embedding_dim / 2);
  970. }
  971. mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch);
  972. }
  973. else {
  974. if (params.rotary_cos == nullptr) {
  975. mmha::apply_rotary_embedding(
  976. q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base);
  977. } else {
  978. mmha::apply_rotary_embedding(
  979. q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength,
  980. params.rotary_cos + bi * params.rotary_embedding_dim / 2,
  981. params.rotary_sin + bi * params.rotary_embedding_dim / 2);
  982. }
  983. }
  984. mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch);
  985. }
  986. __syncthreads();
  987. if (do_rotary) {
  988. q = *reinterpret_cast<Qk_vec*>(q_smem + half_idx * smem_pitch + intra_half_idx);
  989. if (handle_kv) {
  990. k = *reinterpret_cast<Qk_vec*>(k_smem + half_idx * smem_pitch + intra_half_idx);
  991. }
  992. }
  993. __syncthreads();
  994. }
  995. if (!is_masked) {
  996. // Store the Q values to shared memory.
  997. *reinterpret_cast<Qk_vec*>(&q_smem[tidx * QK_VEC_SIZE]) = q;
  998. // Store Dh values of k_bias into smem, since will need to add later
  999. // if params.timestep == 0
  1000. if (DO_CROSS_ATTENTION && params.timestep == 0) {
  1001. *reinterpret_cast<Qk_vec*>(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias;
  1002. }
  1003. // Write the K values to the global memory cache.
  1004. //
  1005. // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
  1006. // system. We designed it this way as it allows much better memory loads (and there are many
  1007. // more loads) + the stores are really "write and forget" since we won't need the ack before
  1008. // the end of the kernel. There's plenty of time for the transactions to complete.
  1009. // The 16B chunk written by the thread.
  1010. int co = tidx / QK_VECS_IN_16B;
  1011. // The position of the thread in that 16B chunk.
  1012. int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
  1013. // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
  1014. int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
  1015. // params.timestep*QK_ELTS_IN_16B +
  1016. tlength_circ * QK_ELTS_IN_16B + ci;
  1017. if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) {
  1018. // Trigger the stores to global memory.
  1019. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
  1020. *reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k;
  1021. }
  1022. }
  1023. // Compute \sum_i Q[i] * K^T[i] for the current timestep.
  1024. #ifdef MMHA_USE_FP32_ACUM_FOR_FMA
  1025. using Qk_vec_acum = typename Qk_vec_acum_fp32_<Qk_vec>::Type;
  1026. #else
  1027. using Qk_vec_acum = Qk_vec;
  1028. #endif
  1029. qk = dot<Qk_vec_acum, Qk_vec>(q, k);
  1030. if (QK_VECS_PER_WARP <= WARP_SIZE) {
  1031. #pragma unroll
  1032. for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) {
  1033. qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask);
  1034. }
  1035. }
  1036. }
  1037. if (QK_VECS_PER_WARP > WARP_SIZE) {
  1038. constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE;
  1039. qk = block_sum<WARPS_PER_RED>(&red_smem[WARPS_PER_RED], qk);
  1040. }
  1041. // Store that value in shared memory. Keep the Q*K^T value in register for softmax.
  1042. if (tidx == 0) {
  1043. // Normalize qk.
  1044. qk *= params.inv_sqrt_dh;
  1045. if (params.relative_attention_bias != nullptr) {
  1046. qk = add(qk,
  1047. params.relative_attention_bias[hi * params.relative_attention_bias_stride
  1048. * params.relative_attention_bias_stride
  1049. + (tlength - padd_len) * params.relative_attention_bias_stride
  1050. + (tlength - padd_len)]);
  1051. }
  1052. // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
  1053. qk_max = qk;
  1054. qk_smem[tlength - first_step] = qk;
  1055. // qk_smem[params.timestep] = qk;
  1056. }
  1057. // Make sure the data is in shared memory.
  1058. __syncthreads();
  1059. // The type of queries and keys for the math in the Q*K^T product.
  1060. using K_vec = typename K_vec_<T, THREADS_PER_KEY>::Type;
  1061. // The number of elements per vector.
  1062. constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T);
  1063. // Make sure the hidden size per head is a multiple of the vector size.
  1064. static_assert(Dh_MAX % K_VEC_SIZE == 0, "");
  1065. // The number of elements per thread.
  1066. constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY;
  1067. // The number of vectors per thread.
  1068. constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE;
  1069. // The position the first key loaded by each thread from the cache buffer (for this B * H).
  1070. int ko = tidx / THREADS_PER_KEY;
  1071. // The position of the thread in the chunk of keys.
  1072. int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE;
  1073. static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD);
  1074. // Load the Q values from shared memory. The values are reused during the loop on K.
  1075. K_vec q_vec[K_VECS_PER_THREAD];
  1076. #pragma unroll
  1077. for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
  1078. q_vec[ii] = *reinterpret_cast<const K_vec*>(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
  1079. }
  1080. K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1];
  1081. if (DO_CROSS_ATTENTION && params.timestep == 0) {
  1082. #pragma unroll
  1083. for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
  1084. k_bias_vec[ii] = *reinterpret_cast<const K_vec*>(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]);
  1085. }
  1086. }
  1087. // The number of timesteps loaded per iteration.
  1088. constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
  1089. // The number of keys per warp.
  1090. constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
  1091. // The base pointer for the key in the cache buffer.
  1092. T* k_cache = &params.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
  1093. // Base pointer for the beam's batch, before offsetting with indirection buffer
  1094. T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
  1095. // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
  1096. // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
  1097. int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step;
  1098. // prefix prompt length if has
  1099. const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi];
  1100. // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
  1101. const bool has_beams = params.cache_indir != nullptr;
  1102. const int* beam_indices = has_beams ? &params.cache_indir[bi_seq_len_offset] : nullptr;
  1103. for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) {
  1104. const int ti_circ = ti % params.memory_max_len;
  1105. // The keys loaded from the key cache.
  1106. K_vec k[K_VECS_PER_THREAD];
  1107. K_vec k_vec_zero;
  1108. zero(k_vec_zero);
  1109. #pragma unroll
  1110. for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
  1111. int jj = ii * params.memory_max_len + ti_circ;
  1112. // if( ti < params.timestep ) {
  1113. const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len);
  1114. if (ti < tlength) {
  1115. if (!within_bounds) {
  1116. k[ii] = k_vec_zero;
  1117. }
  1118. else {
  1119. if (has_beams) {
  1120. const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh;
  1121. k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]);
  1122. }
  1123. else {
  1124. k[ii] = *reinterpret_cast<const K_vec*>(&k_cache_batch[jj * QK_ELTS_IN_16B]);
  1125. }
  1126. }
  1127. // add bias and update k_cache
  1128. if (DO_CROSS_ATTENTION && params.timestep == 0) {
  1129. k[ii] = add(k[ii], k_bias_vec[ii]);
  1130. if (do_ia3) {
  1131. k[ii] = mul<K_vec, K_vec, K_vec>(
  1132. k[ii],
  1133. *reinterpret_cast<const K_vec*>(
  1134. &params.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki
  1135. + ii * THREADS_PER_KEY * K_VEC_SIZE]));
  1136. }
  1137. if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) {
  1138. *reinterpret_cast<K_vec*>(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii];
  1139. }
  1140. }
  1141. }
  1142. }
  1143. // Perform the dot product and normalize qk.
  1144. //
  1145. // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
  1146. float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k) * params.inv_sqrt_dh;
  1147. bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
  1148. // Store the product to shared memory. There's one qk value per timestep. Update the max.
  1149. // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
  1150. if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
  1151. if (params.relative_attention_bias != nullptr) {
  1152. qk = add(qk,
  1153. params.relative_attention_bias[hi * params.relative_attention_bias_stride
  1154. * params.relative_attention_bias_stride
  1155. + tlength * params.relative_attention_bias_stride + ti]);
  1156. }
  1157. if (params.linear_bias_slopes != nullptr) {
  1158. // Apply the linear position bias: (ki - qi) * slope[hi].
  1159. // The padding token locates between the input context and the generated tokens.
  1160. // We need to remove the number of padding tokens in the distance computation.
  1161. // ti : 0 1 2 3 4 5 6 7 8 9(tlength)
  1162. // token: i i i i p p p o o o where i=input, p=pad, o=output.
  1163. // e.g. ti = 2, dist = (9 - 3) - 2 = 4.
  1164. int max_context_length = params.max_prefix_prompt_length + params.max_input_length;
  1165. float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength;
  1166. qk += mul<float, T, float>(params.linear_bias_slopes[hi], dist);
  1167. }
  1168. qk_max = is_mask ? qk_max : fmaxf(qk_max, qk);
  1169. qk_smem[ti - first_step] = qk;
  1170. }
  1171. }
  1172. // Perform the final reduction to compute the max inside each warp.
  1173. //
  1174. // NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
  1175. // group so it's not needed to run the reduction inside the group (again).
  1176. #pragma unroll
  1177. for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) {
  1178. qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
  1179. }
  1180. // Decompose the thread index into warp and lane.
  1181. const int warp = tidx / WARP_SIZE;
  1182. const int lane = tidx % WARP_SIZE;
  1183. // The warp leader writes the max to shared memory.
  1184. if (lane == 0) {
  1185. red_smem[warp] = qk_max;
  1186. }
  1187. // Make sure the products are in shared memory.
  1188. __syncthreads();
  1189. // The warps finalize the reduction.
  1190. qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX;
  1191. #pragma unroll
  1192. for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
  1193. qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
  1194. }
  1195. // Broadcast to all the threads in the warp.
  1196. qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
  1197. // Compute the logits and start the sum.
  1198. float sum = 0.f;
  1199. // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
  1200. for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
  1201. bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti];
  1202. float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max);
  1203. sum += logit;
  1204. qk_smem[ti - first_step] = logit;
  1205. }
  1206. // Compute the sum.
  1207. sum = block_sum<WARPS_PER_BLOCK>(&red_smem[WARPS_PER_BLOCK], sum);
  1208. // Normalize the logits.
  1209. float inv_sum = __fdividef(1.f, sum + 1.e-6f);
  1210. // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
  1211. const size_t cross_attention_out_offset =
  1212. params.is_return_cross_attentions ?
  1213. bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len :
  1214. 0;
  1215. for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
  1216. float logit = qk_smem[ti - first_step] * inv_sum;
  1217. if (params.is_return_cross_attentions) {
  1218. params.cross_attention_out[cross_attention_out_offset + ti] = logit;
  1219. }
  1220. convert_from_float(logits_smem[ti - first_step], logit);
  1221. }
  1222. // Put Values part below so we leverage __syncthreads
  1223. // from the previous step
  1224. // The number of elements per vector.
  1225. constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
  1226. // A vector of V elements for the current timestep.
  1227. using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;
  1228. // The value computed by this thread.
  1229. int vo = tidx / THREADS_PER_VALUE;
  1230. // The hidden dimensions computed by this particular thread.
  1231. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
  1232. // The base pointer for the value in the cache buffer.
  1233. T* v_cache = &params.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
  1234. // Base pointer for the beam's batch, before offsetting with indirection buffer
  1235. T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
  1236. // The number of values processed per iteration of the loop.
  1237. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
  1238. // One group of threads computes the product(s) for the current timestep.
  1239. V_vec v_bias;
  1240. zero(v_bias);
  1241. // if( vo == params.timestep % V_PER_ITER ) {
  1242. if (Dh == Dh_MAX || vi < Dh) {
  1243. if (handle_kv) {
  1244. if (vo == tlength % V_PER_ITER) {
  1245. // Trigger the loads from the V bias buffer.
  1246. if (params.v_bias != nullptr) {
  1247. v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi_kv * Dh + vi]);
  1248. }
  1249. if (DO_CROSS_ATTENTION) {
  1250. *reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
  1251. }
  1252. }
  1253. }
  1254. }
  1255. // From previous, before values, step
  1256. // Also make sure the logits are in shared memory.
  1257. __syncthreads();
  1258. // Values continued
  1259. #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
  1260. using V_vec_acum = typename V_vec_acum_fp32_<V_vec>::Type;
  1261. #else
  1262. using V_vec_acum = V_vec;
  1263. #endif
  1264. // The partial outputs computed by each thread.
  1265. V_vec_acum out;
  1266. zero(out);
  1267. // Loop over the timesteps to compute the partial outputs.
  1268. // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
  1269. if (Dh == Dh_MAX || vi < Dh) {
  1270. for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) {
  1271. const int ti_circ = ti % params.memory_max_len;
  1272. // Fetch offset based on cache_indir when beam sampling
  1273. const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0;
  1274. const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh;
  1275. // Load the values from the cache.
  1276. V_vec v = *reinterpret_cast<const V_vec*>(&v_cache_batch[beam_offset + ti_circ * Dh]);
  1277. if (DO_CROSS_ATTENTION && params.timestep == 0) {
  1278. v = add(v, *reinterpret_cast<V_vec*>(&bias_smem[vi]));
  1279. if (do_ia3) {
  1280. v = mul<V_vec, V_vec, V_vec>(
  1281. v,
  1282. *reinterpret_cast<const V_vec*>(
  1283. &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
  1284. }
  1285. *reinterpret_cast<V_vec*>(&v_cache[ti * Dh]) = v;
  1286. }
  1287. // Load the logits from shared memory.
  1288. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
  1289. float logit = logits_smem[ti - first_step];
  1290. out = fma(logit, cast_to_float(v), out);
  1291. #else
  1292. T logit = logits_smem[ti - first_step];
  1293. // Update the partial sums.
  1294. out = fma(logit, v, out);
  1295. #endif
  1296. }
  1297. }
  1298. // One group of threads computes the product(s) for the current timestep.
  1299. // if( vo == params.timestep % V_PER_ITER ) {
  1300. if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) {
  1301. V_vec v;
  1302. if (DO_CROSS_ATTENTION) {
  1303. v = *reinterpret_cast<const V_vec*>(&v_cache[tlength * Dh]);
  1304. }
  1305. else {
  1306. // Trigger the loads from the V buffer.
  1307. const auto v_offset = v_base_offset + vi;
  1308. if (params.int8_mode == 2) {
  1309. using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
  1310. using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
  1311. const auto v_scaling = params.qkv_scale_out[2];
  1312. const auto v_quant =
  1313. *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.v)[v_offset]);
  1314. convert_from_float(v, mul<Packed_Float_t, float>(v_scaling, float_from_int8(v_quant)));
  1315. }
  1316. else {
  1317. v = *reinterpret_cast<const V_vec*>(&params.v[v_offset]);
  1318. }
  1319. // Trigger the loads from the V bias buffer.
  1320. // V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
  1321. }
  1322. // Compute the V values with bias.
  1323. if (handle_kv) {
  1324. v = add(v, v_bias);
  1325. if (do_ia3) {
  1326. v = mul<V_vec, V_vec, V_vec>(
  1327. v,
  1328. *reinterpret_cast<const V_vec*>(
  1329. &params.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi]));
  1330. }
  1331. // Store the values with bias back to global memory in the cache for V.
  1332. if (hi % params.num_heads_q_kv_ratio == 0) {
  1333. //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
  1334. *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
  1335. }
  1336. }
  1337. // Initialize the output value with the current timestep.
  1338. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
  1339. // out = fma(logits_smem[params.timestep], cast_to_float(v), out);
  1340. out = fma(logits_smem[tlength - first_step], cast_to_float(v), out);
  1341. #else
  1342. // out = fma(logits_smem[params.timestep], v, out);
  1343. out = fma(logits_smem[tlength - first_step], v, out);
  1344. #endif
  1345. }
  1346. // Make sure we can start writing to shared memory.
  1347. __syncthreads();
  1348. // Run the final reduction amongst the different groups computing different partial outputs.
  1349. if (Dh == Dh_MAX || vi < Dh) {
  1350. #pragma unroll
  1351. for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) {
  1352. // The midpoint in the number of active groups.
  1353. int midpoint = active_groups / 2;
  1354. // The upper part of active threads store to shared memory.
  1355. if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) {
  1356. #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
  1357. convert_from_float(*reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]), out);
  1358. #else
  1359. *reinterpret_cast<V_vec*>(&out_smem[(vo - midpoint) * Dh + vi]) = out;
  1360. #endif
  1361. }
  1362. __syncthreads();
  1363. // The bottom warps update their values.
  1364. if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) {
  1365. out = add(*reinterpret_cast<const V_vec*>(&out_smem[vo * Dh + vi]), out);
  1366. }
  1367. __syncthreads();
  1368. }
  1369. }
  1370. // Output the final values.
  1371. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
  1372. #ifdef MMHA_USE_FP32_ACUM_FOR_OUT
  1373. if (params.int8_mode == 2) {
  1374. using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_acum>::value>::type;
  1375. out = mul<V_vec_acum, float>(*params.attention_out_scale, out);
  1376. *reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhi * Dh + vi])) =
  1377. cast_to_int8(out);
  1378. }
  1379. else {
  1380. convert_from_float(*reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]), out);
  1381. }
  1382. #else
  1383. // TODO: support int8_mode?
  1384. *reinterpret_cast<V_vec*>(&params.out[bhi * Dh + vi]) = out;
  1385. #endif
  1386. }
  1387. }
  1388. ////////////////////////////////////////////////////////////////////////////////////////////////////
  1389. } // namespace mmha
  1390. ////////////////////////////////////////////////////////////////////////////////////////////////////
  1391. template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
  1392. void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream);