123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581 |
- /*
- * Adapted from https://github.com/turboderp/exllamav2
- * Copyright (c) 2024 turboderp
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
- #include "compat.cuh"
- namespace aphrodite {
- namespace exl2 {
- #define MAX_Q_GEMM_WEIGHTS 4
- #define EXL2_BLOCK_KN_SIZE 64
- #define EXL2_BLOCK_M_SIZE_MAX 8
- #define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
- __forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
- {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
- }
- __forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
- {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
- }
- __forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
- {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
- return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
- }
- __forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
- {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
- return fma(result_f, qs_f, g_result);
- }
- __forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
- {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
- return fma(result_f, qs_f, g_result);
- }
- __forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
- {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
- float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
- return fma(result_f, qs_f, g_result);
- }
- __forceinline__ __device__ half dot22_8_h(half2(&dq)[4], const half* a_ptr, const half g_result, const half qs_h)
- {
- // Use FP32 accumulator to avoid potential overflow since unscaled weights are in the range -128..127
- float result = {};
- #pragma unroll
- for (int i = 0; i < 4; i++)
- {
- half2 w01 = dq[i];
- float w0 = __low2float(w01);
- float w1 = __high2float(w01);
- float x0 = __half2float(*a_ptr++);
- float x1 = __half2float(*a_ptr++);
- result = fma(w0, x0, result);
- result = fma(w1, x1, result);
- }
- float qs = __half2float(qs_h);
- result *= qs;
- half result_h = __float2half_rn(result);
- return __hadd(result_h, g_result);
- }
- __forceinline__ __device__ half dot22_16_h(half2(&dq)[8], const half* a_ptr, const half g_result, const half qs_h)
- {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
- half result_h = __hadd(__low2half(result), __high2half(result));
- return __hfma(result_h, qs_h, g_result);
- }
- __forceinline__ __device__ half dot22_32_h(half2(&dq)[16], const half* a_ptr, const half g_result, const half qs_h)
- {
- half2 result = {};
- const half2* a2_ptr = (const half2*)a_ptr;
- #pragma unroll
- for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
- half result_h = __hadd(__low2half(result), __high2half(result));
- return __hfma(result_h, qs_h, g_result);
- }
- typedef void (*fp_gemm_half_q_half_kernel)
- (
- const half*,
- const uint32_t*,
- const uint32_t*,
- const half*,
- half*,
- const int,
- const int,
- const int,
- const int,
- const int,
- const uint16_t*,
- const uint16_t*,
- const int,
- const int,
- const int,
- const int,
- const int,
- const int,
- const bool
- );
- template <int m_count>
- __global__ void gemm_half_q_half_kernel
- (
- const half* __restrict__ a,
- const uint32_t* __restrict__ b_q_weight,
- const uint32_t* __restrict__ b_q_scale,
- const half* __restrict__ b_q_scale_max,
- half* __restrict__ c,
- const int size_m,
- const int size_n,
- const int size_k,
- const int height,
- const int groups,
- const uint16_t* __restrict__ b_q_group_map,
- const uint16_t* __restrict__ b_q_perm,
- const int rows_8,
- const int rows_6,
- const int rows_5,
- const int rows_4,
- const int rows_3,
- const int rows_2,
- const bool clear
- )
- {
- MatrixView_half a_(a, size_m, size_k);
- MatrixView_half_rw c_(c, size_m, size_n);
- MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
- int t = threadIdx.x;
- // Block
- int offset_n = blockIdx.x * EXL2_BLOCK_KN_SIZE * 4;
- int offset_m = blockIdx.y * m_count;
- int offset_k = blockIdx.z * EXL2_BLOCK_KN_SIZE;
- int end_n = min(offset_n + EXL2_BLOCK_KN_SIZE * 4, size_n);
- int end_m = min(offset_m + m_count, size_m);
- int end_k = min(offset_k + EXL2_BLOCK_KN_SIZE, height);
- int n = offset_n + t * 4;
- // Read weights
- half_uint16 weights[MAX_Q_GEMM_WEIGHTS];
- // Preload block_a
- __shared__ half block_a[m_count][EXL2_BLOCK_KN_SIZE];
- if (offset_k + t < end_k)
- {
- for (int m = 0; m < m_count; ++m)
- {
- const half* a_ptr = a_.item_ptr(offset_m + m, 0);
- half* block_a_ptr = block_a[m];
- half a0 = a_ptr[b_q_perm[offset_k + t]];
- // half a0 = a_ptr[offset_k + t];
- block_a_ptr[t] = a0;
- }
- }
- // Clear
- if (n >= size_n) return;
- if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
- {
- for (int m = 0; m < m_count; m++)
- *((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
- }
- __syncthreads();
- // Find initial group
- //int group = offset_k / groupsize;
- int group = b_q_group_map[offset_k * 2];
- // if (offset_m == 0 && t == 0)
- // DBGI2(offset_k, group);
- // Preload scales
- half scales[EXL2_MAX_GROUPS_IN_BLOCK][4];
- //int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
- int temp_k = offset_k;
- for (int g = 0; temp_k < end_k; g++)
- {
- int qscales[4];
- b_q_scale_.item4(qscales, group + g, n);
- qscales[0]++;
- qscales[1]++;
- qscales[2]++;
- qscales[3]++;
- half maxscale = b_q_scale_max[group + g];
- scales[g][0] = __hmul(__int2half_rn(qscales[0] * qscales[0]), maxscale);
- scales[g][1] = __hmul(__int2half_rn(qscales[1] * qscales[1]), maxscale);
- scales[g][2] = __hmul(__int2half_rn(qscales[2] * qscales[2]), maxscale);
- scales[g][3] = __hmul(__int2half_rn(qscales[3] * qscales[3]), maxscale);
- temp_k += b_q_group_map[temp_k * 2 + 1];
- }
- // a, b offset
- int pre_rows_8 = min(rows_8, offset_k);
- int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
- int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
- int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
- int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
- int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
- int qk = 0;
- qk += pre_rows_8 / 32 * 8;
- qk += pre_rows_6 / 32 * 6;
- qk += pre_rows_5 / 32 * 5;
- qk += pre_rows_4 / 32 * 4;
- qk += pre_rows_3 / 32 * 3;
- qk += pre_rows_2 / 32 * 2;
- const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
- const half* a_ptr = &block_a[0][0];
- int a_stride = EXL2_BLOCK_KN_SIZE;
- // Initial group
- int scales_idx = 0;
- half qs_h0 = scales[scales_idx][0];
- half qs_h1 = scales[scales_idx][1];
- half qs_h2 = scales[scales_idx][2];
- half qs_h3 = scales[scales_idx][3];
- int nextgroup = offset_k + b_q_group_map[offset_k * 2 + 1];
- // Column result
- half block_c[m_count][4] = {};
- // Dequantize groups
- int k = offset_k;
- while (k < rows_8 && k < end_k)
- {
- if (k == nextgroup)
- {
- group++;
- scales_idx++;
- qs_h0 = scales[scales_idx][0];
- qs_h1 = scales[scales_idx][1];
- qs_h2 = scales[scales_idx][2];
- qs_h3 = scales[scales_idx][3];
- nextgroup += b_q_group_map[k * 2 + 1];
- }
- #pragma unroll
- for (int j = 0; j < 4; j++)
- {
- int4 load_int4[2];
- load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
- load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
- half2 dq[4][4];
- dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
- dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
- dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
- dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
- for (int m = 0; m < m_count; m++)
- {
- block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
- block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
- block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
- block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
- }
- a_ptr += 8;
- }
- k += 32;
- }
- while (k < rows_6 && k < end_k)
- {
- if (k == nextgroup)
- {
- group++;
- scales_idx++;
- qs_h0 = scales[scales_idx][0];
- qs_h1 = scales[scales_idx][1];
- qs_h2 = scales[scales_idx][2];
- qs_h3 = scales[scales_idx][3];
- nextgroup += b_q_group_map[k * 2 + 1];
- }
- #pragma unroll
- for (int j = 0; j < 2; j++)
- {
- int4 load_int4[3];
- load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
- load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
- load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
- half2 dq[4][8];
- dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
- dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
- dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
- dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
- for (int m = 0; m < m_count; m++)
- {
- block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
- block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
- block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
- block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
- }
- a_ptr += 16;
- }
- k += 32;
- }
- while (k < rows_5 && k < end_k)
- {
- if (k == nextgroup)
- {
- group++;
- scales_idx++;
- qs_h0 = scales[scales_idx][0];
- qs_h1 = scales[scales_idx][1];
- qs_h2 = scales[scales_idx][2];
- qs_h3 = scales[scales_idx][3];
- nextgroup += b_q_group_map[k * 2 + 1];
- }
- #pragma unroll
- for (int j = 0; j < 1; j++)
- {
- int4 load_int4[5];
- load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
- load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
- load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
- load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
- load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
- half2 dq[4][16];
- dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
- dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
- dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
- dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
- for (int m = 0; m < m_count; m++)
- {
- block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
- block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
- block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
- block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
- }
- a_ptr += 32;
- }
- k += 32;
- }
- while (k < rows_4 && k < end_k)
- {
- if (k == nextgroup)
- {
- group++;
- scales_idx++;
- qs_h0 = scales[scales_idx][0];
- qs_h1 = scales[scales_idx][1];
- qs_h2 = scales[scales_idx][2];
- qs_h3 = scales[scales_idx][3];
- nextgroup += b_q_group_map[k * 2 + 1];
- }
- #pragma unroll
- for (int j = 0; j < 4; j++)
- {
- int4 load_int4[1];
- load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
- half2 dq[4][4];
- dequant_4bit_8(load_int4[0].x, dq[0], size_n);
- dequant_4bit_8(load_int4[0].y, dq[1], size_n);
- dequant_4bit_8(load_int4[0].z, dq[2], size_n);
- dequant_4bit_8(load_int4[0].w, dq[3], size_n);
- for (int m = 0; m < m_count; m++)
- {
- block_c[m][0] = dot22_8_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
- block_c[m][1] = dot22_8_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
- block_c[m][2] = dot22_8_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
- block_c[m][3] = dot22_8_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
- }
- a_ptr += 8;
- }
- k += 32;
- }
- while (k < rows_3 && k < end_k)
- {
- if (k == nextgroup)
- {
- group++;
- scales_idx++;
- qs_h0 = scales[scales_idx][0];
- qs_h1 = scales[scales_idx][1];
- qs_h2 = scales[scales_idx][2];
- qs_h3 = scales[scales_idx][3];
- nextgroup += b_q_group_map[k * 2 + 1];
- }
- #pragma unroll
- for (int j = 0; j < 1; j++)
- {
- int4 load_int4[3];
- load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
- load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
- load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
- half2 dq[4][16];
- dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
- dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
- dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
- dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
- for (int m = 0; m < m_count; m++)
- {
- block_c[m][0] = dot22_32_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
- block_c[m][1] = dot22_32_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
- block_c[m][2] = dot22_32_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
- block_c[m][3] = dot22_32_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
- }
- a_ptr += 32;
- }
- k += 32;
- }
- while (k < rows_2 && k < end_k)
- {
- if (k == nextgroup)
- {
- group++;
- scales_idx++;
- qs_h0 = scales[scales_idx][0];
- qs_h1 = scales[scales_idx][1];
- qs_h2 = scales[scales_idx][2];
- qs_h3 = scales[scales_idx][3];
- nextgroup += b_q_group_map[k * 2 + 1];
- }
- #pragma unroll
- for (int j = 0; j < 1; j++)
- {
- int4 load_int4[1];
- load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
- half2 dq[4][8];
- dequant_2bit_16(load_int4[0].x, dq[0], size_n);
- dequant_2bit_16(load_int4[0].y, dq[1], size_n);
- dequant_2bit_16(load_int4[0].z, dq[2], size_n);
- dequant_2bit_16(load_int4[0].w, dq[3], size_n);
- for (int m = 0; m < m_count; m++)
- {
- block_c[m][0] = dot22_16_h(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_h0);
- block_c[m][1] = dot22_16_h(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_h1);
- block_c[m][2] = dot22_16_h(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_h2);
- block_c[m][3] = dot22_16_h(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_h3);
- }
- a_ptr += 16;
- }
- k += 16;
- }
- // Accumulate column sums in c
- for (int m = 0; m < m_count; m++)
- {
- half2* out = (half2*)c_.item_ptr(offset_m + m, n);
- half2 result01 = __halves2half2(block_c[m][0], block_c[m][1]);
- half2 result23 = __halves2half2(block_c[m][2], block_c[m][3]);
- atomicAdd(out , result01);
- atomicAdd(out + 1, result23);
- // *out = result01;
- // *(out + 1) = result23;
- }
- }
- struct map_m_count_exl2 {
- static constexpr fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
- {
- #if EXL2_BLOCK_M_SIZE_MAX >= 1
- if (m_count == 1) return gemm_half_q_half_kernel<1>;
- #endif
- #if EXL2_BLOCK_M_SIZE_MAX >= 2
- if (m_count == 2) return gemm_half_q_half_kernel<2>;
- #endif
- #if EXL2_BLOCK_M_SIZE_MAX >= 3
- if (m_count == 3) return gemm_half_q_half_kernel<3>;
- #endif
- #if EXL2_BLOCK_M_SIZE_MAX >= 4
- if (m_count == 4) return gemm_half_q_half_kernel<4>;
- #endif
- #if EXL2_BLOCK_M_SIZE_MAX >= 5
- if (m_count == 5) return gemm_half_q_half_kernel<5>;
- #endif
- #if EXL2_BLOCK_M_SIZE_MAX >= 6
- if (m_count == 6) return gemm_half_q_half_kernel<6>;
- #endif
- #if EXL2_BLOCK_M_SIZE_MAX >= 7
- if (m_count == 7) return gemm_half_q_half_kernel<7>;
- #endif
- #if EXL2_BLOCK_M_SIZE_MAX >= 8
- if (m_count == 8) return gemm_half_q_half_kernel<8>;
- #endif
- return NULL;
- }
- };
- fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(const int m_count)
- {
- return map_m_count_exl2::pick_gemm_half_q_half_kernel(m_count);
- }
- } // namespace exl2
- } // namespace aphrodite
|