common.cu 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. namespace aphrodite {
  8. __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
  9. float old;
  10. old = (value >= 0)
  11. ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
  12. : __uint_as_float(
  13. atomicMin((unsigned int*)addr, __float_as_uint(value)));
  14. return old;
  15. }
  16. #define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
  17. template <typename scalar_t>
  18. __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
  19. const scalar_t val, const float inverted_scale) {
  20. float x = static_cast<float>(val) * inverted_scale;
  21. float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
  22. return static_cast<c10::Float8_e4m3fn>(r);
  23. }
  24. // Compute the absolute maximum m of the input tensor and store
  25. // m / float8_e4m3::max() in *scale. Each thread block performs a
  26. // reduction tree and the memory in scale is atomically updated.
  27. // So to get the right answer, *scale needs to be initialized to
  28. // a value <= 0.0 and we need to wait for all thread blocks to
  29. // finish before consuming *scale.
  30. template <typename scalar_t>
  31. __global__ void segmented_max_reduction(float* __restrict__ scale,
  32. const scalar_t* __restrict__ input,
  33. int64_t num_elems) {
  34. __shared__ float cache[1024];
  35. int i = blockDim.x * blockIdx.x + threadIdx.x;
  36. // First store maximum for all values processes by
  37. // the current thread in cache[threadIdx.x]
  38. scalar_t tmp = 0.0;
  39. while (i < num_elems) {
  40. float x = static_cast<float>(input[i]);
  41. tmp = max(tmp, fabs(x));
  42. i += blockDim.x * gridDim.x;
  43. }
  44. cache[threadIdx.x] = tmp;
  45. __syncthreads();
  46. // Now perform parallel reduction within the thread block
  47. int ib = blockDim.x / 2;
  48. while (ib != 0) {
  49. if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
  50. cache[threadIdx.x] = cache[threadIdx.x + ib];
  51. }
  52. __syncthreads();
  53. ib /= 2;
  54. }
  55. // Finally, since cache[0] contains the maximum for this thread block,
  56. // atomically write the max to the target location
  57. if (threadIdx.x == 0) {
  58. atomicMaxFloat(scale,
  59. cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
  60. }
  61. }
  62. template <typename scalar_t>
  63. struct __align__(8) vec4_t {
  64. scalar_t x;
  65. scalar_t y;
  66. scalar_t z;
  67. scalar_t w;
  68. };
  69. typedef struct __align__(4) {
  70. c10::Float8_e4m3fn x;
  71. c10::Float8_e4m3fn y;
  72. c10::Float8_e4m3fn z;
  73. c10::Float8_e4m3fn w;
  74. }
  75. float8x4_t;
  76. template <typename scalar_t>
  77. __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
  78. const scalar_t* __restrict__ input,
  79. const float* __restrict__ scale,
  80. int64_t num_elems) {
  81. int tid = blockDim.x * blockIdx.x + threadIdx.x;
  82. // Invert the scale so that we can use multiplications to avoid expensive
  83. // division.
  84. const float inverted_scale = 1.0f / (*scale);
  85. // Vectorized input/output to better utilize memory bandwidth.
  86. const vec4_t<scalar_t>* vectorized_in =
  87. reinterpret_cast<const vec4_t<scalar_t>*>(input);
  88. float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
  89. int num_vec_elems = num_elems >> 2;
  90. #pragma unroll 4
  91. for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
  92. vec4_t<scalar_t> in_vec = vectorized_in[i];
  93. float8x4_t out_vec;
  94. out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
  95. out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
  96. out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
  97. out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
  98. vectorized_out[i] = out_vec;
  99. }
  100. // Handle the remaining elements if num_elems is not divisible by 4
  101. for (int i = num_vec_elems * 4 + tid; i < num_elems;
  102. i += blockDim.x * gridDim.x) {
  103. out[i] = scaled_fp8_conversion(input[i], inverted_scale);
  104. }
  105. }
  106. } // namespace aphrodite
  107. void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  108. torch::Tensor& input, // [..., d]
  109. torch::Tensor& scale) // [1]
  110. {
  111. int64_t num_tokens = input.numel() / input.size(-1);
  112. int64_t num_elems = input.numel();
  113. dim3 grid(num_tokens);
  114. dim3 block(1024);
  115. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  116. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  117. APHRODITE_DISPATCH_FLOATING_TYPES(
  118. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  119. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  120. <<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fn>(),
  121. input.data_ptr<scalar_t>(),
  122. scale.data_ptr<float>(), num_elems);
  123. });
  124. }
  125. void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  126. torch::Tensor& input, // [..., d]
  127. torch::Tensor& scale) // [1]
  128. {
  129. int64_t num_tokens = input.numel() / input.size(-1);
  130. int64_t num_elems = input.numel();
  131. dim3 grid(num_tokens);
  132. dim3 block(1024);
  133. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  134. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  135. APHRODITE_DISPATCH_FLOATING_TYPES(
  136. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  137. aphrodite::segmented_max_reduction<scalar_t>
  138. <<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
  139. input.data_ptr<scalar_t>(), num_elems);
  140. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  141. <<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fn>(),
  142. input.data_ptr<scalar_t>(),
  143. scale.data_ptr<float>(), num_elems);
  144. });
  145. }