activation_kernels.cu 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. namespace aphrodite {
  4. template<typename T>
  5. __device__ __forceinline__ T silu(const T& x) {
  6. // x * sigmoid(x)
  7. return (T) (((float) x) / (1.0f + expf((float) -x)));
  8. }
  9. template<typename scalar_t>
  10. __global__ void silu_and_mul_kernel(
  11. scalar_t* __restrict__ out, // [num_tokens, d]
  12. const scalar_t* __restrict__ input, // [num_tokens, 2, d]
  13. const int d) {
  14. const int token_idx = blockIdx.x;
  15. for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
  16. const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
  17. const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
  18. out[token_idx * d + idx] = silu(x) * y;
  19. }
  20. }
  21. } // namespace aphrodite
  22. void silu_and_mul(
  23. torch::Tensor& out, // [num_tokens, d]
  24. torch::Tensor& input) // [num_tokens, 2 * d]
  25. {
  26. int num_tokens = input.size(0);
  27. int d = input.size(1) / 2;
  28. dim3 grid(num_tokens);
  29. dim3 block(std::min(d, 1024));
  30. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  31. AT_DISPATCH_FLOATING_TYPES_AND2(
  32. at::ScalarType::Half,
  33. at::ScalarType::BFloat16,
  34. input.scalar_type(),
  35. "silu_and_mul_kernel",
  36. [&] {
  37. aphrodite::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
  38. out.data_ptr<scalar_t>(),
  39. input.data_ptr<scalar_t>(),
  40. d);
  41. });
  42. }