attention.cu 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958
  1. /*
  2. * Copyright (c) 2024, The vLLM team.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <torch/all.h>
  17. #include <ATen/cuda/CUDAContext.h>
  18. #include <c10/cuda/CUDAGuard.h>
  19. #include <hip/hip_bf16.h>
  20. #include <algorithm>
  21. #if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
  22. defined(__gfx941__) || defined(__gfx942__))
  23. #define __HIP__MI300_MI250__
  24. #endif
  25. #if defined(NDEBUG)
  26. #undef NDEBUG
  27. #include <assert.h>
  28. #define UNREACHABLE_CODE assert(false);
  29. #define NDEBUG
  30. #else
  31. #define UNREACHABLE_CODE assert(false);
  32. #endif
  33. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  34. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  35. #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
  36. #define WARP_SIZE 64
  37. #if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
  38. #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
  39. #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
  40. using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
  41. using float16x4 =
  42. __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
  43. typedef float16x4 _Half4;
  44. typedef struct _Half8 {
  45. _Half4 xy[2];
  46. } _Half8;
  47. using bit16_t = uint16_t;
  48. using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
  49. typedef bit16x4 _B16x4;
  50. typedef struct _B16x8 {
  51. _B16x4 xy[2];
  52. } _B16x8;
  53. ////// Non temporal load stores ///////
  54. template <typename T>
  55. __device__ __forceinline__ T load(T* addr) {
  56. return addr[0];
  57. }
  58. template <typename T>
  59. __device__ __forceinline__ void store(T value, T* addr) {
  60. addr[0] = value;
  61. }
  62. template <typename T, int absz, int cbid, int blgp>
  63. __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA,
  64. const _B16x4& inpB,
  65. const floatx4& inpC) {
  66. if constexpr (std::is_same<T, _Float16>::value) {
  67. return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid,
  68. blgp);
  69. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  70. return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid,
  71. blgp);
  72. } else {
  73. static_assert(false, "unsupported 16b dtype");
  74. }
  75. }
  76. template <typename T>
  77. __device__ __forceinline__ float to_float(const T& inp) {
  78. if constexpr (std::is_same<T, _Float16>::value) {
  79. return (float)inp;
  80. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  81. return __bfloat162float(inp);
  82. } else {
  83. static_assert(false, "unsupported 16b dtype");
  84. }
  85. }
  86. template <typename T>
  87. __device__ __forceinline__ T from_float(const float& inp) {
  88. if constexpr (std::is_same<T, _Float16>::value) {
  89. return (_Float16)inp;
  90. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  91. return __float2bfloat16(inp);
  92. } else {
  93. static_assert(false, "unsupported 16b dtype");
  94. }
  95. }
  96. template <typename T>
  97. __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) {
  98. union tmpcvt {
  99. uint16_t u;
  100. _Float16 f;
  101. __hip_bfloat16 b;
  102. } t16;
  103. _B16x4 ret;
  104. if constexpr (std::is_same<T, _Float16>::value) {
  105. #pragma unroll
  106. for (int i = 0; i < 4; i++) {
  107. t16.f = (_Float16)inp[i];
  108. ret[i] = t16.u;
  109. }
  110. return ret;
  111. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  112. #pragma unroll
  113. for (int i = 0; i < 4; i++) {
  114. t16.b = __float2bfloat16(inp[i]);
  115. ret[i] = t16.u;
  116. }
  117. return ret;
  118. } else {
  119. static_assert(false, "unsupported 16b dtype");
  120. }
  121. }
  122. template <typename T>
  123. __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
  124. const _B16x4& inp2) {
  125. union tmpcvt {
  126. uint16_t u;
  127. _Float16 f;
  128. __hip_bfloat16 b;
  129. } t1, t2, res;
  130. _B16x4 ret;
  131. if constexpr (std::is_same<T, _Float16>::value) {
  132. #pragma unroll
  133. for (int i = 0; i < 4; i++) {
  134. t1.u = inp1[i];
  135. t2.u = inp2[i];
  136. res.f = t1.f + t2.f;
  137. ret[i] = res.u;
  138. }
  139. return ret;
  140. } else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
  141. #pragma unroll
  142. for (int i = 0; i < 4; i++) {
  143. t1.u = inp1[i];
  144. t2.u = inp2[i];
  145. res.b = t1.b + t2.b;
  146. ret[i] = res.u;
  147. }
  148. return ret;
  149. } else {
  150. static_assert(false, "unsupported 16b dtype");
  151. }
  152. }
  153. ///////////////////////////////////////
  154. // grid (num_seqs, num_partitions,num_heads/gqa_ratio)
  155. // block (partition size)
  156. template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
  157. int GQA_RATIO>
  158. __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
  159. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  160. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  161. // head_size/x, block_size, x]
  162. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  163. // head_size, block_size]
  164. const int num_kv_heads, const float scale,
  165. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  166. const int* __restrict__ context_lens, // [num_seqs]
  167. const int max_num_blocks_per_seq,
  168. const float* __restrict__ alibi_slopes, // [num_heads]
  169. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  170. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  171. float* __restrict__ max_logits, // [num_seqs, num_heads,
  172. // max_num_partitions]
  173. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
  174. // head_size]
  175. scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
  176. #if 0
  177. scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
  178. #endif
  179. int max_ctx_blocks) {
  180. constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
  181. const int warpid = threadIdx.x / WARP_SIZE;
  182. const int laneid = threadIdx.x % WARP_SIZE;
  183. const int lane4id = laneid % 4;
  184. const int seq_idx = blockIdx.x;
  185. const int partition_idx = blockIdx.y;
  186. const int partition_size = blockDim.x;
  187. const int max_num_partitions = gridDim.y;
  188. const int context_len = context_lens[seq_idx];
  189. const int partition_start_token_idx = partition_idx * partition_size;
  190. // exit if partition is out of context for seq
  191. if (partition_start_token_idx >= context_len) {
  192. return;
  193. }
  194. constexpr int QHLOOP =
  195. DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads,
  196. // total qheads =8, so qhloop is 2
  197. constexpr int GQA_RATIO4 = 4 * QHLOOP;
  198. __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1];
  199. __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1];
  200. _B16x8 Qlocal[QHLOOP];
  201. constexpr int x = 16 / sizeof(scalar_t);
  202. constexpr int KHELOOP = HEAD_SIZE / x;
  203. _B16x8 Klocal[KHELOOP];
  204. constexpr int VHELOOP =
  205. HEAD_SIZE /
  206. WARP_SIZE; // v head_size dimension is distributed across lanes
  207. constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
  208. // 8xtokens
  209. _B16x8 Vlocal[VHELOOP][VTLOOP];
  210. floatx4 dout[QHLOOP];
  211. float qk_max[QHLOOP];
  212. #pragma unroll
  213. for (int h = 0; h < QHLOOP; h++) {
  214. dout[h] = {0};
  215. qk_max[h] = -FLT_MAX;
  216. }
  217. const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
  218. const int wg_start_kv_head_idx = blockIdx.z;
  219. const int warp_start_token_idx =
  220. partition_start_token_idx + warpid * WARP_SIZE;
  221. if (warp_start_token_idx >= context_len) { // warp out of context
  222. #pragma unroll
  223. for (int h = 0; h < GQA_RATIO4; h++) {
  224. shared_qk_max[warpid][h] = -FLT_MAX;
  225. shared_exp_sum[warpid][h] = 0.0f;
  226. }
  227. } else { // warp within context
  228. const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
  229. const int last_ctx_block = num_context_blocks - 1;
  230. const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
  231. const int local_token_idx = threadIdx.x;
  232. const int global_token_idx = partition_start_token_idx + local_token_idx;
  233. const int block_idx = (global_token_idx < context_len)
  234. ? global_token_idx / BLOCK_SIZE
  235. : last_ctx_block;
  236. // fetch block number for q and k
  237. // int32 physical_block_number leads to overflow when multiplied with
  238. // kv_block_stride
  239. const int64_t physical_block_number =
  240. static_cast<int64_t>(block_table[block_idx]);
  241. // fetch vphysical block numbers up front
  242. constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE;
  243. int vphysical_blocks[VBLOCKS];
  244. const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE;
  245. #pragma unroll
  246. for (int b = 0; b < VBLOCKS; b++) {
  247. const int vblock_idx = warp_start_block_idx + b;
  248. const int vblock_idx_ctx =
  249. (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
  250. vphysical_blocks[b] = block_table[vblock_idx_ctx];
  251. }
  252. // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
  253. const scalar_t* q_ptr =
  254. q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
  255. const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
  256. const int qhead_elemh8 = laneid / 4;
  257. #pragma unroll
  258. for (int h = 0; h < QHLOOP - 1; h++) {
  259. const int qhead_idx = h * 4 + lane4id;
  260. Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
  261. }
  262. const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id;
  263. if (final_qhead_idx < GQA_RATIO) {
  264. Qlocal[QHLOOP - 1] =
  265. q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
  266. } else {
  267. Qlocal[QHLOOP - 1].xy[0] = {0};
  268. Qlocal[QHLOOP - 1].xy[1] = {0};
  269. }
  270. const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
  271. wg_start_kv_head_idx * kv_head_stride;
  272. const int physical_block_offset =
  273. local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
  274. // is already cast as _H8
  275. const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
  276. #pragma unroll
  277. for (int d = 0; d < KHELOOP; d++) {
  278. Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
  279. }
  280. float alibi_slope[QHLOOP];
  281. if (alibi_slopes != nullptr) {
  282. #pragma unroll
  283. for (int h = 0; h < QHLOOP; h++) {
  284. const int qhead_idx = h * 4 + lane4id;
  285. alibi_slope[h] = (qhead_idx < GQA_RATIO)
  286. ? alibi_slopes[wg_start_head_idx + qhead_idx]
  287. : 0.f;
  288. }
  289. }
  290. const scalar_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
  291. const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
  292. // iterate over each v block
  293. #pragma unroll
  294. for (int b = 0; b < VBLOCKS; b++) {
  295. // int32 physical_block_number leads to overflow when multiplied with
  296. // kv_block_stride
  297. const int64_t vphysical_block_number =
  298. static_cast<int64_t>(vphysical_blocks[b]);
  299. const _B16x8* v_ptrh8b =
  300. v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
  301. // iterate over each head elem (within head_size)
  302. #pragma unroll
  303. for (int h = 0; h < VHELOOP; h++) {
  304. const int head_size_elem = h * WARP_SIZE + laneid;
  305. const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
  306. // iterate over all velems within block
  307. #pragma unroll
  308. for (int d = 0; d < BLOCK_SIZE / 8; d++) {
  309. Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
  310. }
  311. }
  312. }
  313. #pragma unroll
  314. for (int h = 0; h < QHLOOP; h++) {
  315. dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[0],
  316. Klocal[0].xy[0], dout[h]);
  317. dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[1],
  318. Klocal[0].xy[1], dout[h]);
  319. dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[0],
  320. Klocal[1].xy[0], dout[h]);
  321. dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[1],
  322. Klocal[1].xy[1], dout[h]);
  323. dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[0],
  324. Klocal[2].xy[0], dout[h]);
  325. dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[1],
  326. Klocal[2].xy[1], dout[h]);
  327. dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[0],
  328. Klocal[3].xy[0], dout[h]);
  329. dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[1],
  330. Klocal[3].xy[1], dout[h]);
  331. dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[0],
  332. Klocal[4].xy[0], dout[h]);
  333. dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[1],
  334. Klocal[4].xy[1], dout[h]);
  335. dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[0],
  336. Klocal[5].xy[0], dout[h]);
  337. dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[1],
  338. Klocal[5].xy[1], dout[h]);
  339. dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[0],
  340. Klocal[6].xy[0], dout[h]);
  341. dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[1],
  342. Klocal[6].xy[1], dout[h]);
  343. dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[0],
  344. Klocal[7].xy[0], dout[h]);
  345. dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[1],
  346. Klocal[7].xy[1], dout[h]);
  347. if constexpr (KHELOOP > 8) {
  348. dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[0],
  349. Klocal[8].xy[0], dout[h]);
  350. dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[1],
  351. Klocal[8].xy[1], dout[h]);
  352. dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[0],
  353. Klocal[9].xy[0], dout[h]);
  354. dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[1],
  355. Klocal[9].xy[1], dout[h]);
  356. dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[0],
  357. Klocal[10].xy[0], dout[h]);
  358. dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[1],
  359. Klocal[10].xy[1], dout[h]);
  360. dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[0],
  361. Klocal[11].xy[0], dout[h]);
  362. dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[1],
  363. Klocal[11].xy[1], dout[h]);
  364. dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[0],
  365. Klocal[12].xy[0], dout[h]);
  366. dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[1],
  367. Klocal[12].xy[1], dout[h]);
  368. dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[0],
  369. Klocal[13].xy[0], dout[h]);
  370. dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[1],
  371. Klocal[13].xy[1], dout[h]);
  372. dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[0],
  373. Klocal[14].xy[0], dout[h]);
  374. dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[1],
  375. Klocal[14].xy[1], dout[h]);
  376. dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[0],
  377. Klocal[15].xy[0], dout[h]);
  378. dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[1],
  379. Klocal[15].xy[1], dout[h]);
  380. } // KHELOOP>8
  381. dout[h] *= scale;
  382. }
  383. // transpose dout so that 4 token ids are in each lane, and 4 heads are across
  384. // 4 lanes
  385. #pragma unroll
  386. for (int h = 0; h < QHLOOP; h++) {
  387. floatx4 tmp = {0};
  388. #pragma unroll
  389. for (int i = 0; i < 4; i++) {
  390. const float B = (lane4id == i) ? 1.0f : 0.0f;
  391. // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
  392. tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0);
  393. // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0);
  394. }
  395. dout[h] = tmp;
  396. }
  397. const int lane4_token_idx = 4 * (global_token_idx >> 2);
  398. const int alibi_offset = lane4_token_idx - context_len + 1;
  399. if (alibi_slopes != nullptr) {
  400. #pragma unroll
  401. for (int h = 0; h < QHLOOP; h++) {
  402. #pragma unroll
  403. for (int i = 0; i < 4; i++) {
  404. dout[h][i] += alibi_slope[h] * (alibi_offset + i);
  405. }
  406. }
  407. }
  408. #pragma unroll
  409. for (int h = 0; h < QHLOOP; h++) {
  410. qk_max[h] = -FLT_MAX;
  411. #pragma unroll
  412. for (int i = 0; i < 4; i++) {
  413. qk_max[h] = (lane4_token_idx + i < context_len)
  414. ? fmaxf(qk_max[h], dout[h][i])
  415. : qk_max[h];
  416. }
  417. #pragma unroll
  418. for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
  419. qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask));
  420. }
  421. }
  422. float exp_sum[QHLOOP];
  423. #pragma unroll
  424. for (int h = 0; h < QHLOOP; h++) {
  425. exp_sum[h] = 0.0f;
  426. #pragma unroll
  427. for (int i = 0; i < 4; i++) {
  428. dout[h][i] = (lane4_token_idx + i < context_len)
  429. ? __expf(dout[h][i] - qk_max[h])
  430. : 0.0f;
  431. exp_sum[h] += dout[h][i];
  432. }
  433. #pragma unroll
  434. for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
  435. exp_sum[h] += __shfl_xor(exp_sum[h], mask);
  436. }
  437. }
  438. #pragma unroll
  439. for (int h = 0; h < QHLOOP; h++) {
  440. const int head_idx = 4 * h + lane4id;
  441. shared_qk_max[warpid][head_idx] = qk_max[h];
  442. shared_exp_sum[warpid][head_idx] = exp_sum[h];
  443. }
  444. } // warp within context
  445. __syncthreads();
  446. const int num_heads = gridDim.z * GQA_RATIO;
  447. float* max_logits_ptr =
  448. max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
  449. float* exp_sums_ptr =
  450. exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx;
  451. #pragma unroll
  452. for (int h = 0; h < QHLOOP; h++) {
  453. float global_qk_max = -FLT_MAX;
  454. float warp_qk_max[NWARPS];
  455. const int head_idx = 4 * h + lane4id;
  456. #pragma unroll
  457. for (int w = 0; w < NWARPS; w++) {
  458. warp_qk_max[w] = shared_qk_max[w][head_idx];
  459. global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]);
  460. }
  461. float global_exp_sum = 0.0f;
  462. #pragma unroll
  463. for (int w = 0; w < NWARPS; w++) {
  464. global_exp_sum +=
  465. shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max);
  466. }
  467. if (head_idx < GQA_RATIO) {
  468. max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
  469. global_qk_max;
  470. exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
  471. global_exp_sum;
  472. }
  473. const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) *
  474. __expf(qk_max[h] - global_qk_max);
  475. dout[h] *= global_inv_sum_scale;
  476. }
  477. // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
  478. // are 4x16 tokens across warp
  479. _B16x4 logits[QHLOOP];
  480. #pragma unroll
  481. for (int h = 0; h < QHLOOP; h++) {
  482. logits[h] = from_floatx4<scalar_t>(dout[h]);
  483. }
  484. __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1];
  485. if (warp_start_token_idx >= context_len) { // warp out of context
  486. #pragma unroll
  487. for (int qh = 0; qh < QHLOOP; qh++) {
  488. #pragma unroll
  489. for (int vh = 0; vh < VHELOOP; vh++) {
  490. vout_shared[qh][vh][laneid][warpid] = {0};
  491. }
  492. }
  493. } else { // warp in context
  494. // iterate across heads
  495. #pragma unroll
  496. for (int qh = 0; qh < QHLOOP; qh++) {
  497. // iterate over each v head elem (within head_size)
  498. #pragma unroll
  499. for (int vh = 0; vh < VHELOOP; vh++) {
  500. floatx4 acc = {0};
  501. // iterate over tokens
  502. acc = gcn_mfma_instr<scalar_t, 4, 0, 0>(logits[qh], Vlocal[vh][0].xy[0],
  503. acc);
  504. acc = gcn_mfma_instr<scalar_t, 4, 1, 0>(logits[qh], Vlocal[vh][0].xy[1],
  505. acc);
  506. acc = gcn_mfma_instr<scalar_t, 4, 2, 0>(logits[qh], Vlocal[vh][1].xy[0],
  507. acc);
  508. acc = gcn_mfma_instr<scalar_t, 4, 3, 0>(logits[qh], Vlocal[vh][1].xy[1],
  509. acc);
  510. acc = gcn_mfma_instr<scalar_t, 4, 4, 0>(logits[qh], Vlocal[vh][2].xy[0],
  511. acc);
  512. acc = gcn_mfma_instr<scalar_t, 4, 5, 0>(logits[qh], Vlocal[vh][2].xy[1],
  513. acc);
  514. acc = gcn_mfma_instr<scalar_t, 4, 6, 0>(logits[qh], Vlocal[vh][3].xy[0],
  515. acc);
  516. acc = gcn_mfma_instr<scalar_t, 4, 7, 0>(logits[qh], Vlocal[vh][3].xy[1],
  517. acc);
  518. acc = gcn_mfma_instr<scalar_t, 4, 8, 0>(logits[qh], Vlocal[vh][4].xy[0],
  519. acc);
  520. acc = gcn_mfma_instr<scalar_t, 4, 9, 0>(logits[qh], Vlocal[vh][4].xy[1],
  521. acc);
  522. acc = gcn_mfma_instr<scalar_t, 4, 10, 0>(logits[qh],
  523. Vlocal[vh][5].xy[0], acc);
  524. acc = gcn_mfma_instr<scalar_t, 4, 11, 0>(logits[qh],
  525. Vlocal[vh][5].xy[1], acc);
  526. acc = gcn_mfma_instr<scalar_t, 4, 12, 0>(logits[qh],
  527. Vlocal[vh][6].xy[0], acc);
  528. acc = gcn_mfma_instr<scalar_t, 4, 13, 0>(logits[qh],
  529. Vlocal[vh][6].xy[1], acc);
  530. acc = gcn_mfma_instr<scalar_t, 4, 14, 0>(logits[qh],
  531. Vlocal[vh][7].xy[0], acc);
  532. acc = gcn_mfma_instr<scalar_t, 4, 15, 0>(logits[qh],
  533. Vlocal[vh][7].xy[1], acc);
  534. vout_shared[qh][vh][laneid][warpid] = from_floatx4<scalar_t>(acc);
  535. }
  536. }
  537. } // warp in context
  538. __syncthreads();
  539. if (warpid == 0) {
  540. _B16x4 vout[QHLOOP][VHELOOP];
  541. // iterate across heads
  542. scalar_t* out_ptr;
  543. int out_num_partitions;
  544. if (context_len > partition_size) {
  545. out_num_partitions = max_num_partitions;
  546. out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  547. partition_idx * HEAD_SIZE;
  548. } else {
  549. out_num_partitions = 1;
  550. out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE;
  551. }
  552. #pragma unroll
  553. for (int qh = 0; qh < QHLOOP; qh++) {
  554. // iterate over each v head elem (within head_size)
  555. #pragma unroll
  556. for (int vh = 0; vh < VHELOOP; vh++) {
  557. vout[qh][vh] = {0};
  558. #pragma unroll
  559. for (int w = 0; w < NWARPS; w++) {
  560. vout[qh][vh] =
  561. addx4<scalar_t>(vout[qh][vh], vout_shared[qh][vh][laneid][w]);
  562. }
  563. const int head_size_elem = vh * WARP_SIZE + laneid;
  564. bit16_t* out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
  565. #pragma unroll
  566. for (int i = 0; i < 4; i++) {
  567. const int head_idx = 4 * qh + i;
  568. if (head_idx < GQA_RATIO) {
  569. out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions *
  570. HEAD_SIZE +
  571. head_size_elem] = vout[qh][vh][i];
  572. }
  573. }
  574. }
  575. }
  576. }
  577. }
  578. // Grid: (num_heads, num_seqs).
  579. template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
  580. int PARTITION_SIZE>
  581. __global__
  582. __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
  583. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  584. const float* __restrict__ exp_sums, // [num_seqs, num_heads,
  585. // max_num_partitions]
  586. const float* __restrict__ max_logits, // [num_seqs, num_heads,
  587. // max_num_partitions]
  588. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
  589. // max_num_partitions, head_size]
  590. const int* __restrict__ context_lens, // [num_seqs]
  591. const int max_num_partitions) {
  592. const int num_heads = gridDim.x;
  593. const int head_idx = blockIdx.x;
  594. const int seq_idx = blockIdx.y;
  595. const int context_len = context_lens[seq_idx];
  596. const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
  597. if (num_partitions == 1) {
  598. // if num_partitions==1, main kernel will write to out directly, no work in
  599. // reduction kernel
  600. return;
  601. }
  602. constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
  603. const int warpid = threadIdx.x / WARP_SIZE;
  604. const int laneid = threadIdx.x % WARP_SIZE;
  605. __shared__ float shared_global_exp_sum;
  606. __shared__ float shared_exp_sums[2 * WARP_SIZE];
  607. if (warpid == 0) {
  608. const float* max_logits_ptr = max_logits +
  609. seq_idx * num_heads * max_num_partitions +
  610. head_idx * max_num_partitions;
  611. // valid partition is the last valid partition in case threadid > num
  612. // partitions
  613. const int valid_partition =
  614. (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1;
  615. const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions)
  616. ? WARP_SIZE + threadIdx.x
  617. : num_partitions - 1;
  618. float reg_max_logit = max_logits_ptr[valid_partition];
  619. float reg_max_logit2 = max_logits_ptr[valid_partition2];
  620. float max_logit = fmaxf(reg_max_logit, reg_max_logit2);
  621. #pragma unroll
  622. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  623. max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
  624. }
  625. const float* exp_sums_ptr = exp_sums +
  626. seq_idx * num_heads * max_num_partitions +
  627. head_idx * max_num_partitions;
  628. float global_exp_sum = 0.0f;
  629. float rescaled_exp_sum = exp_sums_ptr[valid_partition];
  630. float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2];
  631. rescaled_exp_sum *=
  632. (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f;
  633. rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions)
  634. ? expf(reg_max_logit2 - max_logit)
  635. : 0.0f;
  636. global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2;
  637. shared_exp_sums[threadIdx.x] = rescaled_exp_sum;
  638. shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2;
  639. #pragma unroll
  640. for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
  641. global_exp_sum += __shfl_xor(global_exp_sum, mask);
  642. }
  643. if (threadIdx.x == 0) {
  644. shared_global_exp_sum = global_exp_sum;
  645. }
  646. } // warpid == 0
  647. const scalar_t* tmp_out_ptr =
  648. tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
  649. head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
  650. constexpr int MAX_NPAR = 64;
  651. scalar_t tmps[MAX_NPAR];
  652. const float dzero = 0.0f;
  653. #pragma unroll
  654. for (int j = 0; j < MAX_NPAR; j++) {
  655. tmps[j] = from_float<scalar_t>(dzero);
  656. }
  657. const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
  658. const int num_partition_offset = (num_partitions)*HEAD_SIZE;
  659. int idx = 0;
  660. constexpr int JCHUNK = 16;
  661. #pragma unroll
  662. for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) {
  663. // lastj is last valid partition
  664. const int lastj_offset =
  665. (j < num_partition_offset) ? j : last_partition_offset;
  666. tmps[idx] = tmp_out_ptr[lastj_offset];
  667. idx++;
  668. }
  669. __syncthreads();
  670. if (num_partitions > JCHUNK) {
  671. #pragma unroll
  672. for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
  673. j += HEAD_SIZE) {
  674. const int lastj_offset =
  675. (j < num_partition_offset) ? j : last_partition_offset;
  676. tmps[idx] = tmp_out_ptr[lastj_offset];
  677. idx++;
  678. }
  679. if (num_partitions > 2 * JCHUNK) {
  680. #pragma unroll
  681. for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
  682. j += HEAD_SIZE) {
  683. const int lastj_offset =
  684. (j < num_partition_offset) ? j : last_partition_offset;
  685. tmps[idx] = tmp_out_ptr[lastj_offset];
  686. idx++;
  687. }
  688. }
  689. } // num_partitions > JCHUNK
  690. // Aggregate tmp_out to out.
  691. float acc = 0.0f;
  692. #pragma unroll
  693. for (int j = 0; j < JCHUNK; j++) {
  694. acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
  695. }
  696. if (num_partitions > JCHUNK) {
  697. #pragma unroll
  698. for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
  699. acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
  700. }
  701. if (num_partitions > 2 * JCHUNK) {
  702. #pragma unroll
  703. for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
  704. acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
  705. }
  706. }
  707. }
  708. if (num_partitions > MAX_NPAR) {
  709. idx = 0;
  710. #pragma unroll
  711. for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE;
  712. j += HEAD_SIZE) {
  713. // lastj is last valid partition
  714. const int lastj_offset =
  715. (j < num_partition_offset) ? j : last_partition_offset;
  716. tmps[idx] = tmp_out_ptr[lastj_offset];
  717. idx++;
  718. }
  719. #pragma unroll
  720. for (int j = 0; j < MAX_NPAR; j++) {
  721. acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + MAX_NPAR];
  722. }
  723. }
  724. const float inv_global_exp_sum =
  725. __fdividef(1.0f, shared_global_exp_sum + 1e-6f);
  726. acc *= inv_global_exp_sum;
  727. scalar_t* out_ptr =
  728. out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
  729. out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
  730. }
  731. #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
  732. template <typename scalar_t, int BLOCK_SIZE, int HEAD_SIZE, int NUM_THREADS,
  733. int GQA_RATIO>
  734. __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
  735. const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
  736. const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
  737. // head_size/x, block_size, x]
  738. const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
  739. // head_size, block_size]
  740. const int num_kv_heads, const float scale,
  741. const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
  742. const int* __restrict__ context_lens, // [num_seqs]
  743. const int max_num_blocks_per_seq,
  744. const float* __restrict__ alibi_slopes, // [num_heads]
  745. const int q_stride, const int kv_block_stride, const int kv_head_stride,
  746. float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
  747. float* __restrict__ max_logits, // [num_seqs, num_heads,
  748. // max_num_partitions]
  749. scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
  750. // head_size]
  751. scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
  752. #if 0
  753. scalar_t* __restrict__ qk_out, // [num_heads, num_seqs, max_ctx_blocks,block_size]
  754. #endif
  755. int max_ctx_blocks) {
  756. UNREACHABLE_CODE
  757. }
  758. // Grid: (num_heads, num_seqs).
  759. template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
  760. int PARTITION_SIZE>
  761. __global__
  762. __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
  763. scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
  764. const float* __restrict__ exp_sums, // [num_seqs, num_heads,
  765. // max_num_partitions]
  766. const float* __restrict__ max_logits, // [num_seqs, num_heads,
  767. // max_num_partitions]
  768. const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
  769. // max_num_partitions, head_size]
  770. const int* __restrict__ context_lens, // [num_seqs]
  771. const int max_num_partitions){UNREACHABLE_CODE}
  772. #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
  773. #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
  774. paged_attention_ll4mi_QKV_kernel<T, BLOCK_SIZE, HEAD_SIZE, NTHR, GQA_RATIO> \
  775. <<<grid, block, 0, stream>>>( \
  776. query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
  777. block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
  778. alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
  779. exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks);
  780. template <typename T, int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 256>
  781. void paged_attention_custom_launcher(
  782. torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
  783. torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
  784. torch::Tensor& value_cache, const int num_kv_heads, float scale,
  785. torch::Tensor& block_tables, torch::Tensor& context_lens,
  786. int max_context_len,
  787. #if 0
  788. torch::Tensor& qk_out,
  789. torch::Tensor& softmax_out,
  790. #endif
  791. const c10::optional<torch::Tensor>& alibi_slopes) {
  792. int num_seqs = query.size(0);
  793. int num_heads = query.size(1);
  794. int head_size = query.size(2);
  795. int max_num_blocks_per_seq = block_tables.size(1);
  796. int q_stride = query.stride(0);
  797. int kv_block_stride = key_cache.stride(0);
  798. int kv_head_stride = key_cache.stride(1);
  799. // NOTE: alibi_slopes is optional.
  800. const float* alibi_slopes_ptr =
  801. alibi_slopes
  802. ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
  803. : nullptr;
  804. T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
  805. float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
  806. float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
  807. T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
  808. T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
  809. T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
  810. T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
  811. int* block_tables_ptr = block_tables.data_ptr<int>();
  812. int* context_lens_ptr = context_lens.data_ptr<int>();
  813. #if 0
  814. T* qk_out_ptr = reinterpret_cast<T*>(qk_out.data_ptr());
  815. T* softmax_out_ptr = reinterpret_cast<T*>(softmax_out.data_ptr());
  816. #endif
  817. const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
  818. const int max_num_partitions =
  819. DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
  820. const int gqa_ratio = num_heads / num_kv_heads;
  821. assert(num_heads % num_kv_heads == 0);
  822. assert(head_size == HEAD_SIZE);
  823. assert(max_num_partitions <= 128);
  824. constexpr int NTHR = PARTITION_SIZE;
  825. dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
  826. dim3 block(NTHR);
  827. const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
  828. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  829. switch (gqa_ratio) {
  830. case 1:
  831. LAUNCH_CUSTOM_ATTENTION(1);
  832. break;
  833. case 2:
  834. LAUNCH_CUSTOM_ATTENTION(2);
  835. break;
  836. case 3:
  837. LAUNCH_CUSTOM_ATTENTION(3);
  838. break;
  839. case 4:
  840. LAUNCH_CUSTOM_ATTENTION(4);
  841. break;
  842. case 5:
  843. LAUNCH_CUSTOM_ATTENTION(5);
  844. break;
  845. case 6:
  846. LAUNCH_CUSTOM_ATTENTION(6);
  847. break;
  848. case 7:
  849. LAUNCH_CUSTOM_ATTENTION(7);
  850. break;
  851. case 8:
  852. LAUNCH_CUSTOM_ATTENTION(8);
  853. break;
  854. case 9:
  855. LAUNCH_CUSTOM_ATTENTION(9);
  856. break;
  857. case 10:
  858. LAUNCH_CUSTOM_ATTENTION(10);
  859. break;
  860. case 11:
  861. LAUNCH_CUSTOM_ATTENTION(11);
  862. break;
  863. case 12:
  864. LAUNCH_CUSTOM_ATTENTION(12);
  865. break;
  866. case 13:
  867. LAUNCH_CUSTOM_ATTENTION(13);
  868. break;
  869. case 14:
  870. LAUNCH_CUSTOM_ATTENTION(14);
  871. break;
  872. case 15:
  873. LAUNCH_CUSTOM_ATTENTION(15);
  874. break;
  875. case 16:
  876. LAUNCH_CUSTOM_ATTENTION(16);
  877. break;
  878. default:
  879. TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio);
  880. break;
  881. }
  882. // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG);
  883. // dim3 block2(1024);
  884. // LAUNCH_CUSTOM_ATTENTION2;
  885. // reduction kernel is only required if max_context_len > partition size,
  886. // otherwise main kernel writes directly to final output
  887. // note there are cases with graphing where max_context_len is the max
  888. // supported by graphing, not the actual max among all the sequences: in that
  889. // case reduction kernel will still run but return immediately
  890. if (max_context_len > PARTITION_SIZE) {
  891. dim3 reduce_grid(num_heads, num_seqs);
  892. dim3 reduce_block(head_size);
  893. paged_attention_ll4mi_reduce_kernel<T, HEAD_SIZE, HEAD_SIZE, PARTITION_SIZE>
  894. <<<reduce_grid, reduce_block, 0, stream>>>(
  895. out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,
  896. context_lens_ptr, max_num_partitions);
  897. }
  898. }
  899. #define CALL_CUSTOM_LAUNCHER(T, BLK_SIZE, HEAD_SIZE) \
  900. paged_attention_custom_launcher<T, BLK_SIZE, HEAD_SIZE>( \
  901. out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
  902. num_kv_heads, scale, block_tables, context_lens, max_context_len, \
  903. alibi_slopes);
  904. #define CALL_CUSTOM_LAUNCHER_BLK(T, HEAD_SIZE) \
  905. switch (block_size) { \
  906. case 16: \
  907. CALL_CUSTOM_LAUNCHER(T, 16, HEAD_SIZE); \
  908. break; \
  909. case 32: \
  910. CALL_CUSTOM_LAUNCHER(T, 32, HEAD_SIZE); \
  911. break; \
  912. default: \
  913. TORCH_CHECK(false, "Unsupported block size: ", block_size); \
  914. break; \
  915. }
  916. #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T) \
  917. switch (head_size) { \
  918. case 64: \
  919. CALL_CUSTOM_LAUNCHER_BLK(T, 64); \
  920. break; \
  921. case 128: \
  922. CALL_CUSTOM_LAUNCHER_BLK(T, 128); \
  923. break; \
  924. default: \
  925. TORCH_CHECK(false, "Unsupported head size: ", head_size); \
  926. break; \
  927. }
  928. void paged_attention(
  929. torch::Tensor& out, // [num_seqs, num_heads, head_size]
  930. torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
  931. torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
  932. torch::Tensor&
  933. tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
  934. torch::Tensor& query, // [num_seqs, num_heads, head_size]
  935. torch::Tensor&
  936. key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
  937. torch::Tensor&
  938. value_cache, // [num_blocks, num_heads, head_size, block_size]
  939. int64_t num_kv_heads, double scale,
  940. torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
  941. torch::Tensor& context_lens, // [num_seqs]
  942. int64_t block_size, int64_t max_context_len,
  943. const c10::optional<torch::Tensor>& alibi_slopes,
  944. const std::string& kv_cache_dtype) {
  945. assert(kv_cache_dtype == "auto");
  946. const int head_size = query.size(2);
  947. if (query.dtype() == at::ScalarType::Half) {
  948. CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16);
  949. } else if (query.dtype() == at::ScalarType::BFloat16) {
  950. CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16);
  951. } else {
  952. TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
  953. }
  954. }
  955. #undef WARP_SIZE
  956. #undef MAX
  957. #undef MIN
  958. #undef DIVIDE_ROUND_UP