123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- #include <torch/extension.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <c10/cuda/CUDAGuard.h>
- #include "dispatch_utils.h"
- #include "reduction.cuh"
- namespace aphrodite {
- // TODO: Further optimize this kernel.
- template<typename scalar_t>
- __global__ void rms_norm_kernel(
- scalar_t* __restrict__ out, // [..., hidden_size]
- const scalar_t* __restrict__ input, // [..., hidden_size]
- const scalar_t* __restrict__ weight, // [hidden_size]
- const float epsilon,
- const int num_tokens,
- const int hidden_size) {
- __shared__ float s_variance;
- float variance = 0.0f;
- for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
- const float x = (float) input[blockIdx.x * hidden_size + idx];
- variance += x * x;
- }
- variance = blockReduceSum<float>(variance);
- if (threadIdx.x == 0) {
- s_variance = rsqrtf(variance / hidden_size + epsilon);
- }
- __syncthreads();
- for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
- float x = (float) input[blockIdx.x * hidden_size + idx];
- out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
- }
- }
- // TODO: Further optimize this kernel.
- template<typename scalar_t>
- __global__ void fused_add_rms_norm_kernel(
- scalar_t* __restrict__ input, // [..., hidden_size]
- scalar_t* __restrict__ residual, // [..., hidden_size]
- const scalar_t* __restrict__ weight, // [hidden_size]
- const float epsilon,
- const int num_tokens,
- const int hidden_size) {
- __shared__ float s_variance;
- float variance = 0.0f;
- for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
- float x = (float) input[blockIdx.x * hidden_size + idx];
- x += (float) residual[blockIdx.x * hidden_size + idx];
- variance += x * x;
- residual[blockIdx.x * hidden_size + idx] = (scalar_t) x;
- }
- variance = blockReduceSum<float>(variance);
- if (threadIdx.x == 0) {
- s_variance = rsqrtf(variance / hidden_size + epsilon);
- }
- __syncthreads();
- for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
- float x = (float) residual[blockIdx.x * hidden_size + idx];
- input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
- }
- }
- } // namespace aphrodite
- void rms_norm(
- torch::Tensor& out, // [..., hidden_size]
- torch::Tensor& input, // [..., hidden_size]
- torch::Tensor& weight, // [hidden_size]
- float epsilon) {
- int hidden_size = input.size(-1);
- int num_tokens = input.numel() / hidden_size;
- dim3 grid(num_tokens);
- dim3 block(std::min(hidden_size, 1024));
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- APHRODITE_DISPATCH_FLOATING_TYPES(
- input.scalar_type(),
- "rms_norm_kernel",
- [&] {
- aphrodite::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
- out.data_ptr<scalar_t>(),
- input.data_ptr<scalar_t>(),
- weight.data_ptr<scalar_t>(),
- epsilon,
- num_tokens,
- hidden_size);
- });
- }
- void fused_add_rms_norm(
- torch::Tensor& input, // [..., hidden_size]
- torch::Tensor& residual, // [..., hidden_size]
- torch::Tensor& weight, // [hidden_size]
- float epsilon) {
- int hidden_size = input.size(-1);
- int num_tokens = input.numel() / hidden_size;
- dim3 grid(num_tokens);
- dim3 block(std::min(hidden_size, 1024));
- const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
- const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
- APHRODITE_DISPATCH_FLOATING_TYPES(
- input.scalar_type(),
- "fused_add_rms_norm_kernel",
- [&] {
- aphrodite::fused_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
- input.data_ptr<scalar_t>(),
- residual.data_ptr<scalar_t>(),
- weight.data_ptr<scalar_t>(),
- epsilon,
- num_tokens,
- hidden_size);
- });
- }
|