format.cu 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <iostream>
  18. #include "common.h"
  19. namespace aphrodite {
  20. namespace autoquant {
  21. __device__ void atomic_assign_u4(uint32_t* address, uint32_t index, uint32_t value)
  22. {
  23. uint32_t old = *address;
  24. uint32_t assumed;
  25. do {
  26. assumed = old;
  27. uint32_t tmp = (assumed & ~(0xfu << (index * 4u))) | (value << (index * 4u));
  28. old = atomicCAS(address, assumed, tmp);
  29. } while (assumed != old);
  30. }
  31. __device__ uint32_t read_u4(const uint32_t* address, uint32_t index)
  32. {
  33. return (*address >> (index * 4u)) & 0xfu;
  34. }
  35. template<int... Ds>
  36. __global__ void permute_u4(uint* dst, const uint* src, Array<int, sizeof...(Ds)> dims)
  37. {
  38. constexpr int N = sizeof...(Ds);
  39. size_t count = 1;
  40. PRAGMA_UNROLL
  41. for (int i = 0; i < N; ++i) {
  42. count *= dims[i];
  43. }
  44. constexpr int order[] = {Ds...};
  45. for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
  46. int indices[N]{};
  47. PRAGMA_UNROLL
  48. for (int j = N - 1, ii = i; j >= 0; --j) {
  49. indices[j] = ii % dims[j];
  50. ii /= dims[j];
  51. }
  52. auto data = read_u4(src + i / 8, i % 8);
  53. int index = 0;
  54. PRAGMA_UNROLL
  55. for (int j = N - 1, stride = 1; j >= 0; --j) {
  56. index += indices[order[j]] * stride;
  57. stride *= dims[order[j]];
  58. }
  59. atomic_assign_u4(dst + index / 8, index % 8, data);
  60. }
  61. }
  62. void reformat_s4_k8_m(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
  63. {
  64. // permutation for [k/8, m] layout
  65. Array<int, 10> shape{k / 32, 2, 2, m / 32, 2, 2, 8, 2, 2, 2};
  66. // |warp| lane | 2x2 | a0-7 |
  67. permute_u4<0, 3, 6, 8, 9, 1, 4, 7, 2, 5><<<512, 512, 0, st>>>(dst, src, shape);
  68. }
  69. void reformat_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
  70. {
  71. // permutation for [k, m/8] layout
  72. Array<int, 10> shape{k / 32, 2, 2, 4, 2, m / 32, 2, 2, 2, 4};
  73. // |warp| lane | 2x2 | a0-7 |
  74. permute_u4<0, 5, 9, 8, 3, 1, 6, 4, 2, 7><<<512, 512, 0, st>>>(dst, src, shape);
  75. }
  76. __global__ void dequantize_s4_offset_64(uint4* dst, const uint32_t* src, size_t count)
  77. {
  78. for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
  79. dst[i] = dequantize_s4_to_fp16x2_v2(src[i]);
  80. }
  81. }
  82. __global__ void dequantize_s4_offset_64_bf16(uint4* dst, const uint32_t* src, size_t count)
  83. {
  84. for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
  85. dst[i] = dequantize_s4_to_bf16x2_v2(src[i]);
  86. }
  87. }
  88. __global__ void merge_Q(half2* Q, const half* scales, const half* zeros, int count)
  89. {
  90. for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
  91. Q[i] = __halves2half2(zeros[i], scales[i]);
  92. }
  93. }
  94. __global__ void merge_Q(__nv_bfloat162* Q, const __nv_bfloat16* scales, const __nv_bfloat16* zeros, int count)
  95. {
  96. for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
  97. Q[i] = halves2bfloat162(zeros[i], scales[i]);
  98. }
  99. }
  100. void convert_s4_k_m8(uint32_t* A_dst,
  101. half2* Q_dst,
  102. half* workspace,
  103. const uint32_t* A_src,
  104. const half* scales,
  105. const uint32_t* qzeros,
  106. int m,
  107. int k,
  108. int group_size,
  109. cudaStream_t st)
  110. {
  111. dequantize_s4_offset_64<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
  112. merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
  113. reformat_s4_k_m8(A_dst, A_src, m, k, st);
  114. }
  115. void convert_s4_k_m8(uint32_t* A_dst,
  116. __nv_bfloat162* Q_dst,
  117. __nv_bfloat16* workspace,
  118. const uint32_t* A_src,
  119. const __nv_bfloat16* scales,
  120. const uint32_t* qzeros,
  121. int m,
  122. int k,
  123. int group_size,
  124. cudaStream_t st)
  125. {
  126. dequantize_s4_offset_64_bf16<<<256, 256, 0, st>>>((uint4*)workspace, qzeros, k / group_size * m / 8);
  127. merge_Q<<<256, 256, 0, st>>>(Q_dst, scales, workspace, k / group_size * m);
  128. reformat_s4_k_m8(A_dst, A_src, m, k, st);
  129. }
  130. void transpose_qk_s4_k_m8_hf(uint32_t* dst, const uint32_t* src, int m, int k, int size_per_head, cudaStream_t st)
  131. {
  132. Array<int, 7> shape{k, m / size_per_head, 2, size_per_head / 2 / 8, 2, 2, 2};
  133. // dequant transpose quant
  134. // 0123456 -> 0123564 -> 0135642 -> 0135264
  135. permute_u4<0, 1, 3, 5, 2, 6, 4><<<512, 512, 0, st>>>(dst, src, shape);
  136. }
  137. // [2, k, m/8] -> [k, m/8, 2]
  138. void fuse_w1_w3_s4_k_m8(uint32_t* dst, const uint32_t* src, int m, int k, cudaStream_t st)
  139. {
  140. Array<int, 6> shape{2, k, m / 8, 2, 2, 2};
  141. // dequant transpose quant
  142. // 012345 -> 012453 -> 124530 -> 124053
  143. permute_u4<1, 2, 4, 0, 5, 3><<<512, 512, 0, st>>>(dst, src, shape);
  144. }
  145. __global__ void dequantize_s4_kernel(uint4* dst, const uint* src, size_t count)
  146. {
  147. for (int i = threadIdx.x + blockDim.x * blockIdx.x; i < count; i += blockDim.x * gridDim.x) {
  148. dst[i] = dequantize_s4_to_fp16x2(src[i]);
  149. }
  150. }
  151. void dequantize_s4(uint4* dst, const uint32_t* src, size_t count, cudaStream_t st)
  152. {
  153. dequantize_s4_kernel<<<512, 512>>>(dst, src, count);
  154. }
  155. } // namespace autoquant
  156. } // namespace aphrodite