1
0

warp_iterator.h 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /*
  2. * Adapted from https://github.com/InternLM/lmdeploy
  3. * Copyright (c) OpenMMLab. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #pragma once
  18. #include "common.h"
  19. namespace aphrodite {
  20. namespace autoquant {
  21. template<int CTA_M,
  22. int CTA_K,
  23. int WARP_M,
  24. int WARP_K,
  25. int OP_M,
  26. int OP_K,
  27. int GROUP_SIZE,
  28. int STAGES,
  29. int kSizePerStageA,
  30. int kSizePerStageQ,
  31. typename T_BC,
  32. typename T_Q>
  33. struct WarpIteratorA {
  34. static_assert(WARP_K % GROUP_SIZE == 0 || GROUP_SIZE % WARP_K == 0);
  35. static constexpr int ITER_M = 32 / OP_M;
  36. static constexpr int ITER_X = WARP_M / 32;
  37. uint4 frag_A4_[ITER_X]; // 8 value per uint
  38. //half2 frag_Q_[ITER_X][4]; // 4 m8k8 tile along M, as WARP_M == 32
  39. T_Q frag_Q_[ITER_X][4]; // 4 m8k8 tile along M, as WARP_M == 32
  40. const uint4* smem_A_;
  41. const T_Q* smem_Q_;
  42. //const half2* smem_Q_;
  43. const int offset_m_;
  44. const int offset_m_Q_;
  45. int stage_{0};
  46. int offset_A_{0};
  47. int offset_Q_{0};
  48. //__device__ WarpIteratorA(uint4* smem_A, half2* smem_Q, int warp_id, int lane_id, int offset_m, int offset_k):
  49. __device__ WarpIteratorA(uint4* smem_A, T_Q* smem_Q, int warp_id, int lane_id, int offset_m, int offset_k):
  50. smem_A_(smem_A), smem_Q_(smem_Q), offset_m_(offset_m), offset_m_Q_(offset_m / 32 * 32 + lane_id / 4)
  51. {
  52. }
  53. // iter_k must be a compile tile constant
  54. __device__ void load(Array<T_BC, 8>* data, int iter_k)
  55. {
  56. // load A
  57. // smem_A uint4 [SLICE_K/32, CTA_M/32, WARP_SIZE], load as uint4 to avoid bank-conflicts
  58. if (iter_k % 2 == 0) {
  59. PRAGMA_UNROLL
  60. for (int x = 0; x < ITER_X; ++x) {
  61. frag_A4_[x] = smem_A_[offset_A_ + (iter_k / 2) * CTA_M + x * 32 + offset_m_];
  62. }
  63. }
  64. // load Q
  65. if (iter_k * OP_K % GROUP_SIZE == 0) {
  66. const int g = iter_k * OP_K / GROUP_SIZE;
  67. PRAGMA_UNROLL
  68. for (int x = 0; x < ITER_X; ++x) {
  69. PRAGMA_UNROLL
  70. for (int i = 0; i < 4; ++i) {
  71. const int mm = offset_m_Q_ + x * 32 + i * 8; // stride of m8k8 tile
  72. ((uint&)frag_Q_[x][i]) = ((uint&)smem_Q_[offset_Q_ + g * CTA_M + mm]);
  73. }
  74. }
  75. }
  76. PRAGMA_UNROLL
  77. for (int x = 0; x < ITER_X; ++x) {
  78. const uint* frag_A = (uint*)&frag_A4_[x];
  79. PRAGMA_UNROLL
  80. for (int iter_m = 0; iter_m < ITER_M; ++iter_m) {
  81. uint4 tmp;
  82. if(std::is_same<T_BC, half>::value){
  83. tmp = dequantize_s4_to_fp16x2_v2(frag_A[iter_k % 2 * 2 + iter_m]);
  84. }
  85. else{
  86. tmp = dequantize_s4_to_bf16x2_v2(frag_A[iter_k % 2 * 2 + iter_m]);
  87. }
  88. auto& vec = (Array<T_Q, 4>&)tmp;
  89. vec[0] = apply_Q(vec[0], frag_Q_[x][iter_m * 2]);
  90. vec[1] = apply_Q(vec[1], frag_Q_[x][iter_m * 2 + 1]);
  91. vec[2] = apply_Q(vec[2], frag_Q_[x][iter_m * 2]);
  92. vec[3] = apply_Q(vec[3], frag_Q_[x][iter_m * 2 + 1]);
  93. data[x * ITER_M + iter_m] = (Array<T_BC, 8>&)vec;
  94. }
  95. }
  96. }
  97. __device__ void next_stage()
  98. {
  99. ++stage_;
  100. if (stage_ >= STAGES) {
  101. stage_ = 0;
  102. }
  103. offset_A_ = stage_ * kSizePerStageA;
  104. offset_Q_ = stage_ * kSizePerStageQ;
  105. }
  106. };
  107. template<int CTA_N, int CTA_K, int WARP_N, int WARP_K, int OP_N, int OP_K, int SMEM_STRIDE, int STAGES, typename T_BC>
  108. struct WarpIteratorB {
  109. static constexpr int kLdsmNum = WARP_N == 8 ? 2 : 4;
  110. static constexpr int ITER_N = WARP_N / OP_N;
  111. static constexpr int ITER_K = WARP_K / OP_K;
  112. static_assert(OP_N == 8 && OP_K == 16);
  113. const int warp_id_n_;
  114. const int lane_id_;
  115. const int ldsm_group_id_;
  116. const int offset_k_;
  117. int offset_n_;
  118. const uint32_t smem_base_ptr_;
  119. uint32_t smem_ptr_;
  120. int stage_{0};
  121. __device__ WarpIteratorB(uint32_t smem_int_ptr, int warp_id_n, int lane_id, int offset_k):
  122. smem_base_ptr_(smem_int_ptr),
  123. smem_ptr_(smem_base_ptr_),
  124. warp_id_n_(warp_id_n),
  125. lane_id_(lane_id),
  126. ldsm_group_id_(lane_id / 8),
  127. offset_k_(ldsm_group_id_ % 2 * 8 + offset_k),
  128. offset_n_(ldsm_group_id_ / 2 * 8 + lane_id % 8)
  129. {
  130. if (kLdsmNum == 2) {
  131. offset_n_ -= ldsm_group_id_ / 2 * 8;
  132. }
  133. offset_n_ += warp_id_n_ * WARP_N;
  134. }
  135. __device__ void load(Array<T_BC, 4>* data, int iter_k)
  136. {
  137. const int kk = iter_k * OP_K + offset_k_;
  138. auto ptr = (uint*)data;
  139. PRAGMA_UNROLL
  140. for (int iter_n = 0; iter_n < ITER_N;) {
  141. const int nn = offset_n_ + iter_n * OP_N;
  142. auto src = smem_ptr_ + sizeof(T_BC) * (nn * SMEM_STRIDE + kk);
  143. if constexpr (kLdsmNum == 4) {
  144. ldmatrix_m8n8_x4_b16(ptr[0], ptr[1], ptr[2], ptr[3], src);
  145. ptr += 4;
  146. iter_n += 2;
  147. }
  148. else {
  149. ldmatrix_m8n8_x2_b16(ptr[0], ptr[1], src);
  150. ptr += 2;
  151. iter_n += 1;
  152. }
  153. }
  154. }
  155. __device__ void next_stage()
  156. {
  157. ++stage_;
  158. if (stage_ >= STAGES) {
  159. stage_ = 0;
  160. }
  161. smem_ptr_ = smem_base_ptr_ + stage_ * sizeof(half) * CTA_N * SMEM_STRIDE;
  162. }
  163. };
  164. } // namespace autoquant
  165. } // namespace aphrodite