|
@@ -106,9 +106,9 @@ __device__ void paged_attention_kernel(
|
|
|
const int max_num_blocks_per_seq,
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
|
- const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
|
|
- const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
- const int blocksparse_head_sliding_step) {
|
|
|
+ const float k_scale, const float v_scale, const int tp_rank,
|
|
|
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
|
|
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
|
|
const int seq_idx = blockIdx.y;
|
|
|
const int partition_idx = blockIdx.z;
|
|
|
const int max_num_partitions = gridDim.z;
|
|
@@ -175,7 +175,7 @@ __device__ void paged_attention_kernel(
|
|
|
// Each thread in a thread group has a different part of the query.
|
|
|
// For example, if the the thread group size is 4, then the first thread in
|
|
|
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
|
|
- // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
|
|
+ // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because
|
|
|
// q is split from a qkv tensor, it may not be contiguous.
|
|
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
|
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
|
@@ -286,7 +286,7 @@ __device__ void paged_attention_kernel(
|
|
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
|
|
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
|
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
|
|
- k_vec_quant, kv_scale);
|
|
|
+ k_vec_quant, k_scale);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -416,7 +416,7 @@ __device__ void paged_attention_kernel(
|
|
|
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
|
|
// Vector conversion from V_quant_vec to V_vec.
|
|
|
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
|
|
- kv_scale);
|
|
|
+ v_scale);
|
|
|
}
|
|
|
if (block_idx == num_seq_blocks - 1) {
|
|
|
// NOTE: When v_vec contains the tokens that are out of the
|
|
@@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel(
|
|
|
const int max_num_blocks_per_seq,
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
|
- const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
|
|
- const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
- const int blocksparse_head_sliding_step) {
|
|
|
+ const float k_scale, const float v_scale, const int tp_rank,
|
|
|
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
|
|
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
|
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
|
|
KV_DTYPE, IS_BLOCK_SPARSE>(
|
|
|
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
|
|
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
|
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
|
|
- kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
|
|
|
+ kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
|
|
blocksparse_vert_stride, blocksparse_block_size,
|
|
|
blocksparse_head_sliding_step);
|
|
|
}
|
|
@@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel(
|
|
|
const int max_num_blocks_per_seq,
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
|
- const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
|
|
- const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
- const int blocksparse_head_sliding_step) {
|
|
|
+ const float k_scale, const float v_scale, const int tp_rank,
|
|
|
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
|
|
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
|
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
|
|
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
|
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
|
|
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
|
|
- kv_block_stride, kv_head_stride, kv_scale, tp_rank,
|
|
|
+ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
|
|
|
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
|
|
blocksparse_head_sliding_step);
|
|
|
}
|
|
@@ -682,11 +682,11 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|
|
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
|
|
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
|
|
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
|
|
- kv_scale, tp_rank, blocksparse_local_blocks, \
|
|
|
+ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
|
|
blocksparse_vert_stride, blocksparse_block_size, \
|
|
|
blocksparse_head_sliding_step);
|
|
|
|
|
|
-// TODO: Tune NUM_THREADS.
|
|
|
+// TODO(woosuk): Tune NUM_THREADS.
|
|
|
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
|
|
aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
|
|
|
int NUM_THREADS = 128>
|
|
@@ -694,8 +694,8 @@ void paged_attention_v1_launcher(
|
|
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
|
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
|
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
|
|
- const int tp_rank, const int blocksparse_local_blocks,
|
|
|
+ const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
|
|
+ float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
|
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
const int blocksparse_head_sliding_step) {
|
|
|
int num_seqs = query.size(0);
|
|
@@ -771,7 +771,7 @@ void paged_attention_v1_launcher(
|
|
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
|
|
IS_BLOCK_SPARSE>( \
|
|
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
|
|
- seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
|
|
|
+ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
|
|
|
blocksparse_local_blocks, blocksparse_vert_stride, \
|
|
|
blocksparse_block_size, blocksparse_head_sliding_step);
|
|
|
|
|
@@ -816,8 +816,8 @@ void paged_attention_v1(
|
|
|
torch::Tensor& seq_lens, // [num_seqs]
|
|
|
int64_t block_size, int64_t max_seq_len,
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
- const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
|
|
- const int64_t blocksparse_local_blocks,
|
|
|
+ const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
|
|
+ const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
|
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
|
|
const int64_t blocksparse_head_sliding_step) {
|
|
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
|
@@ -834,7 +834,7 @@ void paged_attention_v1(
|
|
|
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
|
|
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
|
|
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
|
|
- kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
|
|
|
+ kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
|
|
|
blocksparse_local_blocks, blocksparse_vert_stride, \
|
|
|
blocksparse_block_size, blocksparse_head_sliding_step); \
|
|
|
aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
|
@@ -851,8 +851,8 @@ void paged_attention_v2_launcher(
|
|
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
|
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
|
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
|
|
- const int tp_rank, const int blocksparse_local_blocks,
|
|
|
+ const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
|
|
+ float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
|
|
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
const int blocksparse_head_sliding_step) {
|
|
|
int num_seqs = query.size(0);
|
|
@@ -917,7 +917,7 @@ void paged_attention_v2_launcher(
|
|
|
LAUNCH_PAGED_ATTENTION_V2(128);
|
|
|
break;
|
|
|
case 192:
|
|
|
- LAUNCH_PAGED_ATTENTION_V1(192);
|
|
|
+ LAUNCH_PAGED_ATTENTION_V2(192);
|
|
|
break;
|
|
|
case 256:
|
|
|
LAUNCH_PAGED_ATTENTION_V2(256);
|
|
@@ -933,8 +933,9 @@ void paged_attention_v2_launcher(
|
|
|
IS_BLOCK_SPARSE>( \
|
|
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
|
|
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
|
|
- kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
|
|
|
- blocksparse_block_size, blocksparse_head_sliding_step);
|
|
|
+ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
|
|
+ blocksparse_vert_stride, blocksparse_block_size, \
|
|
|
+ blocksparse_head_sliding_step);
|
|
|
|
|
|
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
|
|
switch (is_block_sparse) { \
|
|
@@ -981,8 +982,8 @@ void paged_attention_v2(
|
|
|
torch::Tensor& seq_lens, // [num_seqs]
|
|
|
int64_t block_size, int64_t max_seq_len,
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
- const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
|
|
- const int64_t blocksparse_local_blocks,
|
|
|
+ const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
|
|
+ const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
|
|
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
|
|
const int64_t blocksparse_head_sliding_step) {
|
|
|
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|