matrix.cuh 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. // Adapted from turboderp exllama: https://github.com/turboderp/exllama
  2. #ifndef _matrix_cuh
  3. #define _matrix_cuh
  4. #include <cuda_runtime.h>
  5. #include <cuda_fp16.h>
  6. class MatrixView_half
  7. {
  8. public:
  9. const half* data;
  10. const int height;
  11. const int width;
  12. __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
  13. : data(data), height(height), width(width)
  14. { }
  15. __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
  16. __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
  17. __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
  18. __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
  19. };
  20. class MatrixView_half_rw
  21. {
  22. public:
  23. half* data;
  24. const int height;
  25. const int width;
  26. __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
  27. : data(data), height(height), width(width)
  28. { }
  29. __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
  30. __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
  31. __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
  32. __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
  33. __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
  34. };
  35. class MatrixView_q4_row
  36. {
  37. public:
  38. const uint32_t* data;
  39. const int height;
  40. const int width;
  41. __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
  42. : data(data), height(height), width(width)
  43. { }
  44. __device__ __forceinline__ int item(int row, int column) const
  45. {
  46. int shift = (column & 0x07) * 4;
  47. return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
  48. }
  49. };
  50. class MatrixView_q4_column
  51. {
  52. public:
  53. const uint32_t* data;
  54. const int height;
  55. const int width;
  56. __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
  57. : data(data), height(height), width(width)
  58. { }
  59. __device__ __forceinline__ int item(int row, int column) const
  60. {
  61. int shift = (row & 0x07) * 4;
  62. return (data[row / 8 * width + column] >> shift) & 0x0f;
  63. }
  64. __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
  65. __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
  66. };
  67. // TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
  68. // Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
  69. __device__ __forceinline__ half2 dot_product_8
  70. (
  71. const half2 acc,
  72. MatrixView_half& h_,
  73. const int h_row,
  74. const int h_column, // divisible by 8
  75. MatrixView_q4_column& v_,
  76. const int v_row, // divisible by 8
  77. const int v_column,
  78. const half2 v_scale_2,
  79. const uint32_t v_zero, // + 1 (!!)
  80. const int count
  81. )
  82. {
  83. const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
  84. const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
  85. half2 result = acc;
  86. for (int i = 0; i < count; i++)
  87. {
  88. uint32_t v_read = *v_ptr; v_ptr += v_.width;
  89. half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
  90. half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
  91. half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
  92. half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
  93. half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
  94. half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
  95. half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
  96. half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
  97. half2 v_01 = __halves2half2(v_0, v_1);
  98. half2 v_23 = __halves2half2(v_2, v_3);
  99. half2 v_45 = __halves2half2(v_4, v_5);
  100. half2 v_67 = __halves2half2(v_6, v_7);
  101. // half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
  102. // half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
  103. // half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
  104. // half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
  105. half2 tmp = __hmul2(*h_ptr++, v_01);
  106. tmp = __hfma2(*h_ptr++, v_23, tmp);
  107. tmp = __hfma2(*h_ptr++, v_45, tmp);
  108. tmp = __hfma2(*h_ptr++, v_67, tmp);
  109. result = __hfma2(v_scale_2, tmp, result);
  110. }
  111. return result;
  112. }
  113. __device__ __forceinline__ half dot_product_8_h
  114. (
  115. const half acc,
  116. MatrixView_half& h_,
  117. const int h_row,
  118. const int h_column, // divisible by 8
  119. MatrixView_q4_column& v_,
  120. const int v_row, // divisible by 8
  121. const int v_column,
  122. const half v_scale,
  123. const uint32_t v_zero, // + 1 (!!)
  124. const int count
  125. )
  126. {
  127. const half* h_ptr = h_.item_ptr(h_row, h_column);
  128. const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
  129. half result = acc;
  130. for (int i = 0; i < count; i++)
  131. {
  132. uint32_t v_read = *v_ptr; v_ptr += v_.width;
  133. half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
  134. half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
  135. half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
  136. half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
  137. half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
  138. half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
  139. half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
  140. half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
  141. half tmp = __hmul(*h_ptr++, v_0);
  142. tmp = __hfma(*h_ptr++, v_1, tmp);
  143. tmp = __hfma(*h_ptr++, v_2, tmp);
  144. tmp = __hfma(*h_ptr++, v_3, tmp);
  145. tmp = __hfma(*h_ptr++, v_4, tmp);
  146. tmp = __hfma(*h_ptr++, v_5, tmp);
  147. tmp = __hfma(*h_ptr++, v_6, tmp);
  148. tmp = __hfma(*h_ptr++, v_7, tmp);
  149. result = __hfma(v_scale, tmp, result);
  150. }
  151. return result;
  152. }
  153. // Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
  154. __device__ __forceinline__ half2 dot_product_8_x_map
  155. (
  156. const half2 acc,
  157. MatrixView_half& h_,
  158. const int h_row,
  159. const int h_column, // divisible by 8
  160. MatrixView_q4_column& v_,
  161. const int v_row, // divisible by 8
  162. const int v_column,
  163. const half2 v_scale_2,
  164. const uint32_t v_zero, // + 1 (!!)
  165. const int count,
  166. const uint32_t* x_map
  167. )
  168. {
  169. const half* h_ptr = h_.item_ptr(h_row, 0);
  170. const uint32_t* x_map_ptr = x_map + h_column;
  171. const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
  172. half2 result = acc;
  173. for (int i = 0; i < count; i++)
  174. {
  175. uint32_t v_read = *v_ptr; v_ptr += v_.width;
  176. half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
  177. half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
  178. half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
  179. half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
  180. half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
  181. half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
  182. half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
  183. half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
  184. half2 v_01 = __halves2half2(v_0, v_1);
  185. half2 v_23 = __halves2half2(v_2, v_3);
  186. half2 v_45 = __halves2half2(v_4, v_5);
  187. half2 v_67 = __halves2half2(v_6, v_7);
  188. half h_0 = h_ptr[*x_map_ptr++];
  189. half h_1 = h_ptr[*x_map_ptr++];
  190. half h_2 = h_ptr[*x_map_ptr++];
  191. half h_3 = h_ptr[*x_map_ptr++];
  192. half h_4 = h_ptr[*x_map_ptr++];
  193. half h_5 = h_ptr[*x_map_ptr++];
  194. half h_6 = h_ptr[*x_map_ptr++];
  195. half h_7 = h_ptr[*x_map_ptr++];
  196. half2 h_01 = __halves2half2(h_0, h_1);
  197. half2 h_23 = __halves2half2(h_2, h_3);
  198. half2 h_45 = __halves2half2(h_4, h_5);
  199. half2 h_67 = __halves2half2(h_6, h_7);
  200. half2 tmp = __hmul2(h_01, v_01);
  201. tmp = __hfma2(h_23, v_23, tmp);
  202. tmp = __hfma2(h_45, v_45, tmp);
  203. tmp = __hfma2(h_67, v_67, tmp);
  204. result = __hfma2(v_scale_2, tmp, result);
  205. }
  206. return result;
  207. }
  208. __device__ __forceinline__ half dot_product_8_x_map_h
  209. (
  210. const half acc,
  211. MatrixView_half& h_,
  212. const int h_row,
  213. const int h_column, // divisible by 8
  214. MatrixView_q4_column& v_,
  215. const int v_row, // divisible by 8
  216. const int v_column,
  217. const half v_scale,
  218. const uint32_t v_zero, // + 1 (!!)
  219. const int count,
  220. const uint32_t* x_map
  221. )
  222. {
  223. const half* h_ptr = h_.item_ptr(h_row, 0);
  224. const uint32_t* x_map_ptr = x_map + h_column;
  225. const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
  226. half result = acc;
  227. for (int i = 0; i < count; i++)
  228. {
  229. uint32_t v_read = *v_ptr; v_ptr += v_.width;
  230. half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
  231. half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
  232. half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
  233. half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
  234. half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
  235. half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
  236. half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
  237. half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
  238. half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
  239. tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
  240. tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
  241. tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
  242. tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
  243. tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
  244. tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
  245. tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
  246. result = __hfma(v_scale, tmp, result);
  247. }
  248. return result;
  249. }
  250. #endif