|
@@ -23,17 +23,6 @@
|
|
|
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
|
|
|
}
|
|
|
|
|
|
-// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
|
|
|
-// if (TYPE == at::ScalarType::Half) { \
|
|
|
-// using scalar_t = at::Half; \
|
|
|
-// __VA_ARGS__(); \
|
|
|
-// } else if (TYPE == at::ScalarType::Float) { \
|
|
|
-// using scalar_t = float; \
|
|
|
-// __VA_ARGS__(); \
|
|
|
-// } else { \
|
|
|
-// AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
|
|
|
-// }
|
|
|
-
|
|
|
template<typename T>
|
|
|
void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
|
|
|
const cudaStream_t& stream);
|
|
@@ -66,6 +55,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
|
|
|
const int timestep,
|
|
|
const int rotary_embedding_dim,
|
|
|
const bool neox_rotary_style,
|
|
|
+ const int qkv_batch_stride,
|
|
|
T *q_ptr,
|
|
|
T *k_ptr,
|
|
|
T *v_ptr,
|
|
@@ -85,7 +75,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
|
|
|
params.v_cache = v_cache_ptr;
|
|
|
params.out = out_ptr;
|
|
|
params.cache_indir = nullptr;
|
|
|
- params.stride = 0;
|
|
|
+ params.stride = qkv_batch_stride;
|
|
|
params.batch_size = batch_size;
|
|
|
params.beam_width = 1;
|
|
|
params.memory_max_len = memory_max_seqlen;
|
|
@@ -98,8 +88,7 @@ void set_params(Masked_multihead_attention_params<T> ¶ms,
|
|
|
params.total_padding_tokens = nullptr;
|
|
|
params.masked_tokens = nullptr;
|
|
|
params.prefix_prompt_lengths = nullptr;
|
|
|
- // params.max_prefix_prompt_length = memory_max_seqlen; // TODO: waht should this be?
|
|
|
- params.max_prefix_prompt_length = 0; // TODO: waht should this be?
|
|
|
+ params.max_prefix_prompt_length = 0;
|
|
|
params.relative_attention_bias = nullptr;
|
|
|
params.relative_attention_bias_stride = 0;
|
|
|
params.cross_attention_out = nullptr;
|
|
@@ -127,10 +116,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
|
|
|
CHECK_SHAPE(q, batch_size, nheads, headdim);
|
|
|
CHECK_SHAPE(k, batch_size, nheads, headdim);
|
|
|
CHECK_SHAPE(v, batch_size, nheads, headdim);
|
|
|
- // TODO: Check shape of k_cache: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
|
|
|
- // TODO: avoid contiguous requirment by storing the stride
|
|
|
- CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v);
|
|
|
- CHECK_CONTIGUOUS(v_cache);
|
|
|
+ CHECK_SHAPE(v_cache, batch_size, nheads, memory_max_seqlen, headdim);
|
|
|
+ // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
|
|
|
+ int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
|
|
|
+ CHECK_SHAPE(k_cache, batch_size, nheads, headdim / packsize, memory_max_seqlen, packsize);
|
|
|
+ TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
|
|
|
+ TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
|
|
|
+ TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
|
|
|
+ TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
|
|
|
+ CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
|
|
|
|
|
|
if (length_per_sample_.has_value()) {
|
|
|
auto length_per_sample = length_per_sample_.value();
|
|
@@ -146,11 +140,11 @@ torch::Tensor single_query_attention(const torch::Tensor q,
|
|
|
|
|
|
torch::Tensor out = torch::empty_like(q);
|
|
|
|
|
|
- DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), out.scalar_type(), "single_query_attention", [&] {
|
|
|
+ DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
|
|
|
using DataType = typename SATypeConverter<scalar_t>::Type;
|
|
|
Masked_multihead_attention_params<DataType> params;
|
|
|
set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
|
|
|
- rotary_embedding_dim, neox_rotary_style,
|
|
|
+ rotary_embedding_dim, neox_rotary_style, q.stride(0),
|
|
|
reinterpret_cast<DataType*>(q.data_ptr()),
|
|
|
reinterpret_cast<DataType*>(k.data_ptr()),
|
|
|
reinterpret_cast<DataType*>(v.data_ptr()),
|