123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- // Adapted from turboderp exllama: https://github.com/turboderp/exllama
- #ifndef _matrix_cuh
- #define _matrix_cuh
- #include <cuda_runtime.h>
- #include <cuda_fp16.h>
- class MatrixView_half
- {
- public:
- const half* data;
- const int height;
- const int width;
- __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
- : data(data), height(height), width(width)
- { }
- __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
- __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
- __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
- __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
- };
- class MatrixView_half_rw
- {
- public:
- half* data;
- const int height;
- const int width;
- __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
- : data(data), height(height), width(width)
- { }
- __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
- __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
- __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
- __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
- __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
- };
- class MatrixView_q4_row
- {
- public:
- const uint32_t* data;
- const int height;
- const int width;
- __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
- : data(data), height(height), width(width)
- { }
- __device__ __forceinline__ int item(int row, int column) const
- {
- int shift = (column & 0x07) * 4;
- return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
- }
- };
- class MatrixView_q4_column
- {
- public:
- const uint32_t* data;
- const int height;
- const int width;
- __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
- : data(data), height(height), width(width)
- { }
- __device__ __forceinline__ int item(int row, int column) const
- {
- int shift = (row & 0x07) * 4;
- return (data[row / 8 * width + column] >> shift) & 0x0f;
- }
- __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
- __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
- };
- // TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
- // Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
- __device__ __forceinline__ half2 dot_product_8
- (
- const half2 acc,
- MatrixView_half& h_,
- const int h_row,
- const int h_column, // divisible by 8
- MatrixView_q4_column& v_,
- const int v_row, // divisible by 8
- const int v_column,
- const half2 v_scale_2,
- const uint32_t v_zero, // + 1 (!!)
- const int count
- )
- {
- const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
- const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
- half2 result = acc;
- for (int i = 0; i < count; i++)
- {
- uint32_t v_read = *v_ptr; v_ptr += v_.width;
- half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
- half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
- half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
- half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
- half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
- half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
- half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
- half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
- half2 v_01 = __halves2half2(v_0, v_1);
- half2 v_23 = __halves2half2(v_2, v_3);
- half2 v_45 = __halves2half2(v_4, v_5);
- half2 v_67 = __halves2half2(v_6, v_7);
- // half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
- // half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
- // half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
- // half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
- half2 tmp = __hmul2(*h_ptr++, v_01);
- tmp = __hfma2(*h_ptr++, v_23, tmp);
- tmp = __hfma2(*h_ptr++, v_45, tmp);
- tmp = __hfma2(*h_ptr++, v_67, tmp);
- result = __hfma2(v_scale_2, tmp, result);
- }
- return result;
- }
- __device__ __forceinline__ half dot_product_8_h
- (
- const half acc,
- MatrixView_half& h_,
- const int h_row,
- const int h_column, // divisible by 8
- MatrixView_q4_column& v_,
- const int v_row, // divisible by 8
- const int v_column,
- const half v_scale,
- const uint32_t v_zero, // + 1 (!!)
- const int count
- )
- {
- const half* h_ptr = h_.item_ptr(h_row, h_column);
- const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
- half result = acc;
- for (int i = 0; i < count; i++)
- {
- uint32_t v_read = *v_ptr; v_ptr += v_.width;
- half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
- half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
- half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
- half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
- half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
- half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
- half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
- half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
- half tmp = __hmul(*h_ptr++, v_0);
- tmp = __hfma(*h_ptr++, v_1, tmp);
- tmp = __hfma(*h_ptr++, v_2, tmp);
- tmp = __hfma(*h_ptr++, v_3, tmp);
- tmp = __hfma(*h_ptr++, v_4, tmp);
- tmp = __hfma(*h_ptr++, v_5, tmp);
- tmp = __hfma(*h_ptr++, v_6, tmp);
- tmp = __hfma(*h_ptr++, v_7, tmp);
- result = __hfma(v_scale, tmp, result);
- }
- return result;
- }
- // Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
- __device__ __forceinline__ half2 dot_product_8_x_map
- (
- const half2 acc,
- MatrixView_half& h_,
- const int h_row,
- const int h_column, // divisible by 8
- MatrixView_q4_column& v_,
- const int v_row, // divisible by 8
- const int v_column,
- const half2 v_scale_2,
- const uint32_t v_zero, // + 1 (!!)
- const int count,
- const uint32_t* x_map
- )
- {
- const half* h_ptr = h_.item_ptr(h_row, 0);
- const uint32_t* x_map_ptr = x_map + h_column;
- const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
- half2 result = acc;
- for (int i = 0; i < count; i++)
- {
- uint32_t v_read = *v_ptr; v_ptr += v_.width;
- half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
- half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
- half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
- half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
- half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
- half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
- half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
- half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
- half2 v_01 = __halves2half2(v_0, v_1);
- half2 v_23 = __halves2half2(v_2, v_3);
- half2 v_45 = __halves2half2(v_4, v_5);
- half2 v_67 = __halves2half2(v_6, v_7);
- half h_0 = h_ptr[*x_map_ptr++];
- half h_1 = h_ptr[*x_map_ptr++];
- half h_2 = h_ptr[*x_map_ptr++];
- half h_3 = h_ptr[*x_map_ptr++];
- half h_4 = h_ptr[*x_map_ptr++];
- half h_5 = h_ptr[*x_map_ptr++];
- half h_6 = h_ptr[*x_map_ptr++];
- half h_7 = h_ptr[*x_map_ptr++];
- half2 h_01 = __halves2half2(h_0, h_1);
- half2 h_23 = __halves2half2(h_2, h_3);
- half2 h_45 = __halves2half2(h_4, h_5);
- half2 h_67 = __halves2half2(h_6, h_7);
- half2 tmp = __hmul2(h_01, v_01);
- tmp = __hfma2(h_23, v_23, tmp);
- tmp = __hfma2(h_45, v_45, tmp);
- tmp = __hfma2(h_67, v_67, tmp);
- result = __hfma2(v_scale_2, tmp, result);
- }
- return result;
- }
- __device__ __forceinline__ half dot_product_8_x_map_h
- (
- const half acc,
- MatrixView_half& h_,
- const int h_row,
- const int h_column, // divisible by 8
- MatrixView_q4_column& v_,
- const int v_row, // divisible by 8
- const int v_column,
- const half v_scale,
- const uint32_t v_zero, // + 1 (!!)
- const int count,
- const uint32_t* x_map
- )
- {
- const half* h_ptr = h_.item_ptr(h_row, 0);
- const uint32_t* x_map_ptr = x_map + h_column;
- const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
- half result = acc;
- for (int i = 0; i < count; i++)
- {
- uint32_t v_read = *v_ptr; v_ptr += v_.width;
- half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
- half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
- half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
- half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
- half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
- half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
- half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
- half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
- half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
- tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
- result = __hfma(v_scale, tmp, result);
- }
- return result;
- }
- #endif
|