|
@@ -26,6 +26,7 @@
|
|
|
|
|
|
#include "attention_dtypes.h"
|
|
|
#include "attention_utils.cuh"
|
|
|
+#include "../quantization/int8_kvcache/quant_utils.cuh"
|
|
|
#ifdef ENABLE_FP8_E5M2
|
|
|
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
|
|
#endif
|
|
@@ -41,6 +42,13 @@
|
|
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
|
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
|
|
|
|
|
+enum kv_cache_dtype {
|
|
|
+ AUTO,
|
|
|
+#ifdef ENABLE_FP8_E5M2
|
|
|
+ FP8_E5M2,
|
|
|
+#endif
|
|
|
+ INT8};
|
|
|
+
|
|
|
namespace aphrodite {
|
|
|
|
|
|
// Utility function for attention softmax.
|
|
@@ -87,7 +95,7 @@ template<
|
|
|
int HEAD_SIZE,
|
|
|
int BLOCK_SIZE,
|
|
|
int NUM_THREADS,
|
|
|
- bool IS_FP8_E5M2_KV_CACHE,
|
|
|
+ kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
|
|
__device__ void paged_attention_kernel(
|
|
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
@@ -104,7 +112,11 @@ __device__ void paged_attention_kernel(
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
const int q_stride,
|
|
|
const int kv_block_stride,
|
|
|
- const int kv_head_stride) {
|
|
|
+ const int kv_head_stride,
|
|
|
+ const float k_scale = 1.0f,
|
|
|
+ const float k_zp = 0.0f,
|
|
|
+ const float v_scale = 1.0f,
|
|
|
+ const float v_zp = 0.0f) {
|
|
|
const int seq_idx = blockIdx.y;
|
|
|
const int partition_idx = blockIdx.z;
|
|
|
const int max_num_partitions = gridDim.z;
|
|
@@ -151,9 +163,7 @@ __device__ void paged_attention_kernel(
|
|
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
|
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
|
-#ifdef ENABLE_FP8_E5M2
|
|
|
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
|
|
|
-#endif
|
|
|
|
|
|
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
|
|
|
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
|
|
@@ -217,13 +227,16 @@ __device__ void paged_attention_kernel(
|
|
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
|
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
|
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
|
|
- if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
|
|
+ if constexpr (KV_CACHE_DTYPE == INT8) {
|
|
|
+ Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
|
+ using Dequant_vec = typename FloatVec<Quant_vec>::Type;
|
|
|
+ Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
|
|
|
+ k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
|
|
|
#ifdef ENABLE_FP8_E5M2
|
|
|
+ } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
|
|
|
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
|
// Vector conversion from Quant_vec to K_vec.
|
|
|
k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
|
|
-#else
|
|
|
- assert(false);
|
|
|
#endif
|
|
|
} else {
|
|
|
k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
@@ -301,9 +314,7 @@ __device__ void paged_attention_kernel(
|
|
|
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
|
|
|
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
|
|
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
|
|
|
-#ifdef ENABLE_FP8_E5M2
|
|
|
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
|
|
|
-#endif
|
|
|
using Float_L_vec = typename FloatVec<L_vec>::Type;
|
|
|
|
|
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
|
@@ -337,13 +348,17 @@ __device__ void paged_attention_kernel(
|
|
|
if (row_idx < HEAD_SIZE) {
|
|
|
const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
|
|
|
V_vec v_vec;
|
|
|
- if constexpr (IS_FP8_E5M2_KV_CACHE) {
|
|
|
+ if constexpr (KV_CACHE_DTYPE == INT8) {
|
|
|
+ // dequant and conversion
|
|
|
+ V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
|
|
+ using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
|
|
|
+ V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
|
|
|
+ v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
|
|
|
#ifdef ENABLE_FP8_E5M2
|
|
|
+ } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
|
|
|
V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
|
|
// Vector conversion from V_quant_vec to V_vec.
|
|
|
v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
|
|
-#else
|
|
|
- assert(false);
|
|
|
#endif
|
|
|
} else {
|
|
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
|
@@ -431,7 +446,7 @@ template<
|
|
|
int HEAD_SIZE,
|
|
|
int BLOCK_SIZE,
|
|
|
int NUM_THREADS,
|
|
|
- bool IS_FP8_E5M2_KV_CACHE>
|
|
|
+ kv_cache_dtype KV_CACHE_DTYPE>
|
|
|
__global__ void paged_attention_v1_kernel(
|
|
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
@@ -445,11 +460,15 @@ __global__ void paged_attention_v1_kernel(
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
const int q_stride,
|
|
|
const int kv_block_stride,
|
|
|
- const int kv_head_stride) {
|
|
|
- paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
|
|
|
+ const int kv_head_stride,
|
|
|
+ const float k_scale,
|
|
|
+ const float k_zp,
|
|
|
+ const float v_scale,
|
|
|
+ const float v_zp) {
|
|
|
+ paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
|
|
|
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
|
|
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
|
|
- max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
|
|
+ max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
|
|
|
}
|
|
|
|
|
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
|
@@ -459,7 +478,7 @@ template<
|
|
|
int HEAD_SIZE,
|
|
|
int BLOCK_SIZE,
|
|
|
int NUM_THREADS,
|
|
|
- bool IS_FP8_E5M2_KV_CACHE,
|
|
|
+ kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
int PARTITION_SIZE>
|
|
|
__global__ void paged_attention_v2_kernel(
|
|
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
@@ -476,11 +495,15 @@ __global__ void paged_attention_v2_kernel(
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
const int q_stride,
|
|
|
const int kv_block_stride,
|
|
|
- const int kv_head_stride) {
|
|
|
- paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
|
|
|
+ const int kv_head_stride,
|
|
|
+ const float k_scale,
|
|
|
+ const float k_zp,
|
|
|
+ const float v_scale,
|
|
|
+ const float v_zp) {
|
|
|
+ paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
|
|
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
|
|
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
|
|
- q_stride, kv_block_stride, kv_head_stride);
|
|
|
+ q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
|
|
|
}
|
|
|
|
|
|
// Grid: (num_heads, num_seqs).
|
|
@@ -584,32 +607,36 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|
|
|
|
|
} // namespace aphrodite
|
|
|
|
|
|
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
|
|
- APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
|
|
- ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
|
|
- IS_FP8_E5M2_KV_CACHE>), shared_mem_size); \
|
|
|
- aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
|
|
- IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>( \
|
|
|
- out_ptr, \
|
|
|
- query_ptr, \
|
|
|
- key_cache_ptr, \
|
|
|
- value_cache_ptr, \
|
|
|
- num_kv_heads, \
|
|
|
- scale, \
|
|
|
- block_tables_ptr, \
|
|
|
- context_lens_ptr, \
|
|
|
- max_num_blocks_per_seq, \
|
|
|
- alibi_slopes_ptr, \
|
|
|
- q_stride, \
|
|
|
- kv_block_stride, \
|
|
|
- kv_head_stride);
|
|
|
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
|
|
+ APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
|
|
+ ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
|
|
+ KV_CACHE_DTYPE>), shared_mem_size); \
|
|
|
+ aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
|
|
+ KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>( \
|
|
|
+ out_ptr, \
|
|
|
+ query_ptr, \
|
|
|
+ key_cache_ptr, \
|
|
|
+ value_cache_ptr, \
|
|
|
+ num_kv_heads, \
|
|
|
+ scale, \
|
|
|
+ block_tables_ptr, \
|
|
|
+ context_lens_ptr, \
|
|
|
+ max_num_blocks_per_seq, \
|
|
|
+ alibi_slopes_ptr, \
|
|
|
+ q_stride, \
|
|
|
+ kv_block_stride, \
|
|
|
+ kv_head_stride, \
|
|
|
+ k_scale, \
|
|
|
+ k_zp, \
|
|
|
+ v_scale, \
|
|
|
+ v_zp);
|
|
|
|
|
|
// TODO: Tune NUM_THREADS.
|
|
|
template<
|
|
|
typename T,
|
|
|
typename CACHE_T,
|
|
|
int BLOCK_SIZE,
|
|
|
- bool IS_FP8_E5M2_KV_CACHE,
|
|
|
+ kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
int NUM_THREADS = 128>
|
|
|
void paged_attention_v1_launcher(
|
|
|
torch::Tensor& out,
|
|
@@ -621,7 +648,11 @@ void paged_attention_v1_launcher(
|
|
|
torch::Tensor& block_tables,
|
|
|
torch::Tensor& context_lens,
|
|
|
int max_context_len,
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes) {
|
|
|
+ const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
+ const float k_scale,
|
|
|
+ const float k_zp,
|
|
|
+ const float v_scale,
|
|
|
+ const float v_zp) {
|
|
|
int num_seqs = query.size(0);
|
|
|
int num_heads = query.size(1);
|
|
|
int head_size = query.size(2);
|
|
@@ -685,8 +716,8 @@ void paged_attention_v1_launcher(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
|
|
- paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
|
|
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
|
|
|
+ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
|
|
|
out, \
|
|
|
query, \
|
|
|
key_cache, \
|
|
@@ -696,20 +727,24 @@ void paged_attention_v1_launcher(
|
|
|
block_tables, \
|
|
|
context_lens, \
|
|
|
max_context_len, \
|
|
|
- alibi_slopes);
|
|
|
+ alibi_slopes, \
|
|
|
+ k_scale, \
|
|
|
+ k_zp, \
|
|
|
+ v_scale, \
|
|
|
+ v_zp);
|
|
|
|
|
|
// NOTE: To reduce the compilation time, we omitted block sizes
|
|
|
// 1, 2, 4, 64, 128, 256.
|
|
|
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
|
|
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
|
|
|
switch (block_size) { \
|
|
|
case 8: \
|
|
|
- CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
|
|
+ CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
|
|
|
break; \
|
|
|
case 16: \
|
|
|
- CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
|
|
+ CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
|
|
|
break; \
|
|
|
case 32: \
|
|
|
- CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
|
|
+ CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
|
|
|
break; \
|
|
|
default: \
|
|
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
|
@@ -728,24 +763,40 @@ void paged_attention_v1(
|
|
|
int block_size,
|
|
|
int max_context_len,
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
- const std::string& kv_cache_dtype) {
|
|
|
+ const std::string& kv_cache_dtype,
|
|
|
+ const float k_scale = 1.0f,
|
|
|
+ const float k_zp = 0.0f,
|
|
|
+ const float v_scale = 1.0f,
|
|
|
+ const float v_zp = 0.0f) {
|
|
|
if (kv_cache_dtype == "auto") {
|
|
|
if (query.dtype() == at::ScalarType::Float) {
|
|
|
- CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
|
|
|
} else if (query.dtype() == at::ScalarType::Half) {
|
|
|
- CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
|
|
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
|
- CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
|
|
|
} else {
|
|
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
|
}
|
|
|
+#ifdef ENABLE_FP8_E5M2
|
|
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
|
|
if (query.dtype() == at::ScalarType::Float) {
|
|
|
- CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
|
|
|
+ } else if (query.dtype() == at::ScalarType::Half) {
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
|
|
|
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
|
|
|
+ } else {
|
|
|
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
|
+ }
|
|
|
+#endif
|
|
|
+ } else if (kv_cache_dtype == "int8") {
|
|
|
+ if (query.dtype() == at::ScalarType::Float) {
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
|
|
|
} else if (query.dtype() == at::ScalarType::Half) {
|
|
|
- CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
|
|
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
|
- CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
|
|
|
} else {
|
|
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
|
}
|
|
@@ -755,8 +806,8 @@ void paged_attention_v1(
|
|
|
}
|
|
|
|
|
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
|
|
- aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
|
|
- IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE> \
|
|
|
+ aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
|
|
+ KV_CACHE_DTYPE, PARTITION_SIZE> \
|
|
|
<<<grid, block, shared_mem_size, stream>>>( \
|
|
|
exp_sums_ptr, \
|
|
|
max_logits_ptr, \
|
|
@@ -772,7 +823,11 @@ void paged_attention_v1(
|
|
|
alibi_slopes_ptr, \
|
|
|
q_stride, \
|
|
|
kv_block_stride, \
|
|
|
- kv_head_stride); \
|
|
|
+ kv_head_stride, \
|
|
|
+ k_scale, \
|
|
|
+ k_zp, \
|
|
|
+ v_scale, \
|
|
|
+ v_zp); \
|
|
|
aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
|
|
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
|
|
out_ptr, \
|
|
@@ -786,7 +841,7 @@ template<
|
|
|
typename T,
|
|
|
typename CACHE_T,
|
|
|
int BLOCK_SIZE,
|
|
|
- bool IS_FP8_E5M2_KV_CACHE,
|
|
|
+ kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
int NUM_THREADS = 128,
|
|
|
int PARTITION_SIZE = 512>
|
|
|
void paged_attention_v2_launcher(
|
|
@@ -802,7 +857,11 @@ void paged_attention_v2_launcher(
|
|
|
torch::Tensor& block_tables,
|
|
|
torch::Tensor& context_lens,
|
|
|
int max_context_len,
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes) {
|
|
|
+ const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
+ const float k_scale,
|
|
|
+ const float k_zp,
|
|
|
+ const float v_scale,
|
|
|
+ const float v_zp) {
|
|
|
int num_seqs = query.size(0);
|
|
|
int num_heads = query.size(1);
|
|
|
int head_size = query.size(2);
|
|
@@ -872,8 +931,8 @@ void paged_attention_v2_launcher(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE) \
|
|
|
- paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
|
|
|
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
|
|
|
+ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
|
|
|
out, \
|
|
|
exp_sums, \
|
|
|
max_logits, \
|
|
@@ -886,20 +945,24 @@ void paged_attention_v2_launcher(
|
|
|
block_tables, \
|
|
|
context_lens, \
|
|
|
max_context_len, \
|
|
|
- alibi_slopes);
|
|
|
+ alibi_slopes, \
|
|
|
+ k_scale, \
|
|
|
+ k_zp, \
|
|
|
+ v_scale, \
|
|
|
+ v_zp);
|
|
|
|
|
|
// NOTE: To reduce the compilation time, we omitted block sizes
|
|
|
// 1, 2, 4, 64, 128, 256.
|
|
|
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
|
|
|
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
|
|
|
switch (block_size) { \
|
|
|
case 8: \
|
|
|
- CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE); \
|
|
|
+ CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
|
|
|
break; \
|
|
|
case 16: \
|
|
|
- CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE); \
|
|
|
+ CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
|
|
|
break; \
|
|
|
case 32: \
|
|
|
- CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE); \
|
|
|
+ CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
|
|
|
break; \
|
|
|
default: \
|
|
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
|
@@ -921,24 +984,40 @@ void paged_attention_v2(
|
|
|
int block_size,
|
|
|
int max_context_len,
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
- const std::string& kv_cache_dtype) {
|
|
|
+ const std::string& kv_cache_dtype,
|
|
|
+ const float k_scale = 1.0f,
|
|
|
+ const float k_zp = 0.0f,
|
|
|
+ const float v_scale = 1.0f,
|
|
|
+ const float v_zp = 0.0f) {
|
|
|
if (kv_cache_dtype == "auto") {
|
|
|
if (query.dtype() == at::ScalarType::Float) {
|
|
|
- CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
|
|
|
+ CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
|
|
|
} else if (query.dtype() == at::ScalarType::Half) {
|
|
|
- CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
|
|
|
+ CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
|
|
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
|
- CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
|
|
|
+ CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
|
|
|
} else {
|
|
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
|
}
|
|
|
+#ifdef ENABLE_FP8_E5M2
|
|
|
} else if (kv_cache_dtype == "fp8_e5m2") {
|
|
|
if (query.dtype() == at::ScalarType::Float) {
|
|
|
- CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
|
|
|
+ CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
|
|
|
+ } else if (query.dtype() == at::ScalarType::Half) {
|
|
|
+ CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
|
|
|
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
|
+ CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
|
|
|
+ } else {
|
|
|
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
|
+ }
|
|
|
+#endif
|
|
|
+ } else if (kv_cache_dtype == "int8") {
|
|
|
+ if (query.dtype() == at::ScalarType::Float) {
|
|
|
+ CALL_V2_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
|
|
|
} else if (query.dtype() == at::ScalarType::Half) {
|
|
|
- CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
|
|
|
+ CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
|
|
|
} else if (query.dtype() == at::ScalarType::BFloat16) {
|
|
|
- CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
|
|
|
+ CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
|
|
|
} else {
|
|
|
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
|
|
|
}
|