Bläddra i källkod

[FT] Implement MQA/GQA

Tri Dao 1 år sedan
förälder
incheckning
a157cc8c9b

+ 5 - 1
csrc/ft_attention/decoder_masked_multihead_attention.h

@@ -69,7 +69,9 @@ struct Multihead_attention_params_base {
     const int* cache_indir = nullptr;
 
     // Stride to handle the case when KQV is a single buffer
-    int stride = 0;
+    int stride_q = 0;
+    int stride_k = 0;
+    int stride_v = 0;
 
     // The batch size.
     int batch_size = 0;
@@ -79,6 +81,8 @@ struct Multihead_attention_params_base {
     int memory_max_len = 0;
     // The number of heads (H).
     int num_heads = 0;
+    int num_heads_kv = 0;
+    int num_heads_q_kv_ratio = 0;
     // The hidden dimension per head (Dh).
     int hidden_size_per_head = 0;
     // The per-head latent space reserved for rotary embeddings.

+ 27 - 19
csrc/ft_attention/decoder_masked_multihead_attention_template.hpp

@@ -943,10 +943,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
     // The head.
     // const int hi = blockIdx.x;
     const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x];
+    const int hi_kv = hi / params.num_heads_q_kv_ratio;
     // Combine the batch and the head indices.
     const int bhi = bi * params.num_heads + hi;
+    const int bhi_kv = bi * params.num_heads_kv + hi_kv;
     // Combine the "beam-aware" batch idx and the head indices.
-    const int bbhi = bbi * params.beam_width * params.num_heads + hi;
+    const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv;
     // The thread in the block.
     const int tidx = threadIdx.x;
 
@@ -957,7 +959,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
 
     float qk = 0.0F;
 
-    int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
+    int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh;
+    int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh;
+    int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh;
 
     const size_t bi_seq_len_offset = bi * params.memory_max_len;
 
@@ -973,9 +977,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
     const bool is_masked = tidx >= QK_VECS_PER_WARP;
 
     // The offset in the Q and K buffer also accounts for the batch.
-    int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
+    int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
+    int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
     // The offset in the bias buffer.
-    int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
+    int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
+    int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
 
     const bool do_ia3      = handle_kv && params.ia3_tasks != nullptr;
     const int  ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
@@ -989,12 +995,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
             using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
             const auto q_scaling = params.qkv_scale_out[0];
             const auto q_quant =
-                *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[qk_offset]);
+                *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
 
             convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
         }
         else {
-            q = *reinterpret_cast<const Qk_vec*>(&params.q[qk_offset]);
+            q = *reinterpret_cast<const Qk_vec*>(&params.q[q_offset]);
         }
     }
 
@@ -1007,7 +1013,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
         int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
 
         // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
-        int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
+        int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
                      // params.timestep*QK_ELTS_IN_16B +
                      tlength * QK_ELTS_IN_16B + ci;
         k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
@@ -1021,12 +1027,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
                 using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
                 const auto k_scaling = params.qkv_scale_out[1];
                 const auto k_quant =
-                    *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);
+                    *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
 
                 convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
             }
             else {
-                k = *reinterpret_cast<const Qk_vec*>(&params.k[qk_offset]);
+                k = *reinterpret_cast<const Qk_vec*>(&params.k[k_offset]);
             }
         }
     }
@@ -1035,14 +1041,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
     Qk_vec q_bias;
     zero(q_bias);
     q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
-                 *reinterpret_cast<const Qk_vec*>(&params.q_bias[qk_bias_offset]) :
+                 *reinterpret_cast<const Qk_vec*>(&params.q_bias[q_bias_offset]) :
                  q_bias;
 
     Qk_vec k_bias;
     zero(k_bias);
     if (handle_kv) {
         k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
-                     *reinterpret_cast<const Qk_vec*>(&params.k_bias[qk_bias_offset]) :
+                     *reinterpret_cast<const Qk_vec*>(&params.k_bias[k_bias_offset]) :
                      k_bias;
     }
 
@@ -1172,11 +1178,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
         int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
 
         // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
-        int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
+        int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
                      // params.timestep*QK_ELTS_IN_16B +
                      tlength_circ * QK_ELTS_IN_16B + ci;
 
-        if (handle_kv) {
+        if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) {
             // Trigger the stores to global memory.
             if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
                 *reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k;
@@ -1263,7 +1269,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
     constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
 
     // The base pointer for the key in the cache buffer.
-    T* k_cache = &params.k_cache[bhi * params.memory_max_len * Dh + ki];
+    T* k_cache = &params.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
     // Base pointer for the beam's batch, before offsetting with indirection buffer
     T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
 
@@ -1427,7 +1433,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
     int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
 
     // The base pointer for the value in the cache buffer.
-    T* v_cache = &params.v_cache[bhi * params.memory_max_len * Dh + vi];
+    T* v_cache = &params.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
     // Base pointer for the beam's batch, before offsetting with indirection buffer
     T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
 
@@ -1443,7 +1449,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
             if (vo == tlength % V_PER_ITER) {
                 // Trigger the loads from the V bias buffer.
                 if (params.v_bias != nullptr) {
-                    v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi * Dh + vi]);
+                    v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi_kv * Dh + vi]);
                 }
                 if (DO_CROSS_ATTENTION) {
                     *reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
@@ -1510,7 +1516,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
         }
         else {
             // Trigger the loads from the V buffer.
-            const auto v_offset = qkv_base_offset + vi;
+            const auto v_offset = v_base_offset + vi;
             if (params.int8_mode == 2) {
                 using Packed_Int8_t  = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
                 using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
@@ -1539,8 +1545,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
             }
 
             // Store the values with bias back to global memory in the cache for V.
-            //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
-            *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
+            if (hi % params.num_heads_q_kv_ratio == 0) {
+                //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
+                *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
+            }
         }
 
         // Initialize the output value with the current timestep.

+ 18 - 10
csrc/ft_attention/ft_attention.cpp

@@ -50,13 +50,16 @@ template <typename T>
 void set_params(Masked_multihead_attention_params<T> &params,
                 const size_t batch_size,
                 const size_t nheads,
+                const size_t nheads_kv,
                 const size_t memory_max_seqlen,
                 const size_t headdim,
                 const int timestep,
                 const int rotary_embedding_dim,
                 const float rotary_base,
                 const bool neox_rotary_style,
-                const int qkv_batch_stride,
+                const int q_batch_stride,
+                const int k_batch_stride,
+                const int v_batch_stride,
                 const int nnz_heads,
                 T *q_ptr,
                 T *k_ptr,
@@ -80,11 +83,15 @@ void set_params(Masked_multihead_attention_params<T> &params,
     params.v_cache = v_cache_ptr;
     params.out = out_ptr;
     params.cache_indir = nullptr;
-    params.stride = qkv_batch_stride;
+    params.stride_q = q_batch_stride;
+    params.stride_k = k_batch_stride;
+    params.stride_v = v_batch_stride;
     params.batch_size = batch_size;
     params.beam_width = 1;
     params.memory_max_len = memory_max_seqlen;
     params.num_heads = nheads;
+    params.num_heads_kv = nheads_kv;
+    params.num_heads_q_kv_ratio = nheads / nheads_kv;
     params.nnz_heads = nnz_heads;
     params.hidden_size_per_head = headdim;
     params.rotary_embedding_dim = rotary_embedding_dim;
@@ -124,23 +131,23 @@ torch::Tensor single_query_attention(const torch::Tensor q,
                                      const bool neox_rotary_style=true) {
     CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
     int batch_size = v_cache.size(0);
-    int nheads = v_cache.size(1);
+    int nheads = q.size(1);
+    int nheads_kv = v_cache.size(1);
     int memory_max_seqlen = v_cache.size(2);
     int headdim = v_cache.size(3);
     auto input_type = q.scalar_type();
     TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
 
     CHECK_SHAPE(q, batch_size, nheads, headdim);
-    CHECK_SHAPE(k, batch_size, nheads, headdim);
-    CHECK_SHAPE(v, batch_size, nheads, headdim);
-    CHECK_SHAPE(v_cache, batch_size, nheads, memory_max_seqlen, headdim);
+    CHECK_SHAPE(k, batch_size, nheads_kv, headdim);
+    CHECK_SHAPE(v, batch_size, nheads_kv, headdim);
+    CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim);
     // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
     int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
-    CHECK_SHAPE(k_cache, batch_size, nheads, headdim / packsize, memory_max_seqlen, packsize);
+    CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize);
     TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
     TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
     TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
-    TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
     CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
 
     TORCH_CHECK(q.scalar_type() == input_type);
@@ -191,8 +198,9 @@ torch::Tensor single_query_attention(const torch::Tensor q,
     DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
         using DataType = typename SATypeConverter<scalar_t>::Type;
         Masked_multihead_attention_params<DataType> params;
-        set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
-                   rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
+        set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, timestep,
+                   rotary_embedding_dim, rotary_base, neox_rotary_style,
+                   q.stride(0), k.stride(0), v.stride(0),
                    nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0,
                    reinterpret_cast<DataType*>(q.data_ptr()),
                    reinterpret_cast<DataType*>(k.data_ptr()),