layernorm_kernels.cu 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. #include <torch/extension.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "dispatch_utils.h"
  5. #include "reduction.cuh"
  6. #ifndef USE_ROCM
  7. #include <cuda_bf16.h>
  8. #include <cuda_fp16.h>
  9. #else
  10. #include <hip/hip_bf16.h>
  11. #include <hip/hip_fp16.h>
  12. using __nv_bfloat16 = __hip_bfloat16;
  13. using __nv_bfloat162 = __hip_bfloat162;
  14. #endif
  15. namespace aphrodite {
  16. // TODO: Further optimize this kernel.
  17. template<typename scalar_t>
  18. __global__ void rms_norm_kernel(
  19. scalar_t* __restrict__ out, // [..., hidden_size]
  20. const scalar_t* __restrict__ input, // [..., hidden_size]
  21. const scalar_t* __restrict__ weight, // [hidden_size]
  22. const float epsilon,
  23. const int num_tokens,
  24. const int hidden_size) {
  25. __shared__ float s_variance;
  26. float variance = 0.0f;
  27. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  28. const float x = (float) input[blockIdx.x * hidden_size + idx];
  29. variance += x * x;
  30. }
  31. variance = blockReduceSum<float>(variance);
  32. if (threadIdx.x == 0) {
  33. s_variance = rsqrtf(variance / hidden_size + epsilon);
  34. }
  35. __syncthreads();
  36. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  37. float x = (float) input[blockIdx.x * hidden_size + idx];
  38. out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
  39. }
  40. }
  41. /* Converter structs for the conversion from torch types to HIP/CUDA types,
  42. and the associated type conversions within HIP/CUDA. These helpers need
  43. to be implemented for now because the relevant type conversion
  44. operators/constructors are not consistently implemented by HIP/CUDA, so
  45. a generic conversion via type casts cannot be implemented.
  46. Each struct should have the member static constexpr bool `exists`:
  47. If false, the optimized kernel is not used for the corresponding torch type.
  48. If true, the struct should be fully defined as shown in the examples below.
  49. */
  50. template<typename torch_type>
  51. struct _typeConvert { static constexpr bool exists = false; };
  52. #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
  53. // CUDA < 12.0 runs into issues with packed type conversion
  54. template<>
  55. struct _typeConvert<c10::Half> {
  56. static constexpr bool exists = true;
  57. using hip_type = __half;
  58. using packed_hip_type = __half2;
  59. __device__ static inline float convert(hip_type x) { return __half2float(x); }
  60. __device__ static inline float2 convert(packed_hip_type x) { return __half22float2(x); }
  61. __device__ static inline hip_type convert(float x) { return __float2half_rn(x); }
  62. __device__ static inline packed_hip_type convert(float2 x) { return __float22half2_rn(x); }
  63. };
  64. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  65. // CUDA_ARCH < 800 does not have BF16 support
  66. // TODO: Add in ROCm support once public headers handle bf16 maturely
  67. template<>
  68. struct _typeConvert<c10::BFloat16> {
  69. static constexpr bool exists = true;
  70. using hip_type = __nv_bfloat16;
  71. using packed_hip_type = __nv_bfloat162;
  72. __device__ static inline float convert(hip_type x) { return __bfloat162float(x); }
  73. __device__ static inline float2 convert(packed_hip_type x) { return __bfloat1622float2(x); }
  74. __device__ static inline hip_type convert(float x) { return __float2bfloat16(x); }
  75. __device__ static inline packed_hip_type convert(float2 x) { return __float22bfloat162_rn(x); }
  76. };
  77. #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  78. #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
  79. /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
  80. for appropriate specializations of fused_add_rms_norm_kernel.
  81. Only functions that are necessary in that kernel are implemented.
  82. Alignment to 16 bytes is required to use 128-bit global memory ops.
  83. */
  84. template<typename scalar_t, int width>
  85. struct alignas(16) _f16Vec {
  86. /* Not theoretically necessary that width is a power of 2 but should
  87. almost always be the case for optimization purposes */
  88. static_assert(width > 0 && (width & (width - 1)) == 0,
  89. "Width is not a positive power of 2!");
  90. using Converter = _typeConvert<scalar_t>;
  91. using T1 = typename Converter::hip_type;
  92. using T2 = typename Converter::packed_hip_type;
  93. T1 data[width];
  94. __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
  95. if constexpr (width % 2 == 0) {
  96. #pragma unroll
  97. for (int i = 0; i < width; i += 2) {
  98. T2 temp{data[i], data[i+1]};
  99. temp += T2{other.data[i], other.data[i+1]};
  100. data[i] = temp.x;
  101. data[i+1] = temp.y;
  102. }
  103. } else {
  104. #pragma unroll
  105. for (int i = 0; i < width; ++i)
  106. data[i] += other.data[i];
  107. }
  108. return *this;
  109. }
  110. __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
  111. if constexpr (width % 2 == 0) {
  112. #pragma unroll
  113. for (int i = 0; i < width; i += 2) {
  114. T2 temp{data[i], data[i+1]};
  115. temp *= T2{other.data[i], other.data[i+1]};
  116. data[i] = temp.x;
  117. data[i+1] = temp.y;
  118. }
  119. } else {
  120. #pragma unroll
  121. for (int i = 0; i < width; ++i)
  122. data[i] *= other.data[i];
  123. }
  124. return *this;
  125. }
  126. __device__ _f16Vec& operator*=(const float scale) {
  127. if constexpr (width % 2 == 0) {
  128. #pragma unroll
  129. for (int i = 0; i < width; i += 2) {
  130. float2 temp_f = Converter::convert(T2{data[i], data[i+1]});
  131. temp_f.x *= scale;
  132. temp_f.y *= scale;
  133. T2 temp = Converter::convert(temp_f);
  134. data[i] = temp.x;
  135. data[i+1] = temp.y;
  136. }
  137. } else {
  138. #pragma unroll
  139. for (int i = 0; i < width; ++i) {
  140. float temp = Converter::convert(data[i]) * scale;
  141. data[i] = Converter::convert(temp);
  142. }
  143. }
  144. return *this;
  145. }
  146. __device__ float sum_squares() const {
  147. float result = 0.0f;
  148. if constexpr (width % 2 == 0) {
  149. #pragma unroll
  150. for (int i = 0; i < width; i += 2) {
  151. float2 z = Converter::convert(T2{data[i], data[i+1]});
  152. result += z.x * z.x + z.y * z.y;
  153. }
  154. } else {
  155. #pragma unroll
  156. for (int i = 0; i < width; ++i) {
  157. float x = Converter::convert(data[i]);
  158. result += x * x;
  159. }
  160. }
  161. return result;
  162. }
  163. };
  164. /* Function specialization in the case of FP16/BF16 tensors.
  165. Additional optimizations we can make in this case are
  166. packed and vectorized operations, which help with the
  167. memory latency bottleneck. */
  168. template<typename scalar_t, int width>
  169. __global__ std::enable_if_t<
  170. (width > 0) && _typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
  171. scalar_t* __restrict__ input, // [..., hidden_size]
  172. scalar_t* __restrict__ residual, // [..., hidden_size]
  173. const scalar_t* __restrict__ weight, // [hidden_size]
  174. const float epsilon,
  175. const int num_tokens,
  176. const int hidden_size) {
  177. // Sanity checks on our vector struct and type-punned pointer arithmetic
  178. static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
  179. static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
  180. const int vec_hidden_size = hidden_size / width;
  181. __shared__ float s_variance;
  182. float variance = 0.0f;
  183. /* These and the argument pointers are all declared `restrict` as they are
  184. not aliased in practice. Argument pointers should not be dereferenced
  185. in this kernel as that would be undefined behavior */
  186. auto* __restrict__ input_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  187. auto* __restrict__ residual_v = reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  188. auto* __restrict__ weight_v = reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
  189. for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
  190. int id = blockIdx.x * vec_hidden_size + idx;
  191. _f16Vec<scalar_t, width> temp = input_v[id];
  192. temp += residual_v[id];
  193. variance += temp.sum_squares();
  194. residual_v[id] = temp;
  195. }
  196. /* Keep the following if-else block in sync with the
  197. calculation of max_block_size in fused_add_rms_norm */
  198. if (num_tokens < 256) {
  199. variance = blockReduceSum<float, 1024>(variance);
  200. } else variance = blockReduceSum<float, 256>(variance);
  201. if (threadIdx.x == 0) {
  202. s_variance = rsqrtf(variance / hidden_size + epsilon);
  203. }
  204. __syncthreads();
  205. for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
  206. int id = blockIdx.x * vec_hidden_size + idx;
  207. _f16Vec<scalar_t, width> temp = residual_v[id];
  208. temp *= s_variance;
  209. temp *= weight_v[idx];
  210. input_v[id] = temp;
  211. }
  212. }
  213. /* Generic fused_add_rms_norm_kernel
  214. The width field is not used here but necessary for other specializations.
  215. */
  216. template<typename scalar_t, int width>
  217. __global__ std::enable_if_t<
  218. (width == 0) || !_typeConvert<scalar_t>::exists> fused_add_rms_norm_kernel(
  219. scalar_t* __restrict__ input, // [..., hidden_size]
  220. scalar_t* __restrict__ residual, // [..., hidden_size]
  221. const scalar_t* __restrict__ weight, // [hidden_size]
  222. const float epsilon,
  223. const int num_tokens,
  224. const int hidden_size) {
  225. __shared__ float s_variance;
  226. float variance = 0.0f;
  227. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  228. scalar_t z = input[blockIdx.x * hidden_size + idx];
  229. z += residual[blockIdx.x * hidden_size + idx];
  230. float x = (float) z;
  231. variance += x * x;
  232. residual[blockIdx.x * hidden_size + idx] = z;
  233. }
  234. /* Keep the following if-else block in sync with the
  235. calculation of max_block_size in fused_add_rms_norm */
  236. if (num_tokens < 256) {
  237. variance = blockReduceSum<float, 1024>(variance);
  238. } else variance = blockReduceSum<float, 256>(variance);
  239. if (threadIdx.x == 0) {
  240. s_variance = rsqrtf(variance / hidden_size + epsilon);
  241. }
  242. __syncthreads();
  243. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  244. float x = (float) residual[blockIdx.x * hidden_size + idx];
  245. input[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
  246. }
  247. }
  248. } // namespace aphrodite
  249. void rms_norm(
  250. torch::Tensor& out, // [..., hidden_size]
  251. torch::Tensor& input, // [..., hidden_size]
  252. torch::Tensor& weight, // [hidden_size]
  253. float epsilon) {
  254. int hidden_size = input.size(-1);
  255. int num_tokens = input.numel() / hidden_size;
  256. dim3 grid(num_tokens);
  257. dim3 block(std::min(hidden_size, 1024));
  258. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  259. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  260. APHRODITE_DISPATCH_FLOATING_TYPES(
  261. input.scalar_type(),
  262. "rms_norm_kernel",
  263. [&] {
  264. aphrodite::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
  265. out.data_ptr<scalar_t>(),
  266. input.data_ptr<scalar_t>(),
  267. weight.data_ptr<scalar_t>(),
  268. epsilon,
  269. num_tokens,
  270. hidden_size);
  271. });
  272. }
  273. #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
  274. APHRODITE_DISPATCH_FLOATING_TYPES( \
  275. input.scalar_type(), \
  276. "fused_add_rms_norm_kernel", \
  277. [&] { \
  278. aphrodite::fused_add_rms_norm_kernel \
  279. <scalar_t, width><<<grid, block, 0, stream>>>( \
  280. input.data_ptr<scalar_t>(), \
  281. residual.data_ptr<scalar_t>(), \
  282. weight.data_ptr<scalar_t>(), \
  283. epsilon, \
  284. num_tokens, \
  285. hidden_size); \
  286. });
  287. void fused_add_rms_norm(
  288. torch::Tensor& input, // [..., hidden_size]
  289. torch::Tensor& residual, // [..., hidden_size]
  290. torch::Tensor& weight, // [hidden_size]
  291. float epsilon) {
  292. int hidden_size = input.size(-1);
  293. int num_tokens = input.numel() / hidden_size;
  294. dim3 grid(num_tokens);
  295. /* This kernel is memory-latency bound in many scenarios.
  296. When num_tokens is large, a smaller block size allows
  297. for increased block occupancy on CUs and better latency
  298. hiding on global mem ops. */
  299. const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  300. dim3 block(std::min(hidden_size, max_block_size));
  301. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  302. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  303. /*If the tensor types are FP16/BF16, try to use the optimized kernel
  304. with packed + vectorized ops.
  305. Max optimization is achieved with a width-8 vector of FP16/BF16s
  306. since we can load at most 128 bits at once in a global memory op.
  307. However, this requires each tensor's data to be aligned to 16
  308. bytes.
  309. */
  310. auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  311. auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
  312. auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
  313. bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 \
  314. && wt_ptr % 16 == 0;
  315. if (ptrs_are_aligned && hidden_size % 8 == 0) {
  316. LAUNCH_FUSED_ADD_RMS_NORM(8);
  317. } else {
  318. LAUNCH_FUSED_ADD_RMS_NORM(0);
  319. }
  320. }