causal_conv1d_fwd.cu 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. /******************************************************************************
  2. * Copyright (c) 2024, Tri Dao.
  3. ******************************************************************************/
  4. #include <c10/util/BFloat16.h>
  5. #include <c10/util/Half.h>
  6. #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
  7. #include <cub/block/block_load.cuh>
  8. #include <cub/block/block_store.cuh>
  9. #include "causal_conv1d.h"
  10. #include "causal_conv1d_common.h"
  11. #include "static_switch.h"
  12. template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
  13. struct Causal_conv1d_fwd_kernel_traits {
  14. using input_t = input_t_;
  15. using weight_t = weight_t_;
  16. static constexpr int kNThreads = kNThreads_;
  17. static constexpr int kWidth = kWidth_;
  18. static constexpr int kNBytes = sizeof(input_t);
  19. static_assert(kNBytes == 2 || kNBytes == 4);
  20. static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
  21. static_assert(kWidth <= kNElts);
  22. static constexpr bool kIsVecLoad = kIsVecLoad_;
  23. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  24. using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  25. using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
  26. using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
  27. using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
  28. static constexpr int kSmemIOSize = kIsVecLoad
  29. ? 0
  30. : std::max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
  31. static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
  32. static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
  33. };
  34. template<typename Ktraits>
  35. __global__ __launch_bounds__(Ktraits::kNThreads)
  36. void causal_conv1d_fwd_kernel(ConvParamsBase params) {
  37. constexpr int kWidth = Ktraits::kWidth;
  38. constexpr int kNThreads = Ktraits::kNThreads;
  39. constexpr int kNElts = Ktraits::kNElts;
  40. static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
  41. using input_t = typename Ktraits::input_t;
  42. using vec_t = typename Ktraits::vec_t;
  43. using weight_t = typename Ktraits::weight_t;
  44. // Shared memory.
  45. extern __shared__ char smem_[];
  46. auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
  47. auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
  48. auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
  49. auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
  50. vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
  51. const int tidx = threadIdx.x;
  52. const int batch_id = blockIdx.x;
  53. const int channel_id = blockIdx.y;
  54. input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
  55. + channel_id * params.x_c_stride;
  56. weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
  57. input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
  58. + channel_id * params.out_c_stride;
  59. float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
  60. // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
  61. if (tidx == 0) {
  62. input_t zeros[kNElts] = {0};
  63. smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
  64. }
  65. float weight_vals[kWidth];
  66. #pragma unroll
  67. for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
  68. constexpr int kChunkSize = kNThreads * kNElts;
  69. const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
  70. for (int chunk = 0; chunk < n_chunks; ++chunk) {
  71. input_t x_vals_load[2 * kNElts] = {0};
  72. if constexpr(kIsVecLoad) {
  73. Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
  74. } else {
  75. __syncthreads();
  76. Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
  77. }
  78. x += kChunkSize;
  79. __syncthreads();
  80. // Thread kNThreads - 1 don't write yet, so that thread 0 can read
  81. // the last elements of the previous chunk.
  82. if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
  83. __syncthreads();
  84. reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
  85. __syncthreads();
  86. // Now thread kNThreads - 1 can write the last elements of the current chunk.
  87. if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
  88. float x_vals[2 * kNElts];
  89. #pragma unroll
  90. for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
  91. float out_vals[kNElts];
  92. #pragma unroll
  93. for (int i = 0; i < kNElts; ++i) {
  94. out_vals[i] = bias_val;
  95. #pragma unroll
  96. for (int w = 0; w < kWidth; ++w) {
  97. out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
  98. }
  99. }
  100. if (params.silu_activation) {
  101. #pragma unroll
  102. for (int i = 0; i < kNElts; ++i) {
  103. out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
  104. }
  105. }
  106. input_t out_vals_store[kNElts];
  107. #pragma unroll
  108. for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
  109. if constexpr(kIsVecLoad) {
  110. Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
  111. } else {
  112. Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
  113. }
  114. out += kChunkSize;
  115. }
  116. }
  117. template<int kNThreads, int kWidth, typename input_t, typename weight_t>
  118. void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
  119. static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
  120. BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
  121. using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
  122. constexpr int kSmemSize = Ktraits::kSmemSize;
  123. dim3 grid(params.batch, params.dim);
  124. auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
  125. if (kSmemSize >= 48 * 1024) {
  126. C10_CUDA_CHECK(cudaFuncSetAttribute(
  127. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  128. }
  129. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  130. C10_CUDA_KERNEL_LAUNCH_CHECK();
  131. });
  132. }
  133. template<typename input_t, typename weight_t>
  134. void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
  135. if (params.width == 2) {
  136. causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
  137. } else if (params.width == 3) {
  138. causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
  139. } else if (params.width == 4) {
  140. causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
  141. }
  142. }
  143. template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
  144. struct Causal_conv1d_channellast_fwd_kernel_traits {
  145. // The cache line is 128 bytes, and we try to read 16 bytes per thread.
  146. // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
  147. // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
  148. // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
  149. using input_t = input_t_;
  150. using weight_t = weight_t_;
  151. static constexpr int kNThreads = kNThreads_;
  152. static_assert(kNThreads % 32 == 0);
  153. static constexpr int kNWarps = kNThreads / 32;
  154. static constexpr int kWidth = kWidth_;
  155. static constexpr int kChunkSizeL = kChunkSizeL_;
  156. static constexpr int kNBytes = sizeof(input_t);
  157. static_assert(kNBytes == 2 || kNBytes == 4);
  158. static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
  159. static constexpr int kNEltsPerRow = 128 / kNBytes;
  160. static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
  161. static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
  162. static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
  163. static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
  164. static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
  165. static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
  166. static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
  167. static constexpr bool kIsVecLoad = kIsVecLoad_;
  168. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  169. };
  170. template<typename Ktraits, bool kHasSeqIdx>
  171. __global__ __launch_bounds__(Ktraits::kNThreads)
  172. void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
  173. constexpr int kWidth = Ktraits::kWidth;
  174. constexpr int kNThreads = Ktraits::kNThreads;
  175. constexpr int kNElts = Ktraits::kNElts;
  176. constexpr int kNWarp = Ktraits::kNWarps;
  177. constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
  178. constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
  179. constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
  180. constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
  181. using input_t = typename Ktraits::input_t;
  182. using vec_t = typename Ktraits::vec_t;
  183. using weight_t = typename Ktraits::weight_t;
  184. // Shared memory.
  185. __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
  186. const int batch_id = blockIdx.x;
  187. const int chunk_l_id = blockIdx.y;
  188. const int chunk_c_id = blockIdx.z;
  189. const int tid = threadIdx.x;
  190. const int l_idx = tid / kNThreadsPerC;
  191. const int c_idx = tid % kNThreadsPerC;
  192. input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
  193. + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
  194. weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
  195. + chunk_c_id * kChunkSizeC * params.weight_c_stride;
  196. input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
  197. + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
  198. int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
  199. + batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
  200. input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
  201. : reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
  202. // The last L-chunk will also have enough info to write to final states, since it also contain a few x values
  203. // from the previous L-chunk.
  204. input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
  205. : reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
  206. #pragma unroll
  207. for (int l = 0; l < Ktraits::kNLoads; ++l) {
  208. input_t x_vals_load[kNElts] = {0};
  209. if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
  210. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  211. reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
  212. }
  213. reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
  214. }
  215. // Load the elements from the previous chunk that are needed for convolution.
  216. if (l_idx < kWidth - 1) {
  217. input_t x_vals_load[kNElts] = {0};
  218. if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
  219. && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
  220. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  221. reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
  222. } else if (initial_states != nullptr
  223. && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
  224. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  225. reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
  226. }
  227. reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
  228. }
  229. __syncthreads();
  230. if (final_states != nullptr
  231. && l_idx < kWidth - 1
  232. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  233. // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
  234. // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
  235. *reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
  236. }
  237. constexpr int kLPerThread = std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
  238. static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
  239. constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
  240. static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
  241. // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
  242. static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
  243. static_assert((kLPerThread & (kLPerThread - 1)) == 0);
  244. static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
  245. static_assert(kNThreadsPerRow <= 32);
  246. const int row_idx = tid / kNThreadsPerRow;
  247. const int col_idx = tid % kNThreadsPerRow;
  248. float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
  249. float weight_vals[kWidth] = {0};
  250. if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
  251. #pragma unroll
  252. for (int w = 0; w < kWidth; ++w) {
  253. weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
  254. }
  255. }
  256. float x_vals[kWidth - 1 + kLPerThread];
  257. #pragma unroll
  258. for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
  259. x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
  260. }
  261. int seq_idx_thread[kWidth - 1 + kLPerThread];
  262. if constexpr (kHasSeqIdx) {
  263. #pragma unroll
  264. for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
  265. seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
  266. }
  267. }
  268. float out_vals[kLPerThread];
  269. #pragma unroll
  270. for (int i = 0; i < kLPerThread; ++i) {
  271. out_vals[i] = bias_val;
  272. const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
  273. #pragma unroll
  274. for (int w = 0; w < kWidth; ++w) {
  275. if constexpr (!kHasSeqIdx) {
  276. out_vals[i] += weight_vals[w] * x_vals[i + w];
  277. } else {
  278. out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
  279. }
  280. }
  281. if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
  282. }
  283. __syncthreads();
  284. #pragma unroll
  285. for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
  286. __syncthreads();
  287. #pragma unroll
  288. for (int l = 0; l < Ktraits::kNLoads; ++l) {
  289. input_t out_vals_store[kNElts];
  290. reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
  291. if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
  292. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  293. *reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
  294. }
  295. }
  296. }
  297. template<int kNThreads, int kWidth, typename input_t, typename weight_t>
  298. void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
  299. BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
  300. using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
  301. // constexpr int kSmemSize = Ktraits::kSmemSize;
  302. constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
  303. constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
  304. const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
  305. const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
  306. dim3 grid(params.batch, n_chunks_L, n_chunks_C);
  307. dim3 block(Ktraits::kNThreads);
  308. auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
  309. kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
  310. C10_CUDA_KERNEL_LAUNCH_CHECK();
  311. });
  312. }
  313. template<typename input_t, typename weight_t>
  314. void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
  315. if (params.width == 2) {
  316. causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
  317. } else if (params.width == 3) {
  318. causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
  319. } else if (params.width == 4) {
  320. causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
  321. }
  322. }
  323. template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
  324. template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
  325. template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
  326. template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  327. template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  328. template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  329. template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
  330. template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
  331. template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
  332. template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
  333. template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
  334. template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
  335. template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  336. template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  337. template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  338. template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
  339. template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
  340. template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);