|
@@ -101,11 +101,11 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
|
|
vec4_t<scalar_t> const* vectorized_in =
|
|
|
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
|
|
|
|
|
- int const num_vec_elems = num_elems >> 2;
|
|
|
+ int64_t const num_vec_elems = num_elems >> 2;
|
|
|
float absmax_val = 0.0f;
|
|
|
|
|
|
#pragma unroll 4
|
|
|
- for (int i = tid; i < num_vec_elems; i += step) {
|
|
|
+ for (int64_t i = tid; i < num_vec_elems; i += step) {
|
|
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
|
|
absmax_val = max(absmax_val, fabs(in_vec.x));
|
|
|
absmax_val = max(absmax_val, fabs(in_vec.y));
|
|
@@ -114,7 +114,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
|
|
}
|
|
|
|
|
|
// Handle the remaining elements if num_elems is not divisible by 4
|
|
|
- for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
|
|
+ for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
|
|
absmax_val = max(absmax_val, fabs(input[i]));
|
|
|
}
|
|
|
|
|
@@ -132,10 +132,10 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
|
|
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
|
|
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
|
|
|
|
|
- int const num_vec_elems = num_elems >> 2;
|
|
|
+ int64_t const num_vec_elems = num_elems >> 2;
|
|
|
|
|
|
#pragma unroll 4
|
|
|
- for (int i = tid; i < num_vec_elems; i += step) {
|
|
|
+ for (int64_t i = tid; i < num_vec_elems; i += step) {
|
|
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
|
|
float8x4_t out_vec;
|
|
|
|
|
@@ -151,7 +151,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
|
|
}
|
|
|
|
|
|
// Handle the remaining elements if num_elems is not divisible by 4
|
|
|
- for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
|
|
+ for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
|
|
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
|
|
static_cast<float>(input[i]), scale);
|
|
|
}
|