|
@@ -28,6 +28,7 @@ namespace gptq {
|
|
|
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
|
|
|
|
|
#if defined(USE_ROCM)
|
|
|
+#include <hipblas/hipblas.h>
|
|
|
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
|
|
hipblasOperation_t transA,
|
|
|
hipblasOperation_t transB,
|
|
@@ -286,7 +287,8 @@ void gemm_half_q_half_cuda_part
|
|
|
|
|
|
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
|
|
|
|
|
- kernel<<<gridDim, blockDim>>>
|
|
|
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
+ kernel<<<gridDim, blockDim, 0, stream>>>
|
|
|
(
|
|
|
a,
|
|
|
b_q_weight,
|
|
@@ -433,7 +435,8 @@ void reconstruct_exllama
|
|
|
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
|
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
|
|
|
|
|
- reconstruct_exllama_kernel<<<gridDim, blockDim>>>
|
|
|
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
+ reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
|
|
|
(
|
|
|
b_q_weight,
|
|
|
b_q_perm,
|
|
@@ -520,16 +523,25 @@ __global__ void gemm_half_q_half_alt_kernel(
|
|
|
zeros_tmp[tmp_k] = zero;
|
|
|
}
|
|
|
for (int m = 0; m < b_end; m++) {
|
|
|
+#ifndef USE_ROCM
|
|
|
res2 = {};
|
|
|
+#else
|
|
|
+ res2.x = __half_as_ushort(__float2half(0));
|
|
|
+ res2.y = __half_as_ushort(__float2half(0));
|
|
|
+#endif
|
|
|
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
|
|
|
res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
|
|
|
res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
|
|
|
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
|
|
|
+#ifndef USE_ROCM
|
|
|
res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
|
|
|
+#else
|
|
|
+ res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
|
|
|
+#endif
|
|
|
}
|
|
|
i += width;
|
|
|
k += 4;
|
|
|
- }
|
|
|
+}
|
|
|
for (int m = 0; m < b_end; m++) {
|
|
|
atomicAdd(&mul[(b + m) * width + w], res[m]);
|
|
|
}
|
|
@@ -557,7 +569,8 @@ void gemm_half_q_half_alt
|
|
|
gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
|
|
|
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
|
|
|
|
|
- gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>>
|
|
|
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
+ gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
|
|
|
(
|
|
|
(const half2*) a,
|
|
|
b_q_weight,
|
|
@@ -629,7 +642,8 @@ void reconstruct_gptq
|
|
|
blockDim.y = 1;
|
|
|
gridDim.y = DIVIDE(height, 8);
|
|
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
|
|
- reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
|
|
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
+ reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
|
|
|
(
|
|
|
b_q_weight,
|
|
|
b_gptq_scales,
|
|
@@ -784,7 +798,8 @@ void shuffle_exllama_weight
|
|
|
gridDim.x = DIVIDE(width, THREADS_X);
|
|
|
gridDim.y = height / 8;
|
|
|
|
|
|
- make_sequential_kernel<<<gridDim, blockDim>>>
|
|
|
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
+ make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
|
|
|
(
|
|
|
q_weight,
|
|
|
new_qweight,
|
|
@@ -803,7 +818,8 @@ void shuffle_exllama_weight
|
|
|
blockDim.y = 1;
|
|
|
gridDim.x = DIVIDE(width, THREADS_X);
|
|
|
gridDim.y = 1;
|
|
|
- shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width);
|
|
|
+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
+ shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
|
|
|
}
|
|
|
|
|
|
} // namespace gptq
|