123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- /******************************************************************************
- * Copyright (c) 2023, Tri Dao.
- ******************************************************************************/
- #include <c10/util/BFloat16.h>
- #include <c10/util/Half.h>
- #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
- #include <cub/block/block_load.cuh>
- #include <cub/block/block_store.cuh>
- #include "causal_conv1d.h"
- #include "causal_conv1d_common.h"
- #include "static_switch.h"
- template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
- struct Causal_conv1d_update_kernel_traits {
- using input_t = input_t_;
- using weight_t = weight_t_;
- static constexpr int kNThreads = kNThreads_;
- static constexpr int kWidth = kWidth_;
- static constexpr int kNBytes = sizeof(input_t);
- static_assert(kNBytes == 2 || kNBytes == 4);
- };
- template<typename Ktraits>
- __global__ __launch_bounds__(Ktraits::kNThreads)
- void causal_conv1d_update_kernel(ConvParamsBase params) {
- constexpr int kWidth = Ktraits::kWidth;
- constexpr int kNThreads = Ktraits::kNThreads;
- using input_t = typename Ktraits::input_t;
- using weight_t = typename Ktraits::weight_t;
- const int tidx = threadIdx.x;
- const int batch_id = blockIdx.x;
- const int channel_id = blockIdx.y * kNThreads + tidx;
- input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
- + channel_id * params.x_c_stride;
- input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride
- + channel_id * params.conv_state_c_stride;
- weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
- input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
- + channel_id * params.out_c_stride;
- float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
- float weight_vals[kWidth] = {0};
- if (channel_id < params.dim) {
- #pragma unroll
- for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
- }
- float x_vals[kWidth] = {0};
- if (channel_id < params.dim) {
- #pragma unroll
- for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
- x_vals[kWidth - 1] = float(x[0]);
- #pragma unroll
- for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
- }
- float out_val = bias_val;
- #pragma unroll
- for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
- if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
- if (channel_id < params.dim) { out[0] = input_t(out_val); }
- }
- template<int kNThreads, int kWidth, typename input_t, typename weight_t>
- void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
- using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
- dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
- auto kernel = &causal_conv1d_update_kernel<Ktraits>;
- kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
- C10_CUDA_KERNEL_LAUNCH_CHECK();
- }
- template<typename input_t, typename weight_t>
- void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
- if (params.width == 2) {
- causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
- } else if (params.width == 3) {
- causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
- } else if (params.width == 4) {
- causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
- }
- }
- template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
- template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
- template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
- template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
- template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
- template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
- template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
- template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
- template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|