|
@@ -75,14 +75,17 @@ template<
|
|
|
__global__ void single_query_cached_kv_attention_kernel(
|
|
|
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
|
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
|
- const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
- const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
|
+ const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
|
|
+ const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
|
|
+ const int* __restrict__ head_mapping, // [num_heads]
|
|
|
const float scale,
|
|
|
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
const int* __restrict__ context_lens, // [num_seqs]
|
|
|
const int max_num_blocks_per_seq,
|
|
|
const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
- const int q_stride) {
|
|
|
+ const int q_stride,
|
|
|
+ const int kv_block_stride,
|
|
|
+ const int kv_head_stride) {
|
|
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
|
|
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
|
|
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
@@ -92,6 +95,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
|
|
|
|
const int head_idx = blockIdx.x;
|
|
|
const int num_heads = gridDim.x;
|
|
|
+ const int kv_head_idx = head_mapping[head_idx];
|
|
|
const int seq_idx = blockIdx.y;
|
|
|
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
|
|
|
|
@@ -115,7 +119,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
|
// For example, if the the thread group size is 4, then the first thread in the group
|
|
|
// has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
|
|
// th vectors of the query, and so on.
|
|
|
- // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
|
|
|
+ // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
|
|
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
|
|
Q_vec q_vecs[NUM_VECS_PER_THREAD];
|
|
|
#pragma unroll
|
|
@@ -126,7 +130,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
|
|
|
|
// Memory planning.
|
|
|
extern __shared__ char shared_mem[];
|
|
|
- // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
|
|
|
+ // NOTE: We use FP32 for the softmax logits for better accuracy.
|
|
|
float* logits = reinterpret_cast<float*>(shared_mem);
|
|
|
// Workspace for reduction.
|
|
|
__shared__ float red_smem[2 * NUM_WARPS];
|
|
@@ -159,8 +163,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
|
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
|
|
- const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
|
|
- + head_idx * HEAD_SIZE * BLOCK_SIZE
|
|
|
+ const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
|
|
+ + kv_head_idx * kv_head_stride
|
|
|
+ physical_block_offset * x;
|
|
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
|
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
|
@@ -176,7 +180,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
|
|
|
|
if (thread_group_offset == 0) {
|
|
|
// Store the partial reductions to shared memory.
|
|
|
- // NOTE(woosuk): It is required to zero out the masked logits.
|
|
|
+ // NOTE: It is required to zero out the masked logits.
|
|
|
const bool mask = token_idx >= context_len;
|
|
|
logits[token_idx] = mask ? 0.f : qk;
|
|
|
// Update the max value.
|
|
@@ -233,7 +237,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
|
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
|
|
constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
|
|
|
|
|
|
- // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
|
|
|
+ // NOTE: We use FP32 for the accumulator for better accuracy.
|
|
|
float accs[NUM_ROWS_PER_THREAD];
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
@@ -247,8 +251,8 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
|
L_vec logits_vec;
|
|
|
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
|
|
|
|
|
|
- const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
|
|
|
- + head_idx * HEAD_SIZE * BLOCK_SIZE;
|
|
|
+ const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
|
|
+ + kv_head_idx * kv_head_stride;
|
|
|
#pragma unroll
|
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
@@ -271,7 +275,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
|
accs[i] = acc;
|
|
|
}
|
|
|
|
|
|
- // NOTE(woosuk): A barrier is required because the shared memory space for logits
|
|
|
+ // NOTE: A barrier is required because the shared memory space for logits
|
|
|
// is reused for the output.
|
|
|
__syncthreads();
|
|
|
|
|
@@ -329,12 +333,15 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
|
query_ptr, \
|
|
|
key_cache_ptr, \
|
|
|
value_cache_ptr, \
|
|
|
+ head_mapping_ptr, \
|
|
|
scale, \
|
|
|
block_tables_ptr, \
|
|
|
context_lens_ptr, \
|
|
|
max_num_blocks_per_seq, \
|
|
|
alibi_slopes_ptr, \
|
|
|
- query_stride);
|
|
|
+ q_stride, \
|
|
|
+ kv_block_stride, \
|
|
|
+ kv_head_stride);
|
|
|
|
|
|
// TODO: Tune NUM_THREADS.
|
|
|
template<
|
|
@@ -346,6 +353,7 @@ void single_query_cached_kv_attention_launcher(
|
|
|
torch::Tensor& query,
|
|
|
torch::Tensor& key_cache,
|
|
|
torch::Tensor& value_cache,
|
|
|
+ torch::Tensor& head_mapping,
|
|
|
float scale,
|
|
|
torch::Tensor& block_tables,
|
|
|
torch::Tensor& context_lens,
|
|
@@ -355,7 +363,9 @@ void single_query_cached_kv_attention_launcher(
|
|
|
int num_heads = query.size(1);
|
|
|
int head_size = query.size(2);
|
|
|
int max_num_blocks_per_seq = block_tables.size(1);
|
|
|
- int query_stride = query.stride(0);
|
|
|
+ int q_stride = query.stride(0);
|
|
|
+ int kv_block_stride = key_cache.stride(0);
|
|
|
+ int kv_head_stride = key_cache.stride(1);
|
|
|
|
|
|
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
|
|
assert(head_size % thread_group_size == 0);
|
|
@@ -369,6 +379,7 @@ void single_query_cached_kv_attention_launcher(
|
|
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
|
|
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
|
|
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
|
|
+ int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
|
|
int* block_tables_ptr = block_tables.data_ptr<int>();
|
|
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
|
|
|
|
@@ -382,8 +393,8 @@ void single_query_cached_kv_attention_launcher(
|
|
|
dim3 block(NUM_THREADS);
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
switch (head_size) {
|
|
|
- // NOTE(woosuk): To reduce the compilation time, we omitted head sizes
|
|
|
- // 32, 160, 192, 256.
|
|
|
+ // NOTE: To reduce the compilation time, we're omitting head sizes
|
|
|
+ // 32, 160, 192.
|
|
|
// case 32:
|
|
|
// LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
|
|
|
// break;
|
|
@@ -408,9 +419,9 @@ void single_query_cached_kv_attention_launcher(
|
|
|
// case 192:
|
|
|
// LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
|
|
|
// break;
|
|
|
- // case 256:
|
|
|
- // LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
|
|
- // break;
|
|
|
+ case 256:
|
|
|
+ LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
|
|
|
+ break;
|
|
|
default:
|
|
|
TORCH_CHECK(false, "Unsupported head size: ", head_size);
|
|
|
break;
|
|
@@ -423,13 +434,14 @@ void single_query_cached_kv_attention_launcher(
|
|
|
query, \
|
|
|
key_cache, \
|
|
|
value_cache, \
|
|
|
+ head_mapping, \
|
|
|
scale, \
|
|
|
block_tables, \
|
|
|
context_lens, \
|
|
|
max_context_len, \
|
|
|
alibi_slopes);
|
|
|
|
|
|
-// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
|
|
|
+// NOTE: To reduce the compilation time, we're omitting block sizes
|
|
|
// 1, 2, 4, 64, 128, 256.
|
|
|
#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \
|
|
|
switch (block_size) { \
|
|
@@ -470,6 +482,7 @@ void single_query_cached_kv_attention(
|
|
|
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
|
|
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
|
+ torch::Tensor& head_mapping, // [num_heads]
|
|
|
float scale,
|
|
|
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
torch::Tensor& context_lens, // [num_seqs]
|
|
@@ -489,4 +502,4 @@ void single_query_cached_kv_attention(
|
|
|
|
|
|
#undef WARP_SIZE
|
|
|
#undef MAX
|
|
|
-#undef MIN
|
|
|
+#undef MIN
|