aqlm_cuda_entry.cpp 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. /*
  2. * Modified by Neural Magic
  3. * Adapted from https://github.com/Vahe1994/AQLM
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <torch/all.h>
  18. #include <torch/python.h>
  19. #include <c10/cuda/CUDAGuard.h>
  20. #include <iostream>
  21. #include <cstdlib>
  22. void code1x16_matvec_cuda(
  23. const void* A,
  24. const void* B,
  25. void* C,
  26. const void* codebook,
  27. int prob_m,
  28. int prob_k,
  29. const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
  30. const int codebook_stride // as int4.
  31. );
  32. void code2x8_matvec_cuda(
  33. const void* A,
  34. const void* B,
  35. void* C,
  36. const void* codebook,
  37. int prob_m,
  38. int prob_k,
  39. const int4 codebook_a_sizes, // cumulative sizes of A spanning each codebook, at most 3 long.
  40. const int codebook_stride // as int4.
  41. );
  42. void code1x16_matvec(
  43. const torch::Tensor& A,
  44. const torch::Tensor& B,
  45. torch::Tensor& C,
  46. const torch::Tensor& codebook,
  47. const int4 codebook_a_sizes // cumulative sizes of A spanning each codebook, at most 3 long.
  48. ) {
  49. const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  50. int prob_m = C.size(0);
  51. int prob_k = B.size(0);
  52. code1x16_matvec_cuda(
  53. A.data_ptr(),
  54. B.data_ptr(),
  55. C.data_ptr(),
  56. codebook.data_ptr(),
  57. prob_m,
  58. prob_k,
  59. codebook_a_sizes,
  60. codebook.stride(0) * codebook.element_size() / sizeof(int4)
  61. );
  62. }
  63. torch::Tensor code1x16_matmat(
  64. const torch::Tensor& input,
  65. const torch::Tensor& codes,
  66. const torch::Tensor& codebooks,
  67. const torch::Tensor& scales,
  68. const int4 codebook_a_sizes,
  69. const std::optional<torch::Tensor>& bias) {
  70. auto input_sizes = input.sizes();
  71. auto out_features = codes.size(0) * codebooks.size(2);
  72. auto flat_input = input.reshape({-1, input.size(-1)});
  73. auto flat_output = torch::empty({flat_input.size(0), out_features},
  74. torch::TensorOptions()
  75. .dtype(input.dtype())
  76. .device(input.device())
  77. );
  78. for (int i = 0; i < flat_input.size(0); ++i) {
  79. auto input_vec = flat_input.index({i});
  80. auto output_vec = flat_output.index({i});
  81. code1x16_matvec(
  82. codes.squeeze(2),
  83. input_vec,
  84. output_vec,
  85. codebooks,
  86. codebook_a_sizes
  87. );
  88. }
  89. flat_output *= scales.flatten().unsqueeze(0);
  90. if (bias.has_value()) {
  91. flat_output += bias->unsqueeze(0);
  92. }
  93. auto output_sizes = input_sizes.vec();
  94. output_sizes.pop_back();
  95. output_sizes.push_back(-1);
  96. auto output = flat_output.reshape(output_sizes);
  97. return output;
  98. }
  99. void code2x8_matvec(
  100. const torch::Tensor& A,
  101. const torch::Tensor& B,
  102. torch::Tensor& C,
  103. const torch::Tensor& codebook,
  104. const int4 codebook_a_sizes
  105. ) {
  106. const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  107. int prob_m = C.size(0);
  108. int prob_k = B.size(0);
  109. code2x8_matvec_cuda(
  110. A.data_ptr(),
  111. B.data_ptr(),
  112. C.data_ptr(),
  113. codebook.data_ptr(),
  114. prob_m,
  115. prob_k,
  116. codebook_a_sizes,
  117. 2 * codebook.stride(0) * codebook.element_size() / sizeof(int4)
  118. );
  119. }
  120. torch::Tensor code2x8_matmat(
  121. const torch::Tensor& input,
  122. const torch::Tensor& codes,
  123. const torch::Tensor& codebooks,
  124. const torch::Tensor& scales,
  125. const int4 codebook_a_sizes,
  126. const std::optional<torch::Tensor>& bias
  127. ) {
  128. auto input_sizes = input.sizes();
  129. auto out_features = codes.size(0) * codebooks.size(2);
  130. auto flat_input = input.reshape({-1, input.size(-1)});
  131. auto flat_output = torch::empty({flat_input.size(0), out_features},
  132. torch::TensorOptions()
  133. .dtype(input.dtype())
  134. .device(input.device())
  135. );
  136. for (int i = 0; i < flat_input.size(0); ++i) {
  137. auto input_vec = flat_input.index({i});
  138. auto output_vec = flat_output.index({i});
  139. code2x8_matvec(
  140. codes.squeeze(2),
  141. input_vec,
  142. output_vec,
  143. codebooks,
  144. codebook_a_sizes
  145. );
  146. }
  147. flat_output *= scales.flatten().unsqueeze(0);
  148. if (bias.has_value()) {
  149. flat_output += bias->unsqueeze(0);
  150. }
  151. auto output_sizes = input_sizes.vec();
  152. output_sizes.pop_back();
  153. output_sizes.push_back(-1);
  154. auto output = flat_output.reshape(output_sizes);
  155. return output;
  156. }
  157. torch::Tensor aqlm_gemm(
  158. const torch::Tensor& input,
  159. const torch::Tensor& codes,
  160. const torch::Tensor& codebooks,
  161. const torch::Tensor& scales,
  162. const torch::Tensor& codebook_partition_sizes,
  163. const std::optional<torch::Tensor>& bias
  164. )
  165. {
  166. int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
  167. int const entries = codebooks.size(1);
  168. int4 cumulative_sizes;
  169. auto cumulative_size = &cumulative_sizes.x;
  170. int i =0;
  171. int last = 0;
  172. assert(codebook_partition_sizes.size(0) <= 4);
  173. for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size)
  174. {
  175. *cumulative_size = codebook_partition_sizes[i].item<int>() + last;
  176. last = *cumulative_size;
  177. }
  178. // fill in the rest with unreachable.
  179. for (; i < 4; ++i, ++cumulative_size)
  180. {
  181. *cumulative_size = last*10;
  182. }
  183. if (nbooks == 1 && entries == (1 << 16))
  184. {
  185. return code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
  186. }
  187. if (nbooks == 2 && entries == (1 << 8))
  188. {
  189. return code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
  190. }
  191. TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
  192. return {};
  193. }