fp8_cuda_kernels.cu 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #include <ATen/cuda/CUDAContext.h>
  2. #include <torch/extension.h>
  3. #include <c10/cuda/CUDAGuard.h>
  4. #include <cmath>
  5. #include "cuda_compat.h"
  6. #include "dispatch_utils.h"
  7. namespace aphrodite {
  8. __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
  9. float old;
  10. old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) :
  11. __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value)));
  12. return old;
  13. }
  14. // Compute the absolute maximum m of the input tensor and store
  15. // m / float8_e4m3::max() in *scale. Each thread block performs a
  16. // reduction tree and the memory in scale is atomically updated.
  17. // So to get the right answer, *scale needs to be initialized to
  18. // a value <= 0.0 and we need to wait for all thread blocks to
  19. // finish before consuming *scale.
  20. template<typename scalar_t>
  21. __global__ void segmented_max_reduction(
  22. float* __restrict__ scale,
  23. const scalar_t* __restrict__ input,
  24. int64_t num_elems) {
  25. __shared__ float cache[1024];
  26. int i = blockDim.x * blockIdx.x + threadIdx.x;
  27. // First store maximum for all values processes by
  28. // the current thread in cache[threadIdx.x]
  29. scalar_t tmp = 0.0;
  30. while (i < num_elems) {
  31. float x = static_cast<float>(input[i]);
  32. tmp = max(tmp, fabs(x));
  33. i += blockDim.x * gridDim.x;
  34. }
  35. cache[threadIdx.x] = tmp;
  36. __syncthreads();
  37. // Now perform parallel reduction within the thread block
  38. int ib = blockDim.x / 2;
  39. while (ib != 0) {
  40. if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
  41. cache[threadIdx.x] = cache[threadIdx.x + ib];
  42. }
  43. __syncthreads();
  44. ib /= 2;
  45. }
  46. // Finally, since cache[0] contains the maximum for this thread block,
  47. // atomically write the max to the target location
  48. if (threadIdx.x == 0) {
  49. atomicMaxFloat(scale, cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
  50. }
  51. }
  52. template<typename scalar_t>
  53. __global__ void scaled_fp8_quant_kernel(
  54. c10::Float8_e4m3fn* __restrict__ out,
  55. const scalar_t* __restrict__ input,
  56. const float* __restrict__ scale,
  57. int64_t num_elems) {
  58. int i = blockDim.x * blockIdx.x + threadIdx.x;
  59. while (i < num_elems) {
  60. out[i] = static_cast<c10::Float8_e4m3fn>(input[i] / *scale);
  61. i += blockDim.x * gridDim.x;
  62. }
  63. }
  64. } // namespace aphrodite
  65. void scaled_fp8_quant(
  66. torch::Tensor& out, // [..., d]
  67. torch::Tensor& input, // [..., d]
  68. torch::Tensor& scale) // [1]
  69. {
  70. int64_t num_tokens = input.numel() / input.size(-1);
  71. int64_t num_elems = input.numel();
  72. dim3 grid(num_tokens);
  73. dim3 block(1024);
  74. const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
  75. const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  76. APHRODITE_DISPATCH_FLOATING_TYPES(
  77. input.scalar_type(),
  78. "scaled_fp8_quant_kernel",
  79. [&] {
  80. aphrodite::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
  81. scale.data_ptr<float>(),
  82. input.data_ptr<scalar_t>(),
  83. num_elems);
  84. aphrodite::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
  85. out.data_ptr<c10::Float8_e4m3fn>(),
  86. input.data_ptr<scalar_t>(),
  87. scale.data_ptr<float>(),
  88. num_elems);
  89. });
  90. }