causal_conv1d_update.cu 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #include <c10/util/BFloat16.h>
  5. #include <c10/util/Half.h>
  6. #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
  7. #include <cub/block/block_load.cuh>
  8. #include <cub/block/block_store.cuh>
  9. #include "causal_conv1d.h"
  10. #include "causal_conv1d_common.h"
  11. #include "static_switch.h"
  12. template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
  13. struct Causal_conv1d_update_kernel_traits {
  14. using input_t = input_t_;
  15. using weight_t = weight_t_;
  16. static constexpr int kNThreads = kNThreads_;
  17. static constexpr int kWidth = kWidth_;
  18. static constexpr int kNBytes = sizeof(input_t);
  19. static_assert(kNBytes == 2 || kNBytes == 4);
  20. };
  21. template<typename Ktraits>
  22. __global__ __launch_bounds__(Ktraits::kNThreads)
  23. void causal_conv1d_update_kernel(ConvParamsBase params) {
  24. constexpr int kWidth = Ktraits::kWidth;
  25. constexpr int kNThreads = Ktraits::kNThreads;
  26. using input_t = typename Ktraits::input_t;
  27. using weight_t = typename Ktraits::weight_t;
  28. const int tidx = threadIdx.x;
  29. const int batch_id = blockIdx.x;
  30. const int channel_id = blockIdx.y * kNThreads + tidx;
  31. input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
  32. + channel_id * params.x_c_stride;
  33. input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
  34. + channel_id * params.conv_state_c_stride;
  35. weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
  36. input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
  37. + channel_id * params.out_c_stride;
  38. float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
  39. float weight_vals[kWidth] = {0};
  40. if (channel_id < params.dim) {
  41. #pragma unroll
  42. for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
  43. }
  44. float x_vals[kWidth] = {0};
  45. if (channel_id < params.dim) {
  46. #pragma unroll
  47. for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
  48. x_vals[kWidth - 1] = float(x[0]);
  49. #pragma unroll
  50. for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
  51. }
  52. float out_val = bias_val;
  53. #pragma unroll
  54. for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
  55. if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
  56. if (channel_id < params.dim) { out[0] = input_t(out_val); }
  57. }
  58. template<int kNThreads, int kWidth, typename input_t, typename weight_t>
  59. void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
  60. using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
  61. dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
  62. auto kernel = &causal_conv1d_update_kernel<Ktraits>;
  63. kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
  64. C10_CUDA_KERNEL_LAUNCH_CHECK();
  65. }
  66. template<typename input_t, typename weight_t>
  67. void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
  68. if (params.width == 2) {
  69. causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
  70. } else if (params.width == 3) {
  71. causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
  72. } else if (params.width == 4) {
  73. causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
  74. }
  75. }
  76. template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
  77. template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
  78. template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
  79. template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  80. template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  81. template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  82. template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
  83. template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
  84. template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);