|
@@ -1,5 +1,6 @@
|
|
/*
|
|
/*
|
|
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
|
|
|
|
|
+ * Adapted from
|
|
|
|
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
|
* Copyright (c) 2023, The PygmalionAI team.
|
|
* Copyright (c) 2023, The PygmalionAI team.
|
|
* Copyright (c) 2023, The vLLM team.
|
|
* Copyright (c) 2023, The vLLM team.
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
@@ -17,7 +18,7 @@
|
|
* limitations under the License.
|
|
* limitations under the License.
|
|
*/
|
|
*/
|
|
#ifdef USE_ROCM
|
|
#ifdef USE_ROCM
|
|
-#include <hip/hip_runtime.h>
|
|
|
|
|
|
+ #include <hip/hip_runtime.h>
|
|
#endif
|
|
#endif
|
|
|
|
|
|
#include <torch/extension.h>
|
|
#include <torch/extension.h>
|
|
@@ -28,15 +29,15 @@
|
|
#include "attention_utils.cuh"
|
|
#include "attention_utils.cuh"
|
|
#include "../quantization/int8_kvcache/quant_utils.cuh"
|
|
#include "../quantization/int8_kvcache/quant_utils.cuh"
|
|
#ifdef ENABLE_FP8_E5M2
|
|
#ifdef ENABLE_FP8_E5M2
|
|
-#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
|
|
|
|
|
+ #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
|
|
#endif
|
|
#endif
|
|
|
|
|
|
#include <algorithm>
|
|
#include <algorithm>
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
#ifndef USE_ROCM
|
|
-#define WARP_SIZE 32
|
|
|
|
|
|
+ #define WARP_SIZE 32
|
|
#else
|
|
#else
|
|
-#define WARP_SIZE warpSize
|
|
|
|
|
|
+ #define WARP_SIZE warpSize
|
|
#endif
|
|
#endif
|
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
@@ -47,12 +48,13 @@ enum kv_cache_dtype {
|
|
#ifdef ENABLE_FP8_E5M2
|
|
#ifdef ENABLE_FP8_E5M2
|
|
FP8_E5M2,
|
|
FP8_E5M2,
|
|
#endif
|
|
#endif
|
|
- INT8};
|
|
|
|
|
|
+ INT8
|
|
|
|
+};
|
|
|
|
|
|
namespace aphrodite {
|
|
namespace aphrodite {
|
|
|
|
|
|
// Utility function for attention softmax.
|
|
// Utility function for attention softmax.
|
|
-template<int NUM_WARPS>
|
|
|
|
|
|
+template <int NUM_WARPS>
|
|
inline __device__ float block_sum(float* red_smem, float sum) {
|
|
inline __device__ float block_sum(float* red_smem, float sum) {
|
|
// Decompose the thread index into warp / lane.
|
|
// Decompose the thread index into warp / lane.
|
|
int warp = threadIdx.x / WARP_SIZE;
|
|
int warp = threadIdx.x / WARP_SIZE;
|
|
@@ -89,34 +91,29 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
|
|
|
|
|
// TODO: Merge the last two dimensions of the grid.
|
|
// TODO: Merge the last two dimensions of the grid.
|
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
|
-template<
|
|
|
|
- typename scalar_t,
|
|
|
|
- typename cache_t,
|
|
|
|
- int HEAD_SIZE,
|
|
|
|
- int BLOCK_SIZE,
|
|
|
|
- int NUM_THREADS,
|
|
|
|
- kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
|
- int PARTITION_SIZE = 0> // Zero means no partitioning.
|
|
|
|
|
|
+template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
|
|
|
+ int NUM_THREADS, kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
|
+ int PARTITION_SIZE = 0> // Zero means no partitioning.
|
|
__device__ void paged_attention_kernel(
|
|
__device__ void paged_attention_kernel(
|
|
- float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
- float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
- scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
|
|
|
- const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
|
|
- const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
|
|
|
- const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
|
|
|
- const int num_kv_heads, // [num_heads]
|
|
|
|
- const float scale,
|
|
|
|
- const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
- const int* __restrict__ context_lens, // [num_seqs]
|
|
|
|
- const int max_num_blocks_per_seq,
|
|
|
|
- const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
|
- const int q_stride,
|
|
|
|
- const int kv_block_stride,
|
|
|
|
- const int kv_head_stride,
|
|
|
|
- const float k_scale = 1.0f,
|
|
|
|
- const float k_zp = 0.0f,
|
|
|
|
- const float v_scale = 1.0f,
|
|
|
|
- const float v_zp = 0.0f) {
|
|
|
|
|
|
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
+ float* __restrict__ max_logits, // [num_seqs, num_heads,
|
|
|
|
+ // max_num_partitions]
|
|
|
|
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
|
|
|
|
+ // head_size]
|
|
|
|
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
|
|
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
|
|
|
+ // head_size/x, block_size, x]
|
|
|
|
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
|
|
|
+ // head_size, block_size]
|
|
|
|
+ const int num_kv_heads, // [num_heads]
|
|
|
|
+ const float scale,
|
|
|
|
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
+ const int* __restrict__ context_lens, // [num_seqs]
|
|
|
|
+ const int max_num_blocks_per_seq,
|
|
|
|
+ const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
|
+ const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
|
|
+ const float k_scale = 1.0f, const float k_zp = 0.0f,
|
|
|
|
+ const float v_scale = 1.0f, const float v_zp = 0.0f) {
|
|
const int seq_idx = blockIdx.y;
|
|
const int seq_idx = blockIdx.y;
|
|
const int partition_idx = blockIdx.z;
|
|
const int partition_idx = blockIdx.z;
|
|
const int max_num_partitions = gridDim.z;
|
|
const int max_num_partitions = gridDim.z;
|
|
@@ -128,22 +125,29 @@ __device__ void paged_attention_kernel(
|
|
}
|
|
}
|
|
|
|
|
|
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
|
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
|
- const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
|
|
|
|
|
|
+ const int num_blocks_per_partition =
|
|
|
|
+ USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
|
|
|
|
|
|
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
|
// [start_block_idx, end_block_idx) is the range of blocks to process.
|
|
- const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
|
|
|
- const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
|
|
|
|
|
|
+ const int start_block_idx =
|
|
|
|
+ USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
|
|
|
|
+ const int end_block_idx =
|
|
|
|
+ MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
|
|
const int num_blocks = end_block_idx - start_block_idx;
|
|
const int num_blocks = end_block_idx - start_block_idx;
|
|
|
|
|
|
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
|
// [start_token_idx, end_token_idx) is the range of tokens to process.
|
|
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
|
const int start_token_idx = start_block_idx * BLOCK_SIZE;
|
|
- const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
|
|
|
|
|
|
+ const int end_token_idx =
|
|
|
|
+ MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
|
|
const int num_tokens = end_token_idx - start_token_idx;
|
|
const int num_tokens = end_token_idx - start_token_idx;
|
|
|
|
|
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
|
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
|
- constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
|
|
|
|
|
|
+ constexpr int NUM_THREAD_GROUPS =
|
|
|
|
+ NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
|
|
|
|
+ // divides NUM_THREADS
|
|
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
|
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
|
|
- constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
|
|
|
|
|
+ constexpr int NUM_TOKENS_PER_THREAD_GROUP =
|
|
|
|
+ DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
|
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
const int thread_idx = threadIdx.x;
|
|
const int thread_idx = threadIdx.x;
|
|
const int warp_idx = thread_idx / WARP_SIZE;
|
|
const int warp_idx = thread_idx / WARP_SIZE;
|
|
@@ -153,13 +157,14 @@ __device__ void paged_attention_kernel(
|
|
const int num_heads = gridDim.x;
|
|
const int num_heads = gridDim.x;
|
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
|
const int num_queries_per_kv = num_heads / num_kv_heads;
|
|
const int kv_head_idx = head_idx / num_queries_per_kv;
|
|
const int kv_head_idx = head_idx / num_queries_per_kv;
|
|
- const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
|
|
|
|
|
+ const float alibi_slope =
|
|
|
|
+ alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
|
|
|
|
|
// A vector type to store a part of a key or a query.
|
|
// A vector type to store a part of a key or a query.
|
|
- // The vector size is configured in such a way that the threads in a thread group
|
|
|
|
- // fetch or compute 16 bytes at a time.
|
|
|
|
- // For example, if the size of a thread group is 4 and the data type is half,
|
|
|
|
- // then the vector size is 16 / (4 * sizeof(half)) == 2.
|
|
|
|
|
|
+ // The vector size is configured in such a way that the threads in a thread
|
|
|
|
+ // group fetch or compute 16 bytes at a time. For example, if the size of a
|
|
|
|
+ // thread group is 4 and the data type is half, then the vector size is 16 /
|
|
|
|
+ // (4 * sizeof(half)) == 2.
|
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
|
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
|
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
|
|
@@ -173,18 +178,21 @@ __device__ void paged_attention_kernel(
|
|
|
|
|
|
// Load the query to registers.
|
|
// Load the query to registers.
|
|
// Each thread in a thread group has a different part of the query.
|
|
// Each thread in a thread group has a different part of the query.
|
|
- // For example, if the the thread group size is 4, then the first thread in the group
|
|
|
|
- // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
|
|
|
|
- // th vectors of the query, and so on.
|
|
|
|
- // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
|
|
|
|
|
|
+ // For example, if the the thread group size is 4, then the first thread in
|
|
|
|
+ // the group has 0, 4, 8, ... th vectors of the query, and the second thread
|
|
|
|
+ // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because q is
|
|
|
|
+ // split from a qkv tensor, it may not be contiguous.
|
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
|
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
|
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
|
#pragma unroll
|
|
#pragma unroll
|
|
- for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
|
|
|
|
|
|
+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
|
|
|
|
+ i += NUM_THREAD_GROUPS) {
|
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
|
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
|
|
- q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
|
|
|
|
|
+ q_vecs[thread_group_offset][i] =
|
|
|
|
+ *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
|
|
}
|
|
}
|
|
- __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
|
|
|
|
|
|
+ __syncthreads(); // TODO: possible speedup if this is replaced with a memory
|
|
|
|
+ // wall right before we use q_vecs
|
|
|
|
|
|
// Memory planning.
|
|
// Memory planning.
|
|
extern __shared__ char shared_mem[];
|
|
extern __shared__ char shared_mem[];
|
|
@@ -203,51 +211,60 @@ __device__ void paged_attention_kernel(
|
|
// Each thread group in a warp fetches a key from the block, and computes
|
|
// Each thread group in a warp fetches a key from the block, and computes
|
|
// dot product with the query.
|
|
// dot product with the query.
|
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
|
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
|
- for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
|
|
|
|
|
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
|
|
|
+ block_idx += NUM_WARPS) {
|
|
// NOTE: The block number is stored in int32. However, we cast it to int64
|
|
// NOTE: The block number is stored in int32. However, we cast it to int64
|
|
- // because int32 can lead to overflow when this variable is multiplied by large numbers
|
|
|
|
- // (e.g., kv_block_stride).
|
|
|
|
- const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
|
|
|
|
|
+ // because int32 can lead to overflow when this variable is multiplied by
|
|
|
|
+ // large numbers (e.g., kv_block_stride).
|
|
|
|
+ const int64_t physical_block_number =
|
|
|
|
+ static_cast<int64_t>(block_table[block_idx]);
|
|
|
|
|
|
// Load a key to registers.
|
|
// Load a key to registers.
|
|
// Each thread in a thread group has a different part of the key.
|
|
// Each thread in a thread group has a different part of the key.
|
|
- // For example, if the the thread group size is 4, then the first thread in the group
|
|
|
|
- // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
|
|
|
|
- // vectors of the key, and so on.
|
|
|
|
|
|
+ // For example, if the the thread group size is 4, then the first thread in
|
|
|
|
+ // the group has 0, 4, 8, ... th vectors of the key, and the second thread
|
|
|
|
+ // has 1, 5, 9, ... th vectors of the key, and so on.
|
|
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
|
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
|
|
- const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
|
|
|
|
|
+ const int physical_block_offset =
|
|
|
|
+ (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
|
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
|
K_vec k_vecs[NUM_VECS_PER_THREAD];
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
|
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
|
|
- const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
|
|
|
- + kv_head_idx * kv_head_stride
|
|
|
|
- + physical_block_offset * x;
|
|
|
|
|
|
+ const cache_t* k_ptr =
|
|
|
|
+ k_cache + physical_block_number * kv_block_stride +
|
|
|
|
+ kv_head_idx * kv_head_stride + physical_block_offset * x;
|
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
|
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
|
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
|
const int offset1 = (vec_idx * VEC_SIZE) / x;
|
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
|
const int offset2 = (vec_idx * VEC_SIZE) % x;
|
|
if constexpr (KV_CACHE_DTYPE == INT8) {
|
|
if constexpr (KV_CACHE_DTYPE == INT8) {
|
|
- Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
|
|
|
|
+ Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
|
|
|
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
using Dequant_vec = typename FloatVec<Quant_vec>::Type;
|
|
using Dequant_vec = typename FloatVec<Quant_vec>::Type;
|
|
Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
|
|
Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
|
|
k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
|
|
k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
|
|
#ifdef ENABLE_FP8_E5M2
|
|
#ifdef ENABLE_FP8_E5M2
|
|
} else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
|
|
} else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
|
|
- Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
|
|
|
|
+ Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
|
|
|
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
// Vector conversion from Quant_vec to K_vec.
|
|
// Vector conversion from Quant_vec to K_vec.
|
|
- k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
|
|
|
|
|
+ k_vecs[j] =
|
|
|
|
+ fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
|
|
#endif
|
|
#endif
|
|
} else {
|
|
} else {
|
|
- k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
|
|
|
|
+ k_vecs[j] = *reinterpret_cast<const K_vec*>(
|
|
|
|
+ k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
// Compute dot product.
|
|
// Compute dot product.
|
|
// This includes a reduction across the threads in the same thread group.
|
|
// This includes a reduction across the threads in the same thread group.
|
|
- float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
|
|
|
|
|
+ float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
|
|
|
|
+ q_vecs[thread_group_offset], k_vecs);
|
|
// Add the ALiBi bias if slopes are given.
|
|
// Add the ALiBi bias if slopes are given.
|
|
- qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
|
|
|
|
|
+ qk +=
|
|
|
|
+ (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
|
|
|
|
|
|
if (thread_group_offset == 0) {
|
|
if (thread_group_offset == 0) {
|
|
// Store the partial reductions to shared memory.
|
|
// Store the partial reductions to shared memory.
|
|
@@ -300,13 +317,12 @@ __device__ void paged_attention_kernel(
|
|
|
|
|
|
// If partitioning is enabled, store the max logit and exp_sum.
|
|
// If partitioning is enabled, store the max logit and exp_sum.
|
|
if (USE_PARTITIONING && thread_idx == 0) {
|
|
if (USE_PARTITIONING && thread_idx == 0) {
|
|
- float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
|
|
|
- + head_idx * max_num_partitions
|
|
|
|
- + partition_idx;
|
|
|
|
|
|
+ float* max_logits_ptr = max_logits +
|
|
|
|
+ seq_idx * num_heads * max_num_partitions +
|
|
|
|
+ head_idx * max_num_partitions + partition_idx;
|
|
*max_logits_ptr = qk_max;
|
|
*max_logits_ptr = qk_max;
|
|
- float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
|
|
|
- + head_idx * max_num_partitions
|
|
|
|
- + partition_idx;
|
|
|
|
|
|
+ float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
|
|
|
|
+ head_idx * max_num_partitions + partition_idx;
|
|
*exp_sums_ptr = exp_sum;
|
|
*exp_sums_ptr = exp_sum;
|
|
}
|
|
}
|
|
|
|
|
|
@@ -319,7 +335,8 @@ __device__ void paged_attention_kernel(
|
|
|
|
|
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
|
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
|
|
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
|
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
|
|
- constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
|
|
|
|
|
+ constexpr int NUM_ROWS_PER_THREAD =
|
|
|
|
+ DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
|
|
|
|
|
|
// NOTE: We use FP32 for the accumulator for better accuracy.
|
|
// NOTE: We use FP32 for the accumulator for better accuracy.
|
|
float accs[NUM_ROWS_PER_THREAD];
|
|
float accs[NUM_ROWS_PER_THREAD];
|
|
@@ -330,18 +347,21 @@ __device__ void paged_attention_kernel(
|
|
|
|
|
|
scalar_t zero_value;
|
|
scalar_t zero_value;
|
|
zero(zero_value);
|
|
zero(zero_value);
|
|
- for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
|
|
|
|
|
|
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
|
|
|
|
+ block_idx += NUM_WARPS) {
|
|
// NOTE: The block number is stored in int32. However, we cast it to int64
|
|
// NOTE: The block number is stored in int32. However, we cast it to int64
|
|
- // because int32 can lead to overflow when this variable is multiplied by large numbers
|
|
|
|
- // (e.g., kv_block_stride).
|
|
|
|
- const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
|
|
|
|
|
|
+ // because int32 can lead to overflow when this variable is multiplied by
|
|
|
|
+ // large numbers (e.g., kv_block_stride).
|
|
|
|
+ const int64_t physical_block_number =
|
|
|
|
+ static_cast<int64_t>(block_table[block_idx]);
|
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
|
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
|
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
|
|
L_vec logits_vec;
|
|
L_vec logits_vec;
|
|
- from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
|
|
|
|
|
|
+ from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
|
|
|
|
+ start_token_idx));
|
|
|
|
|
|
- const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
|
|
|
|
- + kv_head_idx * kv_head_stride;
|
|
|
|
|
|
+ const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
|
|
|
|
+ kv_head_idx * kv_head_stride;
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
@@ -350,26 +370,32 @@ __device__ void paged_attention_kernel(
|
|
V_vec v_vec;
|
|
V_vec v_vec;
|
|
if constexpr (KV_CACHE_DTYPE == INT8) {
|
|
if constexpr (KV_CACHE_DTYPE == INT8) {
|
|
// dequant and conversion
|
|
// dequant and conversion
|
|
- V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
|
|
|
|
|
+ V_quant_vec v_vec_quant =
|
|
|
|
+ *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
|
using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
|
|
using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
|
|
- V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
|
|
|
|
|
|
+ V_dequant_vec v_vec_dequant =
|
|
|
|
+ int8::dequant(v_vec_quant, v_scale, v_zp);
|
|
v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
|
|
v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
|
|
#ifdef ENABLE_FP8_E5M2
|
|
#ifdef ENABLE_FP8_E5M2
|
|
} else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
|
|
} else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
|
|
- V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
|
|
|
|
|
+ V_quant_vec v_quant_vec =
|
|
|
|
+ *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
|
// Vector conversion from V_quant_vec to V_vec.
|
|
// Vector conversion from V_quant_vec to V_vec.
|
|
- v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
|
|
|
|
|
|
+ v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(
|
|
|
|
+ v_quant_vec);
|
|
#endif
|
|
#endif
|
|
} else {
|
|
} else {
|
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
|
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
|
|
}
|
|
}
|
|
if (block_idx == num_context_blocks - 1) {
|
|
if (block_idx == num_context_blocks - 1) {
|
|
// NOTE: When v_vec contains the tokens that are out of the context,
|
|
// NOTE: When v_vec contains the tokens that are out of the context,
|
|
- // we should explicitly zero out the values since they may contain NaNs.
|
|
|
|
|
|
+ // we should explicitly zero out the values since they may contain
|
|
|
|
+ // NaNs.
|
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
|
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int j = 0; j < V_VEC_SIZE; j++) {
|
|
for (int j = 0; j < V_VEC_SIZE; j++) {
|
|
- v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
|
|
|
|
|
+ v_vec_ptr[j] =
|
|
|
|
+ token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
accs[i] += dot(logits_vec, v_vec);
|
|
accs[i] += dot(logits_vec, v_vec);
|
|
@@ -426,9 +452,9 @@ __device__ void paged_attention_kernel(
|
|
|
|
|
|
// Write the final output.
|
|
// Write the final output.
|
|
if (warp_idx == 0) {
|
|
if (warp_idx == 0) {
|
|
- scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
|
|
|
- + head_idx * max_num_partitions * HEAD_SIZE
|
|
|
|
- + partition_idx * HEAD_SIZE;
|
|
|
|
|
|
+ scalar_t* out_ptr =
|
|
|
|
+ out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
|
|
|
+ head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
|
@@ -440,85 +466,77 @@ __device__ void paged_attention_kernel(
|
|
}
|
|
}
|
|
|
|
|
|
// Grid: (num_heads, num_seqs, 1).
|
|
// Grid: (num_heads, num_seqs, 1).
|
|
-template<
|
|
|
|
- typename scalar_t,
|
|
|
|
- typename cache_t,
|
|
|
|
- int HEAD_SIZE,
|
|
|
|
- int BLOCK_SIZE,
|
|
|
|
- int NUM_THREADS,
|
|
|
|
- kv_cache_dtype KV_CACHE_DTYPE>
|
|
|
|
|
|
+template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
|
|
|
+ int NUM_THREADS,
|
|
|
|
+ kv_cache_dtype KV_CACHE_DTYPE>
|
|
__global__ void paged_attention_v1_kernel(
|
|
__global__ void paged_attention_v1_kernel(
|
|
- scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
|
|
- const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
|
|
- const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
|
|
|
- const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
|
|
|
- const int num_kv_heads, // [num_heads]
|
|
|
|
- const float scale,
|
|
|
|
- const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
- const int* __restrict__ context_lens, // [num_seqs]
|
|
|
|
- const int max_num_blocks_per_seq,
|
|
|
|
- const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
|
- const int q_stride,
|
|
|
|
- const int kv_block_stride,
|
|
|
|
- const int kv_head_stride,
|
|
|
|
- const float k_scale,
|
|
|
|
- const float k_zp,
|
|
|
|
- const float v_scale,
|
|
|
|
- const float v_zp) {
|
|
|
|
- paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
|
|
|
|
- /* exp_sums */ nullptr, /* max_logits */ nullptr,
|
|
|
|
- out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
|
|
|
- max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
|
|
|
|
|
|
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
|
|
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
|
|
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
|
|
|
+ // head_size/x, block_size, x]
|
|
|
|
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
|
|
|
+ // head_size, block_size]
|
|
|
|
+ const int num_kv_heads, // [num_heads]
|
|
|
|
+ const float scale,
|
|
|
|
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
+ const int* __restrict__ context_lens, // [num_seqs]
|
|
|
|
+ const int max_num_blocks_per_seq,
|
|
|
|
+ const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
|
+ const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
|
|
+ const float k_scale, const float k_zp, const float v_scale,
|
|
|
|
+ const float v_zp) {
|
|
|
|
+ paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
|
|
|
+ KV_CACHE_DTYPE>(
|
|
|
|
+ /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
|
|
|
+ v_cache, num_kv_heads, scale, block_tables, context_lens,
|
|
|
|
+ max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
|
|
|
+ kv_head_stride, k_scale, k_zp, v_scale, v_zp);
|
|
}
|
|
}
|
|
|
|
|
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
|
// Grid: (num_heads, num_seqs, max_num_partitions).
|
|
-template<
|
|
|
|
- typename scalar_t,
|
|
|
|
- typename cache_t,
|
|
|
|
- int HEAD_SIZE,
|
|
|
|
- int BLOCK_SIZE,
|
|
|
|
- int NUM_THREADS,
|
|
|
|
- kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
|
- int PARTITION_SIZE>
|
|
|
|
|
|
+template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
|
|
|
|
+ int NUM_THREADS, kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
|
+ int PARTITION_SIZE>
|
|
__global__ void paged_attention_v2_kernel(
|
|
__global__ void paged_attention_v2_kernel(
|
|
- float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
- float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
- scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
|
|
|
- const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
|
|
- const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
|
|
|
- const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
|
|
|
- const int num_kv_heads, // [num_heads]
|
|
|
|
- const float scale,
|
|
|
|
- const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
- const int* __restrict__ context_lens, // [num_seqs]
|
|
|
|
- const int max_num_blocks_per_seq,
|
|
|
|
- const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
|
- const int q_stride,
|
|
|
|
- const int kv_block_stride,
|
|
|
|
- const int kv_head_stride,
|
|
|
|
- const float k_scale,
|
|
|
|
- const float k_zp,
|
|
|
|
- const float v_scale,
|
|
|
|
- const float v_zp) {
|
|
|
|
- paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
|
|
|
|
- exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
|
|
|
- block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
|
|
|
- q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
|
|
|
|
|
|
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
+ float* __restrict__ max_logits, // [num_seqs, num_heads,
|
|
|
|
+ // max_num_partitions]
|
|
|
|
+ scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
|
|
|
+ // max_num_partitions, head_size]
|
|
|
|
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
|
|
|
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
|
|
|
|
+ // head_size/x, block_size, x]
|
|
|
|
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
|
|
|
|
+ // head_size, block_size]
|
|
|
|
+ const int num_kv_heads, // [num_heads]
|
|
|
|
+ const float scale,
|
|
|
|
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
+ const int* __restrict__ context_lens, // [num_seqs]
|
|
|
|
+ const int max_num_blocks_per_seq,
|
|
|
|
+ const float* __restrict__ alibi_slopes, // [num_heads]
|
|
|
|
+ const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
|
|
|
+ const float k_scale, const float k_zp, const float v_scale,
|
|
|
|
+ const float v_zp) {
|
|
|
|
+ paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
|
|
|
+ KV_CACHE_DTYPE, PARTITION_SIZE>(
|
|
|
|
+ exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
|
|
|
+ block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
|
|
|
+ q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
|
|
}
|
|
}
|
|
|
|
|
|
// Grid: (num_heads, num_seqs).
|
|
// Grid: (num_heads, num_seqs).
|
|
-template<
|
|
|
|
- typename scalar_t,
|
|
|
|
- int HEAD_SIZE,
|
|
|
|
- int NUM_THREADS,
|
|
|
|
- int PARTITION_SIZE>
|
|
|
|
|
|
+template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
|
|
|
|
+ int PARTITION_SIZE>
|
|
__global__ void paged_attention_v2_reduce_kernel(
|
|
__global__ void paged_attention_v2_reduce_kernel(
|
|
- scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
|
|
- const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
- const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
- const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
|
|
|
- const int* __restrict__ context_lens, // [num_seqs]
|
|
|
|
- const int max_num_partitions) {
|
|
|
|
|
|
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
|
|
|
|
+ const float* __restrict__ exp_sums, // [num_seqs, num_heads,
|
|
|
|
+ // max_num_partitions]
|
|
|
|
+ const float* __restrict__ max_logits, // [num_seqs, num_heads,
|
|
|
|
+ // max_num_partitions]
|
|
|
|
+ const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
|
|
|
+ // max_num_partitions, head_size]
|
|
|
|
+ const int* __restrict__ context_lens, // [num_seqs]
|
|
|
|
+ const int max_num_partitions) {
|
|
const int num_heads = gridDim.x;
|
|
const int num_heads = gridDim.x;
|
|
const int head_idx = blockIdx.x;
|
|
const int head_idx = blockIdx.x;
|
|
const int seq_idx = blockIdx.y;
|
|
const int seq_idx = blockIdx.y;
|
|
@@ -526,9 +544,11 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
|
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
|
if (num_partitions == 1) {
|
|
if (num_partitions == 1) {
|
|
// No need to reduce. Only copy tmp_out to out.
|
|
// No need to reduce. Only copy tmp_out to out.
|
|
- scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
|
|
|
- const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
|
|
|
- + head_idx * max_num_partitions * HEAD_SIZE;
|
|
|
|
|
|
+ scalar_t* out_ptr =
|
|
|
|
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
|
|
|
+ const scalar_t* tmp_out_ptr =
|
|
|
|
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
|
|
|
+ head_idx * max_num_partitions * HEAD_SIZE;
|
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
|
|
out_ptr[i] = tmp_out_ptr[i];
|
|
out_ptr[i] = tmp_out_ptr[i];
|
|
}
|
|
}
|
|
@@ -547,8 +567,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|
|
|
|
|
// Load max logits to shared memory.
|
|
// Load max logits to shared memory.
|
|
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
|
float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
|
|
- const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
|
|
|
|
- + head_idx * max_num_partitions;
|
|
|
|
|
|
+ const float* max_logits_ptr = max_logits +
|
|
|
|
+ seq_idx * num_heads * max_num_partitions +
|
|
|
|
+ head_idx * max_num_partitions;
|
|
float max_logit = -FLT_MAX;
|
|
float max_logit = -FLT_MAX;
|
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
|
const float l = max_logits_ptr[i];
|
|
const float l = max_logits_ptr[i];
|
|
@@ -577,9 +598,11 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|
max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
|
|
max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
|
|
|
|
|
|
// Load rescaled exp sums to shared memory.
|
|
// Load rescaled exp sums to shared memory.
|
|
- float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
|
|
|
- const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
|
|
|
|
- + head_idx * max_num_partitions;
|
|
|
|
|
|
+ float* shared_exp_sums =
|
|
|
|
+ reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
|
|
|
+ const float* exp_sums_ptr = exp_sums +
|
|
|
|
+ seq_idx * num_heads * max_num_partitions +
|
|
|
|
+ head_idx * max_num_partitions;
|
|
float global_exp_sum = 0.0f;
|
|
float global_exp_sum = 0.0f;
|
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
|
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
|
|
float l = shared_max_logits[i];
|
|
float l = shared_max_logits[i];
|
|
@@ -592,67 +615,47 @@ __global__ void paged_attention_v2_reduce_kernel(
|
|
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
|
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
|
|
|
|
|
|
// Aggregate tmp_out to out.
|
|
// Aggregate tmp_out to out.
|
|
- const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
|
|
|
- + head_idx * max_num_partitions * HEAD_SIZE;
|
|
|
|
- scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
|
|
|
|
|
+ const scalar_t* tmp_out_ptr =
|
|
|
|
+ tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
|
|
|
|
+ head_idx * max_num_partitions * HEAD_SIZE;
|
|
|
|
+ scalar_t* out_ptr =
|
|
|
|
+ out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
|
|
#pragma unroll
|
|
#pragma unroll
|
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
|
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
|
|
float acc = 0.0f;
|
|
float acc = 0.0f;
|
|
for (int j = 0; j < num_partitions; ++j) {
|
|
for (int j = 0; j < num_partitions; ++j) {
|
|
- acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
|
|
|
|
|
|
+ acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
|
|
|
|
+ inv_global_exp_sum;
|
|
}
|
|
}
|
|
from_float(out_ptr[i], acc);
|
|
from_float(out_ptr[i], acc);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-} // namespace aphrodite
|
|
|
|
-
|
|
|
|
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
|
|
|
- APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
|
|
|
- ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
|
|
|
- KV_CACHE_DTYPE>), shared_mem_size); \
|
|
|
|
- aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
|
|
|
- KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>( \
|
|
|
|
- out_ptr, \
|
|
|
|
- query_ptr, \
|
|
|
|
- key_cache_ptr, \
|
|
|
|
- value_cache_ptr, \
|
|
|
|
- num_kv_heads, \
|
|
|
|
- scale, \
|
|
|
|
- block_tables_ptr, \
|
|
|
|
- context_lens_ptr, \
|
|
|
|
- max_num_blocks_per_seq, \
|
|
|
|
- alibi_slopes_ptr, \
|
|
|
|
- q_stride, \
|
|
|
|
- kv_block_stride, \
|
|
|
|
- kv_head_stride, \
|
|
|
|
- k_scale, \
|
|
|
|
- k_zp, \
|
|
|
|
- v_scale, \
|
|
|
|
- v_zp);
|
|
|
|
|
|
+} // namespace aphrodite
|
|
|
|
+
|
|
|
|
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
|
|
|
+ APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
|
|
|
+ ((void*)aphrodite::paged_attention_v1_kernel< \
|
|
|
|
+ T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>), \
|
|
|
|
+ shared_mem_size); \
|
|
|
|
+ aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
|
|
|
+ NUM_THREADS, KV_CACHE_DTYPE> \
|
|
|
|
+ <<<grid, block, shared_mem_size, stream>>>( \
|
|
|
|
+ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
|
|
|
+ scale, block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
|
|
|
|
+ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
|
|
|
+ k_scale, k_zp, v_scale, v_zp);
|
|
|
|
|
|
// TODO: Tune NUM_THREADS.
|
|
// TODO: Tune NUM_THREADS.
|
|
-template<
|
|
|
|
- typename T,
|
|
|
|
- typename CACHE_T,
|
|
|
|
- int BLOCK_SIZE,
|
|
|
|
- kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
|
- int NUM_THREADS = 128>
|
|
|
|
|
|
+template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
|
|
|
+ kv_cache_dtype KV_CACHE_DTYPE, int NUM_THREADS = 128>
|
|
void paged_attention_v1_launcher(
|
|
void paged_attention_v1_launcher(
|
|
- torch::Tensor& out,
|
|
|
|
- torch::Tensor& query,
|
|
|
|
- torch::Tensor& key_cache,
|
|
|
|
- torch::Tensor& value_cache,
|
|
|
|
- int num_kv_heads,
|
|
|
|
- float scale,
|
|
|
|
- torch::Tensor& block_tables,
|
|
|
|
- torch::Tensor& context_lens,
|
|
|
|
- int max_context_len,
|
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
|
- const float k_scale,
|
|
|
|
- const float k_zp,
|
|
|
|
- const float v_scale,
|
|
|
|
- const float v_zp) {
|
|
|
|
|
|
+ torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
|
|
|
+ torch::Tensor& value_cache, int num_kv_heads, float scale,
|
|
|
|
+ torch::Tensor& block_tables, torch::Tensor& context_lens,
|
|
|
|
+ int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
|
+ const float k_scale, const float k_zp, const float v_scale,
|
|
|
|
+ const float v_zp) {
|
|
int num_seqs = query.size(0);
|
|
int num_seqs = query.size(0);
|
|
int num_heads = query.size(1);
|
|
int num_heads = query.size(1);
|
|
int head_size = query.size(2);
|
|
int head_size = query.size(2);
|
|
@@ -665,9 +668,10 @@ void paged_attention_v1_launcher(
|
|
assert(head_size % thread_group_size == 0);
|
|
assert(head_size % thread_group_size == 0);
|
|
|
|
|
|
// NOTE: alibi_slopes is optional.
|
|
// NOTE: alibi_slopes is optional.
|
|
- const float* alibi_slopes_ptr = alibi_slopes ?
|
|
|
|
- reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
|
|
|
- : nullptr;
|
|
|
|
|
|
+ const float* alibi_slopes_ptr =
|
|
|
|
+ alibi_slopes
|
|
|
|
+ ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
|
|
|
+ : nullptr;
|
|
|
|
|
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
|
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
|
@@ -677,11 +681,13 @@ void paged_attention_v1_launcher(
|
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
|
int* context_lens_ptr = context_lens.data_ptr<int>();
|
|
|
|
|
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
|
|
- int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
|
|
|
|
|
+ int padded_max_context_len =
|
|
|
|
+ DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
|
|
int logits_size = padded_max_context_len * sizeof(float);
|
|
int logits_size = padded_max_context_len * sizeof(float);
|
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
|
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
|
|
- // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
|
|
|
|
- // Keep that in sync with the logic here!
|
|
|
|
|
|
+ // Python-side check in
|
|
|
|
+ // aphrodite.worker.worker._check_if_can_support_max_seq_len Keep that in sync
|
|
|
|
+ // with the logic here!
|
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
|
int shared_mem_size = std::max(logits_size, outputs_size);
|
|
|
|
|
|
dim3 grid(num_heads, num_seqs, 1);
|
|
dim3 grid(num_heads, num_seqs, 1);
|
|
@@ -718,56 +724,44 @@ void paged_attention_v1_launcher(
|
|
|
|
|
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
|
|
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
|
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
|
|
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
|
|
- out, \
|
|
|
|
- query, \
|
|
|
|
- key_cache, \
|
|
|
|
- value_cache, \
|
|
|
|
- num_kv_heads, \
|
|
|
|
- scale, \
|
|
|
|
- block_tables, \
|
|
|
|
- context_lens, \
|
|
|
|
- max_context_len, \
|
|
|
|
- alibi_slopes, \
|
|
|
|
- k_scale, \
|
|
|
|
- k_zp, \
|
|
|
|
- v_scale, \
|
|
|
|
- v_zp);
|
|
|
|
|
|
+ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
|
|
|
+ context_lens, max_context_len, alibi_slopes, k_scale, k_zp, v_scale, \
|
|
|
|
+ v_zp);
|
|
|
|
|
|
// NOTE: To reduce the compilation time, we omitted block sizes
|
|
// NOTE: To reduce the compilation time, we omitted block sizes
|
|
// 1, 2, 4, 64, 128, 256.
|
|
// 1, 2, 4, 64, 128, 256.
|
|
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
|
|
|
|
- switch (block_size) { \
|
|
|
|
- case 8: \
|
|
|
|
- CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
|
|
|
|
- break; \
|
|
|
|
- case 16: \
|
|
|
|
- CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
|
|
|
|
- break; \
|
|
|
|
- case 32: \
|
|
|
|
- CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
|
|
|
|
- break; \
|
|
|
|
- default: \
|
|
|
|
- TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
|
|
|
- break; \
|
|
|
|
|
|
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
|
|
|
|
+ switch (block_size) { \
|
|
|
|
+ case 8: \
|
|
|
|
+ CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
|
|
|
|
+ break; \
|
|
|
|
+ case 16: \
|
|
|
|
+ CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
|
|
|
|
+ break; \
|
|
|
|
+ case 32: \
|
|
|
|
+ CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
|
|
|
|
+ break; \
|
|
|
|
+ default: \
|
|
|
|
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
|
|
|
+ break; \
|
|
}
|
|
}
|
|
|
|
|
|
void paged_attention_v1(
|
|
void paged_attention_v1(
|
|
- torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
|
|
|
- torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
|
|
|
- torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
|
- torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
|
|
- int num_kv_heads, // [num_heads]
|
|
|
|
- float scale,
|
|
|
|
- torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
- torch::Tensor& context_lens, // [num_seqs]
|
|
|
|
- int block_size,
|
|
|
|
- int max_context_len,
|
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
|
- const std::string& kv_cache_dtype,
|
|
|
|
- const float k_scale = 1.0f,
|
|
|
|
- const float k_zp = 0.0f,
|
|
|
|
- const float v_scale = 1.0f,
|
|
|
|
- const float v_zp = 0.0f) {
|
|
|
|
|
|
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
|
|
|
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
|
|
|
+ torch::Tensor&
|
|
|
|
+ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
|
+ torch::Tensor&
|
|
|
|
+ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
|
|
+ int num_kv_heads, // [num_heads]
|
|
|
|
+ float scale,
|
|
|
|
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
+ torch::Tensor& context_lens, // [num_seqs]
|
|
|
|
+ int block_size, int max_context_len,
|
|
|
|
+ const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
|
+ const std::string& kv_cache_dtype, const float k_scale = 1.0f,
|
|
|
|
+ const float k_zp = 0.0f, const float v_scale = 1.0f,
|
|
|
|
+ const float v_zp = 0.0f) {
|
|
if (kv_cache_dtype == "auto") {
|
|
if (kv_cache_dtype == "auto") {
|
|
if (query.dtype() == at::ScalarType::Float) {
|
|
if (query.dtype() == at::ScalarType::Float) {
|
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
|
|
CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
|
|
@@ -805,63 +799,33 @@ void paged_attention_v1(
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
|
|
|
- aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, \
|
|
|
|
- KV_CACHE_DTYPE, PARTITION_SIZE> \
|
|
|
|
- <<<grid, block, shared_mem_size, stream>>>( \
|
|
|
|
- exp_sums_ptr, \
|
|
|
|
- max_logits_ptr, \
|
|
|
|
- tmp_out_ptr, \
|
|
|
|
- query_ptr, \
|
|
|
|
- key_cache_ptr, \
|
|
|
|
- value_cache_ptr, \
|
|
|
|
- num_kv_heads, \
|
|
|
|
- scale, \
|
|
|
|
- block_tables_ptr, \
|
|
|
|
- context_lens_ptr, \
|
|
|
|
- max_num_blocks_per_seq, \
|
|
|
|
- alibi_slopes_ptr, \
|
|
|
|
- q_stride, \
|
|
|
|
- kv_block_stride, \
|
|
|
|
- kv_head_stride, \
|
|
|
|
- k_scale, \
|
|
|
|
- k_zp, \
|
|
|
|
- v_scale, \
|
|
|
|
- v_zp); \
|
|
|
|
- aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE> \
|
|
|
|
- <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
|
|
|
- out_ptr, \
|
|
|
|
- exp_sums_ptr, \
|
|
|
|
- max_logits_ptr, \
|
|
|
|
- tmp_out_ptr, \
|
|
|
|
- context_lens_ptr, \
|
|
|
|
- max_num_partitions);
|
|
|
|
-
|
|
|
|
-template<
|
|
|
|
- typename T,
|
|
|
|
- typename CACHE_T,
|
|
|
|
- int BLOCK_SIZE,
|
|
|
|
- kv_cache_dtype KV_CACHE_DTYPE,
|
|
|
|
- int NUM_THREADS = 128,
|
|
|
|
- int PARTITION_SIZE = 512>
|
|
|
|
|
|
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
|
|
|
|
+ aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, \
|
|
|
|
+ NUM_THREADS, KV_CACHE_DTYPE, \
|
|
|
|
+ PARTITION_SIZE> \
|
|
|
|
+ <<<grid, block, shared_mem_size, stream>>>( \
|
|
|
|
+ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
|
|
|
+ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
|
|
|
+ context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \
|
|
|
|
+ q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, \
|
|
|
|
+ v_zp); \
|
|
|
|
+ aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
|
|
|
+ PARTITION_SIZE> \
|
|
|
|
+ <<<reduce_grid, block, reduce_shared_mem_size, stream>>>( \
|
|
|
|
+ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
|
|
|
|
+ context_lens_ptr, max_num_partitions);
|
|
|
|
+
|
|
|
|
+template <typename T, typename CACHE_T, int BLOCK_SIZE,
|
|
|
|
+ kv_cache_dtype KV_CACHE_DTYPE, int NUM_THREADS = 128,
|
|
|
|
+ int PARTITION_SIZE = 512>
|
|
void paged_attention_v2_launcher(
|
|
void paged_attention_v2_launcher(
|
|
- torch::Tensor& out,
|
|
|
|
- torch::Tensor& exp_sums,
|
|
|
|
- torch::Tensor& max_logits,
|
|
|
|
- torch::Tensor& tmp_out,
|
|
|
|
- torch::Tensor& query,
|
|
|
|
- torch::Tensor& key_cache,
|
|
|
|
- torch::Tensor& value_cache,
|
|
|
|
- int num_kv_heads,
|
|
|
|
- float scale,
|
|
|
|
- torch::Tensor& block_tables,
|
|
|
|
- torch::Tensor& context_lens,
|
|
|
|
- int max_context_len,
|
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
|
- const float k_scale,
|
|
|
|
- const float k_zp,
|
|
|
|
- const float v_scale,
|
|
|
|
- const float v_zp) {
|
|
|
|
|
|
+ torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
|
|
|
+ torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
|
|
|
+ torch::Tensor& value_cache, int num_kv_heads, float scale,
|
|
|
|
+ torch::Tensor& block_tables, torch::Tensor& context_lens,
|
|
|
|
+ int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
|
+ const float k_scale, const float k_zp, const float v_scale,
|
|
|
|
+ const float v_zp) {
|
|
int num_seqs = query.size(0);
|
|
int num_seqs = query.size(0);
|
|
int num_heads = query.size(1);
|
|
int num_heads = query.size(1);
|
|
int head_size = query.size(2);
|
|
int head_size = query.size(2);
|
|
@@ -874,9 +838,10 @@ void paged_attention_v2_launcher(
|
|
assert(head_size % thread_group_size == 0);
|
|
assert(head_size % thread_group_size == 0);
|
|
|
|
|
|
// NOTE: alibi_slopes is optional.
|
|
// NOTE: alibi_slopes is optional.
|
|
- const float* alibi_slopes_ptr = alibi_slopes ?
|
|
|
|
- reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
|
|
|
- : nullptr;
|
|
|
|
|
|
+ const float* alibi_slopes_ptr =
|
|
|
|
+ alibi_slopes
|
|
|
|
+ ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
|
|
|
|
+ : nullptr;
|
|
|
|
|
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
|
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
|
|
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
|
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
|
|
@@ -931,64 +896,50 @@ void paged_attention_v2_launcher(
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
|
|
|
|
- paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
|
|
|
|
- out, \
|
|
|
|
- exp_sums, \
|
|
|
|
- max_logits, \
|
|
|
|
- tmp_out, \
|
|
|
|
- query, \
|
|
|
|
- key_cache, \
|
|
|
|
- value_cache, \
|
|
|
|
- num_kv_heads, \
|
|
|
|
- scale, \
|
|
|
|
- block_tables, \
|
|
|
|
- context_lens, \
|
|
|
|
- max_context_len, \
|
|
|
|
- alibi_slopes, \
|
|
|
|
- k_scale, \
|
|
|
|
- k_zp, \
|
|
|
|
- v_scale, \
|
|
|
|
- v_zp);
|
|
|
|
|
|
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE) \
|
|
|
|
+ paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>( \
|
|
|
|
+ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
|
|
|
+ num_kv_heads, scale, block_tables, context_lens, max_context_len, \
|
|
|
|
+ alibi_slopes, k_scale, k_zp, v_scale, v_zp);
|
|
|
|
|
|
// NOTE: To reduce the compilation time, we omitted block sizes
|
|
// NOTE: To reduce the compilation time, we omitted block sizes
|
|
// 1, 2, 4, 64, 128, 256.
|
|
// 1, 2, 4, 64, 128, 256.
|
|
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
|
|
|
|
- switch (block_size) { \
|
|
|
|
- case 8: \
|
|
|
|
- CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
|
|
|
|
- break; \
|
|
|
|
- case 16: \
|
|
|
|
- CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
|
|
|
|
- break; \
|
|
|
|
- case 32: \
|
|
|
|
- CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
|
|
|
|
- break; \
|
|
|
|
- default: \
|
|
|
|
- TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
|
|
|
- break; \
|
|
|
|
|
|
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE) \
|
|
|
|
+ switch (block_size) { \
|
|
|
|
+ case 8: \
|
|
|
|
+ CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE); \
|
|
|
|
+ break; \
|
|
|
|
+ case 16: \
|
|
|
|
+ CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE); \
|
|
|
|
+ break; \
|
|
|
|
+ case 32: \
|
|
|
|
+ CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE); \
|
|
|
|
+ break; \
|
|
|
|
+ default: \
|
|
|
|
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
|
|
|
|
+ break; \
|
|
}
|
|
}
|
|
|
|
|
|
void paged_attention_v2(
|
|
void paged_attention_v2(
|
|
- torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
|
|
|
- torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
- torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
- torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
|
|
|
- torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
|
|
|
- torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
|
- torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
|
|
- int num_kv_heads, // [num_heads]
|
|
|
|
- float scale,
|
|
|
|
- torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
- torch::Tensor& context_lens, // [num_seqs]
|
|
|
|
- int block_size,
|
|
|
|
- int max_context_len,
|
|
|
|
- const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
|
- const std::string& kv_cache_dtype,
|
|
|
|
- const float k_scale = 1.0f,
|
|
|
|
- const float k_zp = 0.0f,
|
|
|
|
- const float v_scale = 1.0f,
|
|
|
|
- const float v_zp = 0.0f) {
|
|
|
|
|
|
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
|
|
|
|
+ torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
+ torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
|
|
|
|
+ torch::Tensor&
|
|
|
|
+ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
|
|
|
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
|
|
|
+ torch::Tensor&
|
|
|
|
+ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
|
|
|
+ torch::Tensor&
|
|
|
|
+ value_cache, // [num_blocks, num_heads, head_size, block_size]
|
|
|
|
+ int num_kv_heads, // [num_heads]
|
|
|
|
+ float scale,
|
|
|
|
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
|
|
|
+ torch::Tensor& context_lens, // [num_seqs]
|
|
|
|
+ int block_size, int max_context_len,
|
|
|
|
+ const c10::optional<torch::Tensor>& alibi_slopes,
|
|
|
|
+ const std::string& kv_cache_dtype, const float k_scale = 1.0f,
|
|
|
|
+ const float k_zp = 0.0f, const float v_scale = 1.0f,
|
|
|
|
+ const float v_zp = 0.0f) {
|
|
if (kv_cache_dtype == "auto") {
|
|
if (kv_cache_dtype == "auto") {
|
|
if (query.dtype() == at::ScalarType::Float) {
|
|
if (query.dtype() == at::ScalarType::Float) {
|
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
|
|
CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
|