fast_hadamard_transform.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /******************************************************************************
  2. * Copyright (c) 2023, Tri Dao.
  3. ******************************************************************************/
  4. #include <ATen/cuda/CUDAContext.h>
  5. #include <c10/cuda/CUDAGuard.h>
  6. #include <torch/extension.h>
  7. #include <vector>
  8. #include "fast_hadamard_transform.h"
  9. #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
  10. #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
  11. if (ITYPE == at::ScalarType::Half) { \
  12. using input_t = at::Half; \
  13. __VA_ARGS__(); \
  14. } else if (ITYPE == at::ScalarType::BFloat16) { \
  15. using input_t = at::BFloat16; \
  16. __VA_ARGS__(); \
  17. } else if (ITYPE == at::ScalarType::Float) { \
  18. using input_t = float; \
  19. __VA_ARGS__(); \
  20. } else { \
  21. AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
  22. }
  23. template<typename input_t>
  24. void fast_hadamard_transform_cuda(HadamardParamsBase &params, cudaStream_t stream);
  25. template<typename input_t>
  26. void fast_hadamard_transform_12N_cuda(HadamardParamsBase &params, cudaStream_t stream);
  27. template<typename input_t>
  28. void fast_hadamard_transform_20N_cuda(HadamardParamsBase &params, cudaStream_t stream);
  29. template<typename input_t>
  30. void fast_hadamard_transform_28N_cuda(HadamardParamsBase &params, cudaStream_t stream);
  31. void set_hadamard_params(HadamardParamsBase &params,
  32. // sizes
  33. const size_t batch,
  34. const size_t dim,
  35. const size_t multiple,
  36. // device pointers
  37. const at::Tensor x,
  38. const at::Tensor out,
  39. float scale
  40. ) {
  41. // Reset the parameters
  42. memset(&params, 0, sizeof(params));
  43. params.batch = batch;
  44. params.dim = dim;
  45. params.log_N = int(ceil(std::log2(dim / multiple)));
  46. // Set the pointers and strides.
  47. params.x_ptr = x.data_ptr();
  48. params.out_ptr = out.data_ptr();
  49. // All stride are in elements, not bytes.
  50. params.x_batch_stride = x.stride(0);
  51. params.out_batch_stride = out.stride(0);
  52. params.scale = scale;
  53. }
  54. at::Tensor
  55. fast_hadamard_transform(at::Tensor &x, float scale) {
  56. auto input_type = x.scalar_type();
  57. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  58. TORCH_CHECK(x.is_cuda());
  59. const auto shapes_og = x.sizes();
  60. const int dim_og = x.size(-1);
  61. x = x.reshape({-1, dim_og});
  62. if (x.stride(-1) != 1) { x = x.contiguous(); }
  63. const auto sizes = x.sizes();
  64. const int batch_size = sizes[0];
  65. CHECK_SHAPE(x, batch_size, dim_og);
  66. TORCH_CHECK(x.stride(1) == 1);
  67. if (dim_og % 8 != 0) {
  68. x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 8 - dim_og % 8}));
  69. }
  70. const int dim = x.size(1);
  71. TORCH_CHECK(dim % 8 == 0, "fast_hadamard_transform only supports hidden dimension divisible by 8 for now");
  72. TORCH_CHECK(dim <= 32768, "fast_hadamard_transform only supports hidden dimension at most 32768 for now");
  73. at::Tensor out = torch::empty_like(x);
  74. HadamardParamsBase params;
  75. set_hadamard_params(params, batch_size, dim, 1, x, out, scale);
  76. // Otherwise the kernel will be launched from cuda:0 device
  77. // Cast to char to avoid compiler warning about narrowing
  78. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  79. auto stream = at::cuda::getCurrentCUDAStream().stream();
  80. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
  81. fast_hadamard_transform_cuda<input_t>(params, stream);
  82. });
  83. if (dim_og % 8 != 0) {
  84. out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
  85. }
  86. return out.reshape(shapes_og);
  87. }
  88. at::Tensor
  89. fast_hadamard_transform_12N(at::Tensor &x, float scale) {
  90. auto input_type = x.scalar_type();
  91. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  92. TORCH_CHECK(x.is_cuda());
  93. const auto shapes_og = x.sizes();
  94. const int dim_og = x.size(-1);
  95. x = x.reshape({-1, dim_og});
  96. if (x.stride(-1) != 1) { x = x.contiguous(); }
  97. const auto sizes = x.sizes();
  98. const int batch_size = sizes[0];
  99. CHECK_SHAPE(x, batch_size, dim_og);
  100. TORCH_CHECK(x.stride(1) == 1);
  101. if (dim_og % (4 * 12) != 0) {
  102. x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 12) - dim_og % (4 * 12)}));
  103. }
  104. const int dim = x.size(1);
  105. TORCH_CHECK(dim % (4 * 12) == 0, "fast_hadamard_transform_12N only supports hidden dimension divisible by 48 for now");
  106. TORCH_CHECK(dim <= 12 * 1024, "fast_hadamard_transform_12N only supports hidden dimension at most 12288 for now");
  107. at::Tensor out = torch::empty_like(x);
  108. HadamardParamsBase params;
  109. set_hadamard_params(params, batch_size, dim, 12, x, out, scale);
  110. // Otherwise the kernel will be launched from cuda:0 device
  111. // Cast to char to avoid compiler warning about narrowing
  112. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  113. auto stream = at::cuda::getCurrentCUDAStream().stream();
  114. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
  115. fast_hadamard_transform_12N_cuda<input_t>(params, stream);
  116. });
  117. if (dim_og % (4 * 12) != 0) {
  118. out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
  119. }
  120. return out.reshape(shapes_og);
  121. }
  122. at::Tensor
  123. fast_hadamard_transform_20N(at::Tensor &x, float scale) {
  124. auto input_type = x.scalar_type();
  125. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  126. TORCH_CHECK(x.is_cuda());
  127. const auto shapes_og = x.sizes();
  128. const int dim_og = x.size(-1);
  129. x = x.reshape({-1, dim_og});
  130. if (x.stride(-1) != 1) { x = x.contiguous(); }
  131. const auto sizes = x.sizes();
  132. const int batch_size = sizes[0];
  133. CHECK_SHAPE(x, batch_size, dim_og);
  134. TORCH_CHECK(x.stride(1) == 1);
  135. if (dim_og % (4 * 20) != 0) {
  136. x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 20) - dim_og % (4 * 20)}));
  137. }
  138. const int dim = x.size(1);
  139. TORCH_CHECK(dim % (4 * 20) == 0, "fast_hadamard_transform_20N only supports hidden dimension divisible by 80 for now");
  140. TORCH_CHECK(dim <= 20 * 1024, "fast_hadamard_transform_20N only supports hidden dimension at most 20480 for now");
  141. at::Tensor out = torch::empty_like(x);
  142. HadamardParamsBase params;
  143. set_hadamard_params(params, batch_size, dim, 20, x, out, scale);
  144. // Otherwise the kernel will be launched from cuda:0 device
  145. // Cast to char to avoid compiler warning about narrowing
  146. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  147. auto stream = at::cuda::getCurrentCUDAStream().stream();
  148. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
  149. fast_hadamard_transform_20N_cuda<input_t>(params, stream);
  150. });
  151. if (dim_og % (4 * 20) != 0) {
  152. out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
  153. }
  154. return out.reshape(shapes_og);
  155. }
  156. at::Tensor
  157. fast_hadamard_transform_28N(at::Tensor &x, float scale) {
  158. auto input_type = x.scalar_type();
  159. TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
  160. TORCH_CHECK(x.is_cuda());
  161. const auto shapes_og = x.sizes();
  162. const int dim_og = x.size(-1);
  163. x = x.reshape({-1, dim_og});
  164. if (x.stride(-1) != 1) { x = x.contiguous(); }
  165. const auto sizes = x.sizes();
  166. const int batch_size = sizes[0];
  167. CHECK_SHAPE(x, batch_size, dim_og);
  168. TORCH_CHECK(x.stride(1) == 1);
  169. if (dim_og % (4 * 28) != 0) {
  170. x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 28) - dim_og % (4 * 28)}));
  171. }
  172. const int dim = x.size(1);
  173. TORCH_CHECK(dim % (4 * 28) == 0, "fast_hadamard_transform_28N only supports hidden dimension divisible by 112 for now");
  174. // TORCH_CHECK(dim <= 28 * 1024, "fast_hadamard_transform_28N only supports hidden dimension at most 28672 for now");
  175. TORCH_CHECK(dim <= 28 * 2048, "fast_hadamard_transform_28N only supports hidden dimension at most 28672 for now");
  176. at::Tensor out = torch::empty_like(x);
  177. HadamardParamsBase params;
  178. set_hadamard_params(params, batch_size, dim, 28, x, out, scale);
  179. // Otherwise the kernel will be launched from cuda:0 device
  180. // Cast to char to avoid compiler warning about narrowing
  181. at::cuda::CUDAGuard device_guard{(char)x.get_device()};
  182. auto stream = at::cuda::getCurrentCUDAStream().stream();
  183. DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
  184. fast_hadamard_transform_28N_cuda<input_t>(params, stream);
  185. });
  186. if (dim_og % (8 * 28) != 0) {
  187. out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
  188. }
  189. return out.reshape(shapes_og);
  190. }
  191. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  192. m.def("fast_hadamard_transform", &fast_hadamard_transform, "Fast Hadamard transform");
  193. m.def("fast_hadamard_transform_12N", &fast_hadamard_transform_20N, "Fast Hadamard transform with dimension = 12 * power of 2");
  194. m.def("fast_hadamard_transform_20N", &fast_hadamard_transform_20N, "Fast Hadamard transform with dimension = 20 * power of 2");
  195. m.def("fast_hadamard_transform_28N", &fast_hadamard_transform_28N, "Fast Hadamard transform with dimension = 28 * power of 2");
  196. }