Browse Source

fix: attention kernel attribute (#52)

* fix: attention kernel attribute

* fix backslashes
AlpinDale 1 year ago
parent
commit
b7918ad45f
1 changed files with 6 additions and 1 deletions
  1. 6 1
      kernels/attention/attention_kernels.cu

+ 6 - 1
kernels/attention/attention_kernels.cu

@@ -269,7 +269,7 @@ __global__ void single_query_cached_kv_attention_kernel(
           // we should explicitly zero out the values since they may contain NaNs.
           // we should explicitly zero out the values since they may contain NaNs.
           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
 #pragma unroll
 #pragma unroll
-          for (int j = 0; j <= V_VEC_SIZE; j++) {
+          for (int j = 0; j < V_VEC_SIZE; j++) {
             v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
             v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
           }
           }
         }
         }
@@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
 } // namespace aphrodite
 } // namespace aphrodite
 
 
 #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS)                        \
 #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS)                        \
+  cudaFuncSetAttribute(                                                                       \
+      aphrodite::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>,   \
+      cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size);                          \
   aphrodite::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>        \
   aphrodite::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>        \
   <<<grid, block, shared_mem_size, stream>>>(                                                 \
   <<<grid, block, shared_mem_size, stream>>>(                                                 \
     out_ptr,                                                                                  \
     out_ptr,                                                                                  \
@@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
   int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
   int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
   int logits_size = padded_max_context_len * sizeof(float);
   int logits_size = padded_max_context_len * sizeof(float);
   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+  // the python-side check is in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
+  // keep that in sync with the logic here!
   int shared_mem_size = std::max(logits_size, outputs_size);
   int shared_mem_size = std::max(logits_size, outputs_size);
 
 
   dim3 grid(num_heads, num_seqs);
   dim3 grid(num_heads, num_seqs);