1
0

bgmv_impl.cuh 11 KB


  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <cooperative_groups.h>
  4. #include <cuda/pipeline>
  5. #include <cuda_runtime.h>
  6. #include <iostream>
  7. #include <stdio.h>
  8. #include "vec_dtypes.cuh"
  9. namespace cg = cooperative_groups;
  10. // nthrs = (32, 4)
  11. template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
  12. size_t W_copy_size, int tx, int ty, int tz, typename in_T,
  13. typename out_T, typename W_T>
  14. __global__ void
  15. bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
  16. const W_T *__restrict__ W,
  17. const int64_t *__restrict__ indicies, int64_t y_offset,
  18. int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
  19. float scale) {
  20. size_t batch_idx = blockIdx.y;
  21. int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
  22. if (idx < 0) {
  23. return;
  24. }
  25. auto block = cg::this_thread_block();
  26. size_t j = blockIdx.x;
  27. constexpr size_t num_pipeline_stages = 2;
  28. constexpr size_t tile_size = tx * ty * vec_size;
  29. __shared__ W_T W_shared[num_pipeline_stages * tile_size];
  30. __shared__ in_T X_shared[num_pipeline_stages * tile_size];
  31. __shared__ float y_warpwise[ty];
  32. size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
  33. size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
  34. auto pipe = cuda::make_pipeline();
  35. // pipeline load W/X and compute WX;
  36. pipe.producer_acquire();
  37. cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
  38. W + (idx * feat_out + j) * feat_in +
  39. (threadIdx.y * tx + threadIdx.x) * vec_size,
  40. cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
  41. cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
  42. X + (batch_idx * feat_in) +
  43. (threadIdx.y * tx + threadIdx.x) * vec_size,
  44. cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
  45. pipe.producer_commit();
  46. size_t copy_idx, compute_idx;
  47. float y = 0.f;
  48. vec_t<in_T, vec_size> x_vec;
  49. vec_t<W_T, vec_size> w_vec;
  50. size_t tile_idx;
  51. #pragma unroll
  52. for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
  53. ++tile_idx) {
  54. copy_idx = tile_idx % num_pipeline_stages;
  55. // pipeline stage: async copy W fragment
  56. pipe.producer_acquire();
  57. if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
  58. cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
  59. (threadIdx.y * tx + threadIdx.x) * vec_size,
  60. W + (idx * feat_out + j) * feat_in +
  61. tile_idx * tile_size +
  62. (threadIdx.y * tx + threadIdx.x) * vec_size,
  63. cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
  64. cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
  65. (threadIdx.y * tx + threadIdx.x) * vec_size,
  66. X + (batch_idx * feat_in) + tile_idx * tile_size +
  67. (threadIdx.y * tx + threadIdx.x) * vec_size,
  68. cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
  69. }
  70. pipe.producer_commit();
  71. compute_idx = (tile_idx - 1) % num_pipeline_stages;
  72. // pipeline stage: compute WX
  73. pipe.consumer_wait();
  74. block.sync();
  75. x_vec.load(X_shared + X_shared_offset[compute_idx] +
  76. (threadIdx.y * tx + threadIdx.x) * vec_size);
  77. w_vec.load(W_shared + W_shared_offset[compute_idx] +
  78. (threadIdx.y * tx + threadIdx.x) * vec_size);
  79. float sum = 0.f;
  80. #pragma unroll
  81. for (size_t i = 0; i < vec_size; ++i) {
  82. sum += float(w_vec[i]) * float(x_vec[i]) * scale;
  83. }
  84. #pragma unroll
  85. for (size_t offset = tx / 2; offset > 0; offset /= 2) {
  86. sum += __shfl_down_sync(0xffffffff, sum, offset);
  87. }
  88. y_warpwise[threadIdx.y] = sum;
  89. block.sync();
  90. #pragma unroll
  91. for (size_t i = 0; i < ty; ++i) {
  92. y += y_warpwise[i];
  93. }
  94. block.sync();
  95. pipe.consumer_release();
  96. }
  97. compute_idx = (tile_idx - 1) % num_pipeline_stages;
  98. // final pipeline stage
  99. pipe.consumer_wait();
  100. block.sync();
  101. x_vec.load(X_shared + X_shared_offset[compute_idx] +
  102. (threadIdx.y * tx + threadIdx.x) * vec_size);
  103. w_vec.load(W_shared + W_shared_offset[compute_idx] +
  104. (threadIdx.y * tx + threadIdx.x) * vec_size);
  105. float sum = 0.f;
  106. #pragma unroll
  107. for (size_t i = 0; i < vec_size; ++i) {
  108. sum += float(w_vec[i]) * float(x_vec[i]) * scale;
  109. }
  110. #pragma unroll
  111. for (size_t offset = tx / 2; offset > 0; offset /= 2) {
  112. sum += __shfl_down_sync(0xffffffff, sum, offset);
  113. }
  114. y_warpwise[threadIdx.y] =
  115. ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
  116. ? sum
  117. : 0.f;
  118. block.sync();
  119. #pragma unroll
  120. for (size_t i = 0; i < ty; ++i) {
  121. y += y_warpwise[i];
  122. }
  123. block.sync();
  124. pipe.consumer_release();
  125. // write Y;
  126. if (block.thread_rank() == 0) {
  127. Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
  128. }
  129. }
  130. // nthrs = (2, 16, 4)
  131. template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
  132. typename in_T, typename out_T, typename W_T>
  133. __global__ void
  134. bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
  135. const W_T *__restrict__ W,
  136. const int64_t *__restrict__ indicies, int64_t y_offset,
  137. int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
  138. float scale) {
  139. size_t batch_idx = blockIdx.y;
  140. int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
  141. if (idx < 0) {
  142. return;
  143. }
  144. auto block = cg::this_thread_block();
  145. size_t tile_idx = blockIdx.x;
  146. // load X;
  147. vec_t<in_T, vec_size> x_vec;
  148. x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
  149. // load W;
  150. vec_t<W_T, vec_size> w_vec;
  151. w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
  152. block.thread_rank() * vec_size);
  153. float sum = 0.f;
  154. #pragma unroll
  155. for (size_t i = 0; i < vec_size; ++i) {
  156. sum += float(w_vec[i]) * float(x_vec[i]) * scale;
  157. }
  158. cg::thread_block_tile g = cg::tiled_partition<tx>(block);
  159. #pragma unroll
  160. for (size_t offset = tx / 2; offset > 0; offset /= 2) {
  161. sum += g.shfl_down(sum, offset);
  162. }
  163. sum = g.shfl(sum, 0);
  164. if (threadIdx.x == 0) {
  165. Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
  166. threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
  167. }
  168. }
  169. template <int feat_in, int feat_out, typename in_T, typename out_T,
  170. typename W_T>
  171. void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
  172. const W_T *__restrict__ W,
  173. const int64_t *__restrict__ indicies, int64_t y_offset,
  174. int64_t full_y_size, int64_t batch_size, int64_t num_layers,
  175. int64_t layer_idx, float scale) {
  176. constexpr size_t vec_size = 8;
  177. constexpr int tz = 4;
  178. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  179. if constexpr (feat_in < feat_out) {
  180. static_assert(feat_in % vec_size == 0);
  181. constexpr int tx = feat_in / vec_size;
  182. static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
  183. (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
  184. (8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
  185. if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
  186. constexpr int ty = 32 / tx;
  187. dim3 nblks(feat_out / (ty * tz), batch_size);
  188. dim3 nthrs(tx, ty, tz);
  189. bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
  190. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  191. full_y_size, num_layers, layer_idx,
  192. scale);
  193. } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
  194. constexpr int ty = 16 / tx;
  195. dim3 nblks(feat_out / (ty * tz), batch_size);
  196. dim3 nthrs(tx, ty, tz);
  197. bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
  198. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  199. full_y_size, num_layers, layer_idx,
  200. scale);
  201. } else {
  202. constexpr int ty = 8 / tx;
  203. dim3 nblks(feat_out / (ty * tz), batch_size);
  204. dim3 nthrs(tx, ty, tz);
  205. bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
  206. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  207. full_y_size, num_layers, layer_idx,
  208. scale);
  209. }
  210. } else {
  211. static_assert(feat_in % (vec_size * 32) == 0 ||
  212. feat_in % (vec_size * 16) == 0 ||
  213. feat_in % (vec_size * 8) == 0);
  214. if constexpr (feat_in % (vec_size * 32) == 0) {
  215. constexpr int tx = 32;
  216. constexpr int ty = 4;
  217. dim3 nblks(feat_out, batch_size);
  218. dim3 nthrs(tx, ty);
  219. bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
  220. vec_size * sizeof(W_T), tx, ty, tz>
  221. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  222. full_y_size, num_layers, layer_idx,
  223. scale);
  224. } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
  225. constexpr int tx = 32;
  226. constexpr int ty = 4;
  227. dim3 nblks(feat_out, batch_size);
  228. dim3 nthrs(tx, ty);
  229. bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
  230. vec_size * sizeof(in_T) / 2,
  231. vec_size * sizeof(W_T) / 2, tx, ty, tz>
  232. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  233. full_y_size, num_layers, layer_idx,
  234. scale);
  235. } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
  236. constexpr int tx = 16;
  237. constexpr int ty = 4;
  238. dim3 nblks(feat_out, batch_size);
  239. dim3 nthrs(tx, ty);
  240. bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
  241. vec_size * sizeof(in_T) / 2,
  242. vec_size * sizeof(W_T) / 2, tx, ty, tz>
  243. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  244. full_y_size, num_layers, layer_idx,
  245. scale);
  246. }
  247. }
  248. }
  249. #define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
  250. template void bgmv_kernel<feat_in, feat_out>( \
  251. out_T * __restrict__ Y, const in_T *__restrict__ X, \
  252. const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
  253. int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
  254. int64_t num_layers, int64_t layer_idx, float scale);
  255. #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
  256. INST_BGMV(narrow, wide, in_T, out_T, W_T) \
  257. INST_BGMV(wide, narrow, in_T, out_T, W_T)