common.cu 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/all.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. namespace aphrodite {
  8. __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
  9. float old;
  10. old = (value >= 0)
  11. ? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
  12. : __uint_as_float(
  13. atomicMin((unsigned int*)addr, __float_as_uint(value)));
  14. return old;
  15. }
  16. #define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
  17. template <typename scalar_t>
  18. __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
  19. const scalar_t val, const float scale) {
  20. float x = static_cast<float>(val) / scale;
  21. float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
  22. return static_cast<c10::Float8_e4m3fn>(r);
  23. }
  24. // Compute the absolute maximum m of the input tensor and store
  25. // m / float8_e4m3::max() in *scale. Each thread block performs a
  26. // reduction tree and the memory in scale is atomically updated.
  27. // So to get the right answer, *scale needs to be initialized to
  28. // a value <= 0.0 and we need to wait for all thread blocks to
  29. // finish before consuming *scale.
  30. template <typename scalar_t>
  31. __global__ void segmented_max_reduction(float* __restrict__ scale,
  32. const scalar_t* __restrict__ input,
  33. int64_t num_elems) {
  34. __shared__ float cache[1024];
  35. int i = blockDim.x * blockIdx.x + threadIdx.x;
  36. // First store maximum for all values processes by
  37. // the current thread in cache[threadIdx.x]
  38. scalar_t tmp = 0.0;
  39. while (i < num_elems) {
  40. float x = static_cast<float>(input[i]);
  41. tmp = max(tmp, fabs(x));
  42. i += blockDim.x * gridDim.x;
  43. }
  44. cache[threadIdx.x] = tmp;
  45. __syncthreads();
  46. // Now perform parallel reduction within the thread block
  47. int ib = blockDim.x / 2;
  48. while (ib != 0) {
  49. if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
  50. cache[threadIdx.x] = cache[threadIdx.x + ib];
  51. }
  52. __syncthreads();
  53. ib /= 2;
  54. }
  55. // Finally, since cache[0] contains the maximum for this thread block,
  56. // atomically write the max to the target location
  57. if (threadIdx.x == 0) {
  58. atomicMaxFloat(scale,
  59. cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
  60. }
  61. }
  62. template <typename scalar_t>
  63. __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
  64. const scalar_t* __restrict__ input,
  65. const float* __restrict__ scale,
  66. int64_t num_elems) {
  67. int i = blockDim.x * blockIdx.x + threadIdx.x;
  68. while (i < num_elems) {
  69. out[i] = scaled_fp8_conversion(input[i], *scale);
  70. i += blockDim.x * gridDim.x;
  71. }
  72. }
  73. } // namespace aphrodite
  74. void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  75. torch::Tensor& input, // [..., d]
  76. torch::Tensor& scale) // [1]
  77. {
  78. int64_t num_tokens = input.numel() / input.size(-1);
  79. int64_t num_elems = input.numel();
  80. dim3 grid(num_tokens);
  81. dim3 block(1024);
  82. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  83. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  84. APHRODITE_DISPATCH_FLOATING_TYPES(
  85. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  86. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  87. <<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fn>(),
  88. input.data_ptr<scalar_t>(),
  89. scale.data_ptr<float>(), num_elems);
  90. });
  91. }
  92. void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
  93. torch::Tensor& input, // [..., d]
  94. torch::Tensor& scale) // [1]
  95. {
  96. int64_t num_tokens = input.numel() / input.size(-1);
  97. int64_t num_elems = input.numel();
  98. dim3 grid(num_tokens);
  99. dim3 block(1024);
  100. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  101. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  102. APHRODITE_DISPATCH_FLOATING_TYPES(
  103. input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
  104. aphrodite::segmented_max_reduction<scalar_t>
  105. <<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
  106. input.data_ptr<scalar_t>(), num_elems);
  107. aphrodite::scaled_fp8_quant_kernel<scalar_t>
  108. <<<grid, block, 0, stream>>>(out.data_ptr<c10::Float8_e4m3fn>(),
  109. input.data_ptr<scalar_t>(),
  110. scale.data_ptr<float>(), num_elems);
  111. });
  112. }