123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- #include <cuda_fp16.h>
- #include <cuda_runtime.h>
- #include <torch/all.h>
- #include <c10/cuda/CUDAGuard.h>
- #include "ggml-common.h"
- #include "vecdotq.cuh"
- #include "dequantize.cuh"
- #include "mmvq.cuh"
- #include "mmq.cuh"
- // Q8 gemv
- static __global__ void quantize_q8_1(const half* __restrict__ x,
- void* __restrict__ vy, const int kx,
- const int kx_padded) {
- const int ix = blockDim.x * blockIdx.x + threadIdx.x;
- if (ix >= kx_padded) {
- return;
- }
- const int iy = blockDim.y * blockIdx.y + threadIdx.y;
- const int i_padded = iy * kx_padded + ix;
- block_q8_1* y = (block_q8_1*)vy;
- const int ib = i_padded / QK8_1; // block index
- const int iqs = i_padded % QK8_1; // quant index
- const float xi = ix < kx ? __half2float(x[iy * kx + ix]) : 0.0f;
- float amax = fabsf(xi);
- float sum = xi;
- #pragma unroll
- for (int mask = 16; mask > 0; mask >>= 1) {
- amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
- sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
- }
- const float d = amax / 127;
- const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
- y[ib].qs[iqs] = q;
- if (iqs > 0) {
- return;
- }
- y[ib].ds.x = __float2half(d);
- y[ib].ds.y = __float2half(sum);
- }
- static void quantize_row_q8_1_cuda(const half* x, void* vy, const int kx,
- const int ky, cudaStream_t stream) {
- const int64_t kx_padded = (kx + 512 - 1) / 512 * 512;
- const int block_num_x =
- (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
- const dim3 num_blocks(block_num_x, ky, 1);
- const dim3 block_size(CUDA_DEQUANTIZE_BLOCK_SIZE, 1, 1);
- quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
- }
- torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
- int8_t type, int64_t m, int64_t n) {
- const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
- auto options =
- torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
- at::Tensor DW = torch::empty({m, n}, options);
- cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
- const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type);
- to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream);
- return DW;
- }
- torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, // quant weight
- torch::Tensor X, // input
- int8_t type, int64_t row) {
- int col = X.sizes()[1];
- const int padded = (col + 512 - 1) / 512 * 512;
- const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
- auto options =
- torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
- at::Tensor Y = torch::empty({1, row}, options);
- cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
- options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
- at::Tensor quant_X = torch::empty({1, padded / 32 * 9}, options);
- quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col, 1,
- stream);
- switch (type) {
- case 2:
- mul_mat_vec_q4_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 3:
- mul_mat_vec_q4_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 6:
- mul_mat_vec_q5_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 7:
- mul_mat_vec_q5_1_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 8:
- mul_mat_vec_q8_0_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 10:
- mul_mat_vec_q2_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 11:
- mul_mat_vec_q3_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 12:
- mul_mat_vec_q4_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 13:
- mul_mat_vec_q5_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 14:
- mul_mat_vec_q6_K_q8_1_cuda((void*)W.data_ptr(), (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 16:
- mul_mat_vec_iq2_xxs_q8_1_cuda((void*)W.data_ptr(),
- (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 17:
- mul_mat_vec_iq2_xs_q8_1_cuda((void*)W.data_ptr(),
- (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 18:
- mul_mat_vec_iq3_xxs_q8_1_cuda((void*)W.data_ptr(),
- (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 19:
- mul_mat_vec_iq1_s_q8_1_cuda((void*)W.data_ptr(),
- (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 20:
- mul_mat_vec_iq4_nl_q8_1_cuda((void*)W.data_ptr(),
- (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 21:
- mul_mat_vec_iq3_s_q8_1_cuda((void*)W.data_ptr(),
- (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 22:
- mul_mat_vec_iq2_s_q8_1_cuda((void*)W.data_ptr(),
- (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- case 23:
- mul_mat_vec_iq4_xs_q8_1_cuda((void*)W.data_ptr(),
- (void*)quant_X.data_ptr(),
- (half*)Y.data_ptr(), col, row, stream);
- break;
- }
- return Y;
- }
- torch::Tensor ggml_mul_mat_a8(torch::Tensor W, // quant weight
- torch::Tensor X, // input
- int8_t type, int64_t row) {
- int col = X.sizes()[1];
- int padded = (col + 512 - 1) / 512 * 512;
- int batch = X.sizes()[0];
- const at::cuda::OptionalCUDAGuard device_guard(device_of(X));
- auto options =
- torch::TensorOptions().dtype(torch::kFloat16).device(W.device());
- at::Tensor Y = torch::empty({batch, row}, options);
- cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
- options = torch::TensorOptions().dtype(torch::kInt32).device(W.device());
- at::Tensor quant_X = torch::empty({batch, padded / 32 * 9}, options);
- quantize_row_q8_1_cuda((half*)X.data_ptr(), (void*)quant_X.data_ptr(), col,
- batch, stream);
- switch (type) {
- case 2:
- ggml_mul_mat_q4_0_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- case 3:
- ggml_mul_mat_q4_1_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- case 6:
- ggml_mul_mat_q5_0_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- case 7:
- ggml_mul_mat_q5_1_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- case 8:
- ggml_mul_mat_q8_0_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- case 10:
- ggml_mul_mat_q2_K_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- case 11:
- ggml_mul_mat_q3_K_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- case 12:
- ggml_mul_mat_q4_K_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- case 13:
- ggml_mul_mat_q5_K_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- case 14:
- ggml_mul_mat_q6_K_q8_1_cuda(
- (void*)W.data_ptr(), (void*)quant_X.data_ptr(), (half*)Y.data_ptr(),
- col, row, batch, padded, row, stream);
- break;
- }
- return Y;
- }
|