causal_conv1d.cu 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724
  1. // clang-format off
  2. // adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
  3. // and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
  4. #include <torch/all.h>
  5. #include <ATen/cuda/CUDAContext.h>
  6. #include <c10/cuda/CUDAGuard.h>
  7. #include "causal_conv1d.h"
  8. #include <c10/util/BFloat16.h>
  9. #include <c10/util/Half.h>
  10. #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
  11. #include <cub/block/block_load.cuh>
  12. #include <cub/block/block_store.cuh>
  13. #include "static_switch.h"
  14. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  15. #define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
  16. if (ITYPE == at::ScalarType::Half) { \
  17. using input_t = at::Half; \
  18. using weight_t = at::Half; \
  19. __VA_ARGS__(); \
  20. } else if (ITYPE == at::ScalarType::BFloat16) { \
  21. using input_t = at::BFloat16; \
  22. using weight_t = at::BFloat16; \
  23. __VA_ARGS__(); \
  24. } else if (ITYPE == at::ScalarType::Float) { \
  25. using input_t = float; \
  26. using weight_t = float; \
  27. __VA_ARGS__(); \
  28. } else { \
  29. AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
  30. }
  31. template<typename input_t, typename weight_t>
  32. void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
  33. template <typename input_t, typename weight_t>
  34. void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
  35. template<typename input_t, typename weight_t>
  36. void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
  37. void set_conv_params_fwd(ConvParamsBase &params,
  38. // sizes
  39. const size_t batch,
  40. const size_t dim,
  41. const size_t seqlen,
  42. const size_t width,
  43. // device pointers
  44. const at::Tensor x,
  45. const at::Tensor weight,
  46. const at::Tensor out,
  47. void* bias_ptr,
  48. bool silu_activation) {
  49. // Reset the parameters
  50. memset(&params, 0, sizeof(params));
  51. params.batch = batch;
  52. params.dim = dim;
  53. params.seqlen = seqlen;
  54. params.width = width;
  55. params.silu_activation = silu_activation;
  56. // Set the pointers and strides.
  57. params.x_ptr = x.data_ptr();
  58. params.weight_ptr = weight.data_ptr();
  59. params.bias_ptr = bias_ptr;
  60. params.out_ptr = out.data_ptr();
  61. // All stride are in elements, not bytes.
  62. params.x_batch_stride = x.stride(0);
  63. params.x_c_stride = x.stride(1);
  64. params.x_l_stride = x.stride(-1);
  65. params.weight_c_stride = weight.stride(0);
  66. params.weight_width_stride = weight.stride(1);
  67. params.out_batch_stride = out.stride(0);
  68. params.out_c_stride = out.stride(1);
  69. params.out_l_stride = out.stride(-1);
  70. }
  71. at::Tensor
  72. causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
  73. const c10::optional<at::Tensor> &bias_,
  74. const c10::optional<at::Tensor> &seq_idx_,
  75. const c10::optional<at::Tensor> &initial_states_,
  76. const c10::optional<at::Tensor> &final_states_out_,
  77. bool silu_activation) {
  78. auto input_type = x.scalar_type();
  79. auto weight_type = weight.scalar_type();
  80. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  81. TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
  82. TORCH_CHECK(x.is_cuda());
  83. TORCH_CHECK(weight.is_cuda());
  84. const auto sizes = x.sizes();
  85. const int batch_size = sizes[0];
  86. const int dim = sizes[1];
  87. const int seqlen = sizes[2];
  88. const int width = weight.size(-1);
  89. CHECK_SHAPE(x, batch_size, dim, seqlen);
  90. CHECK_SHAPE(weight, dim, width);
  91. TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
  92. const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
  93. if (is_channel_last) {
  94. TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
  95. TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
  96. }
  97. TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
  98. if (bias_.has_value()) {
  99. auto bias = bias_.value();
  100. TORCH_CHECK(bias.scalar_type() == weight_type);
  101. TORCH_CHECK(bias.is_cuda());
  102. TORCH_CHECK(bias.stride(-1) == 1);
  103. CHECK_SHAPE(bias, dim);
  104. }
  105. if (seq_idx_.has_value()) {
  106. TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
  107. auto seq_idx = seq_idx_.value();
  108. TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
  109. TORCH_CHECK(seq_idx.is_cuda());
  110. TORCH_CHECK(seq_idx.is_contiguous());
  111. CHECK_SHAPE(seq_idx, batch_size, seqlen);
  112. }
  113. at::Tensor out = torch::empty_like(x);
  114. ConvParamsBase params;
  115. set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
  116. bias_.has_value() ? bias_.value().data_ptr() : nullptr,
  117. silu_activation);
  118. if (seq_idx_.has_value()) {
  119. params.seq_idx_ptr = seq_idx_.value().data_ptr();
  120. } else {
  121. params.seq_idx_ptr = nullptr;
  122. }
  123. if (initial_states_.has_value()) {
  124. TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
  125. auto initial_states = initial_states_.value();
  126. TORCH_CHECK(initial_states.scalar_type() == input_type);
  127. TORCH_CHECK(initial_states.is_cuda());
  128. CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
  129. TORCH_CHECK(initial_states.stride(1) == 1);
  130. params.initial_states_ptr = initial_states.data_ptr();
  131. params.initial_states_batch_stride = initial_states.stride(0);
  132. params.initial_states_c_stride = initial_states.stride(1);
  133. params.initial_states_l_stride = initial_states.stride(2);
  134. } else {
  135. params.initial_states_ptr = nullptr;
  136. }
  137. if (final_states_out_.has_value()) {
  138. TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
  139. auto final_states = final_states_out_.value();
  140. TORCH_CHECK(final_states.scalar_type() == input_type);
  141. TORCH_CHECK(final_states.is_cuda());
  142. CHECK_SHAPE(final_states, batch_size, dim, width - 1);
  143. TORCH_CHECK(final_states.stride(1) == 1);
  144. params.final_states_ptr = final_states.data_ptr();
  145. params.final_states_batch_stride = final_states.stride(0);
  146. params.final_states_c_stride = final_states.stride(1);
  147. params.final_states_l_stride = final_states.stride(2);
  148. } else {
  149. params.final_states_ptr = nullptr;
  150. }
  151. // Otherwise the kernel will be launched from cuda:0 device
  152. // Cast to char to avoid compiler warning about narrowing
  153. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  154. auto stream = at::cuda::getCurrentCUDAStream().stream();
  155. DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
  156. if (!is_channel_last) {
  157. causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
  158. } else {
  159. causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
  160. }
  161. });
  162. return out;
  163. }
  164. at::Tensor
  165. causal_conv1d_update(const at::Tensor &x,
  166. const at::Tensor &conv_state,
  167. const at::Tensor &weight,
  168. const c10::optional<at::Tensor> &bias_,
  169. bool silu_activation,
  170. const c10::optional<at::Tensor> &conv_state_indices_) {
  171. auto input_type = x.scalar_type();
  172. auto weight_type = weight.scalar_type();
  173. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  174. TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
  175. TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
  176. TORCH_CHECK(conv_state.scalar_type() == input_type);
  177. TORCH_CHECK(x.is_cuda());
  178. TORCH_CHECK(conv_state.is_cuda());
  179. TORCH_CHECK(weight.is_cuda());
  180. const auto sizes = x.sizes();
  181. const int batch_size = sizes[0];
  182. const int dim = sizes[1];
  183. const int width = weight.size(-1);
  184. CHECK_SHAPE(x, batch_size, dim);
  185. CHECK_SHAPE(weight, dim, width);
  186. TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
  187. if (bias_.has_value()) {
  188. auto bias = bias_.value();
  189. TORCH_CHECK(bias.scalar_type() == weight_type);
  190. TORCH_CHECK(bias.is_cuda());
  191. TORCH_CHECK(bias.stride(-1) == 1);
  192. CHECK_SHAPE(bias, dim);
  193. }
  194. at::Tensor out = torch::empty_like(x);
  195. ConvParamsBase params;
  196. set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
  197. bias_.has_value() ? bias_.value().data_ptr() : nullptr,
  198. silu_activation);
  199. params.conv_state_ptr = conv_state.data_ptr();
  200. // All stride are in elements, not bytes.
  201. params.conv_state_batch_stride = conv_state.stride(0);
  202. params.conv_state_c_stride = conv_state.stride(1);
  203. params.conv_state_l_stride = conv_state.stride(2);
  204. if (conv_state_indices_.has_value()) {
  205. auto conv_state_indices = conv_state_indices_.value();
  206. TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
  207. TORCH_CHECK(conv_state_indices.is_cuda());
  208. TORCH_CHECK(conv_state_indices.stride(0) == 1)
  209. CHECK_SHAPE(conv_state_indices, batch_size);
  210. int conv_state_entries = conv_state.size(0);
  211. CHECK_SHAPE(conv_state, conv_state_entries, dim, width);
  212. params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
  213. } else {
  214. CHECK_SHAPE(conv_state, batch_size, dim, width);
  215. params.conv_state_indices_ptr = nullptr;
  216. }
  217. // Otherwise the kernel will be launched from cuda:0 device
  218. // Cast to char to avoid compiler warning about narrowing
  219. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  220. auto stream = at::cuda::getCurrentCUDAStream().stream();
  221. DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
  222. causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
  223. });
  224. return out;
  225. }
  226. template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
  227. struct Causal_conv1d_fwd_kernel_traits {
  228. using input_t = input_t_;
  229. using weight_t = weight_t_;
  230. static constexpr int kNThreads = kNThreads_;
  231. static constexpr int kWidth = kWidth_;
  232. static constexpr int kNBytes = sizeof(input_t);
  233. static_assert(kNBytes == 2 || kNBytes == 4);
  234. static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
  235. static_assert(kWidth <= kNElts);
  236. static constexpr bool kIsVecLoad = kIsVecLoad_;
  237. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  238. using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  239. using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
  240. using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
  241. using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
  242. static constexpr int kSmemIOSize = kIsVecLoad
  243. ? 0
  244. : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
  245. static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
  246. static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
  247. };
  248. template<typename Ktraits>
  249. __global__ __launch_bounds__(Ktraits::kNThreads)
  250. void causal_conv1d_fwd_kernel(ConvParamsBase params) {
  251. constexpr int kWidth = Ktraits::kWidth;
  252. constexpr int kNThreads = Ktraits::kNThreads;
  253. constexpr int kNElts = Ktraits::kNElts;
  254. static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
  255. using input_t = typename Ktraits::input_t;
  256. using vec_t = typename Ktraits::vec_t;
  257. using weight_t = typename Ktraits::weight_t;
  258. // Shared memory.
  259. extern __shared__ char smem_[];
  260. auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
  261. auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
  262. auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
  263. auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
  264. vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
  265. const int tidx = threadIdx.x;
  266. const int batch_id = blockIdx.x;
  267. const int channel_id = blockIdx.y;
  268. input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
  269. + channel_id * params.x_c_stride;
  270. weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
  271. input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
  272. + channel_id * params.out_c_stride;
  273. float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
  274. // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
  275. if (tidx == 0) {
  276. input_t zeros[kNElts] = {0};
  277. smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
  278. }
  279. float weight_vals[kWidth];
  280. #pragma unroll
  281. for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
  282. constexpr int kChunkSize = kNThreads * kNElts;
  283. const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
  284. for (int chunk = 0; chunk < n_chunks; ++chunk) {
  285. input_t x_vals_load[2 * kNElts] = {0};
  286. if constexpr(kIsVecLoad) {
  287. typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
  288. } else {
  289. __syncthreads();
  290. typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
  291. }
  292. x += kChunkSize;
  293. __syncthreads();
  294. // Thread kNThreads - 1 don't write yet, so that thread 0 can read
  295. // the last elements of the previous chunk.
  296. if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
  297. __syncthreads();
  298. reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
  299. __syncthreads();
  300. // Now thread kNThreads - 1 can write the last elements of the current chunk.
  301. if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
  302. float x_vals[2 * kNElts];
  303. #pragma unroll
  304. for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
  305. float out_vals[kNElts];
  306. #pragma unroll
  307. for (int i = 0; i < kNElts; ++i) {
  308. out_vals[i] = bias_val;
  309. #pragma unroll
  310. for (int w = 0; w < kWidth; ++w) {
  311. out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
  312. }
  313. }
  314. if (params.silu_activation) {
  315. #pragma unroll
  316. for (int i = 0; i < kNElts; ++i) {
  317. out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
  318. }
  319. }
  320. input_t out_vals_store[kNElts];
  321. #pragma unroll
  322. for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
  323. if constexpr(kIsVecLoad) {
  324. typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
  325. } else {
  326. typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
  327. }
  328. out += kChunkSize;
  329. }
  330. }
  331. template<int kNThreads, int kWidth, typename input_t, typename weight_t>
  332. void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
  333. static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
  334. BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
  335. using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
  336. constexpr int kSmemSize = Ktraits::kSmemSize;
  337. dim3 grid(params.batch, params.dim);
  338. auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
  339. if (kSmemSize >= 48 * 1024) {
  340. #ifndef USE_ROCM
  341. C10_CUDA_CHECK(cudaFuncSetAttribute(
  342. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  343. #else
  344. // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
  345. C10_CUDA_CHECK(cudaFuncSetAttribute(
  346. (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  347. std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
  348. #endif
  349. }
  350. kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  351. C10_CUDA_KERNEL_LAUNCH_CHECK();
  352. });
  353. }
  354. template<typename input_t, typename weight_t>
  355. void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
  356. if (params.width == 2) {
  357. causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
  358. } else if (params.width == 3) {
  359. causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
  360. } else if (params.width == 4) {
  361. causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
  362. }
  363. }
  364. template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
  365. struct Causal_conv1d_channellast_fwd_kernel_traits {
  366. // The cache line is 128 bytes, and we try to read 16 bytes per thread.
  367. // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
  368. // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
  369. // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
  370. using input_t = input_t_;
  371. using weight_t = weight_t_;
  372. static constexpr int kNThreads = kNThreads_;
  373. static_assert(kNThreads % 32 == 0);
  374. static constexpr int kNWarps = kNThreads / 32;
  375. static constexpr int kWidth = kWidth_;
  376. static constexpr int kChunkSizeL = kChunkSizeL_;
  377. static constexpr int kNBytes = sizeof(input_t);
  378. static_assert(kNBytes == 2 || kNBytes == 4);
  379. static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
  380. static constexpr int kNEltsPerRow = 128 / kNBytes;
  381. static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
  382. static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
  383. static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
  384. static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
  385. static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
  386. static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
  387. static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
  388. static constexpr bool kIsVecLoad = kIsVecLoad_;
  389. using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
  390. // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
  391. // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
  392. // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
  393. // sizeof(typename BlockStoreT::TempStorage)});
  394. // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
  395. };
  396. template<typename Ktraits, bool kHasSeqIdx>
  397. __global__ __launch_bounds__(Ktraits::kNThreads)
  398. void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
  399. constexpr int kWidth = Ktraits::kWidth;
  400. constexpr int kNThreads = Ktraits::kNThreads;
  401. constexpr int kNElts = Ktraits::kNElts;
  402. constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
  403. constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
  404. constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
  405. constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
  406. using input_t = typename Ktraits::input_t;
  407. using vec_t = typename Ktraits::vec_t;
  408. using weight_t = typename Ktraits::weight_t;
  409. // Shared memory.
  410. __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
  411. const int batch_id = blockIdx.x;
  412. const int chunk_l_id = blockIdx.y;
  413. const int chunk_c_id = blockIdx.z;
  414. const int tid = threadIdx.x;
  415. const int l_idx = tid / kNThreadsPerC;
  416. const int c_idx = tid % kNThreadsPerC;
  417. input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
  418. + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
  419. weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
  420. + chunk_c_id * kChunkSizeC * params.weight_c_stride;
  421. input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
  422. + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
  423. int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
  424. + batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
  425. input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
  426. : reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
  427. // The last L-chunk will also have enough info to write to final states, since it also contain a few x values
  428. // from the previous L-chunk.
  429. input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
  430. : reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
  431. #pragma unroll
  432. for (int l = 0; l < Ktraits::kNLoads; ++l) {
  433. input_t x_vals_load[kNElts] = {0};
  434. if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
  435. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  436. reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
  437. }
  438. reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
  439. }
  440. // Load the elements from the previous chunk that are needed for convolution.
  441. if (l_idx < kWidth - 1) {
  442. input_t x_vals_load[kNElts] = {0};
  443. if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
  444. && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
  445. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  446. reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
  447. } else if (initial_states != nullptr
  448. && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
  449. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  450. reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
  451. }
  452. reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
  453. }
  454. __syncthreads();
  455. if (final_states != nullptr
  456. && l_idx < kWidth - 1
  457. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  458. // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
  459. // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
  460. *reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
  461. }
  462. constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
  463. static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
  464. constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
  465. static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
  466. // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
  467. static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
  468. static_assert((kLPerThread & (kLPerThread - 1)) == 0);
  469. static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
  470. static_assert(kNThreadsPerRow <= 32);
  471. const int row_idx = tid / kNThreadsPerRow;
  472. const int col_idx = tid % kNThreadsPerRow;
  473. float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
  474. float weight_vals[kWidth] = {0};
  475. if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
  476. #pragma unroll
  477. for (int w = 0; w < kWidth; ++w) {
  478. weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
  479. }
  480. }
  481. float x_vals[kWidth - 1 + kLPerThread];
  482. #pragma unroll
  483. for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
  484. x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
  485. }
  486. int seq_idx_thread[kWidth - 1 + kLPerThread];
  487. if constexpr (kHasSeqIdx) {
  488. #pragma unroll
  489. for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
  490. seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
  491. }
  492. }
  493. float out_vals[kLPerThread];
  494. #pragma unroll
  495. for (int i = 0; i < kLPerThread; ++i) {
  496. out_vals[i] = bias_val;
  497. const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
  498. #pragma unroll
  499. for (int w = 0; w < kWidth; ++w) {
  500. if constexpr (!kHasSeqIdx) {
  501. out_vals[i] += weight_vals[w] * x_vals[i + w];
  502. } else {
  503. out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
  504. }
  505. }
  506. if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
  507. }
  508. __syncthreads();
  509. #pragma unroll
  510. for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
  511. __syncthreads();
  512. #pragma unroll
  513. for (int l = 0; l < Ktraits::kNLoads; ++l) {
  514. input_t out_vals_store[kNElts];
  515. reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
  516. if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
  517. && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
  518. *reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
  519. }
  520. }
  521. }
  522. template<int kNThreads, int kWidth, typename input_t, typename weight_t>
  523. void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
  524. BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
  525. using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
  526. // constexpr int kSmemSize = Ktraits::kSmemSize;
  527. constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
  528. constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
  529. const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
  530. const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
  531. dim3 grid(params.batch, n_chunks_L, n_chunks_C);
  532. dim3 block(Ktraits::kNThreads);
  533. auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
  534. // if (kSmemSize >= 48 * 1024) {
  535. // C10_CUDA_CHECK(cudaFuncSetAttribute(
  536. // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
  537. // }
  538. // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
  539. kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
  540. C10_CUDA_KERNEL_LAUNCH_CHECK();
  541. });
  542. }
  543. template<typename input_t, typename weight_t>
  544. void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
  545. if (params.width == 2) {
  546. causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
  547. } else if (params.width == 3) {
  548. causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
  549. } else if (params.width == 4) {
  550. causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
  551. }
  552. }
  553. template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
  554. template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  555. template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
  556. template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
  557. template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  558. template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
  559. ///////
  560. template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
  561. struct Causal_conv1d_update_kernel_traits {
  562. using input_t = input_t_;
  563. using weight_t = weight_t_;
  564. static constexpr int kNThreads = kNThreads_;
  565. static constexpr int kWidth = kWidth_;
  566. static constexpr int kNBytes = sizeof(input_t);
  567. static_assert(kNBytes == 2 || kNBytes == 4);
  568. };
  569. template<typename Ktraits>
  570. __global__ __launch_bounds__(Ktraits::kNThreads)
  571. void causal_conv1d_update_kernel(ConvParamsBase params) {
  572. constexpr int kWidth = Ktraits::kWidth;
  573. constexpr int kNThreads = Ktraits::kNThreads;
  574. using input_t = typename Ktraits::input_t;
  575. using weight_t = typename Ktraits::weight_t;
  576. const int tidx = threadIdx.x;
  577. const int batch_id = blockIdx.x;
  578. const int channel_id = blockIdx.y * kNThreads + tidx;
  579. input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
  580. + channel_id * params.x_c_stride;
  581. // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
  582. // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
  583. const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
  584. ? batch_id
  585. : params.conv_state_indices_ptr[batch_id];
  586. input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
  587. + conv_state_batch_coord * params.conv_state_batch_stride
  588. + channel_id * params.conv_state_c_stride;
  589. weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
  590. input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
  591. + channel_id * params.out_c_stride;
  592. float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
  593. float weight_vals[kWidth] = {0};
  594. if (channel_id < params.dim) {
  595. #pragma unroll
  596. for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
  597. }
  598. float x_vals[kWidth] = {0};
  599. if (channel_id < params.dim) {
  600. #pragma unroll
  601. for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
  602. x_vals[kWidth - 1] = float(x[0]);
  603. #pragma unroll
  604. for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
  605. }
  606. float out_val = bias_val;
  607. #pragma unroll
  608. for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
  609. if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
  610. if (channel_id < params.dim) { out[0] = input_t(out_val); }
  611. }
  612. template<int kNThreads, int kWidth, typename input_t, typename weight_t>
  613. void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
  614. using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
  615. dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
  616. auto kernel = &causal_conv1d_update_kernel<Ktraits>;
  617. kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
  618. C10_CUDA_KERNEL_LAUNCH_CHECK();
  619. }
  620. template<typename input_t, typename weight_t>
  621. void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
  622. if (params.width == 2) {
  623. causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
  624. } else if (params.width == 3) {
  625. causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
  626. } else if (params.width == 4) {
  627. causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
  628. }
  629. }
  630. template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
  631. template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
  632. template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);