common.cu 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. #include "../../reduction.cuh"
  8. #ifndef USE_ROCM
  9. using FP8_TYPE = c10::Float8_e4m3fn;
  10. C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
  11. std::numeric_limits<FP8_TYPE>::max();
  12. #else
  13. #include "amd/hip_float8.h"
  14. using FP8_TYPE = c10::Float8_e4m3fnuz;
  15. // Using the default max value from pytorch (240.0) will cause accuracy
  16. // issue when running dynamic quantization. Here use 224.0f for rocm.
  17. constexpr auto FP8_E4M3_MAX = 224.0f;
  18. #endif
  19. namespace aphrodite {
  20. __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
  21. float old;
  22. old = (value >= 0)
  23. ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
  24. : __uint_as_float(
  25. atomicMin((unsigned int*)addr, __float_as_uint(value)));
  26. return old;
  27. }
  28. template <bool is_scale_inverted>
  29. __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
  30. float const scale) {
  31. float x = 0.0f;
  32. if constexpr (is_scale_inverted) {
  33. x = val * scale;
  34. } else {
  35. x = val / scale;
  36. }
  37. float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
  38. #ifndef USE_ROCM
  39. return static_cast<c10::Float8_e4m3fn>(r);
  40. #else
  41. // Use hardware cvt instruction for fp8 on rocm
  42. return c10::Float8_e4m3fnuz(hip_fp8(r).data,
  43. c10::Float8_e4m3fnuz::from_bits());
  44. #endif
  45. }
  46. // Compute the absolute maximum m of the input tensor and store
  47. // m / float8_e4m3::max() in *scale. Each thread block performs a
  48. // reduction tree and the memory in scale is atomically updated.
  49. // So to get the right answer, *scale needs to be initialized to
  50. // a value <= 0.0 and we need to wait for all thread blocks to
  51. // finish before consuming *scale.
  52. template <typename scalar_t>
  53. __global__ void segmented_max_reduction(float* __restrict__ scale,
  54. const scalar_t* __restrict__ input,
  55. int64_t num_elems) {
  56. __shared__ float cache[1024];
  57. int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
  58. // First store maximum for all values processes by
  59. // the current thread in cache[threadIdx.x]
  60. scalar_t tmp = 0.0;
  61. while (i < num_elems) {
  62. float x = static_cast<float>(input[i]);
  63. tmp = max(tmp, fabs(x));
  64. i += blockDim.x * gridDim.x;
  65. }
  66. cache[threadIdx.x] = tmp;
  67. __syncthreads();
  68. // Now perform parallel reduction within the thread block
  69. int ib = blockDim.x / 2;
  70. while (ib != 0) {
  71. if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
  72. cache[threadIdx.x] = cache[threadIdx.x + ib];
  73. }
  74. __syncthreads();
  75. ib /= 2;
  76. }
  77. // Finally, since cache[0] contains the maximum for this thread block,
  78. // atomically write the max to the target location
  79. if (threadIdx.x == 0) {
  80. atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
  81. }
  82. }
  83. template <typename scalar_t>
  84. struct __align__(8) vec4_t {
  85. scalar_t x;
  86. scalar_t y;
  87. scalar_t z;
  88. scalar_t w;
  89. };
  90. typedef struct __align__(4) {
  91. FP8_TYPE x;
  92. FP8_TYPE y;
  93. FP8_TYPE z;
  94. FP8_TYPE w;
  95. }
  96. float8x4_t;
  97. template <typename scalar_t>
  98. __device__ float thread_max_vec(scalar_t const* __restrict__ input,
  99. int64_t const num_elems, int const tid,
  100. int const step) {
  101. // Vectorized input/output to better utilize memory bandwidth.
  102. vec4_t<scalar_t> const* vectorized_in =
  103. reinterpret_cast<vec4_t<scalar_t> const*>(input);
  104. int64_t const num_vec_elems = num_elems >> 2;
  105. float absmax_val = 0.0f;
  106. #pragma unroll 4
  107. for (int64_t i = tid; i < num_vec_elems; i += step) {
  108. vec4_t<scalar_t> in_vec = vectorized_in[i];
  109. absmax_val = max(absmax_val, fabs(in_vec.x));
  110. absmax_val = max(absmax_val, fabs(in_vec.y));
  111. absmax_val = max(absmax_val, fabs(in_vec.z));
  112. absmax_val = max(absmax_val, fabs(in_vec.w));
  113. }
  114. // Handle the remaining elements if num_elems is not divisible by 4
  115. for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
  116. absmax_val = max(absmax_val, fabs(input[i]));
  117. }
  118. return absmax_val;
  119. }
  120. template <typename scalar_t, bool is_scale_inverted>
  121. __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
  122. scalar_t const* __restrict__ input,
  123. float const scale,
  124. int64_t const num_elems,
  125. int const tid, int const step) {
  126. // Vectorized input/output to better utilize memory bandwidth.
  127. vec4_t<scalar_t> const* vectorized_in =
  128. reinterpret_cast<vec4_t<scalar_t> const*>(input);
  129. float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
  130. int64_t const num_vec_elems = num_elems >> 2;
  131. #pragma unroll 4
  132. for (int64_t i = tid; i < num_vec_elems; i += step) {
  133. vec4_t<scalar_t> in_vec = vectorized_in[i];
  134. float8x4_t out_vec;
  135. out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
  136. static_cast<float>(in_vec.x), scale);
  137. out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
  138. static_cast<float>(in_vec.y), scale);
  139. out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
  140. static_cast<float>(in_vec.z), scale);
  141. out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
  142. static_cast<float>(in_vec.w), scale);
  143. vectorized_out[i] = out_vec;
  144. }
  145. // Handle the remaining elements if num_elems is not divisible by 4
  146. for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
  147. out[i] = scaled_fp8_conversion<is_scale_inverted>(
  148. static_cast<float>(input[i]), scale);
  149. }
  150. }
  151. template <typename scalar_t>
  152. __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
  153. const scalar_t* __restrict__ input,
  154. const float* __restrict__ scale,
  155. int64_t num_elems) {
  156. int tid = blockDim.x * blockIdx.x + threadIdx.x;
  157. // Invert the scale so that we can use multiplications to avoid expensive
  158. // division.
  159. const float inverted_scale = 1.0f / (*scale);
  160. scaled_fp8_conversion_vec<scalar_t, true>(
  161. out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
  162. }
  163. template <typename scalar_t>
  164. __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
  165. FP8_TYPE* __restrict__ out, float* __restrict__ scale,
  166. scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
  167. const int hidden_size) {
  168. float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
  169. int const tid = threadIdx.x;
  170. int const token_idx = blockIdx.x;
  171. scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
  172. FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
  173. // For vectorization, token_input and token_output pointers need to be
  174. // aligned at 8-byte and 4-byte addresses respectively.
  175. bool const can_vectorize = hidden_size % 4 == 0;
  176. float absmax_val = 0.0f;
  177. if (can_vectorize) {
  178. absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
  179. } else {
  180. for (int i = tid; i < hidden_size; i += blockDim.x) {
  181. float const x = static_cast<float>(token_input[i]);
  182. absmax_val = max(absmax_val, fabs(x));
  183. }
  184. }
  185. float const block_absmax_val_maybe = blockReduceMax(absmax_val);
  186. __shared__ float token_scale;
  187. if (tid == 0) {
  188. if (scale_ub) {
  189. token_scale = min(block_absmax_val_maybe, *scale_ub);
  190. } else {
  191. token_scale = block_absmax_val_maybe;
  192. }
  193. // token scale computation
  194. token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
  195. scale[token_idx] = token_scale;
  196. }
  197. __syncthreads();
  198. // Note that we don't use inverted scales so we can match FBGemm impl.
  199. if (can_vectorize) {
  200. scaled_fp8_conversion_vec<scalar_t, false>(
  201. token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
  202. } else {
  203. for (int i = tid; i < hidden_size; i += blockDim.x) {
  204. token_output[i] = scaled_fp8_conversion<false>(
  205. static_cast<float>(token_input[i]), token_scale);
  206. }
  207. }
  208. }
  209. } // namespace aphrodite
  210. void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  211. torch::Tensor const& input, // [..., d]
  212. torch::Tensor const& scale) // [1]
  213. {
  214. int64_t num_tokens = input.numel() / input.size(-1);
  215. int64_t num_elems = input.numel();
  216. dim3 grid(num_tokens);
  217. dim3 block(1024);
  218. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  219. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  220. APHRODITE_DISPATCH_FLOATING_TYPES(
  221. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  222. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  223. <<<grid, block, 0, stream>>>(out.data_ptr<FP8_TYPE>(),
  224. input.data_ptr<scalar_t>(),
  225. scale.data_ptr<float>(), num_elems);
  226. });
  227. }
  228. void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  229. torch::Tensor const& input, // [..., d]
  230. torch::Tensor& scale) // [1]
  231. {
  232. int64_t num_tokens = input.numel() / input.size(-1);
  233. int64_t num_elems = input.numel();
  234. dim3 grid(num_tokens);
  235. dim3 block(1024);
  236. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  237. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  238. APHRODITE_DISPATCH_FLOATING_TYPES(
  239. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  240. aphrodite::segmented_max_reduction<scalar_t>
  241. <<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
  242. input.data_ptr<scalar_t>(), num_elems);
  243. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  244. <<<grid, block, 0, stream>>>(out.data_ptr<FP8_TYPE>(),
  245. input.data_ptr<scalar_t>(),
  246. scale.data_ptr<float>(), num_elems);
  247. });
  248. }
  249. void dynamic_per_token_scaled_fp8_quant(
  250. torch::Tensor& out, // [..., d]
  251. torch::Tensor const& input, // [..., d]
  252. torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
  253. TORCH_CHECK(input.is_contiguous());
  254. TORCH_CHECK(out.is_contiguous());
  255. int const hidden_size = input.size(-1);
  256. int const num_tokens = input.numel() / hidden_size;
  257. dim3 const grid(num_tokens);
  258. dim3 const block(std::min(hidden_size, 1024));
  259. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  260. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  261. APHRODITE_DISPATCH_FLOATING_TYPES(
  262. input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
  263. aphrodite::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
  264. <<<grid, block, 0, stream>>>(
  265. out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
  266. input.data_ptr<scalar_t>(),
  267. scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
  268. hidden_size);
  269. });
  270. }