Browse Source

feat: add layernorm kernels

AlpinDale 1 year ago
parent
commit
0ec53128b6
3 changed files with 117 additions and 0 deletions
  1. 14 0
      kernels/layernorm.cpp
  2. 60 0
      kernels/layernorm_kernels.cu
  3. 43 0
      kernels/reduction.cuh

+ 14 - 0
kernels/layernorm.cpp

@@ -0,0 +1,14 @@
+#include <torch/extension.h>
+
+void rms_norm(
+    torch::Tensor& out,
+    torch::Tensor& input,
+    torch::Tensor& weight,
+    float epsilon);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def(
+        "rms_norm",
+        &rms_norm,
+        "Apply Root Mean Square (RMS) Normalization to the input tensors.");
+}

+ 60 - 0
kernels/layernorm_kernels.cu

@@ -0,0 +1,60 @@
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include "reduction_utils.cuh"
+
+namespace aphrodite {
+template<typename scalar_t>
+__global__ void rms_norm_kernel(
+    scalar_t* __restrict__ out,
+    const scalar_t* __restrict__ input,
+    const scalar_t* __restrict__ weight,
+    const float epsilon,
+    const int num_tokens,
+    const int hidden_size) {
+    __shared__ float s_variance;
+    float variance = 0.0f;
+
+    for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
+        const float x = (float) input[blockIdx.x * hidden_size + idx];
+        variance += x * x;
+    }
+    variance = blockReduceSum<float>(variance);
+    if (threadIdx.x == 0) {
+        s_variance = rsqrtf(variance / hidden_size + epsilon);
+    }
+    __syncthreads();
+
+    for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
+        float x = (float) input[blockIdx.x * hidden_size + idx];
+        out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
+    }
+}
+}
+
+void rms_norm(
+    torch::Tensor& out,
+    torch::Tensor& input,
+    torch::Tensor& weight,
+    float epsilon) {
+    int num_tokens = input.size(0);
+    int hidden_size = input.size(1);
+
+    dim3 grid(num_tokens);
+    dim3 block(std::min(hidden_size, 1024));
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    AT_DISPATCH_FLOATING_TYPES_AND2(
+        at::ScalarType::Half,
+        at::ScalarType::BFloat16,
+        input.scalar_type(),
+        "rms_norm_kernel",
+        [&] {
+            aphrodite::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
+                out.data_ptr<scalar_t>(),
+                input.data_ptr<scalar_t>(),
+                weight.data_ptr<scalar_t>(),
+                epsilon,
+                num_tokens,
+                hidden_size);
+        });
+}

+ 43 - 0
kernels/reduction.cuh

@@ -0,0 +1,43 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
+ * Copyright (c) 2023 The PygmalionAI team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+namespace aphrodite {
+template<typename T>
+__inline__ __device__ T blockReduceSum(T val)
+{
+    static __shared__ T shared[32];
+    int                 lane = threadIdx.x & 0x1f;
+    int                 wid  = threadIdx.x >> 5;
+
+    val = warpReduceSum<T>(val);
+
+    if (lane == 0)
+        shared[wid] = val;
+
+    __syncthreads();
+
+    val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
+    val = warpReduceSum<T>(val);
+    return val;
+
+}
+
+
+}