Browse Source

fix: use int64_t for indices in fp8 kernels

AlpinDale 7 months ago
parent
commit
31552a81ff
1 changed files with 6 additions and 6 deletions
  1. 6 6
      kernels/quantization/fp8/common.cu

+ 6 - 6
kernels/quantization/fp8/common.cu

@@ -101,11 +101,11 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
   vec4_t<scalar_t> const* vectorized_in =
       reinterpret_cast<vec4_t<scalar_t> const*>(input);
 
-  int const num_vec_elems = num_elems >> 2;
+  int64_t const num_vec_elems = num_elems >> 2;
   float absmax_val = 0.0f;
 
 #pragma unroll 4
-  for (int i = tid; i < num_vec_elems; i += step) {
+  for (int64_t i = tid; i < num_vec_elems; i += step) {
     vec4_t<scalar_t> in_vec = vectorized_in[i];
     absmax_val = max(absmax_val, fabs(in_vec.x));
     absmax_val = max(absmax_val, fabs(in_vec.y));
@@ -114,7 +114,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
   }
 
   // Handle the remaining elements if num_elems is not divisible by 4
-  for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
+  for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
     absmax_val = max(absmax_val, fabs(input[i]));
   }
 
@@ -132,10 +132,10 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
       reinterpret_cast<vec4_t<scalar_t> const*>(input);
   float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
 
-  int const num_vec_elems = num_elems >> 2;
+  int64_t const num_vec_elems = num_elems >> 2;
 
 #pragma unroll 4
-  for (int i = tid; i < num_vec_elems; i += step) {
+  for (int64_t i = tid; i < num_vec_elems; i += step) {
     vec4_t<scalar_t> in_vec = vectorized_in[i];
     float8x4_t out_vec;
 
@@ -151,7 +151,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
   }
 
   // Handle the remaining elements if num_elems is not divisible by 4
-  for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
+  for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
     out[i] = scaled_fp8_conversion<is_scale_inverted>(
         static_cast<float>(input[i]), scale);
   }