common.cu 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. #include "../../reduction.cuh"
  8. namespace aphrodite {
  9. __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
  10. float old;
  11. old = (value >= 0)
  12. ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
  13. : __uint_as_float(
  14. atomicMin((unsigned int*)addr, __float_as_uint(value)));
  15. return old;
  16. }
  17. #define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
  18. template <bool is_scale_inverted>
  19. __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
  20. float const val, float const scale) {
  21. float x = 0.0f;
  22. if constexpr (is_scale_inverted) {
  23. x = val * scale;
  24. } else {
  25. x = val / scale;
  26. }
  27. float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
  28. return static_cast<c10::Float8_e4m3fn>(r);
  29. }
  30. // Compute the absolute maximum m of the input tensor and store
  31. // m / float8_e4m3::max() in *scale. Each thread block performs a
  32. // reduction tree and the memory in scale is atomically updated.
  33. // So to get the right answer, *scale needs to be initialized to
  34. // a value <= 0.0 and we need to wait for all thread blocks to
  35. // finish before consuming *scale.
  36. template <typename scalar_t>
  37. __global__ void segmented_max_reduction(float* __restrict__ scale,
  38. const scalar_t* __restrict__ input,
  39. int64_t num_elems) {
  40. __shared__ float cache[1024];
  41. int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
  42. // First store maximum for all values processes by
  43. // the current thread in cache[threadIdx.x]
  44. scalar_t tmp = 0.0;
  45. while (i < num_elems) {
  46. float x = static_cast<float>(input[i]);
  47. tmp = max(tmp, fabs(x));
  48. i += blockDim.x * gridDim.x;
  49. }
  50. cache[threadIdx.x] = tmp;
  51. __syncthreads();
  52. // Now perform parallel reduction within the thread block
  53. int ib = blockDim.x / 2;
  54. while (ib != 0) {
  55. if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
  56. cache[threadIdx.x] = cache[threadIdx.x + ib];
  57. }
  58. __syncthreads();
  59. ib /= 2;
  60. }
  61. // Finally, since cache[0] contains the maximum for this thread block,
  62. // atomically write the max to the target location
  63. if (threadIdx.x == 0) {
  64. atomicMaxFloat(scale,
  65. cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
  66. }
  67. }
  68. template <typename scalar_t>
  69. struct __align__(8) vec4_t {
  70. scalar_t x;
  71. scalar_t y;
  72. scalar_t z;
  73. scalar_t w;
  74. };
  75. typedef struct __align__(4) {
  76. c10::Float8_e4m3fn x;
  77. c10::Float8_e4m3fn y;
  78. c10::Float8_e4m3fn z;
  79. c10::Float8_e4m3fn w;
  80. }
  81. float8x4_t;
  82. template <typename scalar_t>
  83. __device__ float thread_max_vec(scalar_t const* __restrict__ input,
  84. int64_t const num_elems, int const tid,
  85. int const step) {
  86. // Vectorized input/output to better utilize memory bandwidth.
  87. vec4_t<scalar_t> const* vectorized_in =
  88. reinterpret_cast<vec4_t<scalar_t> const*>(input);
  89. int64_t const num_vec_elems = num_elems >> 2;
  90. float absmax_val = 0.0f;
  91. #pragma unroll 4
  92. for (int64_t i = tid; i < num_vec_elems; i += step) {
  93. vec4_t<scalar_t> in_vec = vectorized_in[i];
  94. absmax_val = max(absmax_val, fabs(in_vec.x));
  95. absmax_val = max(absmax_val, fabs(in_vec.y));
  96. absmax_val = max(absmax_val, fabs(in_vec.z));
  97. absmax_val = max(absmax_val, fabs(in_vec.w));
  98. }
  99. // Handle the remaining elements if num_elems is not divisible by 4
  100. for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
  101. absmax_val = max(absmax_val, fabs(input[i]));
  102. }
  103. return absmax_val;
  104. }
  105. template <typename scalar_t, bool is_scale_inverted>
  106. __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
  107. scalar_t const* __restrict__ input,
  108. float const scale,
  109. int64_t const num_elems,
  110. int const tid, int const step) {
  111. // Vectorized input/output to better utilize memory bandwidth.
  112. vec4_t<scalar_t> const* vectorized_in =
  113. reinterpret_cast<vec4_t<scalar_t> const*>(input);
  114. float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
  115. int64_t const num_vec_elems = num_elems >> 2;
  116. #pragma unroll 4
  117. for (int64_t i = tid; i < num_vec_elems; i += step) {
  118. vec4_t<scalar_t> in_vec = vectorized_in[i];
  119. float8x4_t out_vec;
  120. out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
  121. static_cast<float>(in_vec.x), scale);
  122. out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
  123. static_cast<float>(in_vec.y), scale);
  124. out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
  125. static_cast<float>(in_vec.z), scale);
  126. out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
  127. static_cast<float>(in_vec.w), scale);
  128. vectorized_out[i] = out_vec;
  129. }
  130. // Handle the remaining elements if num_elems is not divisible by 4
  131. for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
  132. out[i] = scaled_fp8_conversion<is_scale_inverted>(
  133. static_cast<float>(input[i]), scale);
  134. }
  135. }
  136. template <typename scalar_t>
  137. __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
  138. const scalar_t* __restrict__ input,
  139. const float* __restrict__ scale,
  140. int64_t num_elems) {
  141. int tid = blockDim.x * blockIdx.x + threadIdx.x;
  142. // Invert the scale so that we can use multiplications to avoid expensive
  143. // division.
  144. const float inverted_scale = 1.0f / (*scale);
  145. scaled_fp8_conversion_vec<scalar_t, true>(
  146. out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
  147. }
  148. template <typename scalar_t>
  149. __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
  150. c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
  151. scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
  152. const int hidden_size) {
  153. float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
  154. int const tid = threadIdx.x;
  155. int const token_idx = blockIdx.x;
  156. scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
  157. c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
  158. // For vectorization, token_input and token_output pointers need to be
  159. // aligned at 8-byte and 4-byte addresses respectively.
  160. bool const can_vectorize = hidden_size % 4 == 0;
  161. float absmax_val = 0.0f;
  162. if (can_vectorize) {
  163. absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
  164. } else {
  165. for (int i = tid; i < hidden_size; i += blockDim.x) {
  166. float const x = static_cast<float>(token_input[i]);
  167. absmax_val = max(absmax_val, fabs(x));
  168. }
  169. }
  170. float const block_absmax_val_maybe = blockReduceMax(absmax_val);
  171. __shared__ float token_scale;
  172. if (tid == 0) {
  173. if (scale_ub) {
  174. token_scale = min(block_absmax_val_maybe, *scale_ub);
  175. } else {
  176. token_scale = block_absmax_val_maybe;
  177. }
  178. // token scale computation
  179. token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
  180. scale[token_idx] = token_scale;
  181. }
  182. __syncthreads();
  183. // Note that we don't use inverted scales so we can match FBGemm impl.
  184. if (can_vectorize) {
  185. scaled_fp8_conversion_vec<scalar_t, false>(
  186. token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
  187. } else {
  188. for (int i = tid; i < hidden_size; i += blockDim.x) {
  189. token_output[i] = scaled_fp8_conversion<false>(
  190. static_cast<float>(token_input[i]), token_scale);
  191. }
  192. }
  193. }
  194. } // namespace aphrodite
  195. void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  196. torch::Tensor const& input, // [..., d]
  197. torch::Tensor const& scale) // [1]
  198. {
  199. int64_t num_tokens = input.numel() / input.size(-1);
  200. int64_t num_elems = input.numel();
  201. dim3 grid(num_tokens);
  202. dim3 block(1024);
  203. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  204. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  205. APHRODITE_DISPATCH_FLOATING_TYPES(
  206. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  207. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  208. <<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fn>(),
  209. input.data_ptr<scalar_t>(),
  210. scale.data_ptr<float>(), num_elems);
  211. });
  212. }
  213. void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  214. torch::Tensor const& input, // [..., d]
  215. torch::Tensor& scale) // [1]
  216. {
  217. int64_t num_tokens = input.numel() / input.size(-1);
  218. int64_t num_elems = input.numel();
  219. dim3 grid(num_tokens);
  220. dim3 block(1024);
  221. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  222. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  223. APHRODITE_DISPATCH_FLOATING_TYPES(
  224. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  225. aphrodite::segmented_max_reduction<scalar_t>
  226. <<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
  227. input.data_ptr<scalar_t>(), num_elems);
  228. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  229. <<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fn>(),
  230. input.data_ptr<scalar_t>(),
  231. scale.data_ptr<float>(), num_elems);
  232. });
  233. }
  234. void dynamic_per_token_scaled_fp8_quant(
  235. torch::Tensor& out, // [..., d]
  236. torch::Tensor const& input, // [..., d]
  237. torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
  238. TORCH_CHECK(input.is_contiguous());
  239. TORCH_CHECK(out.is_contiguous());
  240. int const hidden_size = input.size(-1);
  241. int const num_tokens = input.numel() / hidden_size;
  242. dim3 const grid(num_tokens);
  243. dim3 const block(std::min(hidden_size, 1024));
  244. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  245. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  246. APHRODITE_DISPATCH_FLOATING_TYPES(
  247. input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
  248. aphrodite::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
  249. <<<grid, block, 0, stream>>>(
  250. out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
  251. input.data_ptr<scalar_t>(),
  252. scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
  253. hidden_size);
  254. });
  255. }