Browse Source

include fp8 compilation in rocm

AlpinDale 7 tháng trước cách đây
mục cha
commit
8de8034f8b
1 tập tin đã thay đổi với 11 bổ sung10 xóa
  1. 11 10
      kernels/quantization/quant_ops.h

+ 11 - 10
kernels/quantization/quant_ops.h

@@ -142,16 +142,6 @@ torch::Tensor marlin_gemm(
     int64_t size_n,
     int64_t size_k);
 
-void static_scaled_fp8_quant(
-  torch::Tensor& out,
-  torch::Tensor& input,
-  torch::Tensor& scale);
-
-void dynamic_scaled_fp8_quant(
-  torch::Tensor& out,
-  torch::Tensor& input,
-  torch::Tensor& scale);
-
 // QuIP#
 at::Tensor e8p_mm_origorder(
     const at::Tensor& A,
@@ -188,3 +178,14 @@ void squeezellm_gemm(
   torch::Tensor mat,
   torch::Tensor mul,
   torch::Tensor lookup_table);
+
+// FP8
+void static_scaled_fp8_quant(
+  torch::Tensor& out,
+  torch::Tensor& input,
+  torch::Tensor& scale);
+
+void dynamic_scaled_fp8_quant(
+  torch::Tensor& out,
+  torch::Tensor& input,
+  torch::Tensor& scale);