Explorar el Código

include fp8 compilation in rocm

AlpinDale hace 7 meses
padre
commit
8de8034f8b
Se han modificado 1 ficheros con 11 adiciones y 10 borrados
  1. 11 10
      kernels/quantization/quant_ops.h

+ 11 - 10
kernels/quantization/quant_ops.h

@@ -142,16 +142,6 @@ torch::Tensor marlin_gemm(
     int64_t size_n,
     int64_t size_k);
 
-void static_scaled_fp8_quant(
-  torch::Tensor& out,
-  torch::Tensor& input,
-  torch::Tensor& scale);
-
-void dynamic_scaled_fp8_quant(
-  torch::Tensor& out,
-  torch::Tensor& input,
-  torch::Tensor& scale);
-
 // QuIP#
 at::Tensor e8p_mm_origorder(
     const at::Tensor& A,
@@ -188,3 +178,14 @@ void squeezellm_gemm(
   torch::Tensor mat,
   torch::Tensor mul,
   torch::Tensor lookup_table);
+
+// FP8
+void static_scaled_fp8_quant(
+  torch::Tensor& out,
+  torch::Tensor& input,
+  torch::Tensor& scale);
+
+void dynamic_scaled_fp8_quant(
+  torch::Tensor& out,
+  torch::Tensor& input,
+  torch::Tensor& scale);