123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216 |
- #include <torch/all.h>
- #include <cuda.h>
- #include <cuda_runtime.h>
- #include <cuda_fp16.h>
- // half-tensor
- #include <c10/cuda/CUDAStream.h>
- #include <ATen/cuda/CUDATensorMethods.cuh>
- #include <c10/cuda/CUDAGuard.h>
- #define BLOCKWIDTH 128
- #define BLOCKHEIGHT4 16
- namespace aphrodite {
- namespace squeezellm {
- __device__ inline unsigned int as_unsigned(int i) {
- return *reinterpret_cast<unsigned int*>(&i);
- }
- // 4-bit matvec kernel (LUT-based)
- __global__ void NUQ4MatMulKernel(
- #ifndef USE_ROCM
- const half2* __restrict__ vec,
- #else
- const __half2* __restrict__ vec,
- #endif
- const int* __restrict__ mat,
- #ifndef USE_ROCM
- half2* __restrict__ mul,
- #else
- float2* __restrict__ mul,
- #endif
- const __half* __restrict__ lookup_table, int height, int width, int batch,
- int vec_height) {
- const int blockwidth2 = BLOCKWIDTH / 2;
- int row = BLOCKHEIGHT4 * blockIdx.x;
- int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
- #ifndef USE_ROCM
- __shared__ half2 blockvec[blockwidth2];
- #else
- __shared__ __half2 blockvec[blockwidth2];
- #endif
- __shared__ __half deq2[16][BLOCKWIDTH];
- int off = threadIdx.x;
- int column_offset = col * 16;
- for (int val = 0; val < 16; val += 1) {
- int lut_index = column_offset + val;
- deq2[val][off] = lookup_table[lut_index];
- }
- __half res;
- #ifndef USE_ROCM
- half2 res2;
- half2 tmp2;
- #else
- __half2 res2;
- __half2 tmp2;
- #endif
- int i;
- int k;
- unsigned int tmp1;
- unsigned int lut_index1, lut_index2;
- for (int b = 0; b < batch; ++b) {
- i = width * row + col;
- res = __int2half_rd(0);
- k = 0;
- __syncthreads();
- if (threadIdx.x < blockwidth2)
- blockvec[threadIdx.x] =
- vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
- threadIdx.x];
- __syncthreads();
- while (k < blockwidth2) {
- tmp1 = as_unsigned(mat[i]);
- #ifndef USE_ROCM
- res2 = {};
- tmp2 = {};
- #else
- res2.x = __half_as_ushort(__float2half(0));
- res2.y = __half_as_ushort(__float2half(0));
- tmp2.x = __half_as_ushort(__float2half(0));
- tmp2.y = __half_as_ushort(__float2half(0));
- #endif
- lut_index1 = tmp1 & 0xF;
- lut_index2 = (tmp1 >> 4) & 0xF;
- #ifndef USE_ROCM
- tmp2.x = deq2[lut_index1][off];
- tmp2.y = deq2[lut_index2][off];
- #else
- tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
- tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
- #endif
- res2 = __hfma2(tmp2, blockvec[k + 0], res2);
- lut_index1 = (tmp1 >> 8) & 0xF;
- lut_index2 = (tmp1 >> 12) & 0xF;
- #ifndef USE_ROCM
- tmp2.x = deq2[lut_index1][off];
- tmp2.y = deq2[lut_index2][off];
- #else
- tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
- tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
- #endif
- res2 = __hfma2(tmp2, blockvec[k + 1], res2);
- lut_index1 = (tmp1 >> 16) & 0xF;
- lut_index2 = (tmp1 >> 20) & 0xF;
- #ifndef USE_ROCM
- tmp2.x = deq2[lut_index1][off];
- tmp2.y = deq2[lut_index2][off];
- #else
- tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
- tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
- #endif
- res2 = __hfma2(tmp2, blockvec[k + 2], res2);
- lut_index1 = (tmp1 >> 24) & 0xF;
- lut_index2 = (tmp1 >> 28) & 0xF;
- #ifndef USE_ROCM
- tmp2.x = deq2[lut_index1][off];
- tmp2.y = deq2[lut_index2][off];
- #else
- tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
- tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
- #endif
- res2 = __hfma2(tmp2, blockvec[k + 3], res2);
- #ifndef USE_ROCM
- res = __hadd(__hadd(res2.x, res2.y), res);
- #else
- res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
- res);
- #endif
- i += width;
- k += 4;
- }
- // col%2 -> only set one of the two values
- #ifndef USE_ROCM
- half2 res3 = {};
- if (col % 2 == 0) {
- res3.x = res;
- } else {
- res3.y = res;
- }
- #else
- __half2 res3;
- res3.x = __half_as_ushort(__float2half(0));
- res3.y = __half_as_ushort(__float2half(0));
- if (col % 2 == 0) {
- res3.x = __half_as_ushort(res);
- } else {
- res3.y = __half_as_ushort(res);
- }
- #endif
- #ifndef USE_ROCM
- atomicAdd(&mul[b * width / 2 + col / 2], res3);
- #else
- int tmp_addr = b * width / 2 + col / 2;
- atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
- atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
- #endif
- }
- }
- } // namespace squeezellm
- } // namespace aphrodite
- // 4-bit matvec kernel (LUT-based)
- void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
- torch::Tensor lookup_table) {
- int height = mat.size(0);
- int width = mat.size(1);
- int batch = vec.size(0);
- int vec_height = vec.size(1);
- dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
- (width + BLOCKWIDTH - 1) / BLOCKWIDTH);
- dim3 threads(BLOCKWIDTH);
- const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
- #ifndef USE_ROCM
- (half2*)vec.data_ptr<at::Half>(),
- #else
- (__half2*)vec.data_ptr<at::Half>(),
- #endif
- mat.data_ptr<int>(),
- #ifndef USE_ROCM
- (half2*)mul.data_ptr<at::Half>(),
- (__half*)lookup_table.data_ptr<at::Half>(),
- #else
- (float2*)mul.data_ptr<float>(),
- (__half*)lookup_table.data_ptr<at::Half>(),
- #endif
- height, width, batch, vec_height);
- }
- #undef BLOCKWIDTH
- #undef BLOCKHEIGHT4
|