1
0

bgmv_impl.cuh 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. #pragma once
  2. #include <ATen/cuda/CUDAContext.h>
  3. #ifndef USE_ROCM
  4. #include <cooperative_groups.h>
  5. #else
  6. #include <hip/hip_cooperative_groups.h>
  7. #endif
  8. #ifndef USE_ROCM
  9. #include <cuda/pipeline>
  10. #endif
  11. #include <cuda_runtime.h>
  12. #include <iostream>
  13. #include <stdio.h>
  14. #include "vec_dtypes.cuh"
  15. namespace cg = cooperative_groups;
  16. #ifdef USE_ROCM
  17. template <size_t len>
  18. __host__ __device__
  19. inline void* memcpy_blocking(void *dst, const void *src) {
  20. // does not handle the case of long dtypes
  21. char *d = reinterpret_cast<char*>(dst);
  22. const char *s = reinterpret_cast<const char *>(src);
  23. size_t i = 0;
  24. #pragma unroll
  25. for (i = 0; i < len; ++i) {
  26. d[i] = s[i];
  27. }
  28. return dst;
  29. }
  30. #endif
  31. #ifndef USE_ROCM
  32. // nthrs = (32, 4)
  33. template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
  34. size_t W_copy_size, int tx, int ty, int tz, typename in_T,
  35. typename out_T, typename W_T>
  36. __global__ void
  37. bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
  38. const W_T *__restrict__ W,
  39. const int64_t *__restrict__ indicies, int64_t y_offset,
  40. int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
  41. float scale) {
  42. size_t batch_idx = blockIdx.y;
  43. int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
  44. if (idx < 0) {
  45. return;
  46. }
  47. auto block = cg::this_thread_block();
  48. size_t j = blockIdx.x;
  49. constexpr size_t num_pipeline_stages = 2;
  50. constexpr size_t tile_size = tx * ty * vec_size;
  51. __shared__ W_T W_shared[num_pipeline_stages * tile_size];
  52. __shared__ in_T X_shared[num_pipeline_stages * tile_size];
  53. __shared__ float y_warpwise[ty];
  54. size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
  55. size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
  56. auto pipe = cuda::make_pipeline();
  57. // pipeline load W/X and compute WX;
  58. pipe.producer_acquire();
  59. cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
  60. W + (idx * feat_out + j) * feat_in +
  61. (threadIdx.y * tx + threadIdx.x) * vec_size,
  62. cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
  63. cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
  64. X + (batch_idx * feat_in) +
  65. (threadIdx.y * tx + threadIdx.x) * vec_size,
  66. cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
  67. pipe.producer_commit();
  68. size_t copy_idx, compute_idx;
  69. float y = 0.f;
  70. vec_t<in_T, vec_size> x_vec;
  71. vec_t<W_T, vec_size> w_vec;
  72. size_t tile_idx;
  73. #pragma unroll
  74. for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
  75. ++tile_idx) {
  76. copy_idx = tile_idx % num_pipeline_stages;
  77. // pipeline stage: async copy W fragment
  78. pipe.producer_acquire();
  79. if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
  80. cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
  81. (threadIdx.y * tx + threadIdx.x) * vec_size,
  82. W + (idx * feat_out + j) * feat_in +
  83. tile_idx * tile_size +
  84. (threadIdx.y * tx + threadIdx.x) * vec_size,
  85. cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
  86. cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
  87. (threadIdx.y * tx + threadIdx.x) * vec_size,
  88. X + (batch_idx * feat_in) + tile_idx * tile_size +
  89. (threadIdx.y * tx + threadIdx.x) * vec_size,
  90. cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
  91. }
  92. pipe.producer_commit();
  93. compute_idx = (tile_idx - 1) % num_pipeline_stages;
  94. // pipeline stage: compute WX
  95. pipe.consumer_wait();
  96. block.sync();
  97. x_vec.load(X_shared + X_shared_offset[compute_idx] +
  98. (threadIdx.y * tx + threadIdx.x) * vec_size);
  99. w_vec.load(W_shared + W_shared_offset[compute_idx] +
  100. (threadIdx.y * tx + threadIdx.x) * vec_size);
  101. float sum = 0.f;
  102. #pragma unroll
  103. for (size_t i = 0; i < vec_size; ++i) {
  104. sum += float(w_vec[i]) * float(x_vec[i]) * scale;
  105. }
  106. #pragma unroll
  107. for (size_t offset = tx / 2; offset > 0; offset /= 2) {
  108. sum += __shfl_down_sync(0xffffffff, sum, offset);
  109. }
  110. y_warpwise[threadIdx.y] = sum;
  111. block.sync();
  112. #pragma unroll
  113. for (size_t i = 0; i < ty; ++i) {
  114. y += y_warpwise[i];
  115. }
  116. block.sync();
  117. pipe.consumer_release();
  118. }
  119. compute_idx = (tile_idx - 1) % num_pipeline_stages;
  120. // final pipeline stage
  121. pipe.consumer_wait();
  122. block.sync();
  123. x_vec.load(X_shared + X_shared_offset[compute_idx] +
  124. (threadIdx.y * tx + threadIdx.x) * vec_size);
  125. w_vec.load(W_shared + W_shared_offset[compute_idx] +
  126. (threadIdx.y * tx + threadIdx.x) * vec_size);
  127. float sum = 0.f;
  128. #pragma unroll
  129. for (size_t i = 0; i < vec_size; ++i) {
  130. sum += float(w_vec[i]) * float(x_vec[i]) * scale;
  131. }
  132. #pragma unroll
  133. for (size_t offset = tx / 2; offset > 0; offset /= 2) {
  134. sum += __shfl_down_sync(0xffffffff, sum, offset);
  135. }
  136. y_warpwise[threadIdx.y] =
  137. ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
  138. ? sum
  139. : 0.f;
  140. block.sync();
  141. #pragma unroll
  142. for (size_t i = 0; i < ty; ++i) {
  143. y += y_warpwise[i];
  144. }
  145. block.sync();
  146. pipe.consumer_release();
  147. // write Y;
  148. if (block.thread_rank() == 0) {
  149. Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
  150. }
  151. }
  152. #else
  153. template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
  154. size_t W_copy_size, int tx, int ty, int tz, typename in_T,
  155. typename out_T, typename W_T>
  156. __global__ void
  157. bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
  158. const W_T *__restrict__ W,
  159. const int64_t *__restrict__ indicies, int64_t y_offset,
  160. int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
  161. float scale) {
  162. size_t batch_idx = blockIdx.y;
  163. int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
  164. if (idx < 0) {
  165. return;
  166. }
  167. size_t j = blockIdx.x;
  168. constexpr size_t tile_size = tx * ty * vec_size;
  169. constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
  170. __shared__ float y_warpwise[ty];
  171. float y = 0;
  172. vec_t<in_T, vec_size> x_vec;
  173. vec_t<W_T, vec_size> w_vec;
  174. size_t tile_idx;
  175. #pragma unroll
  176. for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
  177. if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
  178. x_vec.load(X + (batch_idx * feat_in) +
  179. tile_idx * tile_size +
  180. (threadIdx.y * tx + threadIdx.x) * vec_size);
  181. w_vec.load(W + (idx * feat_out + j) * feat_in +
  182. tile_idx * tile_size +
  183. (threadIdx.y * tx + threadIdx.x) * vec_size);
  184. }
  185. float sum = 0.f;
  186. #pragma unroll
  187. for (size_t i = 0; i < vec_size; ++i) {
  188. sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
  189. }
  190. #pragma unroll
  191. for (size_t offset = tx / 2; offset > 0; offset /= 2) {
  192. sum += APHRODITE_SHFL_DOWN_SYNC(sum, offset);
  193. }
  194. __syncthreads();
  195. if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
  196. y += sum;
  197. }
  198. }
  199. if (threadIdx.x == 0) {
  200. y_warpwise[threadIdx.y] = y;
  201. }
  202. __syncthreads();
  203. float y_write = 0.f;
  204. #pragma unroll
  205. for (size_t i = 0; i < ty; ++i) {
  206. y_write += y_warpwise[i];
  207. }
  208. // write Y;
  209. if (threadIdx.x == 0 && threadIdx.y == 0) {
  210. size_t y_idx = batch_idx * full_y_size + y_offset + j;
  211. Y[y_idx] = aphrodite_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
  212. }
  213. }
  214. #endif
  215. // nthrs = (2, 16, 4)
  216. template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
  217. typename in_T, typename out_T, typename W_T>
  218. __global__ void
  219. bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
  220. const W_T *__restrict__ W,
  221. const int64_t *__restrict__ indicies, int64_t y_offset,
  222. int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
  223. float scale) {
  224. size_t batch_idx = blockIdx.y;
  225. int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
  226. if (idx < 0) {
  227. return;
  228. }
  229. auto block = cg::this_thread_block();
  230. size_t tile_idx = blockIdx.x;
  231. // load X;
  232. vec_t<in_T, vec_size> x_vec;
  233. x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
  234. // load W;
  235. vec_t<W_T, vec_size> w_vec;
  236. w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
  237. block.thread_rank() * vec_size);
  238. float sum = 0.f;
  239. #pragma unroll
  240. for (size_t i = 0; i < vec_size; ++i) {
  241. #ifndef USE_ROCM
  242. sum += float(w_vec[i]) * float(x_vec[i]) * scale;
  243. #else
  244. sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
  245. #endif
  246. }
  247. cg::thread_block_tile g = cg::tiled_partition<tx>(block);
  248. #pragma unroll
  249. for (size_t offset = tx / 2; offset > 0; offset /= 2) {
  250. sum += g.shfl_down(sum, offset);
  251. }
  252. sum = g.shfl(sum, 0);
  253. if (threadIdx.x == 0) {
  254. #ifndef USE_ROCM
  255. Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
  256. threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
  257. #else
  258. size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
  259. threadIdx.z * ty + threadIdx.y;
  260. Y[y_idx] = aphrodite_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
  261. #endif
  262. }
  263. }
  264. template <int feat_in, int feat_out, typename in_T, typename out_T,
  265. typename W_T>
  266. void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
  267. const W_T *__restrict__ W,
  268. const int64_t *__restrict__ indicies, int64_t y_offset,
  269. int64_t full_y_size, int64_t batch_size, int64_t num_layers,
  270. int64_t layer_idx, float scale) {
  271. constexpr size_t vec_size = 8;
  272. constexpr int tz = 4;
  273. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  274. if constexpr (feat_in <= feat_out) {
  275. static_assert(feat_in % vec_size == 0);
  276. constexpr int tx = feat_in / vec_size;
  277. static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
  278. (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
  279. (8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
  280. if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
  281. constexpr int ty = 32 / tx;
  282. dim3 nblks(feat_out / (ty * tz), batch_size);
  283. dim3 nthrs(tx, ty, tz);
  284. bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
  285. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  286. full_y_size, num_layers, layer_idx,
  287. scale);
  288. } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
  289. constexpr int ty = 16 / tx;
  290. dim3 nblks(feat_out / (ty * tz), batch_size);
  291. dim3 nthrs(tx, ty, tz);
  292. bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
  293. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  294. full_y_size, num_layers, layer_idx,
  295. scale);
  296. } else {
  297. constexpr int ty = 8 / tx;
  298. dim3 nblks(feat_out / (ty * tz), batch_size);
  299. dim3 nthrs(tx, ty, tz);
  300. bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
  301. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  302. full_y_size, num_layers, layer_idx,
  303. scale);
  304. }
  305. } else {
  306. #ifndef USE_ROCM
  307. static_assert(feat_in % (vec_size * 32) == 0 ||
  308. feat_in % (vec_size * 16) == 0 ||
  309. feat_in % (vec_size * 8) == 0);
  310. if constexpr (feat_in % (vec_size * 32) == 0) {
  311. constexpr int tx = 32;
  312. constexpr int ty = 4;
  313. dim3 nblks(feat_out, batch_size);
  314. dim3 nthrs(tx, ty);
  315. bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
  316. vec_size * sizeof(W_T), tx, ty, tz>
  317. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  318. full_y_size, num_layers, layer_idx,
  319. scale);
  320. } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
  321. constexpr int tx = 32;
  322. constexpr int ty = 4;
  323. dim3 nblks(feat_out, batch_size);
  324. dim3 nthrs(tx, ty);
  325. bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
  326. vec_size * sizeof(in_T) / 2,
  327. vec_size * sizeof(W_T) / 2, tx, ty, tz>
  328. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  329. full_y_size, num_layers, layer_idx,
  330. scale);
  331. } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
  332. constexpr int tx = 16;
  333. constexpr int ty = 4;
  334. dim3 nblks(feat_out, batch_size);
  335. dim3 nthrs(tx, ty);
  336. bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
  337. vec_size * sizeof(in_T) / 2,
  338. vec_size * sizeof(W_T) / 2, tx, ty, tz>
  339. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
  340. full_y_size, num_layers, layer_idx,
  341. scale);
  342. }
  343. #else
  344. constexpr size_t rocm_warp_size = warpSize;
  345. #define CHECK_INPUT_TILEABLE_BY(vec_size_) \
  346. feat_in % (rocm_warp_size * vec_size_) == 0
  347. #define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
  348. if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
  349. constexpr size_t vec_size_shrink = vec_size_; \
  350. constexpr int tx = tx_; \
  351. constexpr int ty = ty_; \
  352. dim3 nblks(feat_out, batch_size); \
  353. dim3 nthrs(tx, ty); \
  354. bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
  355. vec_size_shrink * sizeof(in_T), \
  356. vec_size_shrink * sizeof(W_T), \
  357. tx, ty, tz> \
  358. <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
  359. full_y_size, num_layers, layer_idx, \
  360. scale); \
  361. }
  362. static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
  363. CHECK_INPUT_TILEABLE_BY(16) ||
  364. CHECK_INPUT_TILEABLE_BY( 8) ||
  365. CHECK_INPUT_TILEABLE_BY( 4) ||
  366. CHECK_INPUT_TILEABLE_BY( 2) ||
  367. CHECK_INPUT_TILEABLE_BY( 1));
  368. LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
  369. else
  370. LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
  371. else
  372. LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
  373. else
  374. LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
  375. else
  376. LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
  377. else
  378. LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
  379. #undef CHECK_INPUT_TILEABLE_BY
  380. #undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
  381. #endif
  382. }
  383. }
  384. #define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
  385. template void bgmv_kernel<feat_in, feat_out>( \
  386. out_T * __restrict__ Y, const in_T *__restrict__ X, \
  387. const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
  388. int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
  389. int64_t num_layers, int64_t layer_idx, float scale);
  390. #define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
  391. INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
  392. #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
  393. INST_BGMV(narrow, wide, in_T, out_T, W_T) \
  394. INST_BGMV(wide, narrow, in_T, out_T, W_T)