q_matrix.cuh 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. /*
  2. * Adapted from https://github.com/turboderp/exllamav2
  3. * Copyright (c) 2024 turboderp
  4. *
  5. * Permission is hereby granted, free of charge, to any person obtaining a copy
  6. * of this software and associated documentation files (the "Software"), to deal
  7. * in the Software without restriction, including without limitation the rights
  8. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. * copies of the Software, and to permit persons to whom the Software is
  10. * furnished to do so, subject to the following conditions:
  11. *
  12. * The above copyright notice and this permission notice shall be included in all
  13. * copies or substantial portions of the Software.
  14. *
  15. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. * SOFTWARE.
  22. */
  23. #ifndef _q_matrix_cuh
  24. #define _q_matrix_cuh
  25. #include <cuda_runtime.h>
  26. #include <cuda_fp16.h>
  27. #include <cstdint>
  28. #include <cstdio>
  29. namespace aphrodite {
  30. namespace exl2 {
  31. #define MAX_SUPERGROUPS 16
  32. class QMatrix
  33. {
  34. public:
  35. int device;
  36. bool is_gptq;
  37. int height;
  38. int width;
  39. int groups;
  40. int gptq_groupsize;
  41. int rows_8;
  42. int rows_6;
  43. int rows_5;
  44. int rows_4;
  45. int rows_3;
  46. int rows_2;
  47. uint32_t* cuda_q_weight = NULL;
  48. uint16_t* cuda_q_perm = NULL;
  49. uint16_t* cuda_q_invperm = NULL;
  50. uint32_t* cuda_q_scale = NULL;
  51. half* cuda_q_scale_max = NULL;
  52. uint16_t* cuda_q_groups = NULL;
  53. uint16_t* cuda_q_group_map = NULL;
  54. uint32_t* cuda_gptq_qzeros = NULL;
  55. half* cuda_gptq_scales = NULL;
  56. half* temp_dq;
  57. bool failed;
  58. QMatrix
  59. (
  60. const int _device,
  61. const int _height,
  62. const int _width,
  63. const int _groups,
  64. uint32_t* _q_weight,
  65. uint16_t* _q_perm,
  66. uint16_t* _q_invperm,
  67. uint32_t* _q_scale,
  68. half* _q_scale_max,
  69. uint16_t* _q_groups,
  70. uint16_t* _q_group_map
  71. );
  72. ~QMatrix();
  73. void reconstruct(half* out);
  74. bool make_sequential(const uint32_t* cpu_g_idx);
  75. private:
  76. };
  77. } // namespace exl2
  78. } // namespace aphrodite
  79. #endif