#include #include #include #include "dispatch_utils.h" #include "reduction.cuh" namespace aphrodite { // TODO: Further optimize this kernel. template __global__ void rms_norm_kernel( scalar_t* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { const float x = (float) input[blockIdx.x * hidden_size + idx]; variance += x * x; } variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float) input[blockIdx.x * hidden_size + idx]; out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; } } // TODO: Further optimize this kernel. template __global__ void fused_add_rms_norm_kernel( scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float epsilon, const int num_tokens, const int hidden_size) { __shared__ float s_variance; float variance = 0.0f; for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float) input[blockIdx.x * hidden_size + idx]; x += (float) residual[blockIdx.x * hidden_size + idx]; variance += x * x; residual[blockIdx.x * hidden_size + idx] = (scalar_t) x; } variance = blockReduceSum(variance); if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } __syncthreads(); for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float) residual[blockIdx.x * hidden_size + idx]; input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx]; } } } // namespace aphrodite void rms_norm( torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); APHRODITE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "rms_norm_kernel", [&] { aphrodite::rms_norm_kernel<<>>( out.data_ptr(), input.data_ptr(), weight.data_ptr(), epsilon, num_tokens, hidden_size); }); } void fused_add_rms_norm( torch::Tensor& input, // [..., hidden_size] torch::Tensor& residual, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] float epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); APHRODITE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "fused_add_rms_norm_kernel", [&] { aphrodite::fused_add_rms_norm_kernel<<>>( input.data_ptr(), residual.data_ptr(), weight.data_ptr(), epsilon, num_tokens, hidden_size); }); }