fast_hadamard_transform.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #include <torch/all.h>
  7. #include <vector>
  8. #include "fast_hadamard_transform.h"
  9. #include "../core/registration.h"
  10. #define CHECK_SHAPE(x, ...) \
  11. TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
  12. #x " must have shape (" #__VA_ARGS__ ")")
  13. #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
  14. if (ITYPE == at::ScalarType::Half) { \
  15. using input_t = at::Half; \
  16. __VA_ARGS__(); \
  17. } else if (ITYPE == at::ScalarType::BFloat16) { \
  18. using input_t = at::BFloat16; \
  19. __VA_ARGS__(); \
  20. } else if (ITYPE == at::ScalarType::Float) { \
  21. using input_t = float; \
  22. __VA_ARGS__(); \
  23. } else { \
  24. AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), \
  25. "'"); \
  26. }
  27. template <typename input_t>
  28. void fast_hadamard_transform_cuda(HadamardParamsBase& params,
  29. cudaStream_t stream);
  30. template <typename input_t>
  31. void fast_hadamard_transform_12N_cuda(HadamardParamsBase& params,
  32. cudaStream_t stream);
  33. template <typename input_t>
  34. void fast_hadamard_transform_20N_cuda(HadamardParamsBase& params,
  35. cudaStream_t stream);
  36. template <typename input_t>
  37. void fast_hadamard_transform_28N_cuda(HadamardParamsBase& params,
  38. cudaStream_t stream);
  39. void set_hadamard_params(HadamardParamsBase& params,
  40. // sizes
  41. const size_t batch, const size_t dim,
  42. const size_t multiple,
  43. // device pointers
  44. const at::Tensor x, const at::Tensor out,
  45. double scale) {
  46. // Reset the parameters
  47. memset(&params, 0, sizeof(params));
  48. params.batch = batch;
  49. params.dim = dim;
  50. params.log_N = int(ceil(std::log2(dim / multiple)));
  51. // Set the pointers and strides.
  52. params.x_ptr = x.data_ptr();
  53. params.out_ptr = out.data_ptr();
  54. // All stride are in elements, not bytes.
  55. params.x_batch_stride = x.stride(0);
  56. params.out_batch_stride = out.stride(0);
  57. params.scale = scale;
  58. }
  59. at::Tensor fast_hadamard_transform(at::Tensor& x, double scale) {
  60. auto input_type = x.scalar_type();
  61. TORCH_CHECK(input_type == at::ScalarType::Float ||
  62. input_type == at::ScalarType::Half ||
  63. input_type == at::ScalarType::BFloat16);
  64. TORCH_CHECK(x.is_cuda());
  65. const auto shapes_og = x.sizes();
  66. const int dim_og = x.size(-1);
  67. x = x.reshape({-1, dim_og});
  68. if (x.stride(-1) != 1) {
  69. x = x.contiguous();
  70. }
  71. const auto sizes = x.sizes();
  72. const int batch_size = sizes[0];
  73. CHECK_SHAPE(x, batch_size, dim_og);
  74. TORCH_CHECK(x.stride(1) == 1);
  75. if (dim_og % 8 != 0) {
  76. x = torch::nn::functional::pad(
  77. x, torch::nn::functional::PadFuncOptions({0, 8 - dim_og % 8}));
  78. }
  79. const int dim = x.size(1);
  80. TORCH_CHECK(dim % 8 == 0,
  81. "fast_hadamard_transform only supports hidden dimension "
  82. "divisible by 8 for now");
  83. TORCH_CHECK(dim <= 32768,
  84. "fast_hadamard_transform only supports hidden dimension at most "
  85. "32768 for now");
  86. at::Tensor out = torch::empty_like(x);
  87. HadamardParamsBase params;
  88. set_hadamard_params(params, batch_size, dim, 1, x, out, scale);
  89. // Otherwise the kernel will be launched from cuda:0 device
  90. // Cast to char to avoid compiler warning about narrowing
  91. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  92. auto stream = at::cuda::getCurrentCUDAStream().stream();
  93. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(
  94. x.scalar_type(), "fast_hadamard_transform",
  95. [&] { fast_hadamard_transform_cuda<input_t>(params, stream); });
  96. if (dim_og % 8 != 0) {
  97. out = out.index({torch::indexing::Slice(),
  98. torch::indexing::Slice(torch::indexing::None, dim_og)});
  99. }
  100. return out.reshape(shapes_og);
  101. }
  102. at::Tensor fast_hadamard_transform_12N(at::Tensor& x, double scale) {
  103. auto input_type = x.scalar_type();
  104. TORCH_CHECK(input_type == at::ScalarType::Float ||
  105. input_type == at::ScalarType::Half ||
  106. input_type == at::ScalarType::BFloat16);
  107. TORCH_CHECK(x.is_cuda());
  108. const auto shapes_og = x.sizes();
  109. const int dim_og = x.size(-1);
  110. x = x.reshape({-1, dim_og});
  111. if (x.stride(-1) != 1) {
  112. x = x.contiguous();
  113. }
  114. const auto sizes = x.sizes();
  115. const int batch_size = sizes[0];
  116. CHECK_SHAPE(x, batch_size, dim_og);
  117. TORCH_CHECK(x.stride(1) == 1);
  118. if (dim_og % (4 * 12) != 0) {
  119. x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions(
  120. {0, (4 * 12) - dim_og % (4 * 12)}));
  121. }
  122. const int dim = x.size(1);
  123. TORCH_CHECK(dim % (4 * 12) == 0,
  124. "fast_hadamard_transform_12N only supports hidden dimension "
  125. "divisible by 48 for now");
  126. TORCH_CHECK(dim <= 12 * 1024,
  127. "fast_hadamard_transform_12N only supports hidden dimension at "
  128. "most 12288 for now");
  129. at::Tensor out = torch::empty_like(x);
  130. HadamardParamsBase params;
  131. set_hadamard_params(params, batch_size, dim, 12, x, out, scale);
  132. // Otherwise the kernel will be launched from cuda:0 device
  133. // Cast to char to avoid compiler warning about narrowing
  134. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  135. auto stream = at::cuda::getCurrentCUDAStream().stream();
  136. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(
  137. x.scalar_type(), "fast_hadamard_transform",
  138. [&] { fast_hadamard_transform_12N_cuda<input_t>(params, stream); });
  139. if (dim_og % (4 * 12) != 0) {
  140. out = out.index({torch::indexing::Slice(),
  141. torch::indexing::Slice(torch::indexing::None, dim_og)});
  142. }
  143. return out.reshape(shapes_og);
  144. }
  145. at::Tensor fast_hadamard_transform_20N(at::Tensor& x, double scale) {
  146. auto input_type = x.scalar_type();
  147. TORCH_CHECK(input_type == at::ScalarType::Float ||
  148. input_type == at::ScalarType::Half ||
  149. input_type == at::ScalarType::BFloat16);
  150. TORCH_CHECK(x.is_cuda());
  151. const auto shapes_og = x.sizes();
  152. const int dim_og = x.size(-1);
  153. x = x.reshape({-1, dim_og});
  154. if (x.stride(-1) != 1) {
  155. x = x.contiguous();
  156. }
  157. const auto sizes = x.sizes();
  158. const int batch_size = sizes[0];
  159. CHECK_SHAPE(x, batch_size, dim_og);
  160. TORCH_CHECK(x.stride(1) == 1);
  161. if (dim_og % (4 * 20) != 0) {
  162. x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions(
  163. {0, (4 * 20) - dim_og % (4 * 20)}));
  164. }
  165. const int dim = x.size(1);
  166. TORCH_CHECK(dim % (4 * 20) == 0,
  167. "fast_hadamard_transform_20N only supports hidden dimension "
  168. "divisible by 80 for now");
  169. TORCH_CHECK(dim <= 20 * 1024,
  170. "fast_hadamard_transform_20N only supports hidden dimension at "
  171. "most 20480 for now");
  172. at::Tensor out = torch::empty_like(x);
  173. HadamardParamsBase params;
  174. set_hadamard_params(params, batch_size, dim, 20, x, out, scale);
  175. // Otherwise the kernel will be launched from cuda:0 device
  176. // Cast to char to avoid compiler warning about narrowing
  177. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  178. auto stream = at::cuda::getCurrentCUDAStream().stream();
  179. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(
  180. x.scalar_type(), "fast_hadamard_transform",
  181. [&] { fast_hadamard_transform_20N_cuda<input_t>(params, stream); });
  182. if (dim_og % (4 * 20) != 0) {
  183. out = out.index({torch::indexing::Slice(),
  184. torch::indexing::Slice(torch::indexing::None, dim_og)});
  185. }
  186. return out.reshape(shapes_og);
  187. }
  188. at::Tensor fast_hadamard_transform_28N(at::Tensor& x, double scale) {
  189. auto input_type = x.scalar_type();
  190. TORCH_CHECK(input_type == at::ScalarType::Float ||
  191. input_type == at::ScalarType::Half ||
  192. input_type == at::ScalarType::BFloat16);
  193. TORCH_CHECK(x.is_cuda());
  194. const auto shapes_og = x.sizes();
  195. const int dim_og = x.size(-1);
  196. x = x.reshape({-1, dim_og});
  197. if (x.stride(-1) != 1) {
  198. x = x.contiguous();
  199. }
  200. const auto sizes = x.sizes();
  201. const int batch_size = sizes[0];
  202. CHECK_SHAPE(x, batch_size, dim_og);
  203. TORCH_CHECK(x.stride(1) == 1);
  204. if (dim_og % (4 * 28) != 0) {
  205. x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions(
  206. {0, (4 * 28) - dim_og % (4 * 28)}));
  207. }
  208. const int dim = x.size(1);
  209. TORCH_CHECK(dim % (4 * 28) == 0,
  210. "fast_hadamard_transform_28N only supports hidden dimension "
  211. "divisible by 112 for now");
  212. // TORCH_CHECK(dim <= 28 * 1024, "fast_hadamard_transform_28N only supports
  213. // hidden dimension at most 28672 for now");
  214. TORCH_CHECK(dim <= 28 * 2048,
  215. "fast_hadamard_transform_28N only supports hidden dimension at "
  216. "most 28672 for now");
  217. at::Tensor out = torch::empty_like(x);
  218. HadamardParamsBase params;
  219. set_hadamard_params(params, batch_size, dim, 28, x, out, scale);
  220. // Otherwise the kernel will be launched from cuda:0 device
  221. // Cast to char to avoid compiler warning about narrowing
  222. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  223. auto stream = at::cuda::getCurrentCUDAStream().stream();
  224. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(
  225. x.scalar_type(), "fast_hadamard_transform",
  226. [&] { fast_hadamard_transform_28N_cuda<input_t>(params, stream); });
  227. if (dim_og % (8 * 28) != 0) {
  228. out = out.index({torch::indexing::Slice(),
  229. torch::indexing::Slice(torch::indexing::None, dim_og)});
  230. }
  231. return out.reshape(shapes_og);
  232. }
  233. TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
  234. m.def("fast_hadamard_transform(Tensor x, double scale) -> Tensor",
  235. &fast_hadamard_transform);
  236. m.impl("fast_hadamard_transform", torch::kCUDA, &fast_hadamard_transform);
  237. m.def("fast_hadamard_transform_12N(Tensor x, double scale) -> Tensor",
  238. &fast_hadamard_transform_12N);
  239. m.impl("fast_hadamard_transform_12N", torch::kCUDA,
  240. &fast_hadamard_transform_12N);
  241. m.def("fast_hadamard_transform_20N(Tensor x, double scale) -> Tensor",
  242. &fast_hadamard_transform_20N);
  243. m.impl("fast_hadamard_transform_20N", torch::kCUDA,
  244. &fast_hadamard_transform_20N);
  245. m.def("fast_hadamard_transform_28N(Tensor x, double scale) -> Tensor",
  246. &fast_hadamard_transform_28N);
  247. m.impl("fast_hadamard_transform_28N", torch::kCUDA,
  248. &fast_hadamard_transform_28N);
  249. }
  250. REGISTER_EXTENSION(TORCH_EXTENSION_NAME)