|
@@ -269,7 +269,7 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
// we should explicitly zero out the values since they may contain NaNs.
|
|
// we should explicitly zero out the values since they may contain NaNs.
|
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
|
#pragma unroll
|
|
#pragma unroll
|
|
- for (int j = 0; j <= V_VEC_SIZE; j++) {
|
|
|
|
|
|
+ for (int j = 0; j < V_VEC_SIZE; j++) {
|
|
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
|
v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
@@ -341,6 +341,9 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|
} // namespace aphrodite
|
|
} // namespace aphrodite
|
|
|
|
|
|
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
|
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
|
|
|
+ cudaFuncSetAttribute( \
|
|
|
|
+ aphrodite::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
|
|
|
+ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
|
aphrodite::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
|
aphrodite::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
|
<<<grid, block, shared_mem_size, stream>>>( \
|
|
<<<grid, block, shared_mem_size, stream>>>( \
|
|
out_ptr, \
|
|
out_ptr, \
|
|
@@ -401,6 +404,8 @@ void single_query_cached_kv_attention_launcher(
|
|
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
|
int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
|
int logits_size = padded_max_context_len * sizeof(float);
|
|
int logits_size = padded_max_context_len * sizeof(float);
|
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
|
|
|
+ // the python-side check is in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
|
|
|
|
+ // keep that in sync with the logic here!
|
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
|
|
|
|
|
dim3 grid(num_heads, num_seqs);
|
|
dim3 grid(num_heads, num_seqs);
|