Browse Source

optimization: multi-query attention kernel

AlpinDale 1 year ago
parent
commit
24c78e7306
2 changed files with 35 additions and 21 deletions
  1. 1 0
      kernels/attention.cpp
  2. 34 21
      kernels/attention/attention_kernels.cu

+ 1 - 0
kernels/attention.cpp

@@ -6,6 +6,7 @@ void single_query_cached_kv_attention(
   torch::Tensor& query,
   torch::Tensor& key_cache,
   torch::Tensor& value_cache,
+  torch::Tensor& head_mapping,
   float scale,
   torch::Tensor& block_tables,
   torch::Tensor& context_lens,

+ 34 - 21
kernels/attention/attention_kernels.cu

@@ -75,14 +75,17 @@ template<
 __global__ void single_query_cached_kv_attention_kernel(
   scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_heads, head_size/x, block_size, x]
-  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_heads, head_size, block_size]
+  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
+  const int* __restrict__ head_mapping,   // [num_heads]
   const float scale,
   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
   const int* __restrict__ context_lens,   // [num_seqs]
   const int max_num_blocks_per_seq,
   const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride) {
+  const int q_stride,
+  const int kv_block_stride,
+  const int kv_head_stride) {
   constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
   constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
@@ -92,6 +95,7 @@ __global__ void single_query_cached_kv_attention_kernel(
 
   const int head_idx = blockIdx.x;
   const int num_heads = gridDim.x;
+  const int kv_head_idx = head_mapping[head_idx];
   const int seq_idx = blockIdx.y;
   const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
 
@@ -115,7 +119,7 @@ __global__ void single_query_cached_kv_attention_kernel(
   // For example, if the the thread group size is 4, then the first thread in the group
   // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
   // th vectors of the query, and so on.
-  // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
+  // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
   Q_vec q_vecs[NUM_VECS_PER_THREAD];
 #pragma unroll
@@ -126,7 +130,7 @@ __global__ void single_query_cached_kv_attention_kernel(
 
   // Memory planning.
   extern __shared__ char shared_mem[];
-  // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
+  // NOTE: We use FP32 for the softmax logits for better accuracy.
   float* logits = reinterpret_cast<float*>(shared_mem);
   // Workspace for reduction.
   __shared__ float red_smem[2 * NUM_WARPS];
@@ -159,8 +163,8 @@ __global__ void single_query_cached_kv_attention_kernel(
 
 #pragma unroll
       for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
-        const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
-                                        + head_idx * HEAD_SIZE * BLOCK_SIZE
+        const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+                                        + kv_head_idx * kv_head_stride
                                         + physical_block_offset * x;
         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
         const int offset1 = (vec_idx * VEC_SIZE) / x;
@@ -176,7 +180,7 @@ __global__ void single_query_cached_kv_attention_kernel(
 
       if (thread_group_offset == 0) {
         // Store the partial reductions to shared memory.
-        // NOTE(woosuk): It is required to zero out the masked logits.
+        // NOTE: It is required to zero out the masked logits.
         const bool mask = token_idx >= context_len;
         logits[token_idx] = mask ? 0.f : qk;
         // Update the max value.
@@ -233,7 +237,7 @@ __global__ void single_query_cached_kv_attention_kernel(
   constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
   constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER;
 
-  // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
+  // NOTE: We use FP32 for the accumulator for better accuracy.
   float accs[NUM_ROWS_PER_THREAD];
 #pragma unroll
   for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
@@ -247,8 +251,8 @@ __global__ void single_query_cached_kv_attention_kernel(
     L_vec logits_vec;
     from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx));
 
-    const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE
-                                    + head_idx * HEAD_SIZE * BLOCK_SIZE;
+    const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+                                    + kv_head_idx * kv_head_stride;
 #pragma unroll
     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@@ -271,7 +275,7 @@ __global__ void single_query_cached_kv_attention_kernel(
     accs[i] = acc;
   }
 
-  // NOTE(woosuk): A barrier is required because the shared memory space for logits
+  // NOTE: A barrier is required because the shared memory space for logits
   // is reused for the output.
   __syncthreads();
 
@@ -329,12 +333,15 @@ __global__ void single_query_cached_kv_attention_kernel(
     query_ptr,                                                                                \
     key_cache_ptr,                                                                            \
     value_cache_ptr,                                                                          \
+    head_mapping_ptr,                                                                         \
     scale,                                                                                    \
     block_tables_ptr,                                                                         \
     context_lens_ptr,                                                                         \
     max_num_blocks_per_seq,                                                                   \
     alibi_slopes_ptr,                                                                         \
-    query_stride);
+    q_stride,                                                                                 \
+    kv_block_stride,                                                                          \
+    kv_head_stride);
 
 // TODO: Tune NUM_THREADS.
 template<
@@ -346,6 +353,7 @@ void single_query_cached_kv_attention_launcher(
   torch::Tensor& query,
   torch::Tensor& key_cache,
   torch::Tensor& value_cache,
+  torch::Tensor& head_mapping,
   float scale,
   torch::Tensor& block_tables,
   torch::Tensor& context_lens,
@@ -355,7 +363,9 @@ void single_query_cached_kv_attention_launcher(
   int num_heads = query.size(1);
   int head_size = query.size(2);
   int max_num_blocks_per_seq = block_tables.size(1);
-  int query_stride = query.stride(0);
+  int q_stride = query.stride(0);
+  int kv_block_stride = key_cache.stride(0);
+  int kv_head_stride = key_cache.stride(1);
 
   int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
   assert(head_size % thread_group_size == 0);
@@ -369,6 +379,7 @@ void single_query_cached_kv_attention_launcher(
   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
   T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
   T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
+  int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
   int* block_tables_ptr = block_tables.data_ptr<int>();
   int* context_lens_ptr = context_lens.data_ptr<int>();
 
@@ -382,8 +393,8 @@ void single_query_cached_kv_attention_launcher(
   dim3 block(NUM_THREADS);
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   switch (head_size) {
-    // NOTE(woosuk): To reduce the compilation time, we omitted head sizes
-    // 32, 160, 192, 256.
+    // NOTE: To reduce the compilation time, we're omitting head sizes
+    // 32, 160, 192.
     // case 32:
     //   LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS);
     //   break;
@@ -408,9 +419,9 @@ void single_query_cached_kv_attention_launcher(
     // case 192:
     //   LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS);
     //   break;
-    // case 256:
-    //   LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
-    //   break;
+    case 256:
+      LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS);
+      break;
     default:
       TORCH_CHECK(false, "Unsupported head size: ", head_size);
       break;
@@ -423,13 +434,14 @@ void single_query_cached_kv_attention_launcher(
     query,                                                          \
     key_cache,                                                      \
     value_cache,                                                    \
+    head_mapping,                                                   \
     scale,                                                          \
     block_tables,                                                   \
     context_lens,                                                   \
     max_context_len,                                                \
     alibi_slopes);
 
-// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
+// NOTE: To reduce the compilation time, we're omitting block sizes
 // 1, 2, 4, 64, 128, 256.
 #define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T)                          \
   switch (block_size) {                                             \
@@ -470,6 +482,7 @@ void single_query_cached_kv_attention(
   torch::Tensor& query,           // [num_seqs, num_heads, head_size]
   torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
   torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
+  torch::Tensor& head_mapping,    // [num_heads]
   float scale,
   torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
   torch::Tensor& context_lens,    // [num_seqs]
@@ -489,4 +502,4 @@ void single_query_cached_kv_attention(
 
 #undef WARP_SIZE
 #undef MAX
-#undef MIN
+#undef MIN