#pragma once #include #ifndef USE_ROCM #include #else #include #endif #ifndef USE_ROCM #include #endif #include #include #include #include "vec_dtypes.cuh" namespace cg = cooperative_groups; #ifdef USE_ROCM template __host__ __device__ inline void* memcpy_blocking(void *dst, const void *src) { // does not handle the case of long dtypes char *d = reinterpret_cast(dst); const char *s = reinterpret_cast(src); size_t i = 0; #pragma unroll for (i = 0; i < len; ++i) { d[i] = s[i]; } return dst; } #endif #ifndef USE_ROCM // nthrs = (32, 4) template __global__ void bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, const W_T *__restrict__ W, const int64_t *__restrict__ indicies, int64_t y_offset, int64_t full_y_size, int64_t num_layers, int64_t layer_idx, float scale) { size_t batch_idx = blockIdx.y; int64_t idx = indicies[batch_idx] * num_layers + layer_idx; if (idx < 0) { return; } auto block = cg::this_thread_block(); size_t j = blockIdx.x; constexpr size_t num_pipeline_stages = 2; constexpr size_t tile_size = tx * ty * vec_size; __shared__ W_T W_shared[num_pipeline_stages * tile_size]; __shared__ in_T X_shared[num_pipeline_stages * tile_size]; __shared__ float y_warpwise[ty]; size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; auto pipe = cuda::make_pipeline(); // pipeline load W/X and compute WX; pipe.producer_acquire(); cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, W + (idx * feat_out + j) * feat_in + (threadIdx.y * tx + threadIdx.x) * vec_size, cuda::aligned_size_t(W_copy_size), pipe); cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, X + (batch_idx * feat_in) + (threadIdx.y * tx + threadIdx.x) * vec_size, cuda::aligned_size_t(X_copy_size), pipe); pipe.producer_commit(); size_t copy_idx, compute_idx; float y = 0.f; vec_t x_vec; vec_t w_vec; size_t tile_idx; #pragma unroll for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; ++tile_idx) { copy_idx = tile_idx % num_pipeline_stages; // pipeline stage: async copy W fragment pipe.producer_acquire(); if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size, W + (idx * feat_out + j) * feat_in + tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x) * vec_size, cuda::aligned_size_t(W_copy_size), pipe); cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size, X + (batch_idx * feat_in) + tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x) * vec_size, cuda::aligned_size_t(X_copy_size), pipe); } pipe.producer_commit(); compute_idx = (tile_idx - 1) % num_pipeline_stages; // pipeline stage: compute WX pipe.consumer_wait(); block.sync(); x_vec.load(X_shared + X_shared_offset[compute_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size); w_vec.load(W_shared + W_shared_offset[compute_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size); float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { sum += float(w_vec[i]) * float(x_vec[i]) * scale; } #pragma unroll for (size_t offset = tx / 2; offset > 0; offset /= 2) { sum += __shfl_down_sync(0xffffffff, sum, offset); } y_warpwise[threadIdx.y] = sum; block.sync(); #pragma unroll for (size_t i = 0; i < ty; ++i) { y += y_warpwise[i]; } block.sync(); pipe.consumer_release(); } compute_idx = (tile_idx - 1) % num_pipeline_stages; // final pipeline stage pipe.consumer_wait(); block.sync(); x_vec.load(X_shared + X_shared_offset[compute_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size); w_vec.load(W_shared + W_shared_offset[compute_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size); float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { sum += float(w_vec[i]) * float(x_vec[i]) * scale; } #pragma unroll for (size_t offset = tx / 2; offset > 0; offset /= 2) { sum += __shfl_down_sync(0xffffffff, sum, offset); } y_warpwise[threadIdx.y] = ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) ? sum : 0.f; block.sync(); #pragma unroll for (size_t i = 0; i < ty; ++i) { y += y_warpwise[i]; } block.sync(); pipe.consumer_release(); // write Y; if (block.thread_rank() == 0) { Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); } } #else template __global__ void bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, const W_T *__restrict__ W, const int64_t *__restrict__ indicies, int64_t y_offset, int64_t full_y_size, int64_t num_layers, int64_t layer_idx, float scale) { size_t batch_idx = blockIdx.y; int64_t idx = indicies[batch_idx] * num_layers + layer_idx; if (idx < 0) { return; } size_t j = blockIdx.x; constexpr size_t tile_size = tx * ty * vec_size; constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size; __shared__ float y_warpwise[ty]; float y = 0; vec_t x_vec; vec_t w_vec; size_t tile_idx; #pragma unroll for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { x_vec.load(X + (batch_idx * feat_in) + tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x) * vec_size); w_vec.load(W + (idx * feat_out + j) * feat_in + tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x) * vec_size); } float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; } #pragma unroll for (size_t offset = tx / 2; offset > 0; offset /= 2) { sum += APHRODITE_SHFL_DOWN_SYNC(sum, offset); } __syncthreads(); if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) { y += sum; } } if (threadIdx.x == 0) { y_warpwise[threadIdx.y] = y; } __syncthreads(); float y_write = 0.f; #pragma unroll for (size_t i = 0; i < ty; ++i) { y_write += y_warpwise[i]; } // write Y; if (threadIdx.x == 0 && threadIdx.y == 0) { size_t y_idx = batch_idx * full_y_size + y_offset + j; Y[y_idx] = aphrodite_add(Y[y_idx], convert_type(y_write)); } } #endif // nthrs = (2, 16, 4) template __global__ void bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, const W_T *__restrict__ W, const int64_t *__restrict__ indicies, int64_t y_offset, int64_t full_y_size, int64_t num_layers, int64_t layer_idx, float scale) { size_t batch_idx = blockIdx.y; int64_t idx = indicies[batch_idx] * num_layers + layer_idx; if (idx < 0) { return; } auto block = cg::this_thread_block(); size_t tile_idx = blockIdx.x; // load X; vec_t x_vec; x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); // load W; vec_t w_vec; w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + block.thread_rank() * vec_size); float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { #ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; #else sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; #endif } cg::thread_block_tile g = cg::tiled_partition(block); #pragma unroll for (size_t offset = tx / 2; offset > 0; offset /= 2) { sum += g.shfl_down(sum, offset); } sum = g.shfl(sum, 0); if (threadIdx.x == 0) { #ifndef USE_ROCM Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + threadIdx.z * ty + threadIdx.y] += static_cast(sum); #else size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + threadIdx.z * ty + threadIdx.y; Y[y_idx] = aphrodite_add(Y[y_idx], convert_type(sum)); #endif } } template void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, const W_T *__restrict__ W, const int64_t *__restrict__ indicies, int64_t y_offset, int64_t full_y_size, int64_t batch_size, int64_t num_layers, int64_t layer_idx, float scale) { constexpr size_t vec_size = 8; constexpr int tz = 4; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if constexpr (feat_in <= feat_out) { static_assert(feat_in % vec_size == 0); constexpr int tx = feat_in / vec_size; static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { constexpr int ty = 32 / tx; dim3 nblks(feat_out / (ty * tz), batch_size); dim3 nthrs(tx, ty, tz); bgmv_expand_kernel <<>>(Y, X, W, indicies, y_offset, full_y_size, num_layers, layer_idx, scale); } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { constexpr int ty = 16 / tx; dim3 nblks(feat_out / (ty * tz), batch_size); dim3 nthrs(tx, ty, tz); bgmv_expand_kernel <<>>(Y, X, W, indicies, y_offset, full_y_size, num_layers, layer_idx, scale); } else { constexpr int ty = 8 / tx; dim3 nblks(feat_out / (ty * tz), batch_size); dim3 nthrs(tx, ty, tz); bgmv_expand_kernel <<>>(Y, X, W, indicies, y_offset, full_y_size, num_layers, layer_idx, scale); } } else { #ifndef USE_ROCM static_assert(feat_in % (vec_size * 32) == 0 || feat_in % (vec_size * 16) == 0 || feat_in % (vec_size * 8) == 0); if constexpr (feat_in % (vec_size * 32) == 0) { constexpr int tx = 32; constexpr int ty = 4; dim3 nblks(feat_out, batch_size); dim3 nthrs(tx, ty); bgmv_shrink_kernel <<>>(Y, X, W, indicies, y_offset, full_y_size, num_layers, layer_idx, scale); } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { constexpr int tx = 32; constexpr int ty = 4; dim3 nblks(feat_out, batch_size); dim3 nthrs(tx, ty); bgmv_shrink_kernel <<>>(Y, X, W, indicies, y_offset, full_y_size, num_layers, layer_idx, scale); } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { constexpr int tx = 16; constexpr int ty = 4; dim3 nblks(feat_out, batch_size); dim3 nthrs(tx, ty); bgmv_shrink_kernel <<>>(Y, X, W, indicies, y_offset, full_y_size, num_layers, layer_idx, scale); } #else constexpr size_t rocm_warp_size = warpSize; #define CHECK_INPUT_TILEABLE_BY(vec_size_) \ feat_in % (rocm_warp_size * vec_size_) == 0 #define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \ if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \ constexpr size_t vec_size_shrink = vec_size_; \ constexpr int tx = tx_; \ constexpr int ty = ty_; \ dim3 nblks(feat_out, batch_size); \ dim3 nthrs(tx, ty); \ bgmv_shrink_kernel \ <<>>(Y, X, W, indicies, y_offset, \ full_y_size, num_layers, layer_idx, \ scale); \ } static_assert(CHECK_INPUT_TILEABLE_BY(32) || CHECK_INPUT_TILEABLE_BY(16) || CHECK_INPUT_TILEABLE_BY( 8) || CHECK_INPUT_TILEABLE_BY( 4) || CHECK_INPUT_TILEABLE_BY( 2) || CHECK_INPUT_TILEABLE_BY( 1)); LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size) else LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size) else LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size) else LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4) else LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2) else LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1) #undef CHECK_INPUT_TILEABLE_BY #undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM #endif } } #define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ template void bgmv_kernel( \ out_T * __restrict__ Y, const in_T *__restrict__ X, \ const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ int64_t num_layers, int64_t layer_idx, float scale); #define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \ INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ INST_BGMV(narrow, wide, in_T, out_T, W_T) \ INST_BGMV(wide, narrow, in_T, out_T, W_T)