common.cu 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. #include "../../reduction.cuh"
  8. #ifndef USE_ROCM
  9. using FP8_TYPE = c10::Float8_e4m3fn;
  10. #ifdef _WIN32
  11. #define FP8_E4M3_MAX (std::numeric_limits<FP8_TYPE>::max())
  12. #else
  13. C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
  14. std::numeric_limits<FP8_TYPE>::max();
  15. #endif
  16. #else
  17. #include "amd/hip_float8.h"
  18. using FP8_TYPE = c10::Float8_e4m3fnuz;
  19. // Using the default max value from pytorch (240.0) will cause accuracy
  20. // issue when running dynamic quantization. Here use 224.0f for rocm.
  21. constexpr auto FP8_E4M3_MAX = 224.0f;
  22. #endif
  23. namespace aphrodite {
  24. __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
  25. float old;
  26. old = (value >= 0)
  27. ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
  28. : __uint_as_float(
  29. atomicMin((unsigned int*)addr, __float_as_uint(value)));
  30. return old;
  31. }
  32. template <bool is_scale_inverted>
  33. __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
  34. float const scale) {
  35. float x = 0.0f;
  36. if constexpr (is_scale_inverted) {
  37. x = val * scale;
  38. } else {
  39. x = val / scale;
  40. }
  41. float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
  42. #ifndef USE_ROCM
  43. return static_cast<c10::Float8_e4m3fn>(r);
  44. #else
  45. // Use hardware cvt instruction for fp8 on rocm
  46. return c10::Float8_e4m3fnuz(hip_fp8(r).data,
  47. c10::Float8_e4m3fnuz::from_bits());
  48. #endif
  49. }
  50. // Compute the absolute maximum m of the input tensor and store
  51. // m / float8_e4m3::max() in *scale. Each thread block performs a
  52. // reduction tree and the memory in scale is atomically updated.
  53. // So to get the right answer, *scale needs to be initialized to
  54. // a value <= 0.0 and we need to wait for all thread blocks to
  55. // finish before consuming *scale.
  56. template <typename scalar_t>
  57. __global__ void segmented_max_reduction(float* __restrict__ scale,
  58. const scalar_t* __restrict__ input,
  59. int64_t num_elems) {
  60. __shared__ float cache[1024];
  61. int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
  62. // First store maximum for all values processes by
  63. // the current thread in cache[threadIdx.x]
  64. scalar_t tmp = 0.0;
  65. while (i < num_elems) {
  66. float x = static_cast<float>(input[i]);
  67. tmp = max(tmp, fabs(x));
  68. i += blockDim.x * gridDim.x;
  69. }
  70. cache[threadIdx.x] = tmp;
  71. __syncthreads();
  72. // Now perform parallel reduction within the thread block
  73. int ib = blockDim.x / 2;
  74. while (ib != 0) {
  75. if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
  76. cache[threadIdx.x] = cache[threadIdx.x + ib];
  77. }
  78. __syncthreads();
  79. ib /= 2;
  80. }
  81. // Finally, since cache[0] contains the maximum for this thread block,
  82. // atomically write the max to the target location
  83. if (threadIdx.x == 0) {
  84. atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
  85. }
  86. }
  87. template <typename scalar_t>
  88. struct __align__(8) vec4_t {
  89. scalar_t x;
  90. scalar_t y;
  91. scalar_t z;
  92. scalar_t w;
  93. };
  94. typedef struct __align__(4) {
  95. FP8_TYPE x;
  96. FP8_TYPE y;
  97. FP8_TYPE z;
  98. FP8_TYPE w;
  99. }
  100. float8x4_t;
  101. template <typename scalar_t>
  102. __device__ float thread_max_vec(scalar_t const* __restrict__ input,
  103. int64_t const num_elems, int const tid,
  104. int const step) {
  105. // Vectorized input/output to better utilize memory bandwidth.
  106. vec4_t<scalar_t> const* vectorized_in =
  107. reinterpret_cast<vec4_t<scalar_t> const*>(input);
  108. int64_t const num_vec_elems = num_elems >> 2;
  109. float absmax_val = 0.0f;
  110. #pragma unroll 4
  111. for (int64_t i = tid; i < num_vec_elems; i += step) {
  112. vec4_t<scalar_t> in_vec = vectorized_in[i];
  113. absmax_val = max(absmax_val, fabs(in_vec.x));
  114. absmax_val = max(absmax_val, fabs(in_vec.y));
  115. absmax_val = max(absmax_val, fabs(in_vec.z));
  116. absmax_val = max(absmax_val, fabs(in_vec.w));
  117. }
  118. // Handle the remaining elements if num_elems is not divisible by 4
  119. for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
  120. absmax_val = max(absmax_val, fabs(input[i]));
  121. }
  122. return absmax_val;
  123. }
  124. template <typename scalar_t, bool is_scale_inverted>
  125. __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
  126. scalar_t const* __restrict__ input,
  127. float const scale,
  128. int64_t const num_elems,
  129. int const tid, int const step) {
  130. // Vectorized input/output to better utilize memory bandwidth.
  131. vec4_t<scalar_t> const* vectorized_in =
  132. reinterpret_cast<vec4_t<scalar_t> const*>(input);
  133. float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
  134. int64_t const num_vec_elems = num_elems >> 2;
  135. #pragma unroll 4
  136. for (int64_t i = tid; i < num_vec_elems; i += step) {
  137. vec4_t<scalar_t> in_vec = vectorized_in[i];
  138. float8x4_t out_vec;
  139. out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
  140. static_cast<float>(in_vec.x), scale);
  141. out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
  142. static_cast<float>(in_vec.y), scale);
  143. out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
  144. static_cast<float>(in_vec.z), scale);
  145. out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
  146. static_cast<float>(in_vec.w), scale);
  147. vectorized_out[i] = out_vec;
  148. }
  149. // Handle the remaining elements if num_elems is not divisible by 4
  150. for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
  151. out[i] = scaled_fp8_conversion<is_scale_inverted>(
  152. static_cast<float>(input[i]), scale);
  153. }
  154. }
  155. template <typename scalar_t>
  156. __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
  157. const scalar_t* __restrict__ input,
  158. const float* __restrict__ scale,
  159. int64_t num_elems) {
  160. int tid = blockDim.x * blockIdx.x + threadIdx.x;
  161. // Invert the scale so that we can use multiplications to avoid expensive
  162. // division.
  163. const float inverted_scale = 1.0f / (*scale);
  164. scaled_fp8_conversion_vec<scalar_t, true>(
  165. out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
  166. }
  167. template <typename scalar_t>
  168. __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
  169. FP8_TYPE* __restrict__ out, float* __restrict__ scale,
  170. scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
  171. const int hidden_size) {
  172. float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
  173. int const tid = threadIdx.x;
  174. int const token_idx = blockIdx.x;
  175. scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
  176. FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
  177. // For vectorization, token_input and token_output pointers need to be
  178. // aligned at 8-byte and 4-byte addresses respectively.
  179. bool const can_vectorize = hidden_size % 4 == 0;
  180. float absmax_val = 0.0f;
  181. if (can_vectorize) {
  182. absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
  183. } else {
  184. for (int i = tid; i < hidden_size; i += blockDim.x) {
  185. float const x = static_cast<float>(token_input[i]);
  186. absmax_val = max(absmax_val, fabs(x));
  187. }
  188. }
  189. float const block_absmax_val_maybe = blockReduceMax(absmax_val);
  190. __shared__ float token_scale;
  191. if (tid == 0) {
  192. if (scale_ub) {
  193. token_scale = min(block_absmax_val_maybe, *scale_ub);
  194. } else {
  195. token_scale = block_absmax_val_maybe;
  196. }
  197. // token scale computation
  198. token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
  199. scale[token_idx] = token_scale;
  200. }
  201. __syncthreads();
  202. // Note that we don't use inverted scales so we can match FBGemm impl.
  203. if (can_vectorize) {
  204. scaled_fp8_conversion_vec<scalar_t, false>(
  205. token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
  206. } else {
  207. for (int i = tid; i < hidden_size; i += blockDim.x) {
  208. token_output[i] = scaled_fp8_conversion<false>(
  209. static_cast<float>(token_input[i]), token_scale);
  210. }
  211. }
  212. }
  213. } // namespace aphrodite
  214. void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  215. torch::Tensor const& input, // [..., d]
  216. torch::Tensor const& scale) // [1]
  217. {
  218. int64_t num_tokens = input.numel() / input.size(-1);
  219. int64_t num_elems = input.numel();
  220. dim3 grid(num_tokens);
  221. dim3 block(1024);
  222. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  223. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  224. APHRODITE_DISPATCH_FLOATING_TYPES(
  225. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  226. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  227. <<<grid, block, 0, stream>>>(out.data_ptr<FP8_TYPE>(),
  228. input.data_ptr<scalar_t>(),
  229. scale.data_ptr<float>(), num_elems);
  230. });
  231. }
  232. void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  233. torch::Tensor const& input, // [..., d]
  234. torch::Tensor& scale) // [1]
  235. {
  236. int64_t num_tokens = input.numel() / input.size(-1);
  237. int64_t num_elems = input.numel();
  238. dim3 grid(num_tokens);
  239. dim3 block(1024);
  240. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  241. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  242. APHRODITE_DISPATCH_FLOATING_TYPES(
  243. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  244. aphrodite::segmented_max_reduction<scalar_t>
  245. <<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
  246. input.data_ptr<scalar_t>(), num_elems);
  247. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  248. <<<grid, block, 0, stream>>>(out.data_ptr<FP8_TYPE>(),
  249. input.data_ptr<scalar_t>(),
  250. scale.data_ptr<float>(), num_elems);
  251. });
  252. }
  253. void dynamic_per_token_scaled_fp8_quant(
  254. torch::Tensor& out, // [..., d]
  255. torch::Tensor const& input, // [..., d]
  256. torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
  257. TORCH_CHECK(input.is_contiguous());
  258. TORCH_CHECK(out.is_contiguous());
  259. int const hidden_size = input.size(-1);
  260. int const num_tokens = input.numel() / hidden_size;
  261. dim3 const grid(num_tokens);
  262. dim3 const block(std::min(hidden_size, 1024));
  263. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  264. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  265. APHRODITE_DISPATCH_FLOATING_TYPES(
  266. input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
  267. aphrodite::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
  268. <<<grid, block, 0, stream>>>(
  269. out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
  270. input.data_ptr<scalar_t>(),
  271. scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
  272. hidden_size);
  273. });
  274. }