Browse Source

[Gen] Pass qkv_stride to ft_attention kernel for batched generation

Tri Dao 2 years ago
parent
commit
f1e01c27ba
2 changed files with 15 additions and 21 deletions
  1. 14 20
      csrc/ft_attention/ft_attention.cpp
  2. 1 1
      tests/models/test_gpt_generation.py

+ 14 - 20
csrc/ft_attention/ft_attention.cpp

@@ -23,17 +23,6 @@
     AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
   }
 
-// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...)                  \
-//   if (TYPE == at::ScalarType::Half) {                                      \
-//     using scalar_t = at::Half;                                             \
-//     __VA_ARGS__();                                                         \
-//   } else if (TYPE == at::ScalarType::Float)  {                             \
-//     using scalar_t = float;                                                \
-//     __VA_ARGS__();                                                         \
-//   } else {                                                                 \
-//     AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
-//   }
-
 template<typename T>
 void masked_multihead_attention(const Masked_multihead_attention_params<T>& params,
                                 const cudaStream_t& stream);
@@ -66,6 +55,7 @@ void set_params(Masked_multihead_attention_params<T> &params,
                 const int timestep,
                 const int rotary_embedding_dim,
                 const bool neox_rotary_style,
+                const int qkv_batch_stride,
                 T *q_ptr,
                 T *k_ptr,
                 T *v_ptr,
@@ -85,7 +75,7 @@ void set_params(Masked_multihead_attention_params<T> &params,
     params.v_cache = v_cache_ptr;
     params.out = out_ptr;
     params.cache_indir = nullptr;
-    params.stride = 0;
+    params.stride = qkv_batch_stride;
     params.batch_size = batch_size;
     params.beam_width = 1;
     params.memory_max_len = memory_max_seqlen;
@@ -98,8 +88,7 @@ void set_params(Masked_multihead_attention_params<T> &params,
     params.total_padding_tokens = nullptr;
     params.masked_tokens = nullptr;
     params.prefix_prompt_lengths = nullptr;
-    // params.max_prefix_prompt_length = memory_max_seqlen;  // TODO: waht should this be?
-    params.max_prefix_prompt_length = 0;  // TODO: waht should this be?
+    params.max_prefix_prompt_length = 0;
     params.relative_attention_bias = nullptr;
     params.relative_attention_bias_stride = 0;
     params.cross_attention_out = nullptr;
@@ -127,10 +116,15 @@ torch::Tensor single_query_attention(const torch::Tensor q,
     CHECK_SHAPE(q, batch_size, nheads, headdim);
     CHECK_SHAPE(k, batch_size, nheads, headdim);
     CHECK_SHAPE(v, batch_size, nheads, headdim);
-    // TODO: Check shape of k_cache: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
-    // TODO: avoid contiguous requirment by storing the stride
-    CHECK_CONTIGUOUS(q); CHECK_CONTIGUOUS(k); CHECK_CONTIGUOUS(v);
-    CHECK_CONTIGUOUS(v_cache);
+    CHECK_SHAPE(v_cache, batch_size, nheads, memory_max_seqlen, headdim);
+    // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
+    int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
+    CHECK_SHAPE(k_cache, batch_size, nheads, headdim / packsize, memory_max_seqlen, packsize);
+    TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
+    TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
+    TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
+    TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
+    CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
 
     if (length_per_sample_.has_value()) {
         auto length_per_sample = length_per_sample_.value();
@@ -146,11 +140,11 @@ torch::Tensor single_query_attention(const torch::Tensor q,
 
     torch::Tensor out = torch::empty_like(q);
 
-    DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), out.scalar_type(), "single_query_attention", [&] {
+    DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
         using DataType = typename SATypeConverter<scalar_t>::Type;
         Masked_multihead_attention_params<DataType> params;
         set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
-                   rotary_embedding_dim, neox_rotary_style,
+                   rotary_embedding_dim, neox_rotary_style, q.stride(0),
                    reinterpret_cast<DataType*>(q.data_ptr()),
                    reinterpret_cast<DataType*>(k.data_ptr()),
                    reinterpret_cast<DataType*>(v.data_ptr()),

+ 1 - 1
tests/models/test_gpt_generation.py

@@ -57,7 +57,7 @@ def test_greedy_decode(model_name, rotary, optimized, fused_ft_kernel):
     input_ids = tokenizer("Hello, my dog is cute and ",
                           return_tensors="pt").input_ids.to(device=device)
     max_length = 30
-    # input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
+    # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
     # max_length = input_ids.shape[1] + 40
 
     # Slow generation for reference