layernorm_kernels.cu 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. #include <torch/all.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "dispatch_utils.h"
  5. #ifndef USE_ROCM
  6. #include <cuda_bf16.h>
  7. #include <cuda_fp16.h>
  8. #include <cub/util_type.cuh>
  9. #include <cub/cub.cuh>
  10. #else
  11. #include <hip/hip_bf16.h>
  12. #include <hip/hip_fp16.h>
  13. #include <hipcub/util_type.hpp>
  14. #include <hipcub/hipcub.hpp>
  15. using __nv_bfloat16 = __hip_bfloat16;
  16. using __nv_bfloat162 = __hip_bfloat162;
  17. #endif
  18. namespace aphrodite {
  19. // TODO: Further optimize this kernel.
  20. template <typename scalar_t>
  21. __global__ void rms_norm_kernel(
  22. scalar_t* __restrict__ out, // [..., hidden_size]
  23. const scalar_t* __restrict__ input, // [..., hidden_size]
  24. const scalar_t* __restrict__ weight, // [hidden_size]
  25. const float epsilon, const int num_tokens, const int hidden_size) {
  26. __shared__ float s_variance;
  27. float variance = 0.0f;
  28. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  29. const float x = (float)input[blockIdx.x * hidden_size + idx];
  30. variance += x * x;
  31. }
  32. using BlockReduce = cub::BlockReduce<float, 1024>;
  33. __shared__ typename BlockReduce::TempStorage reduceStore;
  34. variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
  35. if (threadIdx.x == 0) {
  36. s_variance = rsqrtf(variance / hidden_size + epsilon);
  37. }
  38. __syncthreads();
  39. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  40. float x = (float)input[blockIdx.x * hidden_size + idx];
  41. out[blockIdx.x * hidden_size + idx] =
  42. ((scalar_t)(x * s_variance)) * weight[idx];
  43. }
  44. }
  45. /* Converter structs for the conversion from torch types to HIP/CUDA types,
  46. and the associated type conversions within HIP/CUDA. These helpers need
  47. to be implemented for now because the relevant type conversion
  48. operators/constructors are not consistently implemented by HIP/CUDA, so
  49. a generic conversion via type casts cannot be implemented.
  50. Each struct should have the member static constexpr bool `exists`:
  51. If false, the optimized kernel is not used for the corresponding torch type.
  52. If true, the struct should be fully defined as shown in the examples below.
  53. */
  54. template <typename torch_type>
  55. struct _typeConvert {
  56. static constexpr bool exists = false;
  57. };
  58. #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
  59. // CUDA < 12.0 runs into issues with packed type conversion
  60. template <>
  61. struct _typeConvert<c10::Half> {
  62. static constexpr bool exists = true;
  63. using hip_type = __half;
  64. using packed_hip_type = __half2;
  65. __device__ static inline float convert(hip_type x) { return __half2float(x); }
  66. __device__ static inline float2 convert(packed_hip_type x) {
  67. return __half22float2(x);
  68. }
  69. __device__ static inline hip_type convert(float x) {
  70. return __float2half_rn(x);
  71. }
  72. __device__ static inline packed_hip_type convert(float2 x) {
  73. return __float22half2_rn(x);
  74. }
  75. };
  76. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  77. // CUDA_ARCH < 800 does not have BF16 support
  78. // TODO: Add in ROCm support once public headers handle bf16 maturely
  79. template <>
  80. struct _typeConvert<c10::BFloat16> {
  81. static constexpr bool exists = true;
  82. using hip_type = __nv_bfloat16;
  83. using packed_hip_type = __nv_bfloat162;
  84. __device__ static inline float convert(hip_type x) {
  85. return __bfloat162float(x);
  86. }
  87. __device__ static inline float2 convert(packed_hip_type x) {
  88. return __bfloat1622float2(x);
  89. }
  90. __device__ static inline hip_type convert(float x) {
  91. return __float2bfloat16(x);
  92. }
  93. __device__ static inline packed_hip_type convert(float2 x) {
  94. return __float22bfloat162_rn(x);
  95. }
  96. };
  97. #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  98. #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
  99. // 12000))
  100. /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
  101. for appropriate specializations of fused_add_rms_norm_kernel.
  102. Only functions that are necessary in that kernel are implemented.
  103. Alignment to 16 bytes is required to use 128-bit global memory ops.
  104. */
  105. template <typename scalar_t, int width>
  106. struct alignas(16) _f16Vec {
  107. /* Not theoretically necessary that width is a power of 2 but should
  108. almost always be the case for optimization purposes */
  109. static_assert(width > 0 && (width & (width - 1)) == 0,
  110. "Width is not a positive power of 2!");
  111. using Converter = _typeConvert<scalar_t>;
  112. using T1 = typename Converter::hip_type;
  113. using T2 = typename Converter::packed_hip_type;
  114. T1 data[width];
  115. __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
  116. if constexpr (width % 2 == 0) {
  117. #pragma unroll
  118. for (int i = 0; i < width; i += 2) {
  119. T2 temp{data[i], data[i + 1]};
  120. temp += T2{other.data[i], other.data[i + 1]};
  121. data[i] = temp.x;
  122. data[i + 1] = temp.y;
  123. }
  124. } else {
  125. #pragma unroll
  126. for (int i = 0; i < width; ++i) data[i] += other.data[i];
  127. }
  128. return *this;
  129. }
  130. __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
  131. if constexpr (width % 2 == 0) {
  132. #pragma unroll
  133. for (int i = 0; i < width; i += 2) {
  134. T2 temp{data[i], data[i + 1]};
  135. temp *= T2{other.data[i], other.data[i + 1]};
  136. data[i] = temp.x;
  137. data[i + 1] = temp.y;
  138. }
  139. } else {
  140. #pragma unroll
  141. for (int i = 0; i < width; ++i) data[i] *= other.data[i];
  142. }
  143. return *this;
  144. }
  145. __device__ _f16Vec& operator*=(const float scale) {
  146. if constexpr (width % 2 == 0) {
  147. #pragma unroll
  148. for (int i = 0; i < width; i += 2) {
  149. float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
  150. temp_f.x *= scale;
  151. temp_f.y *= scale;
  152. T2 temp = Converter::convert(temp_f);
  153. data[i] = temp.x;
  154. data[i + 1] = temp.y;
  155. }
  156. } else {
  157. #pragma unroll
  158. for (int i = 0; i < width; ++i) {
  159. float temp = Converter::convert(data[i]) * scale;
  160. data[i] = Converter::convert(temp);
  161. }
  162. }
  163. return *this;
  164. }
  165. __device__ float sum_squares() const {
  166. float result = 0.0f;
  167. if constexpr (width % 2 == 0) {
  168. #pragma unroll
  169. for (int i = 0; i < width; i += 2) {
  170. float2 z = Converter::convert(T2{data[i], data[i + 1]});
  171. result += z.x * z.x + z.y * z.y;
  172. }
  173. } else {
  174. #pragma unroll
  175. for (int i = 0; i < width; ++i) {
  176. float x = Converter::convert(data[i]);
  177. result += x * x;
  178. }
  179. }
  180. return result;
  181. }
  182. };
  183. /* Function specialization in the case of FP16/BF16 tensors.
  184. Additional optimizations we can make in this case are
  185. packed and vectorized operations, which help with the
  186. memory latency bottleneck. */
  187. template <typename scalar_t, int width>
  188. __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
  189. fused_add_rms_norm_kernel(
  190. scalar_t* __restrict__ input, // [..., hidden_size]
  191. scalar_t* __restrict__ residual, // [..., hidden_size]
  192. const scalar_t* __restrict__ weight, // [hidden_size]
  193. const float epsilon, const int num_tokens, const int hidden_size) {
  194. // Sanity checks on our vector struct and type-punned pointer arithmetic
  195. static_assert(std::is_standard_layout_v<_f16Vec<scalar_t, width>> &&
  196. std::is_trivial_v<_f16Vec<scalar_t, width>>);
  197. static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
  198. const int vec_hidden_size = hidden_size / width;
  199. __shared__ float s_variance;
  200. float variance = 0.0f;
  201. /* These and the argument pointers are all declared `restrict` as they are
  202. not aliased in practice. Argument pointers should not be dereferenced
  203. in this kernel as that would be undefined behavior */
  204. auto* __restrict__ input_v =
  205. reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  206. auto* __restrict__ residual_v =
  207. reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  208. auto* __restrict__ weight_v =
  209. reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
  210. for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
  211. int id = blockIdx.x * vec_hidden_size + idx;
  212. _f16Vec<scalar_t, width> temp = input_v[id];
  213. temp += residual_v[id];
  214. variance += temp.sum_squares();
  215. residual_v[id] = temp;
  216. }
  217. using BlockReduce = cub::BlockReduce<float, 1024>;
  218. __shared__ typename BlockReduce::TempStorage reduceStore;
  219. variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
  220. if (threadIdx.x == 0) {
  221. s_variance = rsqrtf(variance / hidden_size + epsilon);
  222. }
  223. __syncthreads();
  224. for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
  225. int id = blockIdx.x * vec_hidden_size + idx;
  226. _f16Vec<scalar_t, width> temp = residual_v[id];
  227. temp *= s_variance;
  228. temp *= weight_v[idx];
  229. input_v[id] = temp;
  230. }
  231. }
  232. /* Generic fused_add_rms_norm_kernel
  233. The width field is not used here but necessary for other specializations.
  234. */
  235. template <typename scalar_t, int width>
  236. __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
  237. fused_add_rms_norm_kernel(
  238. scalar_t* __restrict__ input, // [..., hidden_size]
  239. scalar_t* __restrict__ residual, // [..., hidden_size]
  240. const scalar_t* __restrict__ weight, // [hidden_size]
  241. const float epsilon, const int num_tokens, const int hidden_size) {
  242. __shared__ float s_variance;
  243. float variance = 0.0f;
  244. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  245. scalar_t z = input[blockIdx.x * hidden_size + idx];
  246. z += residual[blockIdx.x * hidden_size + idx];
  247. float x = (float)z;
  248. variance += x * x;
  249. residual[blockIdx.x * hidden_size + idx] = z;
  250. }
  251. using BlockReduce = cub::BlockReduce<float, 1024>;
  252. __shared__ typename BlockReduce::TempStorage reduceStore;
  253. variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
  254. if (threadIdx.x == 0) {
  255. s_variance = rsqrtf(variance / hidden_size + epsilon);
  256. }
  257. __syncthreads();
  258. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  259. float x = (float)residual[blockIdx.x * hidden_size + idx];
  260. input[blockIdx.x * hidden_size + idx] =
  261. ((scalar_t)(x * s_variance)) * weight[idx];
  262. }
  263. }
  264. } // namespace aphrodite
  265. void rms_norm(torch::Tensor& out, // [..., hidden_size]
  266. torch::Tensor& input, // [..., hidden_size]
  267. torch::Tensor& weight, // [hidden_size]
  268. double epsilon) {
  269. int hidden_size = input.size(-1);
  270. int num_tokens = input.numel() / hidden_size;
  271. dim3 grid(num_tokens);
  272. dim3 block(std::min(hidden_size, 1024));
  273. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  274. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  275. APHRODITE_DISPATCH_FLOATING_TYPES(
  276. input.scalar_type(), "rms_norm_kernel", [&] {
  277. aphrodite::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
  278. out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
  279. weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
  280. });
  281. }
  282. #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
  283. APHRODITE_DISPATCH_FLOATING_TYPES( \
  284. input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
  285. aphrodite::fused_add_rms_norm_kernel<scalar_t, width> \
  286. <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
  287. residual.data_ptr<scalar_t>(), \
  288. weight.data_ptr<scalar_t>(), epsilon, \
  289. num_tokens, hidden_size); \
  290. });
  291. void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
  292. torch::Tensor& residual, // [..., hidden_size]
  293. torch::Tensor& weight, // [hidden_size]
  294. double epsilon) {
  295. int hidden_size = input.size(-1);
  296. int num_tokens = input.numel() / hidden_size;
  297. dim3 grid(num_tokens);
  298. /* This kernel is memory-latency bound in many scenarios.
  299. When num_tokens is large, a smaller block size allows
  300. for increased block occupancy on CUs and better latency
  301. hiding on global mem ops. */
  302. const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  303. dim3 block(std::min(hidden_size, max_block_size));
  304. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  305. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  306. /*If the tensor types are FP16/BF16, try to use the optimized kernel
  307. with packed + vectorized ops.
  308. Max optimization is achieved with a width-8 vector of FP16/BF16s
  309. since we can load at most 128 bits at once in a global memory op.
  310. However, this requires each tensor's data to be aligned to 16
  311. bytes.
  312. */
  313. auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  314. auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
  315. auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
  316. bool ptrs_are_aligned =
  317. inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
  318. if (ptrs_are_aligned && hidden_size % 8 == 0) {
  319. LAUNCH_FUSED_ADD_RMS_NORM(8);
  320. } else {
  321. LAUNCH_FUSED_ADD_RMS_NORM(0);
  322. }
  323. }