layernorm_kernels.cu 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. #include <torch/all.h>
  2. #include <ATen/cuda/CUDAContext.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include "dispatch_utils.h"
  5. #include "reduction.cuh"
  6. #ifndef USE_ROCM
  7. #include <cuda_bf16.h>
  8. #include <cuda_fp16.h>
  9. #else
  10. #include <hip/hip_bf16.h>
  11. #include <hip/hip_fp16.h>
  12. using __nv_bfloat16 = __hip_bfloat16;
  13. using __nv_bfloat162 = __hip_bfloat162;
  14. #endif
  15. namespace aphrodite {
  16. // TODO: Further optimize this kernel.
  17. template <typename scalar_t>
  18. __global__ void rms_norm_kernel(
  19. scalar_t* __restrict__ out, // [..., hidden_size]
  20. const scalar_t* __restrict__ input, // [..., hidden_size]
  21. const scalar_t* __restrict__ weight, // [hidden_size]
  22. const float epsilon, const int num_tokens, const int hidden_size) {
  23. __shared__ float s_variance;
  24. float variance = 0.0f;
  25. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  26. const float x = (float)input[blockIdx.x * hidden_size + idx];
  27. variance += x * x;
  28. }
  29. variance = blockReduceSum<float>(variance);
  30. if (threadIdx.x == 0) {
  31. s_variance = rsqrtf(variance / hidden_size + epsilon);
  32. }
  33. __syncthreads();
  34. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  35. float x = (float)input[blockIdx.x * hidden_size + idx];
  36. out[blockIdx.x * hidden_size + idx] =
  37. ((scalar_t)(x * s_variance)) * weight[idx];
  38. }
  39. }
  40. /* Converter structs for the conversion from torch types to HIP/CUDA types,
  41. and the associated type conversions within HIP/CUDA. These helpers need
  42. to be implemented for now because the relevant type conversion
  43. operators/constructors are not consistently implemented by HIP/CUDA, so
  44. a generic conversion via type casts cannot be implemented.
  45. Each struct should have the member static constexpr bool `exists`:
  46. If false, the optimized kernel is not used for the corresponding torch type.
  47. If true, the struct should be fully defined as shown in the examples below.
  48. */
  49. template <typename torch_type>
  50. struct _typeConvert {
  51. static constexpr bool exists = false;
  52. };
  53. #if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
  54. // CUDA < 12.0 runs into issues with packed type conversion
  55. template <>
  56. struct _typeConvert<c10::Half> {
  57. static constexpr bool exists = true;
  58. using hip_type = __half;
  59. using packed_hip_type = __half2;
  60. __device__ static inline float convert(hip_type x) { return __half2float(x); }
  61. __device__ static inline float2 convert(packed_hip_type x) {
  62. return __half22float2(x);
  63. }
  64. __device__ static inline hip_type convert(float x) {
  65. return __float2half_rn(x);
  66. }
  67. __device__ static inline packed_hip_type convert(float2 x) {
  68. return __float22half2_rn(x);
  69. }
  70. };
  71. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  72. // CUDA_ARCH < 800 does not have BF16 support
  73. // TODO: Add in ROCm support once public headers handle bf16 maturely
  74. template <>
  75. struct _typeConvert<c10::BFloat16> {
  76. static constexpr bool exists = true;
  77. using hip_type = __nv_bfloat16;
  78. using packed_hip_type = __nv_bfloat162;
  79. __device__ static inline float convert(hip_type x) {
  80. return __bfloat162float(x);
  81. }
  82. __device__ static inline float2 convert(packed_hip_type x) {
  83. return __bfloat1622float2(x);
  84. }
  85. __device__ static inline hip_type convert(float x) {
  86. return __float2bfloat16(x);
  87. }
  88. __device__ static inline packed_hip_type convert(float2 x) {
  89. return __float22bfloat162_rn(x);
  90. }
  91. };
  92. #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  93. #endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
  94. // 12000))
  95. /* Vector POD struct to generate vectorized and packed FP16/BF16 ops
  96. for appropriate specializations of fused_add_rms_norm_kernel.
  97. Only functions that are necessary in that kernel are implemented.
  98. Alignment to 16 bytes is required to use 128-bit global memory ops.
  99. */
  100. template <typename scalar_t, int width>
  101. struct alignas(16) _f16Vec {
  102. /* Not theoretically necessary that width is a power of 2 but should
  103. almost always be the case for optimization purposes */
  104. static_assert(width > 0 && (width & (width - 1)) == 0,
  105. "Width is not a positive power of 2!");
  106. using Converter = _typeConvert<scalar_t>;
  107. using T1 = typename Converter::hip_type;
  108. using T2 = typename Converter::packed_hip_type;
  109. T1 data[width];
  110. __device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
  111. if constexpr (width % 2 == 0) {
  112. #pragma unroll
  113. for (int i = 0; i < width; i += 2) {
  114. T2 temp{data[i], data[i + 1]};
  115. temp += T2{other.data[i], other.data[i + 1]};
  116. data[i] = temp.x;
  117. data[i + 1] = temp.y;
  118. }
  119. } else {
  120. #pragma unroll
  121. for (int i = 0; i < width; ++i) data[i] += other.data[i];
  122. }
  123. return *this;
  124. }
  125. __device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
  126. if constexpr (width % 2 == 0) {
  127. #pragma unroll
  128. for (int i = 0; i < width; i += 2) {
  129. T2 temp{data[i], data[i + 1]};
  130. temp *= T2{other.data[i], other.data[i + 1]};
  131. data[i] = temp.x;
  132. data[i + 1] = temp.y;
  133. }
  134. } else {
  135. #pragma unroll
  136. for (int i = 0; i < width; ++i) data[i] *= other.data[i];
  137. }
  138. return *this;
  139. }
  140. __device__ _f16Vec& operator*=(const float scale) {
  141. if constexpr (width % 2 == 0) {
  142. #pragma unroll
  143. for (int i = 0; i < width; i += 2) {
  144. float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
  145. temp_f.x *= scale;
  146. temp_f.y *= scale;
  147. T2 temp = Converter::convert(temp_f);
  148. data[i] = temp.x;
  149. data[i + 1] = temp.y;
  150. }
  151. } else {
  152. #pragma unroll
  153. for (int i = 0; i < width; ++i) {
  154. float temp = Converter::convert(data[i]) * scale;
  155. data[i] = Converter::convert(temp);
  156. }
  157. }
  158. return *this;
  159. }
  160. __device__ float sum_squares() const {
  161. float result = 0.0f;
  162. if constexpr (width % 2 == 0) {
  163. #pragma unroll
  164. for (int i = 0; i < width; i += 2) {
  165. float2 z = Converter::convert(T2{data[i], data[i + 1]});
  166. result += z.x * z.x + z.y * z.y;
  167. }
  168. } else {
  169. #pragma unroll
  170. for (int i = 0; i < width; ++i) {
  171. float x = Converter::convert(data[i]);
  172. result += x * x;
  173. }
  174. }
  175. return result;
  176. }
  177. };
  178. /* Function specialization in the case of FP16/BF16 tensors.
  179. Additional optimizations we can make in this case are
  180. packed and vectorized operations, which help with the
  181. memory latency bottleneck. */
  182. template <typename scalar_t, int width>
  183. __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
  184. fused_add_rms_norm_kernel(
  185. scalar_t* __restrict__ input, // [..., hidden_size]
  186. scalar_t* __restrict__ residual, // [..., hidden_size]
  187. const scalar_t* __restrict__ weight, // [hidden_size]
  188. const float epsilon, const int num_tokens, const int hidden_size) {
  189. // Sanity checks on our vector struct and type-punned pointer arithmetic
  190. static_assert(std::is_standard_layout_v<_f16Vec<scalar_t, width>> &&
  191. std::is_trivial_v<_f16Vec<scalar_t, width>>);
  192. static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
  193. const int vec_hidden_size = hidden_size / width;
  194. __shared__ float s_variance;
  195. float variance = 0.0f;
  196. /* These and the argument pointers are all declared `restrict` as they are
  197. not aliased in practice. Argument pointers should not be dereferenced
  198. in this kernel as that would be undefined behavior */
  199. auto* __restrict__ input_v =
  200. reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
  201. auto* __restrict__ residual_v =
  202. reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
  203. auto* __restrict__ weight_v =
  204. reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
  205. for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
  206. int id = blockIdx.x * vec_hidden_size + idx;
  207. _f16Vec<scalar_t, width> temp = input_v[id];
  208. temp += residual_v[id];
  209. variance += temp.sum_squares();
  210. residual_v[id] = temp;
  211. }
  212. /* Keep the following if-else block in sync with the
  213. calculation of max_block_size in fused_add_rms_norm */
  214. if (num_tokens < 256) {
  215. variance = blockReduceSum<float, 1024>(variance);
  216. } else
  217. variance = blockReduceSum<float, 256>(variance);
  218. if (threadIdx.x == 0) {
  219. s_variance = rsqrtf(variance / hidden_size + epsilon);
  220. }
  221. __syncthreads();
  222. for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
  223. int id = blockIdx.x * vec_hidden_size + idx;
  224. _f16Vec<scalar_t, width> temp = residual_v[id];
  225. temp *= s_variance;
  226. temp *= weight_v[idx];
  227. input_v[id] = temp;
  228. }
  229. }
  230. /* Generic fused_add_rms_norm_kernel
  231. The width field is not used here but necessary for other specializations.
  232. */
  233. template <typename scalar_t, int width>
  234. __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
  235. fused_add_rms_norm_kernel(
  236. scalar_t* __restrict__ input, // [..., hidden_size]
  237. scalar_t* __restrict__ residual, // [..., hidden_size]
  238. const scalar_t* __restrict__ weight, // [hidden_size]
  239. const float epsilon, const int num_tokens, const int hidden_size) {
  240. __shared__ float s_variance;
  241. float variance = 0.0f;
  242. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  243. scalar_t z = input[blockIdx.x * hidden_size + idx];
  244. z += residual[blockIdx.x * hidden_size + idx];
  245. float x = (float)z;
  246. variance += x * x;
  247. residual[blockIdx.x * hidden_size + idx] = z;
  248. }
  249. /* Keep the following if-else block in sync with the
  250. calculation of max_block_size in fused_add_rms_norm */
  251. if (num_tokens < 256) {
  252. variance = blockReduceSum<float, 1024>(variance);
  253. } else
  254. variance = blockReduceSum<float, 256>(variance);
  255. if (threadIdx.x == 0) {
  256. s_variance = rsqrtf(variance / hidden_size + epsilon);
  257. }
  258. __syncthreads();
  259. for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
  260. float x = (float)residual[blockIdx.x * hidden_size + idx];
  261. input[blockIdx.x * hidden_size + idx] =
  262. ((scalar_t)(x * s_variance)) * weight[idx];
  263. }
  264. }
  265. } // namespace aphrodite
  266. void rms_norm(torch::Tensor& out, // [..., hidden_size]
  267. torch::Tensor& input, // [..., hidden_size]
  268. torch::Tensor& weight, // [hidden_size]
  269. double epsilon) {
  270. int hidden_size = input.size(-1);
  271. int num_tokens = input.numel() / hidden_size;
  272. dim3 grid(num_tokens);
  273. dim3 block(std::min(hidden_size, 1024));
  274. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  275. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  276. APHRODITE_DISPATCH_FLOATING_TYPES(
  277. input.scalar_type(), "rms_norm_kernel", [&] {
  278. aphrodite::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
  279. out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
  280. weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
  281. });
  282. }
  283. #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
  284. APHRODITE_DISPATCH_FLOATING_TYPES( \
  285. input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
  286. aphrodite::fused_add_rms_norm_kernel<scalar_t, width> \
  287. <<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
  288. residual.data_ptr<scalar_t>(), \
  289. weight.data_ptr<scalar_t>(), epsilon, \
  290. num_tokens, hidden_size); \
  291. });
  292. void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
  293. torch::Tensor& residual, // [..., hidden_size]
  294. torch::Tensor& weight, // [hidden_size]
  295. double epsilon) {
  296. int hidden_size = input.size(-1);
  297. int num_tokens = input.numel() / hidden_size;
  298. dim3 grid(num_tokens);
  299. /* This kernel is memory-latency bound in many scenarios.
  300. When num_tokens is large, a smaller block size allows
  301. for increased block occupancy on CUs and better latency
  302. hiding on global mem ops. */
  303. const int max_block_size = (num_tokens < 256) ? 1024 : 256;
  304. dim3 block(std::min(hidden_size, max_block_size));
  305. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  306. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  307. /*If the tensor types are FP16/BF16, try to use the optimized kernel
  308. with packed + vectorized ops.
  309. Max optimization is achieved with a width-8 vector of FP16/BF16s
  310. since we can load at most 128 bits at once in a global memory op.
  311. However, this requires each tensor's data to be aligned to 16
  312. bytes.
  313. */
  314. auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
  315. auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
  316. auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
  317. bool ptrs_are_aligned =
  318. inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
  319. if (ptrs_are_aligned && hidden_size % 8 == 0) {
  320. LAUNCH_FUSED_ADD_RMS_NORM(8);
  321. } else {
  322. LAUNCH_FUSED_ADD_RMS_NORM(0);
  323. }
  324. }