Browse Source

feat: add swiglu activation

AlpinDale 1 năm trước cách đây
mục cha
commit
28866137ea
2 tập tin đã thay đổi với 56 bổ sung0 xóa
  1. 12 0
      kernels/activation.cpp
  2. 44 0
      kernels/activation_kernels.cu

+ 12 - 0
kernels/activation.cpp

@@ -0,0 +1,12 @@
+#include <torch/extension.h>
+
+void silu_and_mul(
+    torch::Tensor& out,
+    torch::Tensor& input);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def(
+        "silu_and_mul",
+        &silu_and_mul,
+        "Activation function for SwiGLU.");
+}

+ 44 - 0
kernels/activation_kernels.cu

@@ -0,0 +1,44 @@
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+
+namespace aphrodite {
+template<typename T>
+__device__ __forceinline__ T silu(const T& x) {
+    return (T) (((float) x) / (1.0f + expf((float) - x)));
+}
+
+template<typename scalar_t>
+__global__ void silu_and_mul_kernel(
+    scalar_t* __restrict__ out,
+    const scalar_t* __restrict__ input,
+    const int d) {
+    const itn token_idx = blockIdx.x; idx < d; idx += blockDim.x {
+        const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
+        const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
+        out[token_idx * d + idx] = silu(x) * y;
+    }
+}
+}
+
+void silu_and_mul(
+    torch::Tensor& out,
+    torch::Tensor& input)
+{
+    int num_tokens = input.size(0);
+    int d = input.size(1) / 2;
+
+    dim3 grid(num_tokens);
+    dim3 block(std::min(d, 1024));
+    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+    AT_DISPATCH_FLOATING_TYPES_AND2(
+        at::ScalarType::Half,
+        at::ScalarType::BFloat16,
+        input.scalar_type(),
+        "silu_and_mul_kernel",
+        [&] {
+        aphrodite::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
+            out.data_ptr<scalar_t>(),
+            input.data_ptr<scalar_t>(),
+            d);
+        });
+}