Kaynağa Gözat

cpu: add support for W8A8 quantization via compressed-tensor (#1017)

AlpinDale 2 ay önce
ebeveyn
işleme
f2b6dc3872

+ 39 - 12
Dockerfile.cpu

@@ -2,15 +2,21 @@
 
 FROM ubuntu:22.04 AS cpu-test-1
 
-RUN apt-get update -y \
-    && apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
+ENV CCACHE_DIR=/root/.cache/ccache
+
+ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache
+
+RUN --mount=type=cache,target=/var/cache/apt \
+    apt-get update -y \
+    && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
     && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
     && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
 
 # https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
 # intel-openmp provides additional performance improvement vs. openmp
 # tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects.
-RUN pip install intel-openmp
+RUN --mount=type=cache,target=/root/.cache/pip \
+    pip install intel-openmp
 
 ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so"
 
@@ -18,26 +24,47 @@ RUN echo 'ulimit -c 0' >> ~/.bashrc
 
 RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
 
-RUN pip install --upgrade pip \
-    && pip install wheel packaging ninja "setuptools>=49.4.0" numpy
+ENV PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu
+RUN --mount=type=cache,target=/root/.cache/pip \
+    --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \
+    pip install --upgrade pip && \
+    pip install -r requirements-build.txt
+
+# install oneDNN
+RUN git clone -b rls-v3.5 https://github.com/oneapi-src/oneDNN.git
+
+RUN --mount=type=cache,target=/root/.cache/ccache \
+    cmake -B ./oneDNN/build -S ./oneDNN -G Ninja -DONEDNN_LIBRARY_TYPE=STATIC \ 
+    -DONEDNN_BUILD_DOC=OFF \ 
+    -DONEDNN_BUILD_EXAMPLES=OFF \ 
+    -DONEDNN_BUILD_TESTS=OFF \ 
+    -DONEDNN_BUILD_GRAPH=OFF \ 
+    -DONEDNN_ENABLE_WORKLOAD=INFERENCE \ 
+    -DONEDNN_ENABLE_PRIMITIVE=MATMUL && \
+    cmake --build ./oneDNN/build --target install --config Release
 
 FROM cpu-test-1 AS build
 
-COPY ./ /workspace/aphrodite-engine
-
 WORKDIR /workspace/aphrodite-engine
 
-RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
+RUN --mount=type=cache,target=/root/.cache/pip \
+    --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \
+    --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \
+    pip install -v -r requirements-cpu.txt
+
+COPY ./ ./
 
 # Support for building with non-AVX512 Aphrodite: docker build --build-arg APHRODITE_CPU_DISABLE_AVX512="true" ...
 ARG APHRODITE_CPU_DISABLE_AVX512
 ENV APHRODITE_CPU_DISABLE_AVX512=${APHRODITE_CPU_DISABLE_AVX512}
 
-RUN APHRODITE_TARGET_DEVICE=cpu python3 setup.py install
-RUN pip install triton
+RUN --mount=type=cache,target=/root/.cache/pip \
+    --mount=type=cache,target=/root/.cache/ccache \
+    APHRODITE_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \
+    pip install dist/*.whl
 
 WORKDIR /workspace/
 
-RUN ln -s /workspace/aphrodite-engine/examples && ln -s /workspace/aphrodite-engine/tests/benchmarks
+RUN ln -s /workspace/aphrodite-engine/tests && ln -s /workspace/aphrodite-engine/examples && ln -s /workspace/aphrodite-engine/benchmarks
 
-ENTRYPOINT ["python3", "-m", "aphrodite.endpoints.openai.api_server"]
+ENTRYPOINT ["python3", "-m", "aphrodite.entrypoints.openai.api_server"]

+ 1 - 1
aphrodite/common/config.py

@@ -1039,7 +1039,7 @@ class ParallelConfig:
             from aphrodite.executor import ray_utils
             backend = "mp"
             ray_found = ray_utils.ray_is_available()
-            if cuda_device_count_stateless() < self.world_size:
+            if not is_cpu() and cuda_device_count_stateless() < self.world_size:
                 if not ray_found:
                     raise ValueError("Unable to load Ray which is "
                                      "required for multi-node inference, "

+ 14 - 1
aphrodite/executor/cpu_executor.py

@@ -6,7 +6,8 @@ import torch
 from loguru import logger
 
 import aphrodite.common.envs as envs
-from aphrodite.common.config import CacheConfig, ModelConfig, SchedulerConfig
+from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
+                                     SchedulerConfig)
 from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (GiB_bytes, get_aphrodite_instance_id,
                                     get_distributed_init_method, get_open_port,
@@ -60,6 +61,8 @@ class CPUExecutor(ExecutorBase):
         self.cache_config = _verify_and_get_cache_config(self.cache_config)
         self.scheduler_config = _verify_and_get_scheduler_config(
             self.scheduler_config)
+        self.parallel_config = _verify_and_get_parallel_config(
+            self.parallel_config)
 
         # Multiprocessing-based executor does not support multi-node setting.
         # Since it only works for single node, we can use the loopback address
@@ -354,6 +357,16 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
     return config
 
 
+def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
+    if (config.distributed_executor_backend is not None
+            and config.distributed_executor_backend != "mp"):
+        logger.warning(
+            f"{config.distributed_executor_backend} is not supported on CPU, "
+            "fallback to mp distributed executor backend.")
+        config.distributed_executor_backend = "mp"
+    return config
+
+
 def _driver_method_invoker(driver, method: str, *args, **kwargs):
     return getattr(driver, method)(*args, **kwargs)
 

+ 2 - 2
aphrodite/modeling/model_loader/loader.py

@@ -91,8 +91,8 @@ def _get_quantization_config(
     """Get the quantization config."""
     if model_config.quantization is not None:
         quant_config = get_quant_config(model_config, load_config)
-        if not current_platform.is_tpu():
-            capability = current_platform.get_device_capability()
+        capability = current_platform.get_device_capability()  # type: ignore
+        if capability is not None:
             capability = capability[0] * 10 + capability[1]
             if capability < quant_config.get_min_capability():
                 raise ValueError(

+ 10 - 0
aphrodite/platforms/__init__.py

@@ -9,6 +9,13 @@ try:
 except ImportError:
     libtpu = None
 
+is_cpu = False
+try:
+    from importlib.metadata import version
+    is_cpu = "cpu" in version("aphrodite-engine")
+except Exception:
+    pass
+
 if libtpu is not None:
     # people might install pytorch built with cuda but run on tpu
     # so we need to check tpu first
@@ -20,6 +27,9 @@ elif torch.version.cuda is not None:
 elif torch.version.hip is not None:
     from .rocm import RocmPlatform
     current_platform = RocmPlatform()
+elif is_cpu:
+    from .cpu import CpuPlatform
+    current_platform = CpuPlatform()
 else:
     current_platform = UnspecifiedPlatform()
 

+ 13 - 0
aphrodite/platforms/cpu.py

@@ -0,0 +1,13 @@
+import torch
+
+from .interface import Platform, PlatformEnum
+
+
+class CpuPlatform(Platform):
+    _enum = PlatformEnum.CPU
+    @staticmethod
+    def get_device_name(device_id: int = 0) -> str:
+        return "cpu"
+    @staticmethod
+    def inference_mode():
+        return torch.no_grad()

+ 7 - 3
aphrodite/platforms/interface.py

@@ -1,5 +1,5 @@
 import enum
-from typing import Tuple
+from typing import Optional, Tuple
 
 import torch
 
@@ -8,6 +8,7 @@ class PlatformEnum(enum.Enum):
     CUDA = enum.auto()
     ROCM = enum.auto()
     TPU = enum.auto()
+    CPU = enum.auto()
     UNSPECIFIED = enum.auto()
 
 
@@ -23,9 +24,12 @@ class Platform:
     def is_tpu(self) -> bool:
         return self._enum == PlatformEnum.TPU
 
+    def is_cpu(self) -> bool:
+        return self._enum == PlatformEnum.CPU
+
     @staticmethod
-    def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
-        raise NotImplementedError
+    def get_device_capability(device_id: int = 0) -> Optional[Tuple[int, int]]:
+        return None
 
     @staticmethod
     def get_device_name(device_id: int = 0) -> str:

+ 0 - 6
aphrodite/platforms/tpu.py

@@ -1,5 +1,3 @@
-from typing import Tuple
-
 import torch
 
 from .interface import Platform, PlatformEnum
@@ -8,10 +6,6 @@ from .interface import Platform, PlatformEnum
 class TpuPlatform(Platform):
     _enum = PlatformEnum.TPU
 
-    @staticmethod
-    def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
-        raise RuntimeError("TPU does not have device capability.")
-
     @staticmethod
     def inference_mode():
         return torch.no_grad()

+ 12 - 9
aphrodite/quantization/compressed_tensors/compressed_tensors.py

@@ -117,15 +117,18 @@ class CompressedTensorsConfig(QuantizationConfig):
     def _check_scheme_supported(self,
                                 min_capability: int,
                                 error: bool = True) -> bool:
-        capability = current_platform.get_device_capability()
-        capability = capability[0] * 10 + capability[1]
-        supported = capability >= min_capability
-        if error and not supported:
-            raise RuntimeError(
-                "Quantization scheme is not supported for ",
-                f"the current GPU. Min capability: {min_capability}. ",
-                f"Current capability: {capability}.")
-        return supported
+        capability = current_platform.get_device_capability()  # type: ignore
+        if capability is not None:
+            capability = capability[0] * 10 + capability[1]
+            supported = capability >= min_capability
+            if error and not supported:
+                raise RuntimeError(
+                    "Quantization scheme is not supported for ",
+                    f"the current GPU. Min capability: {min_capability}. ",
+                    f"Current capability: {capability}.")
+            return supported
+        else:
+            return False
 
     def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
                                input_quant: BaseModel) -> bool:

+ 5 - 3
aphrodite/worker/cpu_worker.py

@@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple
 
 import torch
 import torch.distributed
+from loguru import logger
 
 import aphrodite.common.envs as envs
 from aphrodite.attention import get_attn_backend
@@ -16,8 +17,8 @@ from aphrodite.distributed import (ensure_model_parallel_initialized,
 from aphrodite.modeling import set_random_seed
 from aphrodite.worker.cpu_model_runner import CPUModelRunner
 from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
-                                                LoraNotSupportedWorkerBase,
-                                                WorkerInput)
+                                          LoraNotSupportedWorkerBase,
+                                          WorkerInput)
 
 APHRODITE_CPU_OMP_THREADS_BIND = envs.APHRODITE_CPU_OMP_THREADS_BIND
 
@@ -180,7 +181,8 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
 
     def init_device(self) -> None:
         if self.local_omp_cpuid != "all":
-            torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
+            ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
+            logger.info(ret)
         self.init_distributed_environment()
         # Set random seed.
         set_random_seed(self.model_config.seed)

+ 11 - 6
cmake/cpu_extension.cmake

@@ -1,4 +1,5 @@
 set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+set(CMAKE_CXX_STANDARD 17)
 
 #
 # Define environment variables for special configurations
@@ -83,12 +84,7 @@ endif()
 
 message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
 
-list(APPEND LIBS "numa")
-
-
-#
-# Define extension targets
-#
+list(APPEND LIBS dnnl numa)
 
 #
 # _C extension
@@ -102,6 +98,15 @@ set(APHRODITE_EXT_SRC
     "kernels/cpu/pos_encoding.cpp"
     "kernels/cpu/torch_bindings.cpp")
 
+if (AVX512_FOUND AND NOT AVX512_DISABLED)
+    set(APHRODITE_EXT_SRC
+        "kernels/cpu/quant.cpp"
+        ${APHRODITE_EXT_SRC})
+endif()
+#
+# Define extension targets
+#
+
 define_gpu_extension_target(
     _C
     DESTINATION aphrodite

+ 51 - 2
kernels/cpu/cpu_types_x86.hpp

@@ -24,8 +24,8 @@ namespace vec_op {
 #define CPU_KERNEL_GUARD_OUT(NAME)
 #else
 #define CPU_KERNEL_GUARD_IN(NAME)                                              \
-  std::cout << #NAME << " invoked." << std::endl;
-#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
+  RECORD_FUNCTION(#NAME, c10::ArrayRef<c10::IValue>({}));
+#define CPU_KERNEL_GUARD_OUT(NAME)
 #endif
 
 #define FORCE_INLINE __attribute__((always_inline)) inline
@@ -106,6 +106,11 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
   explicit BF16Vec16(const FP32Vec16 &);
 
   void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
+  void save(void* ptr, const int elem_num) const {
+    constexpr uint32_t M = 0xFFFFFFFF;
+    __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
+    _mm256_mask_storeu_epi16(ptr, mask, reg);
+  }
 };
 
 #ifdef __AVX512F__
@@ -313,8 +318,25 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
     return FP32Vec16(_mm512_div_ps(reg, b.reg));
   }
 
+  FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
+    return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg)));
+  }
+  FP32Vec16 max(const FP32Vec16& b) const {
+    return FP32Vec16(_mm512_max_ps(reg, b.reg));
+  }
+  FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
+    constexpr uint32_t M = 0xFFFFFFFF;
+    __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
+    return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg));
+  }
+  FP32Vec16 abs() const {
+    return FP32Vec16(_mm512_abs_ps(reg));
+  } 
+
   float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
 
+  float reduce_max() const { return _mm512_reduce_max_ps(reg); }
+
   template <int group_size> float reduce_sub_sum(int idx) {
     static_assert(VEC_ELEM_NUM % group_size == 0);
     constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
@@ -323,6 +345,11 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
   }
 
   void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
+  void save(float* ptr, const int elem_num) const {
+    constexpr uint32_t M = 0xFFFFFFFF;
+    __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
+    _mm512_mask_storeu_ps(ptr, mask, reg);
+  }
 };
 #else
 struct FP32Vec16 : public Vec<FP32Vec16> {
@@ -433,6 +460,28 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
 };
 #endif
 
+#ifdef __AVX512F__
+struct INT8Vec16: public Vec<INT8Vec16> {
+  constexpr static int VEC_ELEM_NUM = 16;
+  union AliasReg {
+    __m128i reg;
+    int8_t values[VEC_ELEM_NUM];
+  };
+  __m128i reg;
+  
+  explicit INT8Vec16(const FP32Vec16& vec) : reg(
+    _mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
+  ) {}
+  void save(int8_t* ptr) const {
+    _mm_storeu_epi8(ptr, reg);
+  }
+  void save(int8_t* ptr, const int elem_num) const {
+    constexpr uint32_t M = 0xFFFFFFFF;
+    __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
+    _mm_mask_storeu_epi8(ptr, mask, reg);
+  }
+};
+#endif
 template <typename T> struct VecType { using vec_type = void; };
 
 template <typename T> using vec_t = typename VecType<T>::vec_type;

+ 149 - 0
kernels/cpu/dnnl_helper.hpp

@@ -0,0 +1,149 @@
+#ifndef DNNL_HELPER_HPP
+#define DNNL_HELPER_HPP
+#include <c10/util/BFloat16.h>
+#include "oneapi/dnnl/dnnl.hpp"
+namespace {
+template <typename T>
+struct DNNLType {
+  static constexpr dnnl::memory::data_type type =
+      dnnl::memory::data_type::undef;
+};
+template <>
+struct DNNLType<int8_t> {
+  static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8;
+};
+template <>
+struct DNNLType<int32_t> {
+  static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32;
+};
+template <>
+struct DNNLType<float> {
+  static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32;
+};
+template <>
+struct DNNLType<c10::BFloat16> {
+  static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16;
+};
+template <typename T>
+constexpr inline dnnl::memory::data_type get_dnnl_type() {
+  return DNNLType<std::decay_t<T>>::type;
+}
+};  // namespace
+template <bool InputNoScale>
+class DNNLPrimitiveHelper {
+ public:
+  // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias)
+  // A: [M, K], row-major
+  // B: [K, N], column-major
+  // C: [M, N], row-major
+  // bias: [N], row-major, optional
+  // a_scales: [MS]
+  // b_scales: [NS]
+  // Note: Due to the limitation of oneDNN
+  // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
+  // not supported.
+  template <typename OutputT, typename BiasT>
+  static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
+                            const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
+                            dnnl_dim_t K, const float* a_scales,
+                            const float* b_scales, dnnl_dim_t MS,
+                            dnnl_dim_t NS) {
+    auto&& OutputType = get_dnnl_type<OutputT>();
+    auto&& BiasType = get_dnnl_type<BiasT>();
+    dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1});
+    dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K});
+    dnnl::memory::desc c_md({M, N}, OutputType, {N, 1});
+    dnnl::primitive_attr attr;
+    if constexpr (!InputNoScale) {
+      if (MS == 1) {
+        // per-tensor
+        attr.set_scales_mask(DNNL_ARG_SRC, 0);
+      } else {
+        // per-token
+        TORCH_CHECK(false, "per-token quantization is unsupported.");
+      }
+    }
+    if (NS == 1) {
+      // per-tensor
+      attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
+    } else {
+      // per-channel
+      attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
+    }
+    dnnl::matmul::primitive_desc matmul_pd;
+    if (bias) {
+      dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
+      matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
+                                               bias_md, c_md, attr);
+    } else {
+      matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
+                                               c_md, attr);
+    }
+    dnnl::matmul matmul(matmul_pd);
+    auto& engine = default_engine();
+    dnnl::memory a_m(a_md, engine, (void*)a);
+    dnnl::memory b_m(b_md, engine, (void*)b);
+    dnnl::memory c_m(c_md, engine, (void*)c);
+    dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine,
+                            (void*)a_scales);
+    dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine,
+                            (void*)b_scales);
+    auto& stream = default_stream();
+    if constexpr (InputNoScale) {
+      if (bias) {
+        dnnl::memory::desc bias_md({N}, BiasType, {1});
+        dnnl::memory bias_m(bias_md, engine, (void*)bias);
+        matmul.execute(
+            stream, {
+                        {DNNL_ARG_SRC, a_m},
+                        {DNNL_ARG_WEIGHTS, b_m},
+                        {DNNL_ARG_BIAS, bias_m},
+                        {DNNL_ARG_DST, c_m},
+                        {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
+                    });
+      } else {
+        matmul.execute(
+            stream, {
+                        {DNNL_ARG_SRC, a_m},
+                        {DNNL_ARG_WEIGHTS, b_m},
+                        {DNNL_ARG_DST, c_m},
+                        {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
+                    });
+      }
+    } else {
+      if (bias) {
+        dnnl::memory::desc bias_md({N}, BiasType, {1});
+        dnnl::memory bias_m(bias_md, engine, (void*)bias);
+        matmul.execute(
+            stream, {
+                        {DNNL_ARG_SRC, a_m},
+                        {DNNL_ARG_WEIGHTS, b_m},
+                        {DNNL_ARG_BIAS, bias_m},
+                        {DNNL_ARG_DST, c_m},
+                        {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
+                        {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
+                    });
+      } else {
+        matmul.execute(
+            stream, {
+                        {DNNL_ARG_SRC, a_m},
+                        {DNNL_ARG_WEIGHTS, b_m},
+                        {DNNL_ARG_DST, c_m},
+                        {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
+                        {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
+                    });
+      }
+    }
+    stream.wait();
+  }
+ private:
+  static dnnl::engine& default_engine() {
+    static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
+    return engine;
+  }
+  static dnnl::stream& default_stream() {
+    static dnnl::stream stream(default_engine());
+    return stream;
+  }
+};
+#endif

+ 260 - 0
kernels/cpu/quant.cpp

@@ -0,0 +1,260 @@
+#include "cpu_types.hpp"
+#include "dnnl_helper.hpp"
+namespace {
+template <typename scalar_t>
+struct KernelVecType {
+  using load_vec_type = void;
+  using cvt_vec_type = void;
+};
+template <>
+struct KernelVecType<float> {
+  using load_vec_type = vec_op::FP32Vec16;
+  using cvt_vec_type = vec_op::FP32Vec16;
+};
+template <>
+struct KernelVecType<c10::BFloat16> {
+  using load_vec_type = vec_op::BF16Vec16;
+  using cvt_vec_type = vec_op::FP32Vec16;
+};
+#ifdef __AVX512F__
+template <typename scalar_t>
+void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+                                   const float* scale, const int num_tokens,
+                                   const int hidden_size) {
+  using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
+  using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
+  constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+  constexpr float i8_min =
+      static_cast<float>(std::numeric_limits<int8_t>::min());
+  constexpr float i8_max =
+      static_cast<float>(std::numeric_limits<int8_t>::max());
+  const cvt_vec_t inv_scale(1.0 / *scale);
+  const cvt_vec_t i8_min_vec(i8_min);
+  const cvt_vec_t i8_max_vec(i8_max);
+  #pragma omp parallel for
+  for (int i = 0; i < num_tokens; ++i) {
+    int j = 0;
+    for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+      load_vec_t elems(input + i * hidden_size + j);
+      cvt_vec_t elems_fp32(elems);
+      elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
+      vec_op::INT8Vec16 elems_int8(elems_fp32);
+      elems_int8.save(output + i * hidden_size + j);
+    }
+    load_vec_t elems(input + i * hidden_size + j);
+    cvt_vec_t elems_fp32(elems);
+    elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
+    vec_op::INT8Vec16 elems_int8(elems_fp32);
+    if (j + vec_elem_num == hidden_size) {
+      elems_int8.save(output + i * hidden_size + j);
+    } else {
+      elems_int8.save(output + i * hidden_size + j, hidden_size - j);
+    }
+  }
+}
+template <typename scalar_t>
+void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+                                    float* scale, const int num_tokens,
+                                    const int hidden_size) {
+  using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
+  using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
+  constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+  #pragma omp parallel for
+  for (int i = 0; i < num_tokens; ++i) {
+    cvt_vec_t max_abs(0.0);
+    {
+      int j = 0;
+      for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+        load_vec_t elems(input + i * hidden_size + j);
+        cvt_vec_t elems_fp32(elems);
+        max_abs = max_abs.max(elems_fp32.abs());
+      }
+      load_vec_t elems(input + i * hidden_size + j);
+      cvt_vec_t elems_fp32(elems);
+      if (j + vec_elem_num == hidden_size) {
+        max_abs = max_abs.max(elems_fp32.abs());
+      } else {
+        max_abs = max_abs.max(elems_fp32.abs(), hidden_size - j);
+      }
+    }
+    float scale_val = max_abs.reduce_max() / 127.0f;
+    scale[i] = scale_val;
+    const cvt_vec_t inv_scale(1.0 / scale_val);
+    {
+      int j = 0;
+      for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+        load_vec_t elems(input + i * hidden_size + j);
+        cvt_vec_t elems_fp32(elems);
+        elems_fp32 = (elems_fp32 * inv_scale);
+        vec_op::INT8Vec16 elems_int8(elems_fp32);
+        elems_int8.save(output + i * hidden_size + j);
+      }
+      load_vec_t elems(input + i * hidden_size + j);
+      cvt_vec_t elems_fp32(elems);
+      elems_fp32 = (elems_fp32 * inv_scale);
+      vec_op::INT8Vec16 elems_int8(elems_fp32);
+      if (j + vec_elem_num == hidden_size) {
+        elems_int8.save(output + i * hidden_size + j);
+      } else {
+        elems_int8.save(output + i * hidden_size + j, hidden_size - j);
+      }
+    }
+  }
+}
+template <bool Bias, typename scalar_t>
+void dynamic_output_scale_impl(const float* input, scalar_t* output,
+                               const float* scale, const scalar_t* bias,
+                               const int num_tokens, const int hidden_size) {
+  CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
+  using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
+  using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
+  constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
+  #pragma omp parallel for
+  for (int i = 0; i < num_tokens; ++i) {
+    int j = 0;
+    cvt_vec_t token_scale_vec(scale[i]);
+    for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
+      cvt_vec_t elems_fp32(input + i * hidden_size + j);
+      elems_fp32 = elems_fp32 * token_scale_vec;
+      if constexpr (Bias) {
+        load_vec_t bias_vec(bias + j);
+        cvt_vec_t bias_vec_fp32(bias_vec);
+        elems_fp32 = elems_fp32 + bias_vec_fp32;
+      }
+      load_vec_t elems_out(elems_fp32);
+      elems_out.save(output + i * hidden_size + j);
+    }
+    cvt_vec_t elems_fp32(input + i * hidden_size + j);
+    elems_fp32 = elems_fp32 * token_scale_vec;
+    if constexpr (Bias) {
+      load_vec_t bias_vec(bias + j);
+      cvt_vec_t bias_vec_fp32(bias_vec);
+      elems_fp32 = elems_fp32 + bias_vec_fp32;
+    }
+    load_vec_t elems_out(elems_fp32);
+    if (j + vec_elem_num == hidden_size) {
+      elems_out.save(output + i * hidden_size + j);
+    } else {
+      elems_out.save(output + i * hidden_size + j, hidden_size - j);
+    }
+  }
+}
+#else
+template <typename scalar_t>
+void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+                                   const float* scale, const int num_tokens,
+                                   const int hidden_size) {
+  TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.")
+}
+template <typename scalar_t>
+void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
+                                    float* scale, const int num_tokens,
+                                    const int hidden_size) {
+  TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.")
+}
+template <typename scalar_t>
+void dynamic_output_scale_impl() {
+  TORCH_CHECK(false, "dynamic_output_scale_impl requires AVX512 support.")
+}
+#endif
+}  // namespace
+void int8_scaled_mm(torch::Tensor& c,               // [M, OC], row-major
+                    const torch::Tensor& a,         // [M, IC], row-major
+                    const torch::Tensor& b,         // [IC, OC], column-major
+                    const torch::Tensor& a_scales,  // [1] or [M]
+                    const torch::Tensor& b_scales,  // [1] or [OC]
+                    const c10::optional<torch::Tensor>& bias  // [OC]
+) {
+  CPU_KERNEL_GUARD_IN(cutlass_scaled_mm)
+  // Checks for conformality
+  TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
+              "int8_scaled_mm only supports INT8 inputs.")
+  TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
+  TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
+              b.size(1) == c.size(1));
+  TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
+  TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
+  // Check for strides and alignment
+  TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1);  // Row-major
+  TORCH_CHECK(b.stride(0) == 1);                      // Column-major
+  TORCH_CHECK(c.stride(0) % 16 == 0 &&
+              b.stride(1) % 16 == 0);  // 16 Byte Alignment
+  TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
+  if (bias) {
+    TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
+                bias->dim() == 1);
+  }
+  APHRODITE_DISPATCH_FLOATING_TYPES(c.scalar_type(), "cutlass_scaled_mm", [&] {
+    if (a_scales.numel() != 1) {
+      // per-token
+      // Note: oneDNN doesn't support per-token activation quantization
+      torch::Tensor tmp_fp32_out =
+          torch::empty_like(c, ::at::ScalarType::Float);
+      DNNLPrimitiveHelper<true>::gemm_s8s8_jit(
+          a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
+          tmp_fp32_out.data_ptr<float>(), (void*)(0), a.size(0), b.size(1),
+          a.size(1), (float*)(0), b_scales.data_ptr<float>(), 0,
+          b_scales.numel());
+      if (bias.has_value()) {
+        dynamic_output_scale_impl<true>(
+            tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
+            a_scales.data_ptr<float>(), bias->data_ptr<scalar_t>(), c.size(0),
+            c.size(1));
+      } else {
+        dynamic_output_scale_impl<false>(
+            tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
+            a_scales.data_ptr<float>(), (scalar_t*)(0), c.size(0), c.size(1));
+      }
+    } else {
+      // per-tensor
+      if (bias.has_value()) {
+        DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
+            a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
+            bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
+            a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
+            a_scales.numel(), b_scales.numel());
+      } else {
+        DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
+            a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
+            (void*)(0), a.size(0), b.size(1), a.size(1),
+            a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
+            a_scales.numel(), b_scales.numel());
+      }
+    }
+  });
+}
+// static-per-tensor quantization.
+void static_scaled_int8_quant(torch::Tensor& out,          // [..., hidden_size]
+                              const torch::Tensor& input,  // [..., hidden_size]
+                              const torch::Tensor& scale) {
+  CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
+  TORCH_CHECK(input.is_contiguous());
+  TORCH_CHECK(out.is_contiguous());
+  TORCH_CHECK(scale.numel() == 1);
+  const int hidden_size = input.size(-1);
+  const int num_tokens = input.numel() / hidden_size;
+  APHRODITE_DISPATCH_FLOATING_TYPES(
+      input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
+        static_scaled_int8_quant_impl(
+            input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
+            scale.data_ptr<float>(), num_tokens, hidden_size);
+      });
+}
+// dynamic-per-token quantization.
+void dynamic_scaled_int8_quant(
+    torch::Tensor& out,          // [..., hidden_size]
+    const torch::Tensor& input,  // [..., hidden_size]
+    torch::Tensor& scale         // [..., 1]
+) {
+  CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
+  TORCH_CHECK(input.is_contiguous());
+  TORCH_CHECK(out.is_contiguous());
+  int const hidden_size = input.size(-1);
+  int const num_tokens = input.numel() / hidden_size;
+  APHRODITE_DISPATCH_FLOATING_TYPES(
+      input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
+        dynamic_scaled_int8_quant_impl(
+            input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
+            scale.data_ptr<float>(), num_tokens, hidden_size);
+      });
+}

+ 27 - 3
kernels/cpu/torch_bindings.cpp

@@ -3,8 +3,11 @@
 #include "core/registration.h"
 
 #include <torch/library.h>
-
-void init_cpu_threads_env(const std::string& cpu_ids);
+std::string init_cpu_threads_env(const std::string& cpu_ids);
+void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
+                    const torch::Tensor& b, const torch::Tensor& a_scales,
+                    const torch::Tensor& b_scales,
+                    const c10::optional<torch::Tensor>& bias);
 
 TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   // Aphrodite custom ops
@@ -84,6 +87,27 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "                 Tensor! key, int head_size,"
       "                 Tensor cos_sin_cache, bool is_neox) -> ()");
   ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
+  // Quantization
+#ifdef __AVX512F__
+  // Compute int8 quantized tensor for given scaling factor.
+  ops.def(
+      "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
+      "()");
+  ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
+  // Compute int8 quantized tensor and scaling factor
+  ops.def(
+      "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
+      "()");
+  ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
+           &dynamic_scaled_int8_quant);
+  // W8A8 GEMM, supporting symmetric per-tensor or per-row/column
+  // quantization.
+  ops.def(
+      "cutlass_scaled_mm(Tensor! out, Tensor a,"
+      "                  Tensor b, Tensor a_scales,"
+      "                  Tensor b_scales, Tensor? bias) -> ()");
+  ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
+#endif
 }
 
 TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
@@ -111,7 +135,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
 
 TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
   // CPU utils
-  utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
+  utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
 }
 
 REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

+ 26 - 7
kernels/cpu/utils.cpp

@@ -5,7 +5,7 @@
 
 #include "cpu_types.hpp"
 
-void init_cpu_threads_env(const std::string& cpu_ids) {
+std::string init_cpu_threads_env(const std::string& cpu_ids) {
   bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
   TORCH_CHECK(omp_cpu_mask->size > 0);
   std::vector<int> omp_cpu_ids;
@@ -51,15 +51,34 @@ void init_cpu_threads_env(const std::string& cpu_ids) {
   torch::set_num_threads((int)omp_cpu_ids.size());
   TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
   TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
+  std::vector<std::pair<int, int>> thread_core_mapping;
+  thread_core_mapping.reserve(omp_cpu_ids.size());
+  omp_lock_t writelock;
+  omp_init_lock(&writelock);
 #pragma omp parallel for schedule(static, 1)
   for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
-    cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
-    size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
-    CPU_ZERO_S(size, mask);
-    CPU_SET_S(omp_cpu_ids[i], size, mask);
-    sched_setaffinity(0, sizeof(cpu_set_t), mask);
-    CPU_FREE(mask);
+    cpu_set_t mask;
+    CPU_ZERO(&mask);
+    CPU_SET(omp_cpu_ids[i], &mask);
+    int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask);
+    if (ret == -1) {
+      TORCH_CHECK(false,
+                  "sched_setaffinity failed. errno: " + std::to_string(errno));
+    }
+    omp_set_lock(&writelock);
+    thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]);
+    omp_unset_lock(&writelock);
   }
 
+  omp_destroy_lock(&writelock);
   numa_free_nodemask(omp_cpu_mask);
+  std::stringstream ss;
+  ss << "OMP threads binding of Process " << getpid() << ":\n";
+  std::sort(thread_core_mapping.begin(), thread_core_mapping.end(),
+            [](auto&& a, auto&& b) { return a.second < b.second; });
+  for (auto&& item : thread_core_mapping) {
+    ss << "\t"
+       << "OMP tid: " << item.first << ", core " << item.second << "\n";
+  }
+  return ss.str();
 }