common.cu 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. #ifndef USE_ROCM
  8. #include <cub/util_type.cuh>
  9. #include <cub/cub.cuh>
  10. #else
  11. #include <hipcub/util_type.hpp>
  12. #include <hipcub/hipcub.hpp>
  13. #endif
  14. #ifndef USE_ROCM
  15. using FP8_TYPE = c10::Float8_e4m3fn;
  16. #ifdef _WIN32
  17. #define FP8_E4M3_MAX (std::numeric_limits<FP8_TYPE>::max())
  18. #else
  19. C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
  20. std::numeric_limits<FP8_TYPE>::max();
  21. #endif
  22. #else
  23. #include "amd/hip_float8.h"
  24. using FP8_TYPE = c10::Float8_e4m3fnuz;
  25. // Using the default max value from pytorch (240.0) will cause accuracy
  26. // issue when running dynamic quantization. Here use 224.0f for rocm.
  27. constexpr auto FP8_E4M3_MAX = 224.0f;
  28. #endif
  29. namespace aphrodite {
  30. __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
  31. float old;
  32. old = (value >= 0)
  33. ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
  34. : __uint_as_float(
  35. atomicMin((unsigned int*)addr, __float_as_uint(value)));
  36. return old;
  37. }
  38. template <bool is_scale_inverted>
  39. __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
  40. float const scale) {
  41. float x = 0.0f;
  42. if constexpr (is_scale_inverted) {
  43. x = val * scale;
  44. } else {
  45. x = val / scale;
  46. }
  47. float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
  48. #ifndef USE_ROCM
  49. return static_cast<c10::Float8_e4m3fn>(r);
  50. #else
  51. // Use hardware cvt instruction for fp8 on rocm
  52. return c10::Float8_e4m3fnuz(hip_fp8(r).data,
  53. c10::Float8_e4m3fnuz::from_bits());
  54. #endif
  55. }
  56. // Compute the absolute maximum m of the input tensor and store
  57. // m / float8_e4m3::max() in *scale. Each thread block performs a
  58. // reduction tree and the memory in scale is atomically updated.
  59. // So to get the right answer, *scale needs to be initialized to
  60. // a value <= 0.0 and we need to wait for all thread blocks to
  61. // finish before consuming *scale.
  62. template <typename scalar_t>
  63. __global__ void segmented_max_reduction(float* __restrict__ scale,
  64. const scalar_t* __restrict__ input,
  65. int64_t num_elems) {
  66. __shared__ float cache[1024];
  67. int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
  68. // First store maximum for all values processes by
  69. // the current thread in cache[threadIdx.x]
  70. scalar_t tmp = 0.0;
  71. while (i < num_elems) {
  72. float x = static_cast<float>(input[i]);
  73. tmp = max(tmp, fabs(x));
  74. i += blockDim.x * gridDim.x;
  75. }
  76. cache[threadIdx.x] = tmp;
  77. __syncthreads();
  78. // Now perform parallel reduction within the thread block
  79. int ib = blockDim.x / 2;
  80. while (ib != 0) {
  81. if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
  82. cache[threadIdx.x] = cache[threadIdx.x + ib];
  83. }
  84. __syncthreads();
  85. ib /= 2;
  86. }
  87. // Finally, since cache[0] contains the maximum for this thread block,
  88. // atomically write the max to the target location
  89. if (threadIdx.x == 0) {
  90. atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
  91. }
  92. }
  93. template <typename scalar_t>
  94. struct __align__(8) vec4_t {
  95. scalar_t x;
  96. scalar_t y;
  97. scalar_t z;
  98. scalar_t w;
  99. };
  100. typedef struct __align__(4) {
  101. FP8_TYPE x;
  102. FP8_TYPE y;
  103. FP8_TYPE z;
  104. FP8_TYPE w;
  105. }
  106. float8x4_t;
  107. template <typename scalar_t>
  108. __device__ float thread_max_vec(scalar_t const* __restrict__ input,
  109. int64_t const num_elems, int const tid,
  110. int const step) {
  111. // Vectorized input/output to better utilize memory bandwidth.
  112. vec4_t<scalar_t> const* vectorized_in =
  113. reinterpret_cast<vec4_t<scalar_t> const*>(input);
  114. int64_t const num_vec_elems = num_elems >> 2;
  115. float absmax_val = 0.0f;
  116. #pragma unroll 4
  117. for (int64_t i = tid; i < num_vec_elems; i += step) {
  118. vec4_t<scalar_t> in_vec = vectorized_in[i];
  119. absmax_val = max(absmax_val, fabs(in_vec.x));
  120. absmax_val = max(absmax_val, fabs(in_vec.y));
  121. absmax_val = max(absmax_val, fabs(in_vec.z));
  122. absmax_val = max(absmax_val, fabs(in_vec.w));
  123. }
  124. // Handle the remaining elements if num_elems is not divisible by 4
  125. for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
  126. absmax_val = max(absmax_val, fabs(input[i]));
  127. }
  128. return absmax_val;
  129. }
  130. template <typename scalar_t, bool is_scale_inverted>
  131. __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
  132. scalar_t const* __restrict__ input,
  133. float const scale,
  134. int64_t const num_elems,
  135. int const tid, int const step) {
  136. // Vectorized input/output to better utilize memory bandwidth.
  137. vec4_t<scalar_t> const* vectorized_in =
  138. reinterpret_cast<vec4_t<scalar_t> const*>(input);
  139. float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
  140. int64_t const num_vec_elems = num_elems >> 2;
  141. #pragma unroll 4
  142. for (int64_t i = tid; i < num_vec_elems; i += step) {
  143. vec4_t<scalar_t> in_vec = vectorized_in[i];
  144. float8x4_t out_vec;
  145. out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
  146. static_cast<float>(in_vec.x), scale);
  147. out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
  148. static_cast<float>(in_vec.y), scale);
  149. out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
  150. static_cast<float>(in_vec.z), scale);
  151. out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
  152. static_cast<float>(in_vec.w), scale);
  153. vectorized_out[i] = out_vec;
  154. }
  155. // Handle the remaining elements if num_elems is not divisible by 4
  156. for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
  157. out[i] = scaled_fp8_conversion<is_scale_inverted>(
  158. static_cast<float>(input[i]), scale);
  159. }
  160. }
  161. template <typename scalar_t>
  162. __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
  163. const scalar_t* __restrict__ input,
  164. const float* __restrict__ scale,
  165. int64_t num_elems) {
  166. int tid = blockDim.x * blockIdx.x + threadIdx.x;
  167. // Invert the scale so that we can use multiplications to avoid expensive
  168. // division.
  169. const float inverted_scale = 1.0f / (*scale);
  170. scaled_fp8_conversion_vec<scalar_t, true>(
  171. out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
  172. }
  173. template <typename scalar_t>
  174. __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
  175. FP8_TYPE* __restrict__ out, float* __restrict__ scale,
  176. scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
  177. const int hidden_size) {
  178. float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
  179. int const tid = threadIdx.x;
  180. int const token_idx = blockIdx.x;
  181. scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
  182. FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
  183. // For vectorization, token_input and token_output pointers need to be
  184. // aligned at 8-byte and 4-byte addresses respectively.
  185. bool const can_vectorize = hidden_size % 4 == 0;
  186. float absmax_val = 0.0f;
  187. if (can_vectorize) {
  188. absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
  189. } else {
  190. for (int i = tid; i < hidden_size; i += blockDim.x) {
  191. float const x = static_cast<float>(token_input[i]);
  192. absmax_val = max(absmax_val, fabs(x));
  193. }
  194. }
  195. using BlockReduce = cub::BlockReduce<float, 1024>;
  196. __shared__ typename BlockReduce::TempStorage reduceStorage;
  197. float const block_absmax_val_maybe =
  198. BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
  199. __shared__ float token_scale;
  200. if (tid == 0) {
  201. if (scale_ub) {
  202. token_scale = min(block_absmax_val_maybe, *scale_ub);
  203. } else {
  204. token_scale = block_absmax_val_maybe;
  205. }
  206. // token scale computation
  207. token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
  208. scale[token_idx] = token_scale;
  209. }
  210. __syncthreads();
  211. // Note that we don't use inverted scales so we can match FBGemm impl.
  212. if (can_vectorize) {
  213. scaled_fp8_conversion_vec<scalar_t, false>(
  214. token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
  215. } else {
  216. for (int i = tid; i < hidden_size; i += blockDim.x) {
  217. token_output[i] = scaled_fp8_conversion<false>(
  218. static_cast<float>(token_input[i]), token_scale);
  219. }
  220. }
  221. }
  222. } // namespace aphrodite
  223. void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  224. torch::Tensor const& input, // [..., d]
  225. torch::Tensor const& scale) // [1]
  226. {
  227. int64_t num_tokens = input.numel() / input.size(-1);
  228. int64_t num_elems = input.numel();
  229. dim3 grid(num_tokens);
  230. dim3 block(1024);
  231. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  232. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  233. APHRODITE_DISPATCH_FLOATING_TYPES(
  234. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  235. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  236. <<<grid, block, 0, stream>>>(out.data_ptr<FP8_TYPE>(),
  237. input.data_ptr<scalar_t>(),
  238. scale.data_ptr<float>(), num_elems);
  239. });
  240. }
  241. void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  242. torch::Tensor const& input, // [..., d]
  243. torch::Tensor& scale) // [1]
  244. {
  245. int64_t num_tokens = input.numel() / input.size(-1);
  246. int64_t num_elems = input.numel();
  247. dim3 grid(num_tokens);
  248. dim3 block(1024);
  249. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  250. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  251. APHRODITE_DISPATCH_FLOATING_TYPES(
  252. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  253. aphrodite::segmented_max_reduction<scalar_t>
  254. <<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
  255. input.data_ptr<scalar_t>(), num_elems);
  256. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  257. <<<grid, block, 0, stream>>>(out.data_ptr<FP8_TYPE>(),
  258. input.data_ptr<scalar_t>(),
  259. scale.data_ptr<float>(), num_elems);
  260. });
  261. }
  262. void dynamic_per_token_scaled_fp8_quant(
  263. torch::Tensor& out, // [..., d]
  264. torch::Tensor const& input, // [..., d]
  265. torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
  266. TORCH_CHECK(input.is_contiguous());
  267. TORCH_CHECK(out.is_contiguous());
  268. int const hidden_size = input.size(-1);
  269. int const num_tokens = input.numel() / hidden_size;
  270. dim3 const grid(num_tokens);
  271. dim3 const block(std::min(hidden_size, 1024));
  272. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  273. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  274. APHRODITE_DISPATCH_FLOATING_TYPES(
  275. input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
  276. aphrodite::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
  277. <<<grid, block, 0, stream>>>(
  278. out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
  279. input.data_ptr<scalar_t>(),
  280. scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
  281. hidden_size);
  282. });
  283. }