|
@@ -943,10 +943,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
// The head.
|
|
|
// const int hi = blockIdx.x;
|
|
|
const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x];
|
|
|
+ const int hi_kv = hi / params.num_heads_q_kv_ratio;
|
|
|
// Combine the batch and the head indices.
|
|
|
const int bhi = bi * params.num_heads + hi;
|
|
|
+ const int bhi_kv = bi * params.num_heads_kv + hi_kv;
|
|
|
// Combine the "beam-aware" batch idx and the head indices.
|
|
|
- const int bbhi = bbi * params.beam_width * params.num_heads + hi;
|
|
|
+ const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv;
|
|
|
// The thread in the block.
|
|
|
const int tidx = threadIdx.x;
|
|
|
|
|
@@ -957,7 +959,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
|
|
|
float qk = 0.0F;
|
|
|
|
|
|
- int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
|
|
|
+ int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh;
|
|
|
+ int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh;
|
|
|
+ int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh;
|
|
|
|
|
|
const size_t bi_seq_len_offset = bi * params.memory_max_len;
|
|
|
|
|
@@ -973,9 +977,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
const bool is_masked = tidx >= QK_VECS_PER_WARP;
|
|
|
|
|
|
// The offset in the Q and K buffer also accounts for the batch.
|
|
|
- int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
|
|
|
+ int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
|
|
|
+ int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
|
|
|
// The offset in the bias buffer.
|
|
|
- int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
|
|
|
+ int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
|
|
|
+ int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
|
|
|
|
|
|
const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
|
|
|
const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
|
|
@@ -989,12 +995,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
|
|
|
const auto q_scaling = params.qkv_scale_out[0];
|
|
|
const auto q_quant =
|
|
|
- *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[qk_offset]);
|
|
|
+ *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
|
|
|
|
|
|
convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
|
|
|
}
|
|
|
else {
|
|
|
- q = *reinterpret_cast<const Qk_vec*>(¶ms.q[qk_offset]);
|
|
|
+ q = *reinterpret_cast<const Qk_vec*>(¶ms.q[q_offset]);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -1007,7 +1013,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
|
|
|
|
|
|
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
|
|
|
- int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
|
|
|
+ int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
|
|
|
// params.timestep*QK_ELTS_IN_16B +
|
|
|
tlength * QK_ELTS_IN_16B + ci;
|
|
|
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
|
|
@@ -1021,12 +1027,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
|
|
|
const auto k_scaling = params.qkv_scale_out[1];
|
|
|
const auto k_quant =
|
|
|
- *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);
|
|
|
+ *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
|
|
|
|
|
|
convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
|
|
|
}
|
|
|
else {
|
|
|
- k = *reinterpret_cast<const Qk_vec*>(¶ms.k[qk_offset]);
|
|
|
+ k = *reinterpret_cast<const Qk_vec*>(¶ms.k[k_offset]);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1035,14 +1041,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
Qk_vec q_bias;
|
|
|
zero(q_bias);
|
|
|
q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
|
|
|
- *reinterpret_cast<const Qk_vec*>(¶ms.q_bias[qk_bias_offset]) :
|
|
|
+ *reinterpret_cast<const Qk_vec*>(¶ms.q_bias[q_bias_offset]) :
|
|
|
q_bias;
|
|
|
|
|
|
Qk_vec k_bias;
|
|
|
zero(k_bias);
|
|
|
if (handle_kv) {
|
|
|
k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
|
|
|
- *reinterpret_cast<const Qk_vec*>(¶ms.k_bias[qk_bias_offset]) :
|
|
|
+ *reinterpret_cast<const Qk_vec*>(¶ms.k_bias[k_bias_offset]) :
|
|
|
k_bias;
|
|
|
}
|
|
|
|
|
@@ -1172,11 +1178,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
|
|
|
|
|
|
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
|
|
|
- int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
|
|
|
+ int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
|
|
|
// params.timestep*QK_ELTS_IN_16B +
|
|
|
tlength_circ * QK_ELTS_IN_16B + ci;
|
|
|
|
|
|
- if (handle_kv) {
|
|
|
+ if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) {
|
|
|
// Trigger the stores to global memory.
|
|
|
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
|
|
|
*reinterpret_cast<Qk_vec*>(¶ms.k_cache[offset]) = k;
|
|
@@ -1263,7 +1269,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
|
|
|
|
|
|
// The base pointer for the key in the cache buffer.
|
|
|
- T* k_cache = ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki];
|
|
|
+ T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
|
|
|
// Base pointer for the beam's batch, before offsetting with indirection buffer
|
|
|
T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki];
|
|
|
|
|
@@ -1427,7 +1433,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
|
|
|
|
|
|
// The base pointer for the value in the cache buffer.
|
|
|
- T* v_cache = ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi];
|
|
|
+ T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
|
|
|
// Base pointer for the beam's batch, before offsetting with indirection buffer
|
|
|
T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi];
|
|
|
|
|
@@ -1443,7 +1449,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
if (vo == tlength % V_PER_ITER) {
|
|
|
// Trigger the loads from the V bias buffer.
|
|
|
if (params.v_bias != nullptr) {
|
|
|
- v_bias = *reinterpret_cast<const V_vec*>(¶ms.v_bias[hi * Dh + vi]);
|
|
|
+ v_bias = *reinterpret_cast<const V_vec*>(¶ms.v_bias[hi_kv * Dh + vi]);
|
|
|
}
|
|
|
if (DO_CROSS_ATTENTION) {
|
|
|
*reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
|
|
@@ -1510,7 +1516,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
}
|
|
|
else {
|
|
|
// Trigger the loads from the V buffer.
|
|
|
- const auto v_offset = qkv_base_offset + vi;
|
|
|
+ const auto v_offset = v_base_offset + vi;
|
|
|
if (params.int8_mode == 2) {
|
|
|
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
|
|
|
using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
|
|
@@ -1539,8 +1545,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
|
|
|
}
|
|
|
|
|
|
// Store the values with bias back to global memory in the cache for V.
|
|
|
- //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
|
|
|
- *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
|
|
|
+ if (hi % params.num_heads_q_kv_ratio == 0) {
|
|
|
+ //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
|
|
|
+ *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// Initialize the output value with the current timestep.
|