|
@@ -86,6 +86,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
|
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
|
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
|
|
int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
|
|
|
+ bool IS_BLOCK_SPARSE,
|
|
|
int PARTITION_SIZE = 0> // Zero means no partitioning.
|
|
|
__device__ void paged_attention_kernel(
|
|
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
@@ -105,7 +106,9 @@ __device__ void paged_attention_kernel(
|
|
|
const int max_num_blocks_per_seq,
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
|
- const float kv_scale) {
|
|
|
+ const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
|
|
+ const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
+ const int blocksparse_head_sliding_step) {
|
|
|
const int seq_idx = blockIdx.y;
|
|
|
const int partition_idx = blockIdx.z;
|
|
|
const int max_num_partitions = gridDim.z;
|
|
@@ -172,8 +175,8 @@ __device__ void paged_attention_kernel(
|
|
|
// Each thread in a thread group has a different part of the query.
|
|
|
// For example, if the the thread group size is 4, then the first thread in
|
|
|
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
|
|
- // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because q is
|
|
|
- // split from a qkv tensor, it may not be contiguous.
|
|
|
+ // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
|
|
|
+ // q is split from a qkv tensor, it may not be contiguous.
|
|
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
|
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
|
|
#pragma unroll
|
|
@@ -183,8 +186,8 @@ __device__ void paged_attention_kernel(
|
|
|
q_vecs[thread_group_offset][i] =
|
|
|
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
|
|
}
|
|
|
- __syncthreads(); // TODO: possible speedup if this is replaced with a memory
|
|
|
- // wall right before we use q_vecs
|
|
|
+ __syncthreads(); // TODO: possible speedup if this is replaced with a
|
|
|
+ // memory wall right before we use q_vecs
|
|
|
|
|
|
// Memory planning.
|
|
|
extern __shared__ char shared_mem[];
|
|
@@ -203,11 +206,55 @@ __device__ void paged_attention_kernel(
|
|
|
// Each thread group in a warp fetches a key from the block, and computes
|
|
|
// dot product with the query.
|
|
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
|
|
+
|
|
|
+ // blocksparse specific vars
|
|
|
+ int bs_block_offset;
|
|
|
+ int q_bs_block_id;
|
|
|
+ if constexpr (IS_BLOCK_SPARSE) {
|
|
|
+ // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
|
|
|
+ // blocksparse_block_size);
|
|
|
+ q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
|
|
|
+ if (blocksparse_head_sliding_step >= 0)
|
|
|
+ // sliding on q heads
|
|
|
+ bs_block_offset =
|
|
|
+ (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
|
|
|
+ else
|
|
|
+ // sliding on kv heads
|
|
|
+ bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
|
|
|
+ (-blocksparse_head_sliding_step) +
|
|
|
+ 1;
|
|
|
+ }
|
|
|
+
|
|
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
|
|
block_idx += NUM_WARPS) {
|
|
|
- // NOTE: The block number is stored in int32. However, we cast it to int64
|
|
|
- // because int32 can lead to overflow when this variable is multiplied by
|
|
|
- // large numbers (e.g., kv_block_stride).
|
|
|
+ // NOTE: The block number is stored in int32. However, we cast it to
|
|
|
+ // int64 because int32 can lead to overflow when this variable is multiplied
|
|
|
+ // by large numbers (e.g., kv_block_stride).
|
|
|
+ // For blocksparse attention: skip computation on blocks that are not
|
|
|
+ // attended
|
|
|
+ if constexpr (IS_BLOCK_SPARSE) {
|
|
|
+ const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
|
|
+ const bool is_remote =
|
|
|
+ ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
|
|
|
+ const bool is_local =
|
|
|
+ (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
|
|
|
+ if (!is_remote && !is_local) {
|
|
|
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
|
|
+ const int physical_block_offset =
|
|
|
+ (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
|
|
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
|
+
|
|
|
+ if (thread_group_offset == 0) {
|
|
|
+ // NOTE: assign very large number to skipped tokens to
|
|
|
+ // avoid contribution to the sumexp softmax normalizer. This will
|
|
|
+ // not be used at computing sum(softmax*v) as the blocks will be
|
|
|
+ // skipped.
|
|
|
+ logits[token_idx - start_token_idx] = -FLT_MAX;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
const int64_t physical_block_number =
|
|
|
static_cast<int64_t>(block_table[block_idx]);
|
|
|
|
|
@@ -333,9 +380,18 @@ __device__ void paged_attention_kernel(
|
|
|
zero(zero_value);
|
|
|
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
|
|
block_idx += NUM_WARPS) {
|
|
|
- // NOTE: The block number is stored in int32. However, we cast it to int64
|
|
|
- // because int32 can lead to overflow when this variable is multiplied by
|
|
|
- // large numbers (e.g., kv_block_stride).
|
|
|
+ // NOTE: The block number is stored in int32. However, we cast it to
|
|
|
+ // int64 because int32 can lead to overflow when this variable is multiplied
|
|
|
+ // by large numbers (e.g., kv_block_stride).
|
|
|
+ // For blocksparse attention: skip computation on blocks that are not
|
|
|
+ // attended
|
|
|
+ if constexpr (IS_BLOCK_SPARSE) {
|
|
|
+ int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
|
|
|
+ if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
|
|
|
+ !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
const int64_t physical_block_number =
|
|
|
static_cast<int64_t>(block_table[block_idx]);
|
|
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
|
@@ -363,9 +419,9 @@ __device__ void paged_attention_kernel(
|
|
|
kv_scale);
|
|
|
}
|
|
|
if (block_idx == num_seq_blocks - 1) {
|
|
|
- // NOTE: When v_vec contains the tokens that are out of the context,
|
|
|
- // we should explicitly zero out the values since they may contain
|
|
|
- // NaNs.
|
|
|
+ // NOTE: When v_vec contains the tokens that are out of the
|
|
|
+ // context, we should explicitly zero out the values since they may
|
|
|
+ // contain NaNs.
|
|
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < V_VEC_SIZE; j++) {
|
|
@@ -388,8 +444,8 @@ __device__ void paged_attention_kernel(
|
|
|
accs[i] = acc;
|
|
|
}
|
|
|
|
|
|
- // NOTE: A barrier is required because the shared memory space for logits
|
|
|
- // is reused for the output.
|
|
|
+ // NOTE: A barrier is required because the shared memory space for
|
|
|
+ // logits is reused for the output.
|
|
|
__syncthreads();
|
|
|
|
|
|
// Perform reduction across warps.
|
|
@@ -441,8 +497,8 @@ __device__ void paged_attention_kernel(
|
|
|
|
|
|
// Grid: (num_heads, num_seqs, 1).
|
|
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
|
|
- int NUM_THREADS,
|
|
|
- aphrodite::Fp8KVCacheDataType KV_DTYPE>
|
|
|
+ int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
|
|
|
+ bool IS_BLOCK_SPARSE>
|
|
|
__global__ void paged_attention_v1_kernel(
|
|
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
@@ -457,18 +513,23 @@ __global__ void paged_attention_v1_kernel(
|
|
|
const int max_num_blocks_per_seq,
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
|
- const float kv_scale) {
|
|
|
+ const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
|
|
+ const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
+ const int blocksparse_head_sliding_step) {
|
|
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
|
|
- KV_DTYPE>(
|
|
|
+ KV_DTYPE, IS_BLOCK_SPARSE>(
|
|
|
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
|
|
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
|
|
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
|
|
- kv_head_stride, kv_scale);
|
|
|
+ kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
|
|
|
+ blocksparse_vert_stride, blocksparse_block_size,
|
|
|
+ blocksparse_head_sliding_step);
|
|
|
}
|
|
|
|
|
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
|
|
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
|
|
int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
|
|
|
+ bool IS_BLOCK_SPARSE,
|
|
|
int PARTITION_SIZE>
|
|
|
__global__ void paged_attention_v2_kernel(
|
|
|
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
@@ -488,12 +549,16 @@ __global__ void paged_attention_v2_kernel(
|
|
|
const int max_num_blocks_per_seq,
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
|
- const float kv_scale) {
|
|
|
+ const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
|
|
+ const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
+ const int blocksparse_head_sliding_step) {
|
|
|
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
|
|
- KV_DTYPE, PARTITION_SIZE>(
|
|
|
+ KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
|
|
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
|
|
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
|
|
- kv_block_stride, kv_head_stride, kv_scale);
|
|
|
+ kv_block_stride, kv_head_stride, kv_scale, tp_rank,
|
|
|
+ blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
|
|
+ blocksparse_head_sliding_step);
|
|
|
}
|
|
|
|
|
|
// Grid: (num_heads, num_seqs).
|
|
@@ -605,27 +670,34 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|
|
|
|
|
} // namespace aphrodite
|
|
|
|
|
|
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
|
|
- APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
|
|
- ((void*)aphrodite::paged_attention_v1_kernel< \
|
|
|
- T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>), \
|
|
|
- shared_mem_size); \
|
|
|
- aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
|
|
- NUM_THREADS, KV_DTYPE> \
|
|
|
- <<<grid, block, shared_mem_size, stream>>>( \
|
|
|
- out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
|
|
- scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
|
|
- alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
|
|
- kv_scale);
|
|
|
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
|
|
+ APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
|
|
+ ((void*)aphrodite::paged_attention_v1_kernel< \
|
|
|
+ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \
|
|
|
+ IS_BLOCK_SPARSE>), \
|
|
|
+ shared_mem_size); \
|
|
|
+ aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
|
|
+ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
|
|
|
+ <<<grid, block, shared_mem_size, stream>>>( \
|
|
|
+ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
|
|
+ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
|
|
+ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
|
|
+ kv_scale, tp_rank, blocksparse_local_blocks, \
|
|
|
+ blocksparse_vert_stride, blocksparse_block_size, \
|
|
|
+ blocksparse_head_sliding_step);
|
|
|
|
|
|
// TODO: Tune NUM_THREADS.
|
|
|
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
|
|
- aphrodite::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
|
|
|
+ aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
|
|
|
+ int NUM_THREADS = 128>
|
|
|
void paged_attention_v1_launcher(
|
|
|
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
|
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
|
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
|
|
+ const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
|
|
+ const int tp_rank, const int blocksparse_local_blocks,
|
|
|
+ const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
+ const int blocksparse_head_sliding_step) {
|
|
|
int num_seqs = query.size(0);
|
|
|
int num_heads = query.size(1);
|
|
|
int head_size = query.size(2);
|
|
@@ -692,23 +764,36 @@ void paged_attention_v1_launcher(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
|
|
- paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
|
|
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
|
|
|
+ paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
|
|
+ IS_BLOCK_SPARSE>( \
|
|
|
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
|
|
- seq_lens, max_seq_len, alibi_slopes, kv_scale);
|
|
|
+ seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
|
|
|
+ blocksparse_local_blocks, blocksparse_vert_stride, \
|
|
|
+ blocksparse_block_size, blocksparse_head_sliding_step);
|
|
|
+
|
|
|
+#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
|
|
+ switch (is_block_sparse) { \
|
|
|
+ case true: \
|
|
|
+ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
|
|
|
+ break; \
|
|
|
+ case false: \
|
|
|
+ CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
|
|
|
+ break; \
|
|
|
+ }
|
|
|
|
|
|
// NOTE: To reduce the compilation time, we omitted block sizes
|
|
|
// 1, 2, 4, 64, 128, 256.
|
|
|
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
|
|
switch (block_size) { \
|
|
|
case 8: \
|
|
|
- CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
|
|
|
+ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
|
|
|
break; \
|
|
|
case 16: \
|
|
|
- CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
|
|
|
+ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
|
|
|
break; \
|
|
|
case 32: \
|
|
|
- CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
|
|
|
+ CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
|
|
|
break; \
|
|
|
default: \
|
|
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
|
@@ -728,18 +813,26 @@ void paged_attention_v1(
|
|
|
torch::Tensor& seq_lens, // [num_seqs]
|
|
|
int block_size, int max_seq_len,
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
- const std::string& kv_cache_dtype, float kv_scale){
|
|
|
+ const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
|
|
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
|
|
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
|
|
+ const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
|
|
+
|
|
|
+ DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
|
|
+ CALL_V1_LAUNCHER_BLOCK_SIZE)
|
|
|
+}
|
|
|
|
|
|
- DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
|
|
- CALL_V1_LAUNCHER_BLOCK_SIZE)}
|
|
|
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
|
|
aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
|
|
- NUM_THREADS, KV_DTYPE, PARTITION_SIZE> \
|
|
|
+ NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
|
|
|
+ PARTITION_SIZE> \
|
|
|
<<<grid, block, shared_mem_size, stream>>>( \
|
|
|
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
|
|
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
|
|
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
|
|
- kv_block_stride, kv_head_stride, kv_scale); \
|
|
|
+ kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
|
|
|
+ blocksparse_local_blocks, blocksparse_vert_stride, \
|
|
|
+ blocksparse_block_size, blocksparse_head_sliding_step); \
|
|
|
aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
|
|
PARTITION_SIZE> \
|
|
|
<<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
|
@@ -747,14 +840,17 @@ void paged_attention_v1(
|
|
|
max_num_partitions);
|
|
|
|
|
|
template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
|
|
- aphrodite::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
|
|
|
- int PARTITION_SIZE = 512>
|
|
|
+ aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
|
|
|
+ int NUM_THREADS = 128, int PARTITION_SIZE = 512>
|
|
|
void paged_attention_v2_launcher(
|
|
|
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
|
|
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
|
|
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
|
|
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
|
|
|
+ const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
|
|
+ const int tp_rank, const int blocksparse_local_blocks,
|
|
|
+ const int blocksparse_vert_stride, const int blocksparse_block_size,
|
|
|
+ const int blocksparse_head_sliding_step) {
|
|
|
int num_seqs = query.size(0);
|
|
|
int num_heads = query.size(1);
|
|
|
int head_size = query.size(2);
|
|
@@ -825,24 +921,36 @@ void paged_attention_v2_launcher(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE) \
|
|
|
- paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>( \
|
|
|
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \
|
|
|
+ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
|
|
+ IS_BLOCK_SPARSE>( \
|
|
|
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
|
|
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
|
|
- kv_scale);
|
|
|
+ kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
|
|
|
+ blocksparse_block_size, blocksparse_head_sliding_step);
|
|
|
+
|
|
|
+#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
|
|
+ switch (is_block_sparse) { \
|
|
|
+ case true: \
|
|
|
+ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \
|
|
|
+ break; \
|
|
|
+ case false: \
|
|
|
+ CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \
|
|
|
+ break; \
|
|
|
+ }
|
|
|
|
|
|
// NOTE: To reduce the compilation time, we omitted block sizes
|
|
|
// 1, 2, 4, 64, 128, 256.
|
|
|
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \
|
|
|
switch (block_size) { \
|
|
|
case 8: \
|
|
|
- CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE); \
|
|
|
+ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \
|
|
|
break; \
|
|
|
case 16: \
|
|
|
- CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE); \
|
|
|
+ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \
|
|
|
break; \
|
|
|
case 32: \
|
|
|
- CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE); \
|
|
|
+ CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \
|
|
|
break; \
|
|
|
default: \
|
|
|
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
|
@@ -866,7 +974,10 @@ void paged_attention_v2(
|
|
|
torch::Tensor& seq_lens, // [num_seqs]
|
|
|
int block_size, int max_seq_len,
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
- const std::string& kv_cache_dtype, float kv_scale) {
|
|
|
+ const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
|
|
|
+ const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
|
|
+ const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
|
|
+ const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
|
|
DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
|
|
|
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
|
|
}
|