1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003 |
- #include <torch/all.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <c10/cuda/CUDAGuard.h>
- #include <algorithm>
- #include "attention_dtypes.h"
- #include "attention_utils.cuh"
- #ifdef USE_ROCM
- #include <hip/hip_bf16.h>
- #include "../quantization/fp8/amd/quant_utils.cuh"
- typedef __hip_bfloat16 __nv_bfloat16;
- #else
- #include "../quantization/fp8/nvidia/quant_utils.cuh"
- #endif
- #ifndef USE_ROCM
- #define WARP_SIZE 32
- #else
- #define WARP_SIZE warpSize
- #endif
- #define MAX(a, b) ((a) > (b) ? (a) : (b))
- #define MIN(a, b) ((a) < (b) ? (a) : (b))
- #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
- namespace aphrodite {
- template <int NUM_WARPS>
- inline __device__ float block_sum(float* red_smem, float sum) {
-
- int warp = threadIdx.x / WARP_SIZE;
- int lane = threadIdx.x % WARP_SIZE;
-
- #pragma unroll
- for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
- sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
- }
-
- if (lane == 0) {
- red_smem[warp] = sum;
- }
-
- __syncthreads();
-
- if (lane < NUM_WARPS) {
- sum = red_smem[lane];
- }
-
- #pragma unroll
- for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
- sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
- }
-
- return APHRODITE_SHFL_SYNC(sum, 0);
- }
- template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
- int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
- bool IS_BLOCK_SPARSE,
- int PARTITION_SIZE = 0>
- __device__ void paged_attention_kernel(
- float* __restrict__ exp_sums,
- float* __restrict__ max_logits,
-
- scalar_t* __restrict__ out,
-
- const scalar_t* __restrict__ q,
- const cache_t* __restrict__ k_cache,
-
- const cache_t* __restrict__ v_cache,
-
- const int num_kv_heads,
- const float scale,
- const int* __restrict__ block_tables,
- const int* __restrict__ seq_lens,
- const int max_num_blocks_per_seq,
- const float* __restrict__ alibi_slopes,
- const int q_stride, const int kv_block_stride, const int kv_head_stride,
- const float k_scale, const float v_scale, const int tp_rank,
- const int blocksparse_local_blocks, const int blocksparse_vert_stride,
- const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
- const int seq_idx = blockIdx.y;
- const int partition_idx = blockIdx.z;
- const int max_num_partitions = gridDim.z;
- constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
- const int seq_len = seq_lens[seq_idx];
- if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
-
- return;
- }
- const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
- const int num_blocks_per_partition =
- USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
-
- const int start_block_idx =
- USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
- const int end_block_idx =
- MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
- const int num_blocks = end_block_idx - start_block_idx;
-
- const int start_token_idx = start_block_idx * BLOCK_SIZE;
- const int end_token_idx =
- MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
- const int num_tokens = end_token_idx - start_token_idx;
- constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
- constexpr int NUM_THREAD_GROUPS =
- NUM_THREADS / THREAD_GROUP_SIZE;
-
- assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
- constexpr int NUM_TOKENS_PER_THREAD_GROUP =
- DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
- const int thread_idx = threadIdx.x;
- const int warp_idx = thread_idx / WARP_SIZE;
- const int lane = thread_idx % WARP_SIZE;
- const int head_idx = blockIdx.x;
- const int num_heads = gridDim.x;
- const int num_queries_per_kv = num_heads / num_kv_heads;
- const int kv_head_idx = head_idx / num_queries_per_kv;
- const float alibi_slope =
- alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
-
-
-
-
-
- constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
- using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
- using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
- using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
- constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
- constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
- const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
- const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
-
-
-
-
-
-
- const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
- __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
- #pragma unroll
- for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
- i += NUM_THREAD_GROUPS) {
- const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
- q_vecs[thread_group_offset][i] =
- *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
- }
- __syncthreads();
-
-
- extern __shared__ char shared_mem[];
-
- float* logits = reinterpret_cast<float*>(shared_mem);
-
- __shared__ float red_smem[2 * NUM_WARPS];
-
-
- constexpr int x = 16 / sizeof(cache_t);
- float qk_max = -FLT_MAX;
-
-
-
-
- const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
-
- [[maybe_unused]] int bs_block_offset;
- [[maybe_unused]] int q_bs_block_id;
- if constexpr (IS_BLOCK_SPARSE) {
-
-
- q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
- if (blocksparse_head_sliding_step >= 0)
-
- bs_block_offset =
- (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
- else
-
- bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
- (-blocksparse_head_sliding_step) +
- 1;
- }
- for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
- block_idx += NUM_WARPS) {
-
-
-
-
-
- if constexpr (IS_BLOCK_SPARSE) {
- const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
- const bool is_remote =
- ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
- const bool is_local =
- (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
- if (!is_remote && !is_local) {
- for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
- const int physical_block_offset =
- (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
- const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
- if (thread_group_offset == 0) {
-
-
-
-
- logits[token_idx - start_token_idx] = -FLT_MAX;
- }
- }
- continue;
- }
- }
- const int64_t physical_block_number =
- static_cast<int64_t>(block_table[block_idx]);
-
-
-
-
-
- for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
- const int physical_block_offset =
- (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
- const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
- K_vec k_vecs[NUM_VECS_PER_THREAD];
- #pragma unroll
- for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
- const cache_t* k_ptr =
- k_cache + physical_block_number * kv_block_stride +
- kv_head_idx * kv_head_stride + physical_block_offset * x;
- const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
- const int offset1 = (vec_idx * VEC_SIZE) / x;
- const int offset2 = (vec_idx * VEC_SIZE) % x;
- if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
- k_vecs[j] = *reinterpret_cast<const K_vec*>(
- k_ptr + offset1 * BLOCK_SIZE * x + offset2);
- } else {
-
- Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
- k_ptr + offset1 * BLOCK_SIZE * x + offset2);
- k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
- k_vec_quant, k_scale);
- }
- }
-
-
- float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
- q_vecs[thread_group_offset], k_vecs);
-
- qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
- if (thread_group_offset == 0) {
-
-
- const bool mask = token_idx >= seq_len;
- logits[token_idx - start_token_idx] = mask ? 0.f : qk;
-
- qk_max = mask ? qk_max : fmaxf(qk_max, qk);
- }
- }
- }
-
-
-
- #pragma unroll
- for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
- qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
- }
- if (lane == 0) {
- red_smem[warp_idx] = qk_max;
- }
- __syncthreads();
-
-
- qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
- #pragma unroll
- for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
- qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
- }
-
- qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
-
- float exp_sum = 0.f;
- for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
- float val = __expf(logits[i] - qk_max);
- logits[i] = val;
- exp_sum += val;
- }
- exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
-
- const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
- for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
- logits[i] *= inv_sum;
- }
- __syncthreads();
-
- if (USE_PARTITIONING && thread_idx == 0) {
- float* max_logits_ptr = max_logits +
- seq_idx * num_heads * max_num_partitions +
- head_idx * max_num_partitions + partition_idx;
- *max_logits_ptr = qk_max;
- float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
- head_idx * max_num_partitions + partition_idx;
- *exp_sums_ptr = exp_sum;
- }
-
- constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
- using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
- using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
- using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
- using Float_L_vec = typename FloatVec<L_vec>::Type;
- constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
- constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
- constexpr int NUM_ROWS_PER_THREAD =
- DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
-
- float accs[NUM_ROWS_PER_THREAD];
- #pragma unroll
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
- accs[i] = 0.f;
- }
- scalar_t zero_value;
- zero(zero_value);
- for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
- block_idx += NUM_WARPS) {
-
-
-
-
-
- if constexpr (IS_BLOCK_SPARSE) {
- int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
- if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
- !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
- continue;
- }
- }
- const int64_t physical_block_number =
- static_cast<int64_t>(block_table[block_idx]);
- const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
- const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
- L_vec logits_vec;
- from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
- start_token_idx));
- const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
- kv_head_idx * kv_head_stride;
- #pragma unroll
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
- const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
- if (row_idx < HEAD_SIZE) {
- const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
- V_vec v_vec;
- if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
- v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
- } else {
- V_quant_vec v_quant_vec =
- *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
-
- v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
- v_scale);
- }
- if (block_idx == num_seq_blocks - 1) {
-
-
-
- scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
- #pragma unroll
- for (int j = 0; j < V_VEC_SIZE; j++) {
- v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
- }
- }
- accs[i] += dot(logits_vec, v_vec);
- }
- }
- }
-
- #pragma unroll
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
- float acc = accs[i];
- #pragma unroll
- for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
- acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
- }
- accs[i] = acc;
- }
-
-
- __syncthreads();
-
- float* out_smem = reinterpret_cast<float*>(shared_mem);
- #pragma unroll
- for (int i = NUM_WARPS; i > 1; i /= 2) {
- int mid = i / 2;
-
- if (warp_idx >= mid && warp_idx < i) {
- float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
- #pragma unroll
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
- const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
- if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
- dst[row_idx] = accs[i];
- }
- }
- }
- __syncthreads();
-
- if (warp_idx < mid) {
- const float* src = &out_smem[warp_idx * HEAD_SIZE];
- #pragma unroll
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
- const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
- if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
- accs[i] += src[row_idx];
- }
- }
- }
- __syncthreads();
- }
-
- if (warp_idx == 0) {
- scalar_t* out_ptr =
- out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
- head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
- #pragma unroll
- for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
- const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
- if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
- from_float(*(out_ptr + row_idx), accs[i]);
- }
- }
- }
- }
- template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
- int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
- bool IS_BLOCK_SPARSE>
- __global__ void paged_attention_v1_kernel(
- scalar_t* __restrict__ out,
- const scalar_t* __restrict__ q,
- const cache_t* __restrict__ k_cache,
-
- const cache_t* __restrict__ v_cache,
-
- const int num_kv_heads,
- const float scale,
- const int* __restrict__ block_tables,
- const int* __restrict__ seq_lens,
- const int max_num_blocks_per_seq,
- const float* __restrict__ alibi_slopes,
- const int q_stride, const int kv_block_stride, const int kv_head_stride,
- const float k_scale, const float v_scale, const int tp_rank,
- const int blocksparse_local_blocks, const int blocksparse_vert_stride,
- const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
- paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
- KV_DTYPE, IS_BLOCK_SPARSE>(
- nullptr, nullptr, out, q, k_cache,
- v_cache, num_kv_heads, scale, block_tables, seq_lens,
- max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
- kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
- blocksparse_vert_stride, blocksparse_block_size,
- blocksparse_head_sliding_step);
- }
- template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
- int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
- bool IS_BLOCK_SPARSE,
- int PARTITION_SIZE>
- __global__ void paged_attention_v2_kernel(
- float* __restrict__ exp_sums,
- float* __restrict__ max_logits,
-
- scalar_t* __restrict__ tmp_out,
-
- const scalar_t* __restrict__ q,
- const cache_t* __restrict__ k_cache,
-
- const cache_t* __restrict__ v_cache,
-
- const int num_kv_heads,
- const float scale,
- const int* __restrict__ block_tables,
- const int* __restrict__ seq_lens,
- const int max_num_blocks_per_seq,
- const float* __restrict__ alibi_slopes,
- const int q_stride, const int kv_block_stride, const int kv_head_stride,
- const float k_scale, const float v_scale, const int tp_rank,
- const int blocksparse_local_blocks, const int blocksparse_vert_stride,
- const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
- paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
- KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
- exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
- block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
- kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
- blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
- blocksparse_head_sliding_step);
- }
- template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
- int PARTITION_SIZE>
- __global__ void paged_attention_v2_reduce_kernel(
- scalar_t* __restrict__ out,
- const float* __restrict__ exp_sums,
-
- const float* __restrict__ max_logits,
-
- const scalar_t* __restrict__ tmp_out,
-
- const int* __restrict__ seq_lens,
- const int max_num_partitions) {
- const int num_heads = gridDim.x;
- const int head_idx = blockIdx.x;
- const int seq_idx = blockIdx.y;
- const int seq_len = seq_lens[seq_idx];
- const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
- if (num_partitions == 1) {
-
- scalar_t* out_ptr =
- out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
- const scalar_t* tmp_out_ptr =
- tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
- head_idx * max_num_partitions * HEAD_SIZE;
- for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
- out_ptr[i] = tmp_out_ptr[i];
- }
-
- return;
- }
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
- const int warp_idx = threadIdx.x / WARP_SIZE;
- const int lane = threadIdx.x % WARP_SIZE;
-
- extern __shared__ char shared_mem[];
-
- __shared__ float red_smem[2 * NUM_WARPS];
-
- float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
- const float* max_logits_ptr = max_logits +
- seq_idx * num_heads * max_num_partitions +
- head_idx * max_num_partitions;
- float max_logit = -FLT_MAX;
- for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
- const float l = max_logits_ptr[i];
- shared_max_logits[i] = l;
- max_logit = fmaxf(max_logit, l);
- }
- __syncthreads();
-
-
- #pragma unroll
- for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
- max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
- }
- if (lane == 0) {
- red_smem[warp_idx] = max_logit;
- }
- __syncthreads();
-
- max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
- #pragma unroll
- for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
- max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
- }
-
- max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
-
- float* shared_exp_sums =
- reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
- const float* exp_sums_ptr = exp_sums +
- seq_idx * num_heads * max_num_partitions +
- head_idx * max_num_partitions;
- float global_exp_sum = 0.0f;
- for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
- float l = shared_max_logits[i];
- float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
- global_exp_sum += rescaled_exp_sum;
- shared_exp_sums[i] = rescaled_exp_sum;
- }
- __syncthreads();
- global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
- const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
-
- const scalar_t* tmp_out_ptr =
- tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
- head_idx * max_num_partitions * HEAD_SIZE;
- scalar_t* out_ptr =
- out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
- #pragma unroll
- for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
- float acc = 0.0f;
- for (int j = 0; j < num_partitions; ++j) {
- acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
- inv_global_exp_sum;
- }
- from_float(out_ptr[i], acc);
- }
- }
- }
- #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
- APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
- ((void*)aphrodite::paged_attention_v1_kernel< \
- T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
- IS_BLOCK_SPARSE>), \
- shared_mem_size); \
- aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
- NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
- <<<grid, block, shared_mem_size, stream>>>( \
- out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
- scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
- alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
- k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
- blocksparse_vert_stride, blocksparse_block_size, \
- blocksparse_head_sliding_step);
- template <typename T, typename CACHE_T, int BLOCK_SIZE,
- aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
- int NUM_THREADS = 128>
- void paged_attention_v1_launcher(
- torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
- torch::Tensor& value_cache, int num_kv_heads, float scale,
- torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
- const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
- float v_scale, const int tp_rank, const int blocksparse_local_blocks,
- const int blocksparse_vert_stride, const int blocksparse_block_size,
- const int blocksparse_head_sliding_step) {
- int num_seqs = query.size(0);
- int num_heads = query.size(1);
- int head_size = query.size(2);
- int max_num_blocks_per_seq = block_tables.size(1);
- int q_stride = query.stride(0);
- int kv_block_stride = key_cache.stride(0);
- int kv_head_stride = key_cache.stride(1);
- [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
- assert(head_size % thread_group_size == 0);
-
- const float* alibi_slopes_ptr =
- alibi_slopes
- ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
- : nullptr;
- T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
- T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
- CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
- CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
- int* block_tables_ptr = block_tables.data_ptr<int>();
- int* seq_lens_ptr = seq_lens.data_ptr<int>();
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
- int padded_max_seq_len =
- DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
- int logits_size = padded_max_seq_len * sizeof(float);
- int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
-
-
-
- int shared_mem_size = std::max(logits_size, outputs_size);
- dim3 grid(num_heads, num_seqs, 1);
- dim3 block(NUM_THREADS);
- const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- switch (head_size) {
-
-
-
- case 64:
- LAUNCH_PAGED_ATTENTION_V1(64);
- break;
- case 80:
- LAUNCH_PAGED_ATTENTION_V1(80);
- break;
- case 96:
- LAUNCH_PAGED_ATTENTION_V1(96);
- break;
- case 112:
- LAUNCH_PAGED_ATTENTION_V1(112);
- break;
- case 120:
- LAUNCH_PAGED_ATTENTION_V1(120);
- break;
- case 128:
- LAUNCH_PAGED_ATTENTION_V1(128);
- break;
- case 192:
- LAUNCH_PAGED_ATTENTION_V1(192);
- break;
- case 256:
- LAUNCH_PAGED_ATTENTION_V1(256);
- break;
- default:
- TORCH_CHECK(false, "Unsupported head size: ", head_size);
- break;
- }
- }
- #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
- paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
- IS_BLOCK_SPARSE>( \
- out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
- seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
- blocksparse_local_blocks, blocksparse_vert_stride, \
- blocksparse_block_size, blocksparse_head_sliding_step);
- #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
- switch (is_block_sparse) { \
- case true: \
- CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
- break; \
- case false: \
- CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
- break; \
- }
- #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
- switch (block_size) { \
- case 8: \
- CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
- break; \
- case 16: \
- CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
- break; \
- case 32: \
- CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
- break; \
- default: \
- TORCH_CHECK(false, "Unsupported block size: ", block_size); \
- break; \
- }
- void paged_attention_v1(
- torch::Tensor& out,
- torch::Tensor& query,
- torch::Tensor&
- key_cache,
- torch::Tensor&
- value_cache,
- int64_t num_kv_heads,
- double scale,
- torch::Tensor& block_tables,
- torch::Tensor& seq_lens,
- int64_t block_size, int64_t max_seq_len,
- const c10::optional<torch::Tensor>& alibi_slopes,
- const std::string& kv_cache_dtype, double k_scale, double v_scale,
- const int64_t tp_rank, const int64_t blocksparse_local_blocks,
- const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
- const int64_t blocksparse_head_sliding_step) {
- const bool is_block_sparse = (blocksparse_vert_stride > 1);
- DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
- CALL_V1_LAUNCHER_BLOCK_SIZE)
- }
- #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
- aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
- NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
- PARTITION_SIZE> \
- <<<grid, block, shared_mem_size, stream>>>( \
- exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
- value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
- seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
- kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
- blocksparse_local_blocks, blocksparse_vert_stride, \
- blocksparse_block_size, blocksparse_head_sliding_step); \
- aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
- PARTITION_SIZE> \
- <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
- out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
- max_num_partitions);
- template <typename T, typename CACHE_T, int BLOCK_SIZE,
- aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
- int NUM_THREADS = 128, int PARTITION_SIZE = 512>
- void paged_attention_v2_launcher(
- torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
- torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
- torch::Tensor& value_cache, int num_kv_heads, float scale,
- torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
- const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
- float v_scale, const int tp_rank, const int blocksparse_local_blocks,
- const int blocksparse_vert_stride, const int blocksparse_block_size,
- const int blocksparse_head_sliding_step) {
- int num_seqs = query.size(0);
- int num_heads = query.size(1);
- int head_size = query.size(2);
- int max_num_blocks_per_seq = block_tables.size(1);
- int q_stride = query.stride(0);
- int kv_block_stride = key_cache.stride(0);
- int kv_head_stride = key_cache.stride(1);
- [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
- assert(head_size % thread_group_size == 0);
-
- const float* alibi_slopes_ptr =
- alibi_slopes
- ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
- : nullptr;
- T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
- float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
- float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
- T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
- T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
- CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
- CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
- int* block_tables_ptr = block_tables.data_ptr<int>();
- int* seq_lens_ptr = seq_lens.data_ptr<int>();
- constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
- int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
- int logits_size = PARTITION_SIZE * sizeof(float);
- int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
-
- dim3 grid(num_heads, num_seqs, max_num_partitions);
- int shared_mem_size = std::max(logits_size, outputs_size);
-
- dim3 reduce_grid(num_heads, num_seqs);
- int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
- dim3 block(NUM_THREADS);
- const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- switch (head_size) {
-
-
-
- case 64:
- LAUNCH_PAGED_ATTENTION_V2(64);
- break;
- case 80:
- LAUNCH_PAGED_ATTENTION_V2(80);
- break;
- case 96:
- LAUNCH_PAGED_ATTENTION_V2(96);
- break;
- case 112:
- LAUNCH_PAGED_ATTENTION_V2(112);
- break;
- case 120:
- LAUNCH_PAGED_ATTENTION_V2(120);
- break;
- case 128:
- LAUNCH_PAGED_ATTENTION_V2(128);
- break;
- case 192:
- LAUNCH_PAGED_ATTENTION_V2(192);
- break;
- case 256:
- LAUNCH_PAGED_ATTENTION_V2(256);
- break;
- default:
- TORCH_CHECK(false, "Unsupported head size: ", head_size);
- break;
- }
- }
- #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
- paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
- IS_BLOCK_SPARSE>( \
- out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
- num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
- k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
- blocksparse_vert_stride, blocksparse_block_size, \
- blocksparse_head_sliding_step);
- #define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
- switch (is_block_sparse) { \
- case true: \
- CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
- break; \
- case false: \
- CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
- break; \
- }
- #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
- switch (block_size) { \
- case 8: \
- CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
- break; \
- case 16: \
- CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
- break; \
- case 32: \
- CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
- break; \
- default: \
- TORCH_CHECK(false, "Unsupported block size: ", block_size); \
- break; \
- }
- void paged_attention_v2(
- torch::Tensor& out,
- torch::Tensor& exp_sums,
- torch::Tensor& max_logits,
- torch::Tensor&
- tmp_out,
- torch::Tensor& query,
- torch::Tensor&
- key_cache,
- torch::Tensor&
- value_cache,
- int64_t num_kv_heads,
- double scale,
- torch::Tensor& block_tables,
- torch::Tensor& seq_lens,
- int64_t block_size, int64_t max_seq_len,
- const c10::optional<torch::Tensor>& alibi_slopes,
- const std::string& kv_cache_dtype, double k_scale, double v_scale,
- const int64_t tp_rank, const int64_t blocksparse_local_blocks,
- const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
- const int64_t blocksparse_head_sliding_step) {
- const bool is_block_sparse = (blocksparse_vert_stride > 1);
- DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
- CALL_V2_LAUNCHER_BLOCK_SIZE)
- }
- #undef WARP_SIZE
- #undef MAX
- #undef MIN
- #undef DIVIDE_ROUND_UP
|