attention.cu 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123
  1. /*
  2. * Copyright (c) 2024, The PygmalionAI team.
  3. * Copyright (c) 2024, The vLLM team.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <torch/all.h>
  18. #include <ATen/cuda/CUDAContext.h>
  19. #include <c10/cuda/CUDAGuard.h>
  20. #include <hip/hip_bf16.h>
  21. #include "cuda_compat.h"
  22. #include <algorithm>
  23. #include "../attention/dtype_fp8.cuh"
  24. #include "../quantization/fp8/amd/quant_utils.cuh"
  25. #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
  26. defined(__gfx941__) || defined(__gfx942__))
  27. #define __HIP__MI300_MI250__
  28. #endif
  29. #if defined(NDEBUG)
  30. #undef NDEBUG
  31. #include <assert.h>
  32. #define UNREACHABLE_CODE assert(false);
  33. #define NDEBUG
  34. #else
  35. #define UNREACHABLE_CODE assert(false);
  36. #endif
  37. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  38. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  39. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  40. #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
  41. #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
  42. #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
  43. using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
  44. using float16x4 =
  45. __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
  46. typedef float16x4 _Half4;
  47. typedef struct _Half8 {
  48. _Half4 xy[2];
  49. } _Half8;
  50. using bit16_t = uint16_t;
  51. using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
  52. typedef bit16x4 _B16x4;
  53. typedef struct _B16x8 {
  54. _B16x4 xy[2];
  55. } _B16x8;
  56. using _B8x8 = uint2;
  57. ////// Non temporal load stores ///////
  58. template <typename T>
  59. __device__ __forceinline__ T load(T* addr) {
  60. return addr[0];
  61. }
  62. template <typename T>
  63. __device__ __forceinline__ void store(T value, T* addr) {
  64. addr[0] = value;
  65. }
  66. template <typename T, int absz, int cbid, int blgp>
  67. __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA,
  68. const _B16x4& inpB,
  69. const floatx4& inpC) {
  70. if constexpr (std::is_same<T, _Float16>::value) {
  71. return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid,
  72. blgp);
  73. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  74. return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid,
  75. blgp);
  76. } else {
  77. static_assert(false, "unsupported 16b dtype");
  78. }
  79. }
  80. template <typename T>
  81. __device__ __forceinline__ float to_float(const T& inp) {
  82. if constexpr (std::is_same<T, _Float16>::value) {
  83. return (float)inp;
  84. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  85. return __bfloat162float(inp);
  86. } else {
  87. static_assert(false, "unsupported 16b dtype");
  88. }
  89. }
  90. template <typename T>
  91. __device__ __forceinline__ T from_float(const float& inp) {
  92. if constexpr (std::is_same<T, _Float16>::value) {
  93. return (_Float16)inp;
  94. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  95. return __float2bfloat16(inp);
  96. } else {
  97. static_assert(false, "unsupported 16b dtype");
  98. }
  99. }
  100. template <typename T>
  101. __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) {
  102. union tmpcvt {
  103. uint16_t u;
  104. _Float16 f;
  105. __hip_bfloat16 b;
  106. } t16;
  107. _B16x4 ret;
  108. if constexpr (std::is_same<T, _Float16>::value) {
  109. #pragma unroll
  110. for (int i = 0; i < 4; i++) {
  111. t16.f = (_Float16)inp[i];
  112. ret[i] = t16.u;
  113. }
  114. return ret;
  115. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  116. #pragma unroll
  117. for (int i = 0; i < 4; i++) {
  118. t16.b = __float2bfloat16(inp[i]);
  119. ret[i] = t16.u;
  120. }
  121. return ret;
  122. } else {
  123. static_assert(false, "unsupported 16b dtype");
  124. }
  125. }
  126. template <typename T>
  127. __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
  128. const _B16x4& inp2) {
  129. union tmpcvt {
  130. uint16_t u;
  131. _Float16 f;
  132. __hip_bfloat16 b;
  133. } t1, t2, res;
  134. _B16x4 ret;
  135. if constexpr (std::is_same<T, _Float16>::value) {
  136. #pragma unroll
  137. for (int i = 0; i < 4; i++) {
  138. t1.u = inp1[i];
  139. t2.u = inp2[i];
  140. res.f = t1.f + t2.f;
  141. ret[i] = res.u;
  142. }
  143. return ret;
  144. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  145. #pragma unroll
  146. for (int i = 0; i < 4; i++) {
  147. t1.u = inp1[i];
  148. t2.u = inp2[i];
  149. res.b = t1.b + t2.b;
  150. ret[i] = res.u;
  151. }
  152. return ret;
  153. } else {
  154. static_assert(false, "unsupported 16b dtype");
  155. }
  156. }
  157. template <typename T, aphrodite::Fp8KVCacheDataType KV_DTYPE>
  158. __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input,
  159. const float scale) {
  160. union alignas(16) {
  161. uint4 u4;
  162. _B16x8 u16x8;
  163. aphrodite::bf16_8_t b16x8;
  164. } tmp;
  165. if constexpr (std::is_same<T, _Float16>::value) {
  166. tmp.u4 =
  167. aphrodite::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale);
  168. return tmp.u16x8;
  169. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  170. tmp.b16x8 =
  171. aphrodite::fp8::scaled_convert<aphrodite::bf16_8_t, _B8x8, KV_DTYPE>(
  172. input, scale);
  173. return tmp.u16x8;
  174. } else {
  175. static_assert(false, "unsupported 16b dtype");
  176. }
  177. }
  178. ///////////////////////////////////////
  179. // grid (num_seqs, num_partitions,num_heads/gqa_ratio)
  180. // block (partition size)
  181. template <typename scalar_t, typename cache_t,
  182. aphrodite::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
  183. int NUM_THREADS,
  184. int GQA_RATIO>
  185. __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
  186. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  187. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  188. // head_size/x, block_size, x]
  189. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  190. // head_size, block_size]
  191. const int num_kv_heads, const float scale,
  192. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  193. const int* __restrict__ context_lens, // [num_seqs]
  194. const int max_num_blocks_per_seq,
  195. const float* __restrict__ alibi_slopes, // [num_heads]
  196. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  197. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  198. float* __restrict__ max_logits, // [num_seqs, num_heads,
  199. // max_num_partitions]
  200. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
  201. // head_size]
  202. scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
  203. int max_ctx_blocks, float k_scale, float v_scale) {
  204. constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
  205. const int warpid = threadIdx.x / WARP_SIZE;
  206. const int laneid = threadIdx.x % WARP_SIZE;
  207. const int lane4id = laneid % 4;
  208. const int seq_idx = blockIdx.x;
  209. const int partition_idx = blockIdx.y;
  210. const int partition_size = blockDim.x;
  211. const int max_num_partitions = gridDim.y;
  212. const int context_len = context_lens[seq_idx];
  213. const int partition_start_token_idx = partition_idx * partition_size;
  214. // exit if partition is out of context for seq
  215. if (partition_start_token_idx >= context_len) {
  216. return;
  217. }
  218. constexpr int QHLOOP =
  219. DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads,
  220. // total qheads =8, so qhloop is 2
  221. constexpr int GQA_RATIO4 = 4 * QHLOOP;
  222. __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1];
  223. __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1];
  224. _B16x8 Qlocal[QHLOOP];
  225. constexpr int x = 16 / sizeof(scalar_t);
  226. constexpr int KHELOOP = HEAD_SIZE / x;
  227. _B16x8 Klocal[KHELOOP];
  228. _B8x8 Klocalb8[KHELOOP];
  229. constexpr int VHELOOP =
  230. HEAD_SIZE /
  231. WARP_SIZE; // v head_size dimension is distributed across lanes
  232. constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
  233. // 8xtokens
  234. _B16x8 Vlocal[VHELOOP][VTLOOP];
  235. _B8x8 Vlocalb8[VHELOOP][VTLOOP];
  236. floatx4 dout[QHLOOP];
  237. float qk_max[QHLOOP];
  238. #pragma unroll
  239. for (int h = 0; h < QHLOOP; h++) {
  240. dout[h] = {0};
  241. qk_max[h] = -FLT_MAX;
  242. }
  243. const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
  244. const int wg_start_kv_head_idx = blockIdx.z;
  245. const int warp_start_token_idx =
  246. partition_start_token_idx + warpid * WARP_SIZE;
  247. if (warp_start_token_idx >= context_len) { // warp out of context
  248. #pragma unroll
  249. for (int h = 0; h < GQA_RATIO4; h++) {
  250. shared_qk_max[warpid][h] = -FLT_MAX;
  251. shared_exp_sum[warpid][h] = 0.0f;
  252. }
  253. } else { // warp within context
  254. const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  255. const int last_ctx_block = num_context_blocks - 1;
  256. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  257. const int local_token_idx = threadIdx.x;
  258. const int global_token_idx = partition_start_token_idx + local_token_idx;
  259. const int block_idx = (global_token_idx < context_len)
  260. ? global_token_idx / BLOCK_SIZE
  261. : last_ctx_block;
  262. // fetch block number for q and k
  263. // int32 physical_block_number leads to overflow when multiplied with
  264. // kv_block_stride
  265. const int64_t physical_block_number =
  266. static_cast<int64_t>(block_table[block_idx]);
  267. // fetch vphysical block numbers up front
  268. constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE;
  269. int vphysical_blocks[VBLOCKS];
  270. const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE;
  271. #pragma unroll
  272. for (int b = 0; b < VBLOCKS; b++) {
  273. const int vblock_idx = warp_start_block_idx + b;
  274. const int vblock_idx_ctx =
  275. (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
  276. vphysical_blocks[b] = block_table[vblock_idx_ctx];
  277. }
  278. // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
  279. const scalar_t* q_ptr =
  280. q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
  281. const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
  282. const int qhead_elemh8 = laneid / 4;
  283. #pragma unroll
  284. for (int h = 0; h < QHLOOP - 1; h++) {
  285. const int qhead_idx = h * 4 + lane4id;
  286. Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
  287. }
  288. const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id;
  289. if (final_qhead_idx < GQA_RATIO) {
  290. Qlocal[QHLOOP - 1] =
  291. q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
  292. } else {
  293. Qlocal[QHLOOP - 1].xy[0] = {0};
  294. Qlocal[QHLOOP - 1].xy[1] = {0};
  295. }
  296. const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
  297. wg_start_kv_head_idx * kv_head_stride;
  298. const int physical_block_offset =
  299. local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
  300. // is already cast as _H8
  301. if constexpr (KV_DTYPE == aphrodite::Fp8KVCacheDataType::kAuto) {
  302. const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
  303. #pragma unroll
  304. for (int d = 0; d < KHELOOP; d++) {
  305. Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
  306. }
  307. } else {
  308. constexpr int X = 16 / sizeof(cache_t);
  309. const cache_t* k_ptr2 = k_ptr + physical_block_offset * X;
  310. #pragma unroll
  311. for (int d = 0; d < KHELOOP; d++) {
  312. const int head_elem = d * 8;
  313. const int offset1 = head_elem / X;
  314. const int offset2 = head_elem % X;
  315. const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2;
  316. Klocalb8[d] = *reinterpret_cast<const _B8x8*>(k_ptr3);
  317. }
  318. }
  319. float alibi_slope[QHLOOP];
  320. if (alibi_slopes != nullptr) {
  321. #pragma unroll
  322. for (int h = 0; h < QHLOOP; h++) {
  323. const int qhead_idx = h * 4 + lane4id;
  324. alibi_slope[h] = (qhead_idx < GQA_RATIO)
  325. ? alibi_slopes[wg_start_head_idx + qhead_idx]
  326. : 0.f;
  327. }
  328. }
  329. const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
  330. if constexpr (KV_DTYPE == aphrodite::Fp8KVCacheDataType::kAuto) {
  331. const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
  332. // iterate over each v block
  333. #pragma unroll
  334. for (int b = 0; b < VBLOCKS; b++) {
  335. // int32 physical_block_number leads to overflow when multiplied with
  336. // kv_block_stride
  337. const int64_t vphysical_block_number =
  338. static_cast<int64_t>(vphysical_blocks[b]);
  339. const _B16x8* v_ptrh8b =
  340. v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
  341. // iterate over each head elem (within head_size)
  342. #pragma unroll
  343. for (int h = 0; h < VHELOOP; h++) {
  344. const int head_size_elem = h * WARP_SIZE + laneid;
  345. const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
  346. // iterate over all velems within block
  347. #pragma unroll
  348. for (int d = 0; d < BLOCK_SIZE / 8; d++) {
  349. Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
  350. }
  351. }
  352. }
  353. } else {
  354. const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr);
  355. // iterate over each v block
  356. #pragma unroll
  357. for (int b = 0; b < VBLOCKS; b++) {
  358. // int32 physical_block_number leads to overflow when multiplied with
  359. // kv_block_stride
  360. const int64_t vphysical_block_number =
  361. static_cast<int64_t>(vphysical_blocks[b]);
  362. const _B8x8* v_ptrh8b =
  363. v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
  364. // iterate over each head elem (within head_size)
  365. #pragma unroll
  366. for (int h = 0; h < VHELOOP; h++) {
  367. const int head_size_elem = h * WARP_SIZE + laneid;
  368. const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
  369. // iterate over all velems within block
  370. #pragma unroll
  371. for (int d = 0; d < BLOCK_SIZE / 8; d++) {
  372. // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
  373. const _B8x8 Vlocalb8 = v_ptrh8be[d];
  374. Vlocal[h][b * BLOCK_SIZE / 8 + d] =
  375. scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
  376. }
  377. }
  378. }
  379. }
  380. if constexpr (KV_DTYPE != aphrodite::Fp8KVCacheDataType::kAuto) {
  381. #pragma unroll
  382. for (int d = 0; d < KHELOOP; d++) {
  383. Klocal[d] =
  384. scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
  385. }
  386. }
  387. #pragma unroll
  388. for (int h = 0; h < QHLOOP; h++) {
  389. dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[0],
  390. Klocal[0].xy[0], dout[h]);
  391. dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[1],
  392. Klocal[0].xy[1], dout[h]);
  393. dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[0],
  394. Klocal[1].xy[0], dout[h]);
  395. dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[1],
  396. Klocal[1].xy[1], dout[h]);
  397. dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[0],
  398. Klocal[2].xy[0], dout[h]);
  399. dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[1],
  400. Klocal[2].xy[1], dout[h]);
  401. dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[0],
  402. Klocal[3].xy[0], dout[h]);
  403. dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[1],
  404. Klocal[3].xy[1], dout[h]);
  405. dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[0],
  406. Klocal[4].xy[0], dout[h]);
  407. dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[1],
  408. Klocal[4].xy[1], dout[h]);
  409. dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[0],
  410. Klocal[5].xy[0], dout[h]);
  411. dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[1],
  412. Klocal[5].xy[1], dout[h]);
  413. dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[0],
  414. Klocal[6].xy[0], dout[h]);
  415. dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[1],
  416. Klocal[6].xy[1], dout[h]);
  417. dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[0],
  418. Klocal[7].xy[0], dout[h]);
  419. dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[1],
  420. Klocal[7].xy[1], dout[h]);
  421. if constexpr (KHELOOP > 8) {
  422. dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[0],
  423. Klocal[8].xy[0], dout[h]);
  424. dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[1],
  425. Klocal[8].xy[1], dout[h]);
  426. dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[0],
  427. Klocal[9].xy[0], dout[h]);
  428. dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[1],
  429. Klocal[9].xy[1], dout[h]);
  430. dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[0],
  431. Klocal[10].xy[0], dout[h]);
  432. dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[1],
  433. Klocal[10].xy[1], dout[h]);
  434. dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[0],
  435. Klocal[11].xy[0], dout[h]);
  436. dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[1],
  437. Klocal[11].xy[1], dout[h]);
  438. dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[0],
  439. Klocal[12].xy[0], dout[h]);
  440. dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[1],
  441. Klocal[12].xy[1], dout[h]);
  442. dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[0],
  443. Klocal[13].xy[0], dout[h]);
  444. dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[1],
  445. Klocal[13].xy[1], dout[h]);
  446. dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[0],
  447. Klocal[14].xy[0], dout[h]);
  448. dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[1],
  449. Klocal[14].xy[1], dout[h]);
  450. dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[0],
  451. Klocal[15].xy[0], dout[h]);
  452. dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[1],
  453. Klocal[15].xy[1], dout[h]);
  454. } // KHELOOP>8
  455. dout[h] *= scale;
  456. }
  457. // transpose dout so that 4 token ids are in each lane, and 4 heads are across
  458. // 4 lanes
  459. #pragma unroll
  460. for (int h = 0; h < QHLOOP; h++) {
  461. floatx4 tmp = {0};
  462. #pragma unroll
  463. for (int i = 0; i < 4; i++) {
  464. const float B = (lane4id == i) ? 1.0f : 0.0f;
  465. // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
  466. tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0);
  467. // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0);
  468. }
  469. dout[h] = tmp;
  470. }
  471. const int lane4_token_idx = 4 * (global_token_idx >> 2);
  472. const int alibi_offset = lane4_token_idx - context_len + 1;
  473. if (alibi_slopes != nullptr) {
  474. #pragma unroll
  475. for (int h = 0; h < QHLOOP; h++) {
  476. #pragma unroll
  477. for (int i = 0; i < 4; i++) {
  478. dout[h][i] += alibi_slope[h] * (alibi_offset + i);
  479. }
  480. }
  481. }
  482. #pragma unroll
  483. for (int h = 0; h < QHLOOP; h++) {
  484. qk_max[h] = -FLT_MAX;
  485. #pragma unroll
  486. for (int i = 0; i < 4; i++) {
  487. qk_max[h] = (lane4_token_idx + i < context_len)
  488. ? fmaxf(qk_max[h], dout[h][i])
  489. : qk_max[h];
  490. }
  491. #pragma unroll
  492. for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
  493. qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask));
  494. }
  495. }
  496. float exp_sum[QHLOOP];
  497. #pragma unroll
  498. for (int h = 0; h < QHLOOP; h++) {
  499. exp_sum[h] = 0.0f;
  500. #pragma unroll
  501. for (int i = 0; i < 4; i++) {
  502. dout[h][i] = (lane4_token_idx + i < context_len)
  503. ? __expf(dout[h][i] - qk_max[h])
  504. : 0.0f;
  505. exp_sum[h] += dout[h][i];
  506. }
  507. #pragma unroll
  508. for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
  509. exp_sum[h] += __shfl_xor(exp_sum[h], mask);
  510. }
  511. }
  512. #pragma unroll
  513. for (int h = 0; h < QHLOOP; h++) {
  514. const int head_idx = 4 * h + lane4id;
  515. shared_qk_max[warpid][head_idx] = qk_max[h];
  516. shared_exp_sum[warpid][head_idx] = exp_sum[h];
  517. }
  518. } // warp within context
  519. __syncthreads();
  520. const int num_heads = gridDim.z * GQA_RATIO;
  521. float* max_logits_ptr =
  522. max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
  523. float* exp_sums_ptr =
  524. exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx;
  525. #pragma unroll
  526. for (int h = 0; h < QHLOOP; h++) {
  527. float global_qk_max = -FLT_MAX;
  528. float warp_qk_max[NWARPS];
  529. const int head_idx = 4 * h + lane4id;
  530. #pragma unroll
  531. for (int w = 0; w < NWARPS; w++) {
  532. warp_qk_max[w] = shared_qk_max[w][head_idx];
  533. global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]);
  534. }
  535. float global_exp_sum = 0.0f;
  536. #pragma unroll
  537. for (int w = 0; w < NWARPS; w++) {
  538. global_exp_sum +=
  539. shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max);
  540. }
  541. if (head_idx < GQA_RATIO) {
  542. max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
  543. global_qk_max;
  544. exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
  545. global_exp_sum;
  546. }
  547. const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) *
  548. __expf(qk_max[h] - global_qk_max);
  549. dout[h] *= global_inv_sum_scale;
  550. }
  551. // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
  552. // are 4x16 tokens across warp
  553. _B16x4 logits[QHLOOP];
  554. #pragma unroll
  555. for (int h = 0; h < QHLOOP; h++) {
  556. logits[h] = from_floatx4<scalar_t>(dout[h]);
  557. }
  558. __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1];
  559. if (warp_start_token_idx >= context_len) { // warp out of context
  560. #pragma unroll
  561. for (int qh = 0; qh < QHLOOP; qh++) {
  562. #pragma unroll
  563. for (int vh = 0; vh < VHELOOP; vh++) {
  564. vout_shared[qh][vh][laneid][warpid] = {0};
  565. }
  566. }
  567. } else { // warp in context
  568. // iterate across heads
  569. #pragma unroll
  570. for (int qh = 0; qh < QHLOOP; qh++) {
  571. // iterate over each v head elem (within head_size)
  572. #pragma unroll
  573. for (int vh = 0; vh < VHELOOP; vh++) {
  574. floatx4 acc = {0};
  575. // iterate over tokens
  576. acc = gcn_mfma_instr<scalar_t, 4, 0, 0>(logits[qh], Vlocal[vh][0].xy[0],
  577. acc);
  578. acc = gcn_mfma_instr<scalar_t, 4, 1, 0>(logits[qh], Vlocal[vh][0].xy[1],
  579. acc);
  580. acc = gcn_mfma_instr<scalar_t, 4, 2, 0>(logits[qh], Vlocal[vh][1].xy[0],
  581. acc);
  582. acc = gcn_mfma_instr<scalar_t, 4, 3, 0>(logits[qh], Vlocal[vh][1].xy[1],
  583. acc);
  584. acc = gcn_mfma_instr<scalar_t, 4, 4, 0>(logits[qh], Vlocal[vh][2].xy[0],
  585. acc);
  586. acc = gcn_mfma_instr<scalar_t, 4, 5, 0>(logits[qh], Vlocal[vh][2].xy[1],
  587. acc);
  588. acc = gcn_mfma_instr<scalar_t, 4, 6, 0>(logits[qh], Vlocal[vh][3].xy[0],
  589. acc);
  590. acc = gcn_mfma_instr<scalar_t, 4, 7, 0>(logits[qh], Vlocal[vh][3].xy[1],
  591. acc);
  592. acc = gcn_mfma_instr<scalar_t, 4, 8, 0>(logits[qh], Vlocal[vh][4].xy[0],
  593. acc);
  594. acc = gcn_mfma_instr<scalar_t, 4, 9, 0>(logits[qh], Vlocal[vh][4].xy[1],
  595. acc);
  596. acc = gcn_mfma_instr<scalar_t, 4, 10, 0>(logits[qh],
  597. Vlocal[vh][5].xy[0], acc);
  598. acc = gcn_mfma_instr<scalar_t, 4, 11, 0>(logits[qh],
  599. Vlocal[vh][5].xy[1], acc);
  600. acc = gcn_mfma_instr<scalar_t, 4, 12, 0>(logits[qh],
  601. Vlocal[vh][6].xy[0], acc);
  602. acc = gcn_mfma_instr<scalar_t, 4, 13, 0>(logits[qh],
  603. Vlocal[vh][6].xy[1], acc);
  604. acc = gcn_mfma_instr<scalar_t, 4, 14, 0>(logits[qh],
  605. Vlocal[vh][7].xy[0], acc);
  606. acc = gcn_mfma_instr<scalar_t, 4, 15, 0>(logits[qh],
  607. Vlocal[vh][7].xy[1], acc);
  608. vout_shared[qh][vh][laneid][warpid] = from_floatx4<scalar_t>(acc);
  609. }
  610. }
  611. } // warp in context
  612. __syncthreads();
  613. if (warpid == 0) {
  614. _B16x4 vout[QHLOOP][VHELOOP];
  615. // iterate across heads
  616. scalar_t* out_ptr;
  617. int out_num_partitions;
  618. if (context_len > partition_size) {
  619. out_num_partitions = max_num_partitions;
  620. out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  621. partition_idx * HEAD_SIZE;
  622. } else {
  623. out_num_partitions = 1;
  624. out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE;
  625. }
  626. #pragma unroll
  627. for (int qh = 0; qh < QHLOOP; qh++) {
  628. // iterate over each v head elem (within head_size)
  629. #pragma unroll
  630. for (int vh = 0; vh < VHELOOP; vh++) {
  631. vout[qh][vh] = {0};
  632. #pragma unroll
  633. for (int w = 0; w < NWARPS; w++) {
  634. vout[qh][vh] =
  635. addx4<scalar_t>(vout[qh][vh], vout_shared[qh][vh][laneid][w]);
  636. }
  637. const int head_size_elem = vh * WARP_SIZE + laneid;
  638. bit16_t* out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
  639. #pragma unroll
  640. for (int i = 0; i < 4; i++) {
  641. const int head_idx = 4 * qh + i;
  642. if (head_idx < GQA_RATIO) {
  643. out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions *
  644. HEAD_SIZE +
  645. head_size_elem] = vout[qh][vh][i];
  646. }
  647. }
  648. }
  649. }
  650. }
  651. }
  652. // Grid: (num_heads, num_seqs).
  653. template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
  654. int PARTITION_SIZE>
  655. __global__
  656. __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
  657. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  658. const float* __restrict__ exp_sums, // [num_seqs, num_heads,
  659. // max_num_partitions]
  660. const float* __restrict__ max_logits, // [num_seqs, num_heads,
  661. // max_num_partitions]
  662. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
  663. // max_num_partitions, head_size]
  664. const int* __restrict__ context_lens, // [num_seqs]
  665. const int max_num_partitions) {
  666. const int num_heads = gridDim.x;
  667. const int head_idx = blockIdx.x;
  668. const int seq_idx = blockIdx.y;
  669. const int context_len = context_lens[seq_idx];
  670. const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  671. if (num_partitions == 1) {
  672. // if num_partitions==1, main kernel will write to out directly, no work in
  673. // reduction kernel
  674. return;
  675. }
  676. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  677. const int warpid = threadIdx.x / WARP_SIZE;
  678. const int laneid = threadIdx.x % WARP_SIZE;
  679. __shared__ float shared_global_exp_sum;
  680. __shared__ float shared_exp_sums[2 * WARP_SIZE];
  681. if (warpid == 0) {
  682. const float* max_logits_ptr = max_logits +
  683. seq_idx * num_heads * max_num_partitions +
  684. head_idx * max_num_partitions;
  685. // valid partition is the last valid partition in case threadid > num
  686. // partitions
  687. const int valid_partition =
  688. (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1;
  689. const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions)
  690. ? WARP_SIZE + threadIdx.x
  691. : num_partitions - 1;
  692. float reg_max_logit = max_logits_ptr[valid_partition];
  693. float reg_max_logit2 = max_logits_ptr[valid_partition2];
  694. float max_logit = fmaxf(reg_max_logit, reg_max_logit2);
  695. #pragma unroll
  696. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  697. max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
  698. }
  699. const float* exp_sums_ptr = exp_sums +
  700. seq_idx * num_heads * max_num_partitions +
  701. head_idx * max_num_partitions;
  702. float global_exp_sum = 0.0f;
  703. float rescaled_exp_sum = exp_sums_ptr[valid_partition];
  704. float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2];
  705. rescaled_exp_sum *=
  706. (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f;
  707. rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions)
  708. ? expf(reg_max_logit2 - max_logit)
  709. : 0.0f;
  710. global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2;
  711. shared_exp_sums[threadIdx.x] = rescaled_exp_sum;
  712. shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2;
  713. #pragma unroll
  714. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  715. global_exp_sum += __shfl_xor(global_exp_sum, mask);
  716. }
  717. if (threadIdx.x == 0) {
  718. shared_global_exp_sum = global_exp_sum;
  719. }
  720. } // warpid == 0
  721. const scalar_t* tmp_out_ptr =
  722. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  723. head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
  724. constexpr int MAX_NPAR = 64;
  725. scalar_t tmps[MAX_NPAR];
  726. const float dzero = 0.0f;
  727. #pragma unroll
  728. for (int j = 0; j < MAX_NPAR; j++) {
  729. tmps[j] = from_float<scalar_t>(dzero);
  730. }
  731. const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
  732. const int num_partition_offset = (num_partitions)*HEAD_SIZE;
  733. int idx = 0;
  734. constexpr int JCHUNK = 16;
  735. #pragma unroll
  736. for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) {
  737. // lastj is last valid partition
  738. const int lastj_offset =
  739. (j < num_partition_offset) ? j : last_partition_offset;
  740. tmps[idx] = tmp_out_ptr[lastj_offset];
  741. idx++;
  742. }
  743. __syncthreads();
  744. if (num_partitions > JCHUNK) {
  745. #pragma unroll
  746. for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
  747. j += HEAD_SIZE) {
  748. const int lastj_offset =
  749. (j < num_partition_offset) ? j : last_partition_offset;
  750. tmps[idx] = tmp_out_ptr[lastj_offset];
  751. idx++;
  752. }
  753. if (num_partitions > 2 * JCHUNK) {
  754. #pragma unroll
  755. for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
  756. j += HEAD_SIZE) {
  757. const int lastj_offset =
  758. (j < num_partition_offset) ? j : last_partition_offset;
  759. tmps[idx] = tmp_out_ptr[lastj_offset];
  760. idx++;
  761. }
  762. }
  763. } // num_partitions > JCHUNK
  764. // Aggregate tmp_out to out.
  765. float acc = 0.0f;
  766. #pragma unroll
  767. for (int j = 0; j < JCHUNK; j++) {
  768. acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
  769. }
  770. if (num_partitions > JCHUNK) {
  771. #pragma unroll
  772. for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
  773. acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
  774. }
  775. if (num_partitions > 2 * JCHUNK) {
  776. #pragma unroll
  777. for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
  778. acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
  779. }
  780. }
  781. }
  782. if (num_partitions > MAX_NPAR) {
  783. idx = 0;
  784. #pragma unroll
  785. for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE;
  786. j += HEAD_SIZE) {
  787. // lastj is last valid partition
  788. const int lastj_offset =
  789. (j < num_partition_offset) ? j : last_partition_offset;
  790. tmps[idx] = tmp_out_ptr[lastj_offset];
  791. idx++;
  792. }
  793. #pragma unroll
  794. for (int j = 0; j < MAX_NPAR; j++) {
  795. acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + MAX_NPAR];
  796. }
  797. }
  798. const float inv_global_exp_sum =
  799. __fdividef(1.0f, shared_global_exp_sum + 1e-6f);
  800. acc *= inv_global_exp_sum;
  801. scalar_t* out_ptr =
  802. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  803. out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
  804. }
  805. #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
  806. template <typename scalar_t, typename cache_t,
  807. aphrodite::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
  808. int NUM_THREADS,
  809. int GQA_RATIO>
  810. __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
  811. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  812. const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  813. // head_size/x, block_size, x]
  814. const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  815. // head_size, block_size]
  816. const int num_kv_heads, const float scale,
  817. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  818. const int* __restrict__ context_lens, // [num_seqs]
  819. const int max_num_blocks_per_seq,
  820. const float* __restrict__ alibi_slopes, // [num_heads]
  821. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  822. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  823. float* __restrict__ max_logits, // [num_seqs, num_heads,
  824. // max_num_partitions]
  825. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
  826. // head_size]
  827. scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
  828. int max_ctx_blocks, float k_scale, float v_scale) {
  829. UNREACHABLE_CODE
  830. }
  831. // Grid: (num_heads, num_seqs).
  832. template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
  833. int PARTITION_SIZE>
  834. __global__
  835. __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
  836. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  837. const float* __restrict__ exp_sums, // [num_seqs, num_heads,
  838. // max_num_partitions]
  839. const float* __restrict__ max_logits, // [num_seqs, num_heads,
  840. // max_num_partitions]
  841. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
  842. // max_num_partitions, head_size]
  843. const int* __restrict__ context_lens, // [num_seqs]
  844. const int max_num_partitions){UNREACHABLE_CODE}
  845. #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
  846. #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
  847. paged_attention_ll4mi_QKV_kernel<T, KVT, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, \
  848. NTHR, GQA_RATIO> \
  849. <<<grid, block, 0, stream>>>( \
  850. query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
  851. block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
  852. alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
  853. exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
  854. k_scale, v_scale);
  855. template <typename T, typename KVT, aphrodite::Fp8KVCacheDataType KV_DTYPE,
  856. int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
  857. void paged_attention_custom_launcher(
  858. torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
  859. torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
  860. torch::Tensor& value_cache, const int num_kv_heads, float scale,
  861. torch::Tensor& block_tables, torch::Tensor& context_lens,
  862. int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
  863. float k_scale, float v_scale) {
  864. int num_seqs = query.size(0);
  865. int num_heads = query.size(1);
  866. int head_size = query.size(2);
  867. int max_num_blocks_per_seq = block_tables.size(1);
  868. int q_stride = query.stride(0);
  869. int kv_block_stride = key_cache.stride(0);
  870. int kv_head_stride = key_cache.stride(1);
  871. // NOTE: alibi_slopes is optional.
  872. const float* alibi_slopes_ptr =
  873. alibi_slopes
  874. ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  875. : nullptr;
  876. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  877. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  878. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  879. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  880. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  881. KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
  882. KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
  883. int* block_tables_ptr = block_tables.data_ptr<int>();
  884. int* context_lens_ptr = context_lens.data_ptr<int>();
  885. const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
  886. const int max_num_partitions =
  887. DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  888. const int gqa_ratio = num_heads / num_kv_heads;
  889. assert(num_heads % num_kv_heads == 0);
  890. assert(head_size == HEAD_SIZE);
  891. assert(max_num_partitions <= 128);
  892. constexpr int NTHR = PARTITION_SIZE;
  893. dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
  894. dim3 block(NTHR);
  895. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  896. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  897. switch (gqa_ratio) {
  898. case 1:
  899. LAUNCH_CUSTOM_ATTENTION(1);
  900. break;
  901. case 2:
  902. LAUNCH_CUSTOM_ATTENTION(2);
  903. break;
  904. case 3:
  905. LAUNCH_CUSTOM_ATTENTION(3);
  906. break;
  907. case 4:
  908. LAUNCH_CUSTOM_ATTENTION(4);
  909. break;
  910. case 5:
  911. LAUNCH_CUSTOM_ATTENTION(5);
  912. break;
  913. case 6:
  914. LAUNCH_CUSTOM_ATTENTION(6);
  915. break;
  916. case 7:
  917. LAUNCH_CUSTOM_ATTENTION(7);
  918. break;
  919. case 8:
  920. LAUNCH_CUSTOM_ATTENTION(8);
  921. break;
  922. case 9:
  923. LAUNCH_CUSTOM_ATTENTION(9);
  924. break;
  925. case 10:
  926. LAUNCH_CUSTOM_ATTENTION(10);
  927. break;
  928. case 11:
  929. LAUNCH_CUSTOM_ATTENTION(11);
  930. break;
  931. case 12:
  932. LAUNCH_CUSTOM_ATTENTION(12);
  933. break;
  934. case 13:
  935. LAUNCH_CUSTOM_ATTENTION(13);
  936. break;
  937. case 14:
  938. LAUNCH_CUSTOM_ATTENTION(14);
  939. break;
  940. case 15:
  941. LAUNCH_CUSTOM_ATTENTION(15);
  942. break;
  943. case 16:
  944. LAUNCH_CUSTOM_ATTENTION(16);
  945. break;
  946. default:
  947. TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio);
  948. break;
  949. }
  950. // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG);
  951. // dim3 block2(1024);
  952. // LAUNCH_CUSTOM_ATTENTION2;
  953. // reduction kernel is only required if max_context_len > partition size,
  954. // otherwise main kernel writes directly to final output
  955. // note there are cases with graphing where max_context_len is the max
  956. // supported by graphing, not the actual max among all the sequences: in that
  957. // case reduction kernel will still run but return immediately
  958. if (max_context_len > PARTITION_SIZE) {
  959. dim3 reduce_grid(num_heads, num_seqs);
  960. dim3 reduce_block(head_size);
  961. paged_attention_ll4mi_reduce_kernel<T, HEAD_SIZE, HEAD_SIZE, PARTITION_SIZE>
  962. <<<reduce_grid, reduce_block, 0, stream>>>(
  963. out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,
  964. context_lens_ptr, max_num_partitions);
  965. }
  966. }
  967. #define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
  968. paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE>( \
  969. out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
  970. num_kv_heads, scale, block_tables, context_lens, max_context_len, \
  971. alibi_slopes, k_scale, v_scale);
  972. #define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
  973. switch (block_size) { \
  974. case 16: \
  975. CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
  976. break; \
  977. case 32: \
  978. CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
  979. break; \
  980. default: \
  981. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  982. break; \
  983. }
  984. #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
  985. switch (head_size) { \
  986. case 64: \
  987. CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
  988. break; \
  989. case 128: \
  990. CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
  991. break; \
  992. default: \
  993. TORCH_CHECK(false, "Unsupported head size: ", head_size); \
  994. break; \
  995. }
  996. void paged_attention(
  997. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  998. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  999. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  1000. torch::Tensor&
  1001. tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  1002. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  1003. torch::Tensor&
  1004. key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  1005. torch::Tensor&
  1006. value_cache, // [num_blocks, num_heads, head_size, block_size]
  1007. int64_t num_kv_heads, double scale,
  1008. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  1009. torch::Tensor& context_lens, // [num_seqs]
  1010. int64_t block_size, int64_t max_context_len,
  1011. const c10::optional<torch::Tensor>& alibi_slopes,
  1012. const std::string& kv_cache_dtype, double k_scale, double v_scale) {
  1013. const int head_size = query.size(2);
  1014. if (kv_cache_dtype == "auto") {
  1015. if (query.dtype() == at::ScalarType::Half) {
  1016. CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16,
  1017. aphrodite::Fp8KVCacheDataType::kAuto);
  1018. } else if (query.dtype() == at::ScalarType::BFloat16) {
  1019. CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
  1020. aphrodite::Fp8KVCacheDataType::kAuto);
  1021. } else {
  1022. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  1023. }
  1024. } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
  1025. if (query.dtype() == at::ScalarType::Half) {
  1026. CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
  1027. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  1028. } else if (query.dtype() == at::ScalarType::BFloat16) {
  1029. CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
  1030. aphrodite::Fp8KVCacheDataType::kFp8E4M3);
  1031. } else {
  1032. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  1033. }
  1034. } else {
  1035. TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype);
  1036. }
  1037. }
  1038. #undef WARP_SIZE
  1039. #undef MAX
  1040. #undef MIN
  1041. #undef DIVIDE_ROUND_UP