Browse Source

feat: inference support for PowerPC ISA

Co-authored-by: Chip Kerchner <49959681+ChipKerchner@users.noreply.github.com>
AlpinDale 7 months ago
parent
commit
271a680026

+ 22 - 0
Dockerfile.ppc64le

@@ -0,0 +1,22 @@
+FROM mambaorg/micromamba
+ARG MAMBA_DOCKERFILE_ACTIVATE=1
+USER root
+
+RUN apt-get update  -y     && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev     && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
+
+# Some packages in requirements-cpu are installed here
+# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
+# Currently these may not be available for venv or pip directly
+RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults     python=3.10     pytorch-cpu=2.1.2     torchvision-cpu=0.16.2    &&     micromamba clean --all --yes
+
+COPY ./ /workspace/aphrodite-engine
+
+WORKDIR /workspace/aphrodite-engine
+
+# These packages will be in rocketce eventually
+RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing
+
+RUN APHRODITE_TARGET_DEVICE=cpu python3 setup.py install
+
+WORKDIR /aphrodite-workspace
+ENTRYPOINT ["/opt/conda/bin/python3", "-m", "aphrodite.endpoints.openai.api_server"]

+ 10 - 1
cmake/cpu_extension.cmake

@@ -46,6 +46,8 @@ is_avx512_disabled(AVX512_DISABLED)
 
 find_isa(${CPUINFO} "avx2" AVX2_FOUND)
 find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
+find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
+find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
 
 if (AVX512_FOUND AND NOT AVX512_DISABLED)
     list(APPEND CXX_COMPILE_FLAGS
@@ -68,8 +70,15 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
 elseif (AVX2_FOUND)
     list(APPEND CXX_COMPILE_FLAGS "-mavx2")
     message(WARNING "Aphrodite CPU backend using AVX2 ISA")
+elseif (POWER9_FOUND OR POWER10_FOUND)
+    message(STATUS "PowerPC detected")
+    # Check for PowerPC VSX support
+    list(APPEND CXX_COMPILE_FLAGS
+        "-mvsx"
+        "-mcpu=native"
+        "-mtune=native")
 else()
-    message(FATAL_ERROR "Aphrodite CPU backend requires AVX512 or AVX2 ISA support.")
+    message(FATAL_ERROR "Aphrodite CPU backend requires AVX512 or AVX2 or Power9+ ISA support.")
 endif()
 
 message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")

+ 8 - 507
kernels/cpu/cpu_types.hpp

@@ -1,514 +1,15 @@
+
 #ifndef CPU_TYPES_HPP
 #define CPU_TYPES_HPP
 
-#include <immintrin.h>
-#include <torch/all.h>
-
-#ifndef __AVX2__
-static_assert(false, "AVX2 must be supported for the current implementation.");
-#endif
-
-namespace vec_op {
-
-// FIXME: FP16 is not fully supported in Torch-CPU
-#define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...)                                 \
-  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)                         \
-  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
-
-#define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)                          \
-  AT_DISPATCH_SWITCH(TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
-
-#ifndef CPU_OP_GUARD
-#define CPU_KERNEL_GUARD_IN(NAME)
-#define CPU_KERNEL_GUARD_OUT(NAME)
+#if defined(__x86_64__)
+  //x86 implementation
+  #include "cpu_types_x86.hpp"
+#elif defined(__POWER9_VECTOR__)
+  //ppc implementation
+  #include "cpu_types_vsx.hpp"
 #else
-#define CPU_KERNEL_GUARD_IN(NAME)                                              \
-  std::cout << #NAME << " invoked." << std::endl;
-#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
-#endif
-
-#define FORCE_INLINE __attribute__((always_inline)) inline
-
-namespace {
-template <typename T, T... indexes, typename F>
-constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
-  (f(std::integral_constant<T, indexes>{}), ...);
-}
-}; // namespace
-
-template <typename T, T count, typename F,
-          typename = std::enable_if_t<std::is_invocable_v<F, T>>>
-constexpr void unroll_loop(F &&f) {
-  unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
-}
-
-template <typename T> struct Vec {
-  constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
-};
-
-struct FP32Vec8;
-struct FP32Vec16;
-
-#ifdef __AVX512FP16__
-struct FP16Vec8 : public Vec<FP16Vec8> {
-  constexpr static int VEC_ELEM_NUM = 8;
-
-  __m128h reg;
-
-  explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
-
-  explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
-
-  explicit FP16Vec8(__m128h data) : reg(data) {}
-
-  FP16Vec8 operator*(const FP16Vec8 &b) const {
-    return FP16Vec8(_mm_mul_ph(reg, b.reg));
-  }
-
-  FP16Vec8 operator+(const FP16Vec8 &b) const {
-    return FP16Vec8(_mm_add_ph(reg, b.reg));
-  }
-
-  FP16Vec8 operator-(const FP16Vec8 &b) const {
-    return FP16Vec8(_mm_sub_ph(reg, b.reg));
-  }
-
-  FP16Vec8 operator/(const FP16Vec8 &b) const {
-    return FP16Vec8(_mm_div_ph(reg, b.reg));
-  }
-
-  void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
-};
+  #warning "unsupported Aphrodite cpu implementation"
 #endif
 
-struct BF16Vec8 : public Vec<BF16Vec8> {
-  constexpr static int VEC_ELEM_NUM = 8;
-
-  __m128i reg;
-
-  explicit BF16Vec8(const void *ptr)
-      : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
-
-  explicit BF16Vec8(const FP32Vec8 &);
-
-  void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
-};
-
-struct BF16Vec16 : public Vec<BF16Vec16> {
-  constexpr static int VEC_ELEM_NUM = 16;
-
-  __m256i reg;
-
-  explicit BF16Vec16(const void *ptr)
-      : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
-
-  explicit BF16Vec16(const FP32Vec16 &);
-
-  void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
-};
-
-#ifdef __AVX512F__
-struct BF16Vec32 : public Vec<BF16Vec32> {
-  constexpr static int VEC_ELEM_NUM = 32;
-
-  __m512i reg;
-
-  explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
-
-  explicit BF16Vec32(__m512i data) : reg(data) {}
-
-  explicit BF16Vec32(BF16Vec8 &vec8_data)
-      : reg((__m512i)_mm512_inserti32x4(
-            _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
-                                                      (__m128i)vec8_data.reg),
-                                                  (__m128i)vec8_data.reg, 1),
-                               (__m128i)vec8_data.reg, 2),
-            (__m128i)vec8_data.reg, 3)) {}
-
-  void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
-};
-#else
-struct BF16Vec32 : public Vec<BF16Vec32> {
-  constexpr static int VEC_ELEM_NUM = 32;
-
-  __m256i reg_low;
-  __m256i reg_high;
-
-  explicit BF16Vec32(const void *ptr)
-      : reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
-        reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
-
-  explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
-                                                  reg_high(high) {}
-
-  explicit BF16Vec32(BF16Vec8 &vec8_data)
-      : reg_low((__m256i)_mm256_inserti32x4(
-                _mm256_castsi128_si256((__m128i)vec8_data.reg),
-                                       (__m128i)vec8_data.reg, 1)),
-        reg_high((__m256i)_mm256_inserti32x4(
-                _mm256_castsi128_si256((__m128i)vec8_data.reg),
-                                       (__m128i)vec8_data.reg, 1)) {}
-
-  void save(void *ptr) const {
-    *reinterpret_cast<__m256i *>(ptr) = reg_low;
-    *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
-  }
-};
-#endif
-
-struct FP32Vec4 : public Vec<FP32Vec4> {
-  constexpr static int VEC_ELEM_NUM = 4;
-  union AliasReg {
-    __m128 reg;
-    float values[VEC_ELEM_NUM];
-  };
-
-  __m128 reg;
-
-  explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
-
-  explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
-
-  explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
-
-  explicit FP32Vec4(__m128 data) : reg(data) {}
-
-  explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
-};
-
-struct FP32Vec8 : public Vec<FP32Vec8> {
-  constexpr static int VEC_ELEM_NUM = 8;
-  union AliasReg {
-    __m256 reg;
-    float values[VEC_ELEM_NUM];
-  };
-
-  __m256 reg;
-
-  explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
-
-  explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
-
-  explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
-
-  explicit FP32Vec8(__m256 data) : reg(data) {}
-
-  explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
-
-#ifdef __AVX512FP16__
-  explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
-#endif
-
-  explicit FP32Vec8(const BF16Vec8 &v)
-      : reg(_mm256_castsi256_ps(
-            _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
-
-  float reduce_sum() const {
-    AliasReg ar;
-    ar.reg = reg;
-    float result = 0;
-    unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
-
-    return result;
-  }
-
-  FP32Vec8 exp() const {
-    AliasReg ar;
-    ar.reg = reg;
-    return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
-                                  expf(ar.values[5]), expf(ar.values[4]),
-                                  expf(ar.values[3]), expf(ar.values[2]),
-                                  expf(ar.values[1]), expf(ar.values[0])));
-  }
-
-  FP32Vec8 tanh() const {
-    AliasReg ar;
-    ar.reg = reg;
-    return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
-                                  tanhf(ar.values[5]), tanhf(ar.values[4]),
-                                  tanhf(ar.values[3]), tanhf(ar.values[2]),
-                                  tanhf(ar.values[1]), tanhf(ar.values[0])));
-  }
-
-  FP32Vec8 er() const {
-    AliasReg ar;
-    ar.reg = reg;
-    return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
-                                  erf(ar.values[5]), erf(ar.values[4]),
-                                  erf(ar.values[3]), erf(ar.values[2]),
-                                  erf(ar.values[1]), erf(ar.values[0])));
-  }
-
-  FP32Vec8 operator*(const FP32Vec8 &b) const {
-    return FP32Vec8(_mm256_mul_ps(reg, b.reg));
-  }
-
-  FP32Vec8 operator+(const FP32Vec8 &b) const {
-    return FP32Vec8(_mm256_add_ps(reg, b.reg));
-  }
-
-  FP32Vec8 operator-(const FP32Vec8 &b) const {
-    return FP32Vec8(_mm256_sub_ps(reg, b.reg));
-  }
-
-  FP32Vec8 operator/(const FP32Vec8 &b) const {
-    return FP32Vec8(_mm256_div_ps(reg, b.reg));
-  }
-
-  void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
-};
-
-#ifdef __AVX512F__
-struct FP32Vec16 : public Vec<FP32Vec16> {
-  constexpr static int VEC_ELEM_NUM = 16;
-  union AliasReg {
-    __m512 reg;
-    float values[VEC_ELEM_NUM];
-  };
-
-  __m512 reg;
-
-  explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
-
-  explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
-
-  explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
-
-  explicit FP32Vec16(__m512 data) : reg(data) {}
-
-  explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
-
-  explicit FP32Vec16(const FP32Vec4 &data)
-      : reg((__m512)_mm512_inserti32x4(
-            _mm512_inserti32x4(
-                _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
-                                   (__m128i)data.reg, 1),
-                (__m128i)data.reg, 2),
-            (__m128i)data.reg, 3)) {}
-
-  explicit FP32Vec16(const FP32Vec8 &data)
-      : reg((__m512)_mm512_inserti32x8(
-            _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
-
-  explicit FP32Vec16(const BF16Vec16 &v)
-      : reg(_mm512_castsi512_ps(
-            _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
-
-  explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
-
-  FP32Vec16 operator*(const FP32Vec16 &b) const {
-    return FP32Vec16(_mm512_mul_ps(reg, b.reg));
-  }
-
-  FP32Vec16 operator+(const FP32Vec16 &b) const {
-    return FP32Vec16(_mm512_add_ps(reg, b.reg));
-  }
-
-  FP32Vec16 operator-(const FP32Vec16 &b) const {
-    return FP32Vec16(_mm512_sub_ps(reg, b.reg));
-  }
-
-  FP32Vec16 operator/(const FP32Vec16 &b) const {
-    return FP32Vec16(_mm512_div_ps(reg, b.reg));
-  }
-
-  float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
-
-  template <int group_size> float reduce_sub_sum(int idx) {
-    static_assert(VEC_ELEM_NUM % group_size == 0);
-    constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
-    __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
-    return _mm512_mask_reduce_add_ps(mask, reg);
-  }
-
-  void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
-};
-#else
-struct FP32Vec16 : public Vec<FP32Vec16> {
-  constexpr static int VEC_ELEM_NUM = 16;
-
-  union AliasReg {
-    __m256 reg;
-    float values[8];
-  };
-
-  __m256 reg_low;
-  __m256 reg_high;
-
-  explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
-                                reg_high(_mm256_set1_ps(v)) {}
-
-  explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
-                         reg_high(_mm256_set1_ps(0.0)) {}
-
-  explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
-                                         reg_high(_mm256_loadu_ps(ptr + 8)) {}
-
-  explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
-
-  explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
-                                              reg_high(data.reg_high) {}
-
-  explicit FP32Vec16(const FP32Vec4 &data)
-      : reg_low((__m256)_mm256_inserti128_si256(
-                _mm256_castsi128_si256((__m128i)data.reg),
-                                       (__m128i)data.reg, 1)),
-        reg_high((__m256)_mm256_inserti128_si256(
-                 _mm256_castsi128_si256((__m128i)data.reg),
-                                       (__m128i)data.reg, 1)) {}
-
-  explicit FP32Vec16(const FP32Vec8 &data)
-      : reg_low(data.reg), reg_high(data.reg) {}
-
-  explicit FP32Vec16(const BF16Vec16 &v) {
-    __m128i low = _mm256_extractf128_si256(v.reg, 0);
-    __m128i high = _mm256_extractf128_si256(v.reg, 1);
-
-    __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low);
-    __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high);
-
-    __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2);
-    __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2);
-
-    reg_low = _mm256_castsi256_ps(v_low_shifted);
-    reg_high = _mm256_castsi256_ps(v_high_shifted);
-  }
-
-  explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
-
-  FP32Vec16 operator*(const FP32Vec16 &b) const {
-    return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
-                     _mm256_mul_ps(reg_high, b.reg_high));
-  }
-
-  FP32Vec16 operator+(const FP32Vec16 &b) const {
-    return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
-                     _mm256_add_ps(reg_high, b.reg_high));
-  }
-
-  FP32Vec16 operator-(const FP32Vec16 &b) const {
-    return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
-                     _mm256_sub_ps(reg_high, b.reg_high));
-  }
-
-  FP32Vec16 operator/(const FP32Vec16 &b) const {
-    return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
-                     _mm256_div_ps(reg_high, b.reg_high));
-  }
-
-  float reduce_sum() const {
-    FP32Vec8 low = FP32Vec8(reg_low);
-    FP32Vec8 high = FP32Vec8(reg_high);
-    return low.reduce_sum() + high.reduce_sum();
-  }
-
-  template <int group_size> float reduce_sub_sum(int idx) {
-    float sum = 0.0;
-    static_assert(VEC_ELEM_NUM % group_size == 0);
-    constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
-    uint32_t mask = base_mask << (idx * group_size);
-
-    AliasReg ar;
-
-    auto func = [&sum, &mask, &ar](int i) {
-      int flag = mask & 0x1;
-      mask = mask >> 1;
-      if (flag != 0) sum += ar.values[i];
-    };
-
-    ar.reg = reg_low;
-    unroll_loop<int, 8>(func);
-
-    ar.reg = reg_high;
-    unroll_loop<int, 8>(func);
-
-    return sum;
-  }
-
-  void save(float *ptr) const {
-    _mm256_storeu_ps(ptr, reg_low);
-    _mm256_storeu_ps(ptr + 8, reg_high);
-  }
-};
-#endif
-
-template <typename T> struct VecType { using vec_type = void; };
-
-template <typename T> using vec_t = typename VecType<T>::vec_type;
-
-template <> struct VecType<float> { using vec_type = FP32Vec8; };
-
-#ifdef __AVX512FP16__
-template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
-#endif
-
-template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
-
-template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
-
-#ifdef __AVX512FP16__
-template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
-  *reinterpret_cast<_Float16 *>(ptr) = v;
-}
-#endif
-
-inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
-  acc = acc + a * b;
-}
-
-#ifdef __AVX512BF16__
-template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
-  *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
-}
-
-inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
-    : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
-
-inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
-    : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
-
-inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
-  acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
-}
-#else
-template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
-  c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
-      reinterpret_cast<c10::BFloat16 *>(&v);
-  *ptr = *(v_ptr + 1);
-}
-
-#ifdef __AVX512F__
-inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
-    : reg(_mm256_cvtepi32_epi16(
-          _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
-
-inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
-    : reg(_mm512_cvtepi32_epi16(
-          _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
-#else
-namespace{
-__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
-  __m256i ai = _mm256_castps_si256(a);
-  ai = _mm256_srli_epi32(ai, 16);
-  ai = _mm256_packus_epi32(ai, ai);
-  ai = _mm256_permute4x64_epi64(ai, 0b00111001);
-  return _mm256_extracti128_si256(ai, 0);
-}
-}
-
-inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
-    : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
-
-inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
-  BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
-  BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
-  reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
-}
-#endif // __AVX512F__
-#endif // __AVX512BF16__
-
-inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
-
-}; // namespace vec_op
-
 #endif

+ 491 - 0
kernels/cpu/cpu_types_vsx.hpp

@@ -0,0 +1,491 @@
+
+#ifndef CPU_TYPES_VSX_HPP
+#define CPU_TYPES_VSX_HPP
+
+#include <altivec.h>
+#include <cmath>
+#include <torch/all.h>
+
+namespace vec_op {
+
+// FIXME: FP16 is not fully supported in Torch-CPU
+#define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...)                                 \
+  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)                         \
+  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
+
+#define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)                          \
+  AT_DISPATCH_SWITCH(TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
+
+#ifndef CPU_OP_GUARD
+#define CPU_KERNEL_GUARD_IN(NAME)
+#define CPU_KERNEL_GUARD_OUT(NAME)
+#else
+#define CPU_KERNEL_GUARD_IN(NAME)                                              \
+  std::cout << #NAME << " invoked." << std::endl;
+#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
+#endif
+
+#define FORCE_INLINE __attribute__((always_inline)) inline
+
+namespace {
+template <typename T, T... indexes, typename F>
+constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
+  (f(std::integral_constant<T, indexes>{}), ...);
+}
+}; // namespace
+
+template <typename T, T count, typename F,
+          typename = std::enable_if_t<std::is_invocable_v<F, T>>>
+constexpr void unroll_loop(F &&f) {
+  unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
+}
+
+template <typename T> struct Vec {
+  constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
+};
+
+typedef struct ss16x8x2_t {
+  __vector signed short val[2];
+} ss16x8x2_t;
+
+typedef struct ss16x8x4_t {
+  __vector signed short val[4];
+} ss16x8x4_t;
+
+typedef struct f32x4x2_t {
+  __vector float val[2];
+} f32x4x2_t;
+
+typedef struct f32x4x4_t {
+  __vector float val[4];
+} f32x4x4_t;
+
+struct FP32Vec8;
+struct FP32Vec16;
+
+struct BF16Vec8 : public Vec<BF16Vec8> {
+  constexpr static int VEC_ELEM_NUM = 8;
+
+  __vector signed short reg;
+
+  explicit BF16Vec8(const void *ptr)
+      : reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {}
+
+  explicit BF16Vec8(const FP32Vec8 &);
+
+  void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; }
+};
+
+struct BF16Vec16 : public Vec<BF16Vec16> {
+  constexpr static int VEC_ELEM_NUM = 16;
+
+  ss16x8x2_t reg;
+
+  explicit BF16Vec16(const void *ptr) {
+    // Load 256 bits in two parts
+    reg.val[0] = (__vector signed short)vec_xl(0,  (signed short *)ptr);
+    reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr);
+  }
+
+  explicit BF16Vec16(const FP32Vec16 &);
+
+  void save(void *ptr) const {
+    // Save 256 bits in two parts
+    vec_xst(reg.val[0], 0, (signed short *)ptr);
+    vec_xst(reg.val[1], 16, (signed short *)ptr);
+  }
+};
+
+const static __vector signed short zero = vec_splats((signed short)0);
+
+struct BF16Vec32 : public Vec<BF16Vec32> {
+  constexpr static int VEC_ELEM_NUM = 32;
+
+  ss16x8x4_t reg;
+  explicit BF16Vec32(const void *ptr)
+      : reg(*reinterpret_cast<const ss16x8x4_t *>(ptr)) {}
+
+  explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
+
+  explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
+    vec8_data.reg,
+    vec8_data.reg,
+    vec8_data.reg,
+    vec8_data.reg
+  }) {}
+
+  void save(void *ptr) const { *reinterpret_cast<ss16x8x4_t *>(ptr) = reg; }
+};
+
+struct FP32Vec4 : public Vec<FP32Vec4> {
+  constexpr static int VEC_ELEM_NUM = 4;
+  union AliasReg {
+    __vector float reg;
+    float values[VEC_ELEM_NUM];
+  };
+
+  __vector float reg;
+
+  explicit FP32Vec4(float v) : reg(vec_splats(v)) {}
+
+  explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
+
+  explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {}
+
+  explicit FP32Vec4(__vector float data) : reg(data) {}
+
+  explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
+};
+
+struct FP32Vec8 : public Vec<FP32Vec8> {
+  constexpr static int VEC_ELEM_NUM = 8;
+  union AliasReg {
+    f32x4x2_t reg;
+    float values[VEC_ELEM_NUM];
+  };
+
+  f32x4x2_t reg;
+
+  explicit FP32Vec8(float v) {
+    reg.val[0] = vec_splats(v);
+    reg.val[1] = vec_splats(v);
+  }
+
+  explicit FP32Vec8() {
+    reg.val[0] = vec_splats(0.0f);
+    reg.val[1] = vec_splats(0.0f);
+  }
+
+  explicit FP32Vec8(const float *ptr) {
+    reg.val[0] = vec_xl(0, ptr);
+    reg.val[1] = vec_xl(16, ptr);
+  }
+
+  explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
+
+  explicit FP32Vec8(const FP32Vec8 &data) {
+    reg.val[0] = data.reg.val[0];
+    reg.val[1] = data.reg.val[1];
+  }
+
+  explicit FP32Vec8(const BF16Vec8 &v) {
+    reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
+    reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
+  }
+
+  float reduce_sum() const {
+    AliasReg ar;
+    ar.reg = reg;
+    float result = 0;
+    unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
+
+    return result;
+  }
+
+  FP32Vec8 exp() const {
+    // TODO: Vectorize this
+    AliasReg ar;
+    ar.reg = reg;
+    f32x4x4_t ret;
+    ret.val[0][0] = std::exp(ar.values[0]);
+    ret.val[0][1] = std::exp(ar.values[1]);
+    ret.val[0][2] = std::exp(ar.values[2]);
+    ret.val[0][3] = std::exp(ar.values[3]);
+    ret.val[1][0] = std::exp(ar.values[4]);
+    ret.val[1][1] = std::exp(ar.values[5]);
+    ret.val[1][2] = std::exp(ar.values[6]);
+    ret.val[1][3] = std::exp(ar.values[7]);
+    return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+  }
+
+  FP32Vec8 tanh() const {
+    // TODO: Vectorize this
+    AliasReg ar;
+    ar.reg = reg;
+    f32x4x4_t ret;
+    ret.val[0][0] = std::tanh(ar.values[0]);
+    ret.val[0][1] = std::tanh(ar.values[1]);
+    ret.val[0][2] = std::tanh(ar.values[2]);
+    ret.val[0][3] = std::tanh(ar.values[3]);
+    ret.val[1][0] = std::tanh(ar.values[4]);
+    ret.val[1][1] = std::tanh(ar.values[5]);
+    ret.val[1][2] = std::tanh(ar.values[6]);
+    ret.val[1][3] = std::tanh(ar.values[7]);
+    return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+  }
+
+  FP32Vec8 er() const {
+    // TODO: Vectorize this
+    AliasReg ar;
+    ar.reg = reg;
+    f32x4x4_t ret;
+    ret.val[0][0] = std::erf(ar.values[0]);
+    ret.val[0][1] = std::erf(ar.values[1]);
+    ret.val[0][2] = std::erf(ar.values[2]);
+    ret.val[0][3] = std::erf(ar.values[3]);
+    ret.val[1][0] = std::erf(ar.values[4]);
+    ret.val[1][1] = std::erf(ar.values[5]);
+    ret.val[1][2] = std::erf(ar.values[6]);
+    ret.val[1][3] = std::erf(ar.values[7]);
+    return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
+  }
+
+  FP32Vec8 operator*(const FP32Vec8 &b) const {
+    return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
+  }
+
+  FP32Vec8 operator+(const FP32Vec8 &b) const {
+    return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
+  }
+
+  FP32Vec8 operator-(const FP32Vec8 &b) const {
+    return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
+  }
+
+  FP32Vec8 operator/(const FP32Vec8 &b) const {
+    return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
+  }
+
+  void save(float *ptr) const {
+    vec_xst(reg.val[0], 0, ptr);
+    vec_xst(reg.val[1], 16, ptr);
+  }
+};
+
+struct FP32Vec16 : public Vec<FP32Vec16> {
+  constexpr static int VEC_ELEM_NUM = 16;
+  union AliasReg {
+    f32x4x4_t reg;
+    float values[VEC_ELEM_NUM];
+  };
+
+  f32x4x4_t reg;
+
+  explicit FP32Vec16(float v) {
+    reg.val[0] = vec_splats(v);
+    reg.val[1] = vec_splats(v);
+    reg.val[2] = vec_splats(v);
+    reg.val[3] = vec_splats(v);
+  }
+
+  explicit FP32Vec16() {
+    reg.val[0] = vec_splats(0.0f);
+    reg.val[1] = vec_splats(0.0f);
+    reg.val[2] = vec_splats(0.0f);
+    reg.val[3] = vec_splats(0.0f);
+  }
+
+  explicit FP32Vec16(const float *ptr) {
+    reg.val[0] = vec_xl(0, ptr);
+    reg.val[1] = vec_xl(16, ptr);
+    reg.val[2] = vec_xl(32, ptr);
+    reg.val[3] = vec_xl(48, ptr);
+  }
+
+  explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
+
+  explicit FP32Vec16(const FP32Vec16 &data) {
+    reg.val[0] = data.reg.val[0];
+    reg.val[1] = data.reg.val[1];
+    reg.val[2] = data.reg.val[2];
+    reg.val[3] = data.reg.val[3];
+  }
+
+  explicit FP32Vec16(const FP32Vec4 &data) {
+    reg.val[0] = data.reg;
+    reg.val[1] = data.reg;
+    reg.val[2] = data.reg;
+    reg.val[3] = data.reg;
+  }
+
+  explicit FP32Vec16(const FP32Vec8 &data) {
+    reg.val[0] = data.reg.val[0];
+    reg.val[1] = data.reg.val[1];
+    reg.val[2] = data.reg.val[0];
+    reg.val[3] = data.reg.val[1];
+  }
+
+  explicit FP32Vec16(const BF16Vec16 &v) {
+    reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
+    reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
+    reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
+    reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
+  }
+
+  explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
+
+  FP32Vec16 operator*(const FP32Vec16 &b) const {
+    return FP32Vec16(f32x4x4_t({
+        vec_mul(reg.val[0], b.reg.val[0]),
+        vec_mul(reg.val[1], b.reg.val[1]),
+        vec_mul(reg.val[2], b.reg.val[2]),
+        vec_mul(reg.val[3], b.reg.val[3])}));
+  }
+
+  FP32Vec16 operator+(const FP32Vec16 &b) const {
+    return FP32Vec16(f32x4x4_t({
+        vec_add(reg.val[0], b.reg.val[0]),
+        vec_add(reg.val[1], b.reg.val[1]),
+        vec_add(reg.val[2], b.reg.val[2]),
+        vec_add(reg.val[3], b.reg.val[3])}));
+  }
+
+  FP32Vec16 operator-(const FP32Vec16 &b) const {
+    return FP32Vec16(f32x4x4_t({
+        vec_sub(reg.val[0], b.reg.val[0]),
+        vec_sub(reg.val[1], b.reg.val[1]),
+        vec_sub(reg.val[2], b.reg.val[2]),
+        vec_sub(reg.val[3], b.reg.val[3])}));
+  }
+
+  FP32Vec16 operator/(const FP32Vec16 &b) const {
+    return FP32Vec16(f32x4x4_t({
+        vec_div(reg.val[0], b.reg.val[0]),
+        vec_div(reg.val[1], b.reg.val[1]),
+        vec_div(reg.val[2], b.reg.val[2]),
+        vec_div(reg.val[3], b.reg.val[3])}));
+  }
+
+  float reduce_sum() const {
+    AliasReg ar;
+    ar.reg = reg;
+    float result = 0;
+    unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
+
+    return result;
+  }
+
+  template <int group_size> float reduce_sub_sum(int idx) {
+    static_assert(VEC_ELEM_NUM % group_size == 0);
+
+    AliasReg ar;
+    ar.reg = reg;
+    float result = 0;
+    const int start = idx * group_size;
+    unroll_loop<int, group_size>(
+        [&result, &start, ar](int i) { result += ar.values[start + i]; });
+
+    return result;
+  }
+
+  void save(float *ptr) const {
+    vec_xst(reg.val[0], 0, ptr);
+    vec_xst(reg.val[1], 16, ptr);
+    vec_xst(reg.val[2], 32, ptr);
+    vec_xst(reg.val[3], 48, ptr);
+  }
+};
+
+template <typename T> struct VecType { using vec_type = void; };
+
+template <typename T> using vec_t = typename VecType<T>::vec_type;
+
+template <> struct VecType<float> { using vec_type = FP32Vec8; };
+
+template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
+
+template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
+
+inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
+  acc = acc + a * b;
+}
+
+template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
+  c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
+      reinterpret_cast<c10::BFloat16 *>(&v);
+  *ptr = *(v_ptr + 1);
+}
+
+#ifndef __VEC_CLASS_FP_NAN
+#define __VEC_CLASS_FP_NAN (1 << 6)
+#endif
+
+const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 };
+#ifndef _ARCH_PWR10
+const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff };
+const static __vector unsigned int nan  = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 };
+const static __vector unsigned int sh16 = { 16, 16, 16, 16 };
+const static __vector unsigned int one  = { 1, 1, 1, 1 };
+#endif
+
+inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
+#ifdef _ARCH_PWR10
+  __vector signed short ret[2];
+  ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
+  ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
+  reg = vec_perm(ret[0], ret[1], omask);
+#elif defined(_ARCH_PWR9)
+  __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
+  __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
+  __vector unsigned int lsb0 = vec_sr(inp0, sh16);
+  __vector unsigned int lsb1 = vec_sr(inp1, sh16);
+  lsb0 = vec_and(lsb0, one);
+  lsb1 = vec_and(lsb1, one);
+  __vector unsigned int rnd0 = vec_add(lsb0, bias);
+  __vector unsigned int rnd1 = vec_add(lsb1, bias);
+  inp0 = vec_add(inp0, rnd0);
+  inp1 = vec_add(inp1, rnd1);
+  __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
+  __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
+  inp0 = vec_sel(inp0, nan, sel0);
+  inp1 = vec_sel(inp1, nan, sel1);
+  inp0 = vec_sr(inp0, sh16);
+  inp1 = vec_sr(inp1, sh16);
+  reg = (__vector signed short)vec_perm(inp0, inp1, omask);
+#endif
+}
+
+inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
+#ifdef _ARCH_PWR10
+  __vector signed short ret[4];
+  ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
+  ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
+  ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]);
+  ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]);
+  reg.val[0] = vec_perm(ret[0], ret[1], omask);
+  reg.val[1] = vec_perm(ret[2], ret[3], omask);
+#elif defined(_ARCH_PWR9)
+  __vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
+  __vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
+  __vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
+  __vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
+  __vector unsigned int lsb0 = vec_sr(inp0, sh16);
+  __vector unsigned int lsb1 = vec_sr(inp1, sh16);
+  __vector unsigned int lsb2 = vec_sr(inp2, sh16);
+  __vector unsigned int lsb3 = vec_sr(inp3, sh16);
+  lsb0 = vec_and(lsb0, one);
+  lsb1 = vec_and(lsb1, one);
+  lsb2 = vec_and(lsb2, one);
+  lsb3 = vec_and(lsb3, one);
+  __vector unsigned int rnd0 = vec_add(lsb0, bias);
+  __vector unsigned int rnd1 = vec_add(lsb1, bias);
+  __vector unsigned int rnd2 = vec_add(lsb2, bias);
+  __vector unsigned int rnd3 = vec_add(lsb3, bias);
+  inp0 = vec_add(inp0, rnd0);
+  inp1 = vec_add(inp1, rnd1);
+  inp2 = vec_add(inp2, rnd2);
+  inp3 = vec_add(inp3, rnd3);
+  __vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
+  __vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
+  __vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
+  __vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
+  inp0 = vec_sel(inp0, nan, sel0);
+  inp1 = vec_sel(inp1, nan, sel1);
+  inp2 = vec_sel(inp2, nan, sel2);
+  inp3 = vec_sel(inp3, nan, sel3);
+  inp0 = vec_sr(inp0, sh16);
+  inp1 = vec_sr(inp1, sh16);
+  inp2 = vec_sr(inp2, sh16);
+  inp3 = vec_sr(inp3, sh16);
+  reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
+  reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
+#endif
+}
+
+inline void prefetch(const void *addr) {
+  __asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
+}
+
+}; // namespace vec_op
+
+#endif

+ 515 - 0
kernels/cpu/cpu_types_x86.hpp

@@ -0,0 +1,515 @@
+
+#ifndef CPU_TYPES_X86_HPP
+#define CPU_TYPES_X86_HPP
+
+#include <immintrin.h>
+#include <torch/all.h>
+
+#ifndef __AVX2__
+static_assert(false, "AVX2 must be supported for the current implementation.");
+#endif
+
+namespace vec_op {
+
+// FIXME: FP16 is not fully supported in Torch-CPU
+#define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...)                                 \
+  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)                         \
+  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
+
+#define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)                          \
+  AT_DISPATCH_SWITCH(TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
+
+#ifndef CPU_OP_GUARD
+#define CPU_KERNEL_GUARD_IN(NAME)
+#define CPU_KERNEL_GUARD_OUT(NAME)
+#else
+#define CPU_KERNEL_GUARD_IN(NAME)                                              \
+  std::cout << #NAME << " invoked." << std::endl;
+#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
+#endif
+
+#define FORCE_INLINE __attribute__((always_inline)) inline
+
+namespace {
+template <typename T, T... indexes, typename F>
+constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
+  (f(std::integral_constant<T, indexes>{}), ...);
+}
+}; // namespace
+
+template <typename T, T count, typename F,
+          typename = std::enable_if_t<std::is_invocable_v<F, T>>>
+constexpr void unroll_loop(F &&f) {
+  unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
+}
+
+template <typename T> struct Vec {
+  constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
+};
+
+struct FP32Vec8;
+struct FP32Vec16;
+
+#ifdef __AVX512FP16__
+struct FP16Vec8 : public Vec<FP16Vec8> {
+  constexpr static int VEC_ELEM_NUM = 8;
+
+  __m128h reg;
+
+  explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
+
+  explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
+
+  explicit FP16Vec8(__m128h data) : reg(data) {}
+
+  FP16Vec8 operator*(const FP16Vec8 &b) const {
+    return FP16Vec8(_mm_mul_ph(reg, b.reg));
+  }
+
+  FP16Vec8 operator+(const FP16Vec8 &b) const {
+    return FP16Vec8(_mm_add_ph(reg, b.reg));
+  }
+
+  FP16Vec8 operator-(const FP16Vec8 &b) const {
+    return FP16Vec8(_mm_sub_ph(reg, b.reg));
+  }
+
+  FP16Vec8 operator/(const FP16Vec8 &b) const {
+    return FP16Vec8(_mm_div_ph(reg, b.reg));
+  }
+
+  void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
+};
+#endif
+
+struct BF16Vec8 : public Vec<BF16Vec8> {
+  constexpr static int VEC_ELEM_NUM = 8;
+
+  __m128i reg;
+
+  explicit BF16Vec8(const void *ptr)
+      : reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
+
+  explicit BF16Vec8(const FP32Vec8 &);
+
+  void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
+};
+
+struct BF16Vec16 : public Vec<BF16Vec16> {
+  constexpr static int VEC_ELEM_NUM = 16;
+
+  __m256i reg;
+
+  explicit BF16Vec16(const void *ptr)
+      : reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
+
+  explicit BF16Vec16(const FP32Vec16 &);
+
+  void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
+};
+
+#ifdef __AVX512F__
+struct BF16Vec32 : public Vec<BF16Vec32> {
+  constexpr static int VEC_ELEM_NUM = 32;
+
+  __m512i reg;
+
+  explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
+
+  explicit BF16Vec32(__m512i data) : reg(data) {}
+
+  explicit BF16Vec32(BF16Vec8 &vec8_data)
+      : reg((__m512i)_mm512_inserti32x4(
+            _mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
+                                                      (__m128i)vec8_data.reg),
+                                                  (__m128i)vec8_data.reg, 1),
+                               (__m128i)vec8_data.reg, 2),
+            (__m128i)vec8_data.reg, 3)) {}
+
+  void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
+};
+#else
+struct BF16Vec32 : public Vec<BF16Vec32> {
+  constexpr static int VEC_ELEM_NUM = 32;
+
+  __m256i reg_low;
+  __m256i reg_high;
+
+  explicit BF16Vec32(const void *ptr)
+      : reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
+        reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
+
+  explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
+                                                  reg_high(high) {}
+
+  explicit BF16Vec32(BF16Vec8 &vec8_data)
+      : reg_low((__m256i)_mm256_inserti32x4(
+                _mm256_castsi128_si256((__m128i)vec8_data.reg),
+                                       (__m128i)vec8_data.reg, 1)),
+        reg_high((__m256i)_mm256_inserti32x4(
+                _mm256_castsi128_si256((__m128i)vec8_data.reg),
+                                       (__m128i)vec8_data.reg, 1)) {}
+
+  void save(void *ptr) const {
+    *reinterpret_cast<__m256i *>(ptr) = reg_low;
+    *reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
+  }
+};
+#endif
+
+struct FP32Vec4 : public Vec<FP32Vec4> {
+  constexpr static int VEC_ELEM_NUM = 4;
+  union AliasReg {
+    __m128 reg;
+    float values[VEC_ELEM_NUM];
+  };
+
+  __m128 reg;
+
+  explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
+
+  explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
+
+  explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
+
+  explicit FP32Vec4(__m128 data) : reg(data) {}
+
+  explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
+};
+
+struct FP32Vec8 : public Vec<FP32Vec8> {
+  constexpr static int VEC_ELEM_NUM = 8;
+  union AliasReg {
+    __m256 reg;
+    float values[VEC_ELEM_NUM];
+  };
+
+  __m256 reg;
+
+  explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
+
+  explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
+
+  explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
+
+  explicit FP32Vec8(__m256 data) : reg(data) {}
+
+  explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
+
+#ifdef __AVX512FP16__
+  explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
+#endif
+
+  explicit FP32Vec8(const BF16Vec8 &v)
+      : reg(_mm256_castsi256_ps(
+            _mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
+
+  float reduce_sum() const {
+    AliasReg ar;
+    ar.reg = reg;
+    float result = 0;
+    unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
+
+    return result;
+  }
+
+  FP32Vec8 exp() const {
+    AliasReg ar;
+    ar.reg = reg;
+    return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
+                                  expf(ar.values[5]), expf(ar.values[4]),
+                                  expf(ar.values[3]), expf(ar.values[2]),
+                                  expf(ar.values[1]), expf(ar.values[0])));
+  }
+
+  FP32Vec8 tanh() const {
+    AliasReg ar;
+    ar.reg = reg;
+    return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
+                                  tanhf(ar.values[5]), tanhf(ar.values[4]),
+                                  tanhf(ar.values[3]), tanhf(ar.values[2]),
+                                  tanhf(ar.values[1]), tanhf(ar.values[0])));
+  }
+
+  FP32Vec8 er() const {
+    AliasReg ar;
+    ar.reg = reg;
+    return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
+                                  erf(ar.values[5]), erf(ar.values[4]),
+                                  erf(ar.values[3]), erf(ar.values[2]),
+                                  erf(ar.values[1]), erf(ar.values[0])));
+  }
+
+  FP32Vec8 operator*(const FP32Vec8 &b) const {
+    return FP32Vec8(_mm256_mul_ps(reg, b.reg));
+  }
+
+  FP32Vec8 operator+(const FP32Vec8 &b) const {
+    return FP32Vec8(_mm256_add_ps(reg, b.reg));
+  }
+
+  FP32Vec8 operator-(const FP32Vec8 &b) const {
+    return FP32Vec8(_mm256_sub_ps(reg, b.reg));
+  }
+
+  FP32Vec8 operator/(const FP32Vec8 &b) const {
+    return FP32Vec8(_mm256_div_ps(reg, b.reg));
+  }
+
+  void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
+};
+
+#ifdef __AVX512F__
+struct FP32Vec16 : public Vec<FP32Vec16> {
+  constexpr static int VEC_ELEM_NUM = 16;
+  union AliasReg {
+    __m512 reg;
+    float values[VEC_ELEM_NUM];
+  };
+
+  __m512 reg;
+
+  explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
+
+  explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
+
+  explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
+
+  explicit FP32Vec16(__m512 data) : reg(data) {}
+
+  explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
+
+  explicit FP32Vec16(const FP32Vec4 &data)
+      : reg((__m512)_mm512_inserti32x4(
+            _mm512_inserti32x4(
+                _mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
+                                   (__m128i)data.reg, 1),
+                (__m128i)data.reg, 2),
+            (__m128i)data.reg, 3)) {}
+
+  explicit FP32Vec16(const FP32Vec8 &data)
+      : reg((__m512)_mm512_inserti32x8(
+            _mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
+
+  explicit FP32Vec16(const BF16Vec16 &v)
+      : reg(_mm512_castsi512_ps(
+            _mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
+
+  explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
+
+  FP32Vec16 operator*(const FP32Vec16 &b) const {
+    return FP32Vec16(_mm512_mul_ps(reg, b.reg));
+  }
+
+  FP32Vec16 operator+(const FP32Vec16 &b) const {
+    return FP32Vec16(_mm512_add_ps(reg, b.reg));
+  }
+
+  FP32Vec16 operator-(const FP32Vec16 &b) const {
+    return FP32Vec16(_mm512_sub_ps(reg, b.reg));
+  }
+
+  FP32Vec16 operator/(const FP32Vec16 &b) const {
+    return FP32Vec16(_mm512_div_ps(reg, b.reg));
+  }
+
+  float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
+
+  template <int group_size> float reduce_sub_sum(int idx) {
+    static_assert(VEC_ELEM_NUM % group_size == 0);
+    constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
+    __mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
+    return _mm512_mask_reduce_add_ps(mask, reg);
+  }
+
+  void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
+};
+#else
+struct FP32Vec16 : public Vec<FP32Vec16> {
+  constexpr static int VEC_ELEM_NUM = 16;
+
+  union AliasReg {
+    __m256 reg;
+    float values[8];
+  };
+
+  __m256 reg_low;
+  __m256 reg_high;
+
+  explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
+                                reg_high(_mm256_set1_ps(v)) {}
+
+  explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
+                         reg_high(_mm256_set1_ps(0.0)) {}
+
+  explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
+                                         reg_high(_mm256_loadu_ps(ptr + 8)) {}
+
+  explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
+
+  explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
+                                              reg_high(data.reg_high) {}
+
+  explicit FP32Vec16(const FP32Vec4 &data)
+      : reg_low((__m256)_mm256_inserti128_si256(
+                _mm256_castsi128_si256((__m128i)data.reg),
+                                       (__m128i)data.reg, 1)),
+        reg_high((__m256)_mm256_inserti128_si256(
+                 _mm256_castsi128_si256((__m128i)data.reg),
+                                       (__m128i)data.reg, 1)) {}
+
+  explicit FP32Vec16(const FP32Vec8 &data)
+      : reg_low(data.reg), reg_high(data.reg) {}
+
+  explicit FP32Vec16(const BF16Vec16 &v) {
+    __m128i low = _mm256_extractf128_si256(v.reg, 0);
+    __m128i high = _mm256_extractf128_si256(v.reg, 1);
+
+    __m256i v_low_epi32 = _mm256_cvtepu16_epi32(low);
+    __m256i v_high_epi32 = _mm256_cvtepu16_epi32(high);
+
+    __m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2);
+    __m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2);
+
+    reg_low = _mm256_castsi256_ps(v_low_shifted);
+    reg_high = _mm256_castsi256_ps(v_high_shifted);
+  }
+
+  explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
+
+  FP32Vec16 operator*(const FP32Vec16 &b) const {
+    return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
+                     _mm256_mul_ps(reg_high, b.reg_high));
+  }
+
+  FP32Vec16 operator+(const FP32Vec16 &b) const {
+    return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
+                     _mm256_add_ps(reg_high, b.reg_high));
+  }
+
+  FP32Vec16 operator-(const FP32Vec16 &b) const {
+    return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
+                     _mm256_sub_ps(reg_high, b.reg_high));
+  }
+
+  FP32Vec16 operator/(const FP32Vec16 &b) const {
+    return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
+                     _mm256_div_ps(reg_high, b.reg_high));
+  }
+
+  float reduce_sum() const {
+    FP32Vec8 low = FP32Vec8(reg_low);
+    FP32Vec8 high = FP32Vec8(reg_high);
+    return low.reduce_sum() + high.reduce_sum();
+  }
+
+  template <int group_size> float reduce_sub_sum(int idx) {
+    float sum = 0.0;
+    static_assert(VEC_ELEM_NUM % group_size == 0);
+    constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
+    uint32_t mask = base_mask << (idx * group_size);
+
+    AliasReg ar;
+
+    auto func = [&sum, &mask, &ar](int i) {
+      int flag = mask & 0x1;
+      mask = mask >> 1;
+      if (flag != 0) sum += ar.values[i];
+    };
+
+    ar.reg = reg_low;
+    unroll_loop<int, 8>(func);
+
+    ar.reg = reg_high;
+    unroll_loop<int, 8>(func);
+
+    return sum;
+  }
+
+  void save(float *ptr) const {
+    _mm256_storeu_ps(ptr, reg_low);
+    _mm256_storeu_ps(ptr + 8, reg_high);
+  }
+};
+#endif
+
+template <typename T> struct VecType { using vec_type = void; };
+
+template <typename T> using vec_t = typename VecType<T>::vec_type;
+
+template <> struct VecType<float> { using vec_type = FP32Vec8; };
+
+#ifdef __AVX512FP16__
+template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
+#endif
+
+template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
+
+template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
+
+#ifdef __AVX512FP16__
+template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
+  *reinterpret_cast<_Float16 *>(ptr) = v;
+}
+#endif
+
+inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
+  acc = acc + a * b;
+}
+
+#ifdef __AVX512BF16__
+template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
+  *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
+}
+
+inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
+    : reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
+
+inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
+    : reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
+
+inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
+  acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
+}
+#else
+template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
+  c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
+      reinterpret_cast<c10::BFloat16 *>(&v);
+  *ptr = *(v_ptr + 1);
+}
+
+#ifdef __AVX512F__
+inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
+    : reg(_mm256_cvtepi32_epi16(
+          _mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
+
+inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
+    : reg(_mm512_cvtepi32_epi16(
+          _mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
+#else
+namespace{
+__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
+  __m256i ai = _mm256_castps_si256(a);
+  ai = _mm256_srli_epi32(ai, 16);
+  ai = _mm256_packus_epi32(ai, ai);
+  ai = _mm256_permute4x64_epi64(ai, 0b00111001);
+  return _mm256_extracti128_si256(ai, 0);
+}
+}
+
+inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
+    : reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
+
+inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
+  BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
+  BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
+  reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
+}
+#endif // __AVX512F__
+#endif // __AVX512BF16__
+
+inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
+
+}; // namespace vec_op
+
+#endif

+ 1 - 0
kernels/ops.h

@@ -1,5 +1,6 @@
 #pragma once
 
+#include <optional>
 #include <torch/library.h>
 
 void paged_attention_v1(

+ 2 - 2
requirements-cpu.txt

@@ -2,6 +2,6 @@
 -r requirements-common.txt
 
 # Dependencies for x86_64 CPUs
-torch == 2.3.1+cpu
-torchvision == 0.18.1+cpu
+torch == 2.3.1+cpu; platform_machine != "ppc64le"
+torchvision == 0.18.1+cpu; platform_machine != "ppc64le"
 triton >= 2.2.0