gptq_marlin.cuh 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #pragma once
  2. #include <torch/all.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <c10/cuda/CUDAGuard.h>
  5. #include <cuda.h>
  6. #include <cuda_fp16.h>
  7. #include <cuda_runtime.h>
  8. #include <iostream>
  9. namespace gptq_marlin {
  10. // 8 warps are a good choice since every SM has 4 schedulers and having more
  11. // than 1 warp per schedule allows some more latency hiding. At the same time,
  12. // we want relatively few warps to have many registers per warp and small tiles.
  13. static constexpr int default_threads = 256;
  14. static constexpr int pipe_stages =
  15. 4; // 4 pipeline stages fit into shared memory
  16. static constexpr int min_thread_n = 64;
  17. static constexpr int min_thread_k = 64;
  18. static constexpr int tile_size = 16;
  19. static constexpr int max_par = 16;
  20. template <typename T, int n>
  21. struct Vec {
  22. T elems[n];
  23. __device__ T& operator[](int i) { return elems[i]; }
  24. };
  25. using I4 = Vec<int, 4>;
  26. constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
  27. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  28. // No support for async
  29. #else
  30. __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
  31. bool pred = true) {
  32. const int BYTES = 16;
  33. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  34. asm volatile(
  35. "{\n"
  36. " .reg .pred p;\n"
  37. " setp.ne.b32 p, %0, 0;\n"
  38. " @p cp.async.cg.shared.global [%1], [%2], %3;\n"
  39. "}\n" ::"r"((int)pred),
  40. "r"(smem), "l"(glob_ptr), "n"(BYTES));
  41. }
  42. __device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
  43. const int BYTES = 16;
  44. uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
  45. asm volatile(
  46. "{\n"
  47. " cp.async.cg.shared.global [%0], [%1], %2;\n"
  48. "}\n" ::"r"(smem),
  49. "l"(glob_ptr), "n"(BYTES));
  50. }
  51. __device__ inline void cp_async_fence() {
  52. asm volatile("cp.async.commit_group;\n" ::);
  53. }
  54. template <int n>
  55. __device__ inline void cp_async_wait() {
  56. asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
  57. }
  58. #endif
  59. } // namespace gptq_marlin