causal_conv1d.cu 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. #include <torch/all.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "causal_conv1d.h"
  5. #include <c10/util/BFloat16.h>
  6. #include <c10/util/Half.h>
  7. #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
  8. #include <cub/block/block_load.cuh>
  9. #include <cub/block/block_store.cuh>
  10. #include "static_switch.h"
  11. #define CHECK_SHAPE(x, ...) \
  12. TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
  13. #x " must have shape (" #__VA_ARGS__ ")")
  14. #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
  15. if (ITYPE == at::ScalarType::Half) { \
  16. using input_t = at::Half; \
  17. __VA_ARGS__(); \
  18. } else if (ITYPE == at::ScalarType::BFloat16) { \
  19. using input_t = at::BFloat16; \
  20. __VA_ARGS__(); \
  21. } else if (ITYPE == at::ScalarType::Float) { \
  22. using input_t = float; \
  23. __VA_ARGS__(); \
  24. } else { \
  25. AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), \
  26. "'"); \
  27. }
  28. #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
  29. if (WTYPE == at::ScalarType::Half) { \
  30. using weight_t = at::Half; \
  31. __VA_ARGS__(); \
  32. } else if (WTYPE == at::ScalarType::BFloat16) { \
  33. using weight_t = at::BFloat16; \
  34. __VA_ARGS__(); \
  35. } else if (WTYPE == at::ScalarType::Float) { \
  36. using weight_t = float; \
  37. __VA_ARGS__(); \
  38. } else { \
  39. AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), \
  40. "'"); \
  41. }
  42. template <typename input_t, typename weight_t>
  43. void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream);
  44. template <typename input_t, typename weight_t>
  45. void causal_conv1d_channellast_fwd_cuda(ConvParamsBase& params,
  46. cudaStream_t stream);
  47. template <typename input_t, typename weight_t>
  48. void causal_conv1d_update_cuda(ConvParamsBase& params, cudaStream_t stream);
  49. void set_conv_params_fwd(ConvParamsBase& params,
  50. // sizes
  51. const size_t batch, const size_t dim,
  52. const size_t seqlen, const size_t width,
  53. // device pointers
  54. const at::Tensor x, const at::Tensor weight,
  55. const at::Tensor out, void* bias_ptr,
  56. bool silu_activation) {
  57. // Reset the parameters
  58. memset(&params, 0, sizeof(params));
  59. params.batch = batch;
  60. params.dim = dim;
  61. params.seqlen = seqlen;
  62. params.width = width;
  63. params.silu_activation = silu_activation;
  64. // Set the pointers and strides.
  65. params.x_ptr = x.data_ptr();
  66. params.weight_ptr = weight.data_ptr();
  67. params.bias_ptr = bias_ptr;
  68. params.out_ptr = out.data_ptr();
  69. // All stride are in elements, not bytes.
  70. params.x_batch_stride = x.stride(0);
  71. params.x_c_stride = x.stride(1);
  72. params.x_l_stride = x.stride(-1);
  73. params.weight_c_stride = weight.stride(0);
  74. params.weight_width_stride = weight.stride(1);
  75. params.out_batch_stride = out.stride(0);
  76. params.out_c_stride = out.stride(1);
  77. params.out_l_stride = out.stride(-1);
  78. }
  79. at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
  80. const c10::optional<at::Tensor>& bias_,
  81. const c10::optional<at::Tensor>& seq_idx_,
  82. const c10::optional<at::Tensor>& seq_pos_idx_,
  83. const c10::optional<at::Tensor>& initial_states_,
  84. const c10::optional<at::Tensor>& final_states_out_,
  85. bool silu_activation) {
  86. auto input_type = x.scalar_type();
  87. auto weight_type = weight.scalar_type();
  88. TORCH_CHECK(input_type == at::ScalarType::Float ||
  89. input_type == at::ScalarType::Half ||
  90. input_type == at::ScalarType::BFloat16);
  91. TORCH_CHECK(weight_type == at::ScalarType::Float ||
  92. weight_type == at::ScalarType::Half ||
  93. weight_type == at::ScalarType::BFloat16);
  94. TORCH_CHECK(x.is_cuda());
  95. TORCH_CHECK(weight.is_cuda());
  96. const auto sizes = x.sizes();
  97. const int batch_size = sizes[0];
  98. const int dim = sizes[1];
  99. const int seqlen = sizes[2];
  100. const int width = weight.size(-1);
  101. CHECK_SHAPE(x, batch_size, dim, seqlen);
  102. CHECK_SHAPE(weight, dim, width);
  103. TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
  104. const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
  105. if (is_channel_last) {
  106. TORCH_CHECK(
  107. dim % 8 == 0,
  108. "causal_conv1d only supports channel dimension divisible by 8 for now");
  109. TORCH_CHECK(x.stride(2) % 8 == 0 && x.stride(0) % 8 == 0,
  110. "causal_conv1d with channel last layout requires strides "
  111. "(x.stride(0) and x.stride(2)) to be multiples of 8");
  112. }
  113. TORCH_CHECK(width >= 2 && width <= 4,
  114. "causal_conv1d only supports width between 2 and 4");
  115. if (bias_.has_value()) {
  116. auto bias = bias_.value();
  117. TORCH_CHECK(bias.scalar_type() == weight_type);
  118. TORCH_CHECK(bias.is_cuda());
  119. TORCH_CHECK(bias.stride(-1) == 1);
  120. CHECK_SHAPE(bias, dim);
  121. }
  122. if (seq_idx_.has_value()) {
  123. TORCH_CHECK(is_channel_last,
  124. "seq_idx is only supported for channel last layout");
  125. auto seq_idx = seq_idx_.value();
  126. TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
  127. TORCH_CHECK(seq_idx.is_cuda());
  128. TORCH_CHECK(seq_idx.is_contiguous());
  129. CHECK_SHAPE(seq_idx, batch_size, seqlen);
  130. }
  131. if (seq_pos_idx_.has_value()) {
  132. auto seq_pos_idx = seq_pos_idx_.value();
  133. TORCH_CHECK(seq_pos_idx.scalar_type() == torch::kInt32);
  134. TORCH_CHECK(seq_pos_idx.is_cuda());
  135. TORCH_CHECK(seq_pos_idx.is_contiguous());
  136. CHECK_SHAPE(seq_pos_idx, batch_size, seqlen);
  137. }
  138. at::Tensor out = torch::empty_like(x);
  139. ConvParamsBase params;
  140. set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
  141. bias_.has_value() ? bias_.value().data_ptr() : nullptr,
  142. silu_activation);
  143. if (seq_idx_.has_value()) {
  144. params.seq_idx_ptr = seq_idx_.value().data_ptr();
  145. } else {
  146. params.seq_idx_ptr = nullptr;
  147. }
  148. if (seq_pos_idx_.has_value()) {
  149. params.seq_pos_idx_ptr = seq_pos_idx_.value().data_ptr();
  150. } else {
  151. params.seq_pos_idx_ptr = nullptr;
  152. }
  153. if (initial_states_.has_value()) {
  154. TORCH_CHECK(is_channel_last,
  155. "initial_states is only supported for channel last layout");
  156. auto initial_states = initial_states_.value();
  157. TORCH_CHECK(initial_states.scalar_type() == input_type);
  158. TORCH_CHECK(initial_states.is_cuda());
  159. CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
  160. TORCH_CHECK(initial_states.stride(1) == 1);
  161. params.initial_states_ptr = initial_states.data_ptr();
  162. params.initial_states_batch_stride = initial_states.stride(0);
  163. params.initial_states_c_stride = initial_states.stride(1);
  164. params.initial_states_l_stride = initial_states.stride(2);
  165. } else {
  166. params.initial_states_ptr = nullptr;
  167. }
  168. if (final_states_out_.has_value()) {
  169. TORCH_CHECK(is_channel_last,
  170. "final_states is only supported for channel last layout");
  171. auto final_states = final_states_out_.value();
  172. TORCH_CHECK(final_states.scalar_type() == input_type);
  173. TORCH_CHECK(final_states.is_cuda());
  174. CHECK_SHAPE(final_states, batch_size, dim, width - 1);
  175. TORCH_CHECK(final_states.stride(1) == 1);
  176. params.final_states_ptr = final_states.data_ptr();
  177. params.final_states_batch_stride = final_states.stride(0);
  178. params.final_states_c_stride = final_states.stride(1);
  179. params.final_states_l_stride = final_states.stride(2);
  180. } else {
  181. params.final_states_ptr = nullptr;
  182. }
  183. // Otherwise the kernel will be launched from cuda:0 device
  184. // Cast to char to avoid compiler warning about narrowing
  185. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  186. auto stream = at::cuda::getCurrentCUDAStream().stream();
  187. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(
  188. x.scalar_type(), "causal_conv1d_fwd", [&] {
  189. DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(
  190. weight.scalar_type(), "causal_conv1d_fwd", [&] {
  191. if (!is_channel_last) {
  192. causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
  193. } else {
  194. causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params,
  195. stream);
  196. }
  197. });
  198. });
  199. return out;
  200. }
  201. at::Tensor causal_conv1d_update(const at::Tensor& x,
  202. const at::Tensor& conv_state,
  203. const at::Tensor& weight,
  204. const c10::optional<at::Tensor>& bias_,
  205. bool silu_activation) {
  206. auto input_type = x.scalar_type();
  207. auto weight_type = weight.scalar_type();
  208. TORCH_CHECK(input_type == at::ScalarType::Float ||
  209. input_type == at::ScalarType::Half ||
  210. input_type == at::ScalarType::BFloat16);
  211. TORCH_CHECK(weight_type == at::ScalarType::Float ||
  212. weight_type == at::ScalarType::Half ||
  213. weight_type == at::ScalarType::BFloat16);
  214. TORCH_CHECK(conv_state.scalar_type() == input_type);
  215. TORCH_CHECK(x.is_cuda());
  216. TORCH_CHECK(conv_state.is_cuda());
  217. TORCH_CHECK(weight.is_cuda());
  218. const auto sizes = x.sizes();
  219. const int batch_size = sizes[0];
  220. const int dim = sizes[1];
  221. const int width = weight.size(-1);
  222. CHECK_SHAPE(x, batch_size, dim);
  223. CHECK_SHAPE(conv_state, batch_size, dim, width);
  224. CHECK_SHAPE(weight, dim, width);
  225. TORCH_CHECK(width >= 2 && width <= 4,
  226. "causal_conv1d only supports width between 2 and 4");
  227. if (bias_.has_value()) {
  228. auto bias = bias_.value();
  229. TORCH_CHECK(bias.scalar_type() == weight_type);
  230. TORCH_CHECK(bias.is_cuda());
  231. TORCH_CHECK(bias.stride(-1) == 1);
  232. CHECK_SHAPE(bias, dim);
  233. }
  234. at::Tensor out = torch::empty_like(x);
  235. ConvParamsBase params;
  236. set_conv_params_fwd(
  237. params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
  238. bias_.has_value() ? bias_.value().data_ptr() : nullptr, silu_activation);
  239. params.conv_state_ptr = conv_state.data_ptr();
  240. // All stride are in elements, not bytes.
  241. params.conv_state_batch_stride = conv_state.stride(0);
  242. params.conv_state_c_stride = conv_state.stride(1);
  243. params.conv_state_l_stride = conv_state.stride(2);
  244. // Otherwise the kernel will be launched from cuda:0 device
  245. // Cast to char to avoid compiler warning about narrowing
  246. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  247. auto stream = at::cuda::getCurrentCUDAStream().stream();
  248. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(
  249. x.scalar_type(), "causal_conv1d_update", [&] {
  250. DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(
  251. weight.scalar_type(), "causal_conv1d_update", [&] {
  252. causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
  253. });
  254. });
  255. return out;
  256. }
  257. template <int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_,
  258. typename weight_t_>
  259. struct Causal_conv1d_fwd_kernel_traits {
  260. using input_t = input_t_;
  261. using weight_t = weight_t_;
  262. static constexpr int kNThreads = kNThreads_;
  263. static constexpr int kWidth = kWidth_;
  264. static constexpr int kNBytes = sizeof(input_t);
  265. static_assert(kNBytes == 2 || kNBytes == 4);
  266. static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
  267. static_assert(kWidth <= kNElts);
  268. static constexpr bool kIsVecLoad = kIsVecLoad_;
  269. static constexpr int kNLoadsIndex = kNElts / 4;
  270. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  271. using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts,
  272. cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  273. using BlockLoadVecT =
  274. cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
  275. using BlockLoadIndexT =
  276. cub::BlockLoad<int, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  277. using BlockLoadIndexVecT = cub::BlockLoad<int4, kNThreads, kNLoadsIndex,
  278. !(kIsVecLoad && kNLoadsIndex == 1)
  279. ? cub::BLOCK_LOAD_WARP_TRANSPOSE
  280. : cub::BLOCK_LOAD_DIRECT>;
  281. using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts,
  282. cub::BLOCK_STORE_WARP_TRANSPOSE>;
  283. using BlockStoreVecT =
  284. cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
  285. static constexpr int kSmemIOSize =
  286. (kIsVecLoad && kNLoadsIndex == 1)
  287. ? 0
  288. : std::max({sizeof(typename BlockLoadT::TempStorage),
  289. sizeof(typename BlockStoreT::TempStorage),
  290. sizeof(typename BlockLoadIndexT::TempStorage),
  291. sizeof(typename BlockLoadIndexVecT::TempStorage)});
  292. static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
  293. static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
  294. };
  295. template <typename Ktraits, bool kHasSeqPosIdx>
  296. __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_fwd_kernel(
  297. ConvParamsBase params) {
  298. constexpr int kWidth = Ktraits::kWidth;
  299. constexpr int kNThreads = Ktraits::kNThreads;
  300. constexpr int kNElts = Ktraits::kNElts;
  301. static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
  302. using input_t = typename Ktraits::input_t;
  303. using vec_t = typename Ktraits::vec_t;
  304. using weight_t = typename Ktraits::weight_t;
  305. // Shared memory.
  306. extern __shared__ char smem_[];
  307. [[maybe_unused]] auto& smem_load =
  308. reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
  309. [[maybe_unused]] auto& smem_load_vec =
  310. reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
  311. [[maybe_unused]] auto& smem_load_index =
  312. reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
  313. [[maybe_unused]] auto& smem_load_index_vec =
  314. reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(
  315. smem_);
  316. [[maybe_unused]] auto& smem_store =
  317. reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
  318. [[maybe_unused]] auto& smem_store_vec =
  319. reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
  320. vec_t* smem_exchange = reinterpret_cast<vec_t*>(smem_ + Ktraits::kSmemIOSize);
  321. const int tidx = threadIdx.x;
  322. const int batch_id = blockIdx.x;
  323. const int channel_id = blockIdx.y;
  324. input_t* x = reinterpret_cast<input_t*>(params.x_ptr) +
  325. batch_id * params.x_batch_stride +
  326. channel_id * params.x_c_stride;
  327. weight_t* weight = reinterpret_cast<weight_t*>(params.weight_ptr) +
  328. channel_id * params.weight_c_stride;
  329. input_t* out = reinterpret_cast<input_t*>(params.out_ptr) +
  330. batch_id * params.out_batch_stride +
  331. channel_id * params.out_c_stride;
  332. float bias_val =
  333. params.bias_ptr == nullptr
  334. ? 0.f
  335. : float(reinterpret_cast<weight_t*>(params.bias_ptr)[channel_id]);
  336. int* seq_pos_idx = !kHasSeqPosIdx
  337. ? nullptr
  338. : reinterpret_cast<int*>(params.seq_pos_idx_ptr) +
  339. batch_id * params.seqlen;
  340. // Thread 0 will load the last elements of the previous chunk, so we
  341. // initialize those to 0.
  342. if (tidx == 0) {
  343. input_t zeros[kNElts] = {0};
  344. smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t*>(zeros)[0];
  345. }
  346. float weight_vals[kWidth];
  347. #pragma unroll
  348. for (int i = 0; i < kWidth; ++i) {
  349. weight_vals[i] = float(weight[i * params.weight_width_stride]);
  350. }
  351. constexpr int kChunkSize = kNThreads * kNElts;
  352. const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
  353. for (int chunk = 0; chunk < n_chunks; ++chunk) {
  354. input_t x_vals_load[2 * kNElts] = {0};
  355. int seq_pos_idx_load[kNElts];
  356. if constexpr (kIsVecLoad) {
  357. Ktraits::BlockLoadVecT(smem_load_vec)
  358. .Load(reinterpret_cast<vec_t*>(x),
  359. *reinterpret_cast<vec_t(*)[1]>(&x_vals_load[kNElts]),
  360. (params.seqlen - chunk * kChunkSize) / kNElts);
  361. if (kHasSeqPosIdx)
  362. Ktraits::BlockLoadIndexVecT(smem_load_index_vec)
  363. .Load(reinterpret_cast<int4*>(seq_pos_idx),
  364. *reinterpret_cast<int4(*)[Ktraits::kNLoadsIndex]>(
  365. seq_pos_idx_load),
  366. (params.seqlen - chunk * kChunkSize) / kNElts *
  367. Ktraits::kNLoadsIndex);
  368. } else {
  369. __syncthreads();
  370. Ktraits::BlockLoadT(smem_load).Load(
  371. x, *reinterpret_cast<input_t(*)[kNElts]>(&x_vals_load[kNElts]),
  372. params.seqlen - chunk * kChunkSize);
  373. if (kHasSeqPosIdx)
  374. Ktraits::BlockLoadIndexT(smem_load_index)
  375. .Load(seq_pos_idx, seq_pos_idx_load,
  376. (params.seqlen - chunk * kChunkSize), 0);
  377. }
  378. x += kChunkSize;
  379. if (kHasSeqPosIdx) seq_pos_idx += kChunkSize;
  380. __syncthreads();
  381. // Thread kNThreads - 1 don't write yet, so that thread 0 can read
  382. // the last elements of the previous chunk.
  383. if (tidx < kNThreads - 1) {
  384. smem_exchange[tidx] = reinterpret_cast<vec_t*>(x_vals_load)[1];
  385. }
  386. __syncthreads();
  387. reinterpret_cast<vec_t*>(x_vals_load)[0] =
  388. smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
  389. __syncthreads();
  390. // Now thread kNThreads - 1 can write the last elements of the current
  391. // chunk.
  392. if (tidx == kNThreads - 1) {
  393. smem_exchange[tidx] = reinterpret_cast<vec_t*>(x_vals_load)[1];
  394. }
  395. float x_vals[2 * kNElts];
  396. #pragma unroll
  397. for (int i = 0; i < 2 * kNElts; ++i) {
  398. x_vals[i] = float(x_vals_load[i]);
  399. }
  400. float out_vals[kNElts];
  401. #pragma unroll
  402. for (int i = 0; i < kNElts; ++i) {
  403. out_vals[i] = bias_val;
  404. int w = 0;
  405. if (kHasSeqPosIdx) {
  406. if (seq_pos_idx_load[i] < kWidth) {
  407. w = kWidth - seq_pos_idx_load[i] - 1;
  408. }
  409. }
  410. for (; w < kWidth; ++w) {
  411. out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
  412. }
  413. }
  414. if (params.silu_activation) {
  415. #pragma unroll
  416. for (int i = 0; i < kNElts; ++i) {
  417. out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
  418. }
  419. }
  420. input_t out_vals_store[kNElts];
  421. #pragma unroll
  422. for (int i = 0; i < kNElts; ++i) {
  423. out_vals_store[i] = out_vals[i];
  424. }
  425. if constexpr (kIsVecLoad) {
  426. Ktraits::BlockStoreVecT(smem_store_vec)
  427. .Store(reinterpret_cast<vec_t*>(out),
  428. reinterpret_cast<vec_t(&)[1]>(out_vals_store),
  429. (params.seqlen - chunk * kChunkSize) / kNElts);
  430. } else {
  431. Ktraits::BlockStoreT(smem_store)
  432. .Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
  433. }
  434. out += kChunkSize;
  435. }
  436. }
  437. template <int kNThreads, int kWidth, typename input_t, typename weight_t>
  438. void causal_conv1d_fwd_launch(ConvParamsBase& params, cudaStream_t stream) {
  439. static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
  440. BOOL_SWITCH(params.seq_pos_idx_ptr != nullptr, kHasSeqPosIdx, [&] {
  441. BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
  442. using Ktraits =
  443. Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad,
  444. input_t, weight_t>;
  445. constexpr int kSmemSize = Ktraits::kSmemSize;
  446. dim3 grid(params.batch, params.dim);
  447. auto kernel = &causal_conv1d_fwd_kernel<Ktraits, kHasSeqPosIdx>;
  448. if (kSmemSize >= 48 * 1024) {
  449. C10_CUDA_CHECK(cudaFuncSetAttribute(
  450. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  451. }
  452. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  453. C10_CUDA_KERNEL_LAUNCH_CHECK();
  454. });
  455. });
  456. }
  457. template <typename input_t, typename weight_t>
  458. void causal_conv1d_fwd_cuda(ConvParamsBase& params, cudaStream_t stream) {
  459. if (params.width == 2) {
  460. causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
  461. } else if (params.width == 3) {
  462. causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
  463. } else if (params.width == 4) {
  464. causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
  465. }
  466. }
  467. template <int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_,
  468. typename input_t_, typename weight_t_>
  469. struct Causal_conv1d_channellast_fwd_kernel_traits {
  470. // The cache line is 128 bytes, and we try to read 16 bytes per thread.
  471. // So we have 8 threads per "row", so 32 or 64 elements in the channel
  472. // dimension. That leaves 4 columns per warp, and so 16 columns per block
  473. // (assuming each block has 128 threads). Each each load is 16 x 32|64
  474. // elements in the L x C dimensions.
  475. using input_t = input_t_;
  476. using weight_t = weight_t_;
  477. static constexpr int kNThreads = kNThreads_;
  478. static_assert(kNThreads % 32 == 0);
  479. static constexpr int kNWarps = kNThreads / 32;
  480. static constexpr int kWidth = kWidth_;
  481. static constexpr int kChunkSizeL = kChunkSizeL_;
  482. static constexpr int kNBytes = sizeof(input_t);
  483. static_assert(kNBytes == 2 || kNBytes == 4);
  484. static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
  485. static constexpr int kNEltsPerRow = 128 / kNBytes;
  486. static constexpr int kNThreadsPerRow =
  487. kNEltsPerRow / kNElts; // Always 8 for now
  488. static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
  489. static constexpr int kNColsPerWarp =
  490. 32 / kNThreadsPerRow; // Always 4 for now
  491. static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
  492. static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
  493. static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
  494. static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
  495. static constexpr bool kIsVecLoad = kIsVecLoad_;
  496. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  497. // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems,
  498. // cub::BLOCK_LOAD_WARP_TRANSPOSE>; using BlockStoreT =
  499. // cub::BlockStore<input_t, kNThreads, kNItems,
  500. // cub::BLOCK_STORE_WARP_TRANSPOSE>; static constexpr int kSmemSize =
  501. // std::max({sizeof(typename BlockLoadT::TempStorage),
  502. // sizeof(typename
  503. // BlockStoreT::TempStorage)});
  504. // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
  505. };
  506. template <typename Ktraits, bool kHasSeqIdx>
  507. __global__
  508. __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_channellast_fwd_kernel(
  509. ConvParamsBase params) {
  510. constexpr int kWidth = Ktraits::kWidth;
  511. constexpr int kNThreads = Ktraits::kNThreads;
  512. constexpr int kNElts = Ktraits::kNElts;
  513. constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
  514. constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
  515. constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
  516. constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
  517. using input_t = typename Ktraits::input_t;
  518. using vec_t = typename Ktraits::vec_t;
  519. using weight_t = typename Ktraits::weight_t;
  520. // Shared memory.
  521. __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
  522. const int batch_id = blockIdx.x;
  523. const int chunk_l_id = blockIdx.y;
  524. const int chunk_c_id = blockIdx.z;
  525. const int tid = threadIdx.x;
  526. const int l_idx = tid / kNThreadsPerC;
  527. const int c_idx = tid % kNThreadsPerC;
  528. input_t* x = reinterpret_cast<input_t*>(params.x_ptr) +
  529. batch_id * params.x_batch_stride +
  530. (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride +
  531. chunk_c_id * kChunkSizeC + c_idx * kNElts;
  532. weight_t* weight = reinterpret_cast<weight_t*>(params.weight_ptr) +
  533. chunk_c_id * kChunkSizeC * params.weight_c_stride;
  534. input_t* out = reinterpret_cast<input_t*>(params.out_ptr) +
  535. batch_id * params.out_batch_stride +
  536. (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride +
  537. chunk_c_id * kChunkSizeC + c_idx * kNElts;
  538. [[maybe_unused]] int* seq_idx =
  539. !kHasSeqIdx ? nullptr
  540. : reinterpret_cast<int*>(params.seq_idx_ptr) +
  541. batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
  542. input_t* initial_states =
  543. params.initial_states_ptr == nullptr || chunk_l_id > 0
  544. ? nullptr
  545. : reinterpret_cast<input_t*>(params.initial_states_ptr) +
  546. batch_id * params.initial_states_batch_stride +
  547. l_idx * params.initial_states_l_stride +
  548. chunk_c_id * kChunkSizeC + c_idx * kNElts;
  549. // The last L-chunk will also have enough info to write to final states, since
  550. // it also contain a few x values from the previous L-chunk.
  551. input_t* final_states =
  552. params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1
  553. ? nullptr
  554. : reinterpret_cast<input_t*>(params.final_states_ptr) +
  555. batch_id * params.final_states_batch_stride +
  556. l_idx * params.final_states_l_stride +
  557. chunk_c_id * kChunkSizeC + c_idx * kNElts;
  558. #pragma unroll
  559. for (int l = 0; l < Ktraits::kNLoads; ++l) {
  560. input_t x_vals_load[kNElts] = {0};
  561. if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen &&
  562. chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  563. reinterpret_cast<vec_t*>(x_vals_load)[0] =
  564. *reinterpret_cast<vec_t*>(x + l * kLPerLoad * params.x_l_stride);
  565. }
  566. reinterpret_cast<vec_t*>(
  567. x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] =
  568. reinterpret_cast<vec_t*>(x_vals_load)[0];
  569. }
  570. // Load the elements from the previous chunk that are needed for convolution.
  571. if (l_idx < kWidth - 1) {
  572. input_t x_vals_load[kNElts] = {0};
  573. if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 &&
  574. chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen &&
  575. chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  576. reinterpret_cast<vec_t*>(x_vals_load)[0] =
  577. *reinterpret_cast<vec_t*>(x - (kWidth - 1) * params.x_l_stride);
  578. } else if (initial_states != nullptr &&
  579. chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 &&
  580. chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  581. reinterpret_cast<vec_t*>(x_vals_load)[0] =
  582. *reinterpret_cast<vec_t*>(initial_states);
  583. }
  584. reinterpret_cast<vec_t*>(x_smem[l_idx])[c_idx] =
  585. reinterpret_cast<vec_t*>(x_vals_load)[0];
  586. }
  587. __syncthreads();
  588. if (final_states != nullptr && l_idx < kWidth - 1 &&
  589. chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  590. // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth -
  591. // 1) So last few elements (index params.seqlen - kWidth + 1 + l_idx) are
  592. // stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id *
  593. // kChunkSizeL - kWidth + 1)][c_idx]
  594. *reinterpret_cast<vec_t*>(final_states) = reinterpret_cast<vec_t*>(
  595. x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
  596. }
  597. constexpr int kLPerThread =
  598. std::min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
  599. static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
  600. constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
  601. static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
  602. // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for
  603. // simplicity
  604. static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
  605. static_assert((kLPerThread & (kLPerThread - 1)) == 0);
  606. static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
  607. static_assert(kNThreadsPerRow <= 32);
  608. const int row_idx = tid / kNThreadsPerRow;
  609. const int col_idx = tid % kNThreadsPerRow;
  610. float bias_val =
  611. params.bias_ptr == nullptr ||
  612. chunk_c_id * kChunkSizeC + row_idx >= params.dim
  613. ? 0.f
  614. : float(reinterpret_cast<weight_t*>(
  615. params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
  616. float weight_vals[kWidth] = {0};
  617. if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
  618. #pragma unroll
  619. for (int w = 0; w < kWidth; ++w) {
  620. weight_vals[w] = weight[row_idx * params.weight_c_stride +
  621. w * params.weight_width_stride];
  622. }
  623. }
  624. float x_vals[kWidth - 1 + kLPerThread];
  625. #pragma unroll
  626. for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
  627. x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
  628. }
  629. int seq_idx_thread[kWidth - 1 + kLPerThread];
  630. if constexpr (kHasSeqIdx) {
  631. #pragma unroll
  632. for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
  633. seq_idx_thread[i] =
  634. chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >=
  635. 0
  636. ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)]
  637. : -1;
  638. }
  639. }
  640. float out_vals[kLPerThread];
  641. #pragma unroll
  642. for (int i = 0; i < kLPerThread; ++i) {
  643. out_vals[i] = bias_val;
  644. [[maybe_unused]] const int seq_idx_cur =
  645. !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
  646. #pragma unroll
  647. for (int w = 0; w < kWidth; ++w) {
  648. if constexpr (!kHasSeqIdx) {
  649. out_vals[i] += weight_vals[w] * x_vals[i + w];
  650. } else {
  651. out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur
  652. ? weight_vals[w] * x_vals[i + w]
  653. : 0.f;
  654. }
  655. }
  656. if (params.silu_activation) {
  657. out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
  658. }
  659. }
  660. __syncthreads();
  661. #pragma unroll
  662. for (int i = 0; i < kLPerThread; ++i) {
  663. x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i];
  664. }
  665. __syncthreads();
  666. #pragma unroll
  667. for (int l = 0; l < Ktraits::kNLoads; ++l) {
  668. input_t out_vals_store[kNElts];
  669. reinterpret_cast<vec_t*>(out_vals_store)[0] =
  670. reinterpret_cast<vec_t*>(x_smem[l * kLPerLoad + l_idx])[c_idx];
  671. if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen &&
  672. chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  673. *reinterpret_cast<vec_t*>(out + l * kLPerLoad * params.out_l_stride) =
  674. reinterpret_cast<vec_t*>(out_vals_store)[0];
  675. }
  676. }
  677. }
  678. template <int kNThreads, int kWidth, typename input_t, typename weight_t>
  679. void causal_conv1d_channellast_fwd_launch(ConvParamsBase& params,
  680. cudaStream_t stream) {
  681. BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
  682. using Ktraits =
  683. Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true,
  684. input_t, weight_t>;
  685. // constexpr int kSmemSize = Ktraits::kSmemSize;
  686. constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
  687. constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
  688. const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
  689. const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
  690. dim3 grid(params.batch, n_chunks_L, n_chunks_C);
  691. dim3 block(Ktraits::kNThreads);
  692. auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
  693. // if (kSmemSize >= 48 * 1024) {
  694. // C10_CUDA_CHECK(cudaFuncSetAttribute(
  695. // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  696. // }
  697. // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  698. kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
  699. C10_CUDA_KERNEL_LAUNCH_CHECK();
  700. });
  701. }
  702. template <typename input_t, typename weight_t>
  703. void causal_conv1d_channellast_fwd_cuda(ConvParamsBase& params,
  704. cudaStream_t stream) {
  705. if (params.width == 2) {
  706. causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params,
  707. stream);
  708. } else if (params.width == 3) {
  709. causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params,
  710. stream);
  711. } else if (params.width == 4) {
  712. causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params,
  713. stream);
  714. }
  715. }
  716. template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase& params,
  717. cudaStream_t stream);
  718. template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase& params,
  719. cudaStream_t stream);
  720. template void causal_conv1d_fwd_cuda<at::BFloat16, float>(
  721. ConvParamsBase& params, cudaStream_t stream);
  722. template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase& params,
  723. cudaStream_t stream);
  724. template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase& params,
  725. cudaStream_t stream);
  726. template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(
  727. ConvParamsBase& params, cudaStream_t stream);
  728. template void causal_conv1d_fwd_cuda<float, at::BFloat16>(
  729. ConvParamsBase& params, cudaStream_t stream);
  730. template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(
  731. ConvParamsBase& params, cudaStream_t stream);
  732. template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(
  733. ConvParamsBase& params, cudaStream_t stream);
  734. template void causal_conv1d_channellast_fwd_cuda<float, float>(
  735. ConvParamsBase& params, cudaStream_t stream);
  736. template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(
  737. ConvParamsBase& params, cudaStream_t stream);
  738. template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(
  739. ConvParamsBase& params, cudaStream_t stream);
  740. template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(
  741. ConvParamsBase& params, cudaStream_t stream);
  742. template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(
  743. ConvParamsBase& params, cudaStream_t stream);
  744. template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(
  745. ConvParamsBase& params, cudaStream_t stream);
  746. template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(
  747. ConvParamsBase& params, cudaStream_t stream);
  748. template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(
  749. ConvParamsBase& params, cudaStream_t stream);
  750. template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(
  751. ConvParamsBase& params, cudaStream_t stream);
  752. ///////
  753. template <int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
  754. struct Causal_conv1d_update_kernel_traits {
  755. using input_t = input_t_;
  756. using weight_t = weight_t_;
  757. static constexpr int kNThreads = kNThreads_;
  758. static constexpr int kWidth = kWidth_;
  759. static constexpr int kNBytes = sizeof(input_t);
  760. static_assert(kNBytes == 2 || kNBytes == 4);
  761. };
  762. template <typename Ktraits>
  763. __global__
  764. __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kernel(
  765. ConvParamsBase params) {
  766. constexpr int kWidth = Ktraits::kWidth;
  767. constexpr int kNThreads = Ktraits::kNThreads;
  768. using input_t = typename Ktraits::input_t;
  769. using weight_t = typename Ktraits::weight_t;
  770. const int tidx = threadIdx.x;
  771. const int batch_id = blockIdx.x;
  772. const int channel_id = blockIdx.y * kNThreads + tidx;
  773. input_t* x = reinterpret_cast<input_t*>(params.x_ptr) +
  774. batch_id * params.x_batch_stride +
  775. channel_id * params.x_c_stride;
  776. input_t* conv_state = reinterpret_cast<input_t*>(params.conv_state_ptr) +
  777. batch_id * params.conv_state_batch_stride +
  778. channel_id * params.conv_state_c_stride;
  779. weight_t* weight = reinterpret_cast<weight_t*>(params.weight_ptr) +
  780. channel_id * params.weight_c_stride;
  781. input_t* out = reinterpret_cast<input_t*>(params.out_ptr) +
  782. batch_id * params.out_batch_stride +
  783. channel_id * params.out_c_stride;
  784. float bias_val =
  785. params.bias_ptr == nullptr || channel_id >= params.dim
  786. ? 0.f
  787. : float(reinterpret_cast<weight_t*>(params.bias_ptr)[channel_id]);
  788. float weight_vals[kWidth] = {0};
  789. if (channel_id < params.dim) {
  790. #pragma unroll
  791. for (int i = 0; i < kWidth; ++i) {
  792. weight_vals[i] = float(weight[i * params.weight_width_stride]);
  793. }
  794. }
  795. float x_vals[kWidth] = {0};
  796. if (channel_id < params.dim) {
  797. #pragma unroll
  798. for (int i = 0; i < kWidth - 1; ++i) {
  799. x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]);
  800. }
  801. x_vals[kWidth - 1] = float(x[0]);
  802. #pragma unroll
  803. for (int i = 0; i < kWidth; ++i) {
  804. conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]);
  805. }
  806. }
  807. float out_val = bias_val;
  808. #pragma unroll
  809. for (int i = 0; i < kWidth; ++i) {
  810. out_val += weight_vals[i] * x_vals[i];
  811. }
  812. if (params.silu_activation) {
  813. out_val = out_val / (1 + expf(-out_val));
  814. }
  815. if (channel_id < params.dim) {
  816. out[0] = input_t(out_val);
  817. }
  818. }
  819. template <int kNThreads, int kWidth, typename input_t, typename weight_t>
  820. void causal_conv1d_update_launch(ConvParamsBase& params, cudaStream_t stream) {
  821. using Ktraits =
  822. Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
  823. dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
  824. auto kernel = &causal_conv1d_update_kernel<Ktraits>;
  825. kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
  826. C10_CUDA_KERNEL_LAUNCH_CHECK();
  827. }
  828. template <typename input_t, typename weight_t>
  829. void causal_conv1d_update_cuda(ConvParamsBase& params, cudaStream_t stream) {
  830. if (params.width == 2) {
  831. causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
  832. } else if (params.width == 3) {
  833. causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
  834. } else if (params.width == 4) {
  835. causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
  836. }
  837. }
  838. template void causal_conv1d_update_cuda<float, float>(ConvParamsBase& params,
  839. cudaStream_t stream);
  840. template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase& params,
  841. cudaStream_t stream);
  842. template void causal_conv1d_update_cuda<at::BFloat16, float>(
  843. ConvParamsBase& params, cudaStream_t stream);
  844. template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase& params,
  845. cudaStream_t stream);
  846. template void causal_conv1d_update_cuda<at::Half, at::Half>(
  847. ConvParamsBase& params, cudaStream_t stream);
  848. template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(
  849. ConvParamsBase& params, cudaStream_t stream);
  850. template void causal_conv1d_update_cuda<float, at::BFloat16>(
  851. ConvParamsBase& params, cudaStream_t stream);
  852. template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(
  853. ConvParamsBase& params, cudaStream_t stream);
  854. template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(
  855. ConvParamsBase& params, cudaStream_t stream);