#include #include #include #include #include "cuda_compat.h" #include "dispatch_utils.h" namespace aphrodite { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { float old; old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : __uint_as_float( atomicMin((unsigned int*)addr, __float_as_uint(value))); return old; } #define FP8_E4M3_MAX std::numeric_limits::max() template __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( const scalar_t val, const float inverted_scale) { float x = static_cast(val) * inverted_scale; float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); return static_cast(r); } // Compute the absolute maximum m of the input tensor and store // m / float8_e4m3::max() in *scale. Each thread block performs a // reduction tree and the memory in scale is atomically updated. // So to get the right answer, *scale needs to be initialized to // a value <= 0.0 and we need to wait for all thread blocks to // finish before consuming *scale. template __global__ void segmented_max_reduction(float* __restrict__ scale, const scalar_t* __restrict__ input, int64_t num_elems) { __shared__ float cache[1024]; int i = blockDim.x * blockIdx.x + threadIdx.x; // First store maximum for all values processes by // the current thread in cache[threadIdx.x] scalar_t tmp = 0.0; while (i < num_elems) { float x = static_cast(input[i]); tmp = max(tmp, fabs(x)); i += blockDim.x * gridDim.x; } cache[threadIdx.x] = tmp; __syncthreads(); // Now perform parallel reduction within the thread block int ib = blockDim.x / 2; while (ib != 0) { if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) { cache[threadIdx.x] = cache[threadIdx.x + ib]; } __syncthreads(); ib /= 2; } // Finally, since cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (threadIdx.x == 0) { atomicMaxFloat(scale, cache[0] / std::numeric_limits::max()); } } template struct __align__(8) vec4_t { scalar_t x; scalar_t y; scalar_t z; scalar_t w; }; typedef struct __align__(4) { c10::Float8_e4m3fn x; c10::Float8_e4m3fn y; c10::Float8_e4m3fn z; c10::Float8_e4m3fn w; } float8x4_t; template __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, const scalar_t* __restrict__ input, const float* __restrict__ scale, int64_t num_elems) { int tid = blockDim.x * blockIdx.x + threadIdx.x; // Invert the scale so that we can use multiplications to avoid expensive // division. const float inverted_scale = 1.0f / (*scale); // Vectorized input/output to better utilize memory bandwidth. const vec4_t* vectorized_in = reinterpret_cast*>(input); float8x4_t* vectorized_out = reinterpret_cast(out); int num_vec_elems = num_elems >> 2; #pragma unroll 4 for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) { vec4_t in_vec = vectorized_in[i]; float8x4_t out_vec; out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale); out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale); out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale); out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale); vectorized_out[i] = out_vec; } // Handle the remaining elements if num_elems is not divisible by 4 for (int i = num_vec_elems * 4 + tid; i < num_elems; i += blockDim.x * gridDim.x) { out[i] = scaled_fp8_conversion(input[i], inverted_scale); } } } // namespace aphrodite void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); dim3 grid(num_tokens); dim3 block(1024); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); APHRODITE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel", [&] { aphrodite::scaled_fp8_quant_kernel <<>>(out.data_ptr(), input.data_ptr(), scale.data_ptr(), num_elems); }); } void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] torch::Tensor& scale) // [1] { int64_t num_tokens = input.numel() / input.size(-1); int64_t num_elems = input.numel(); dim3 grid(num_tokens); dim3 block(1024); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); APHRODITE_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel", [&] { aphrodite::segmented_max_reduction <<>>(scale.data_ptr(), input.data_ptr(), num_elems); aphrodite::scaled_fp8_quant_kernel <<>>(out.data_ptr(), input.data_ptr(), scale.data_ptr(), num_elems); }); }