Explorar o código

feat: add fast hadamard transformation kernels (#232)

* add hadamard kernels

* fix protected access

* formatting

* add failsafe

* add hadamard transform interface

* formatting

* get rid of useless comments
AlpinDale hai 1 ano
pai
achega
5d288aa76c

+ 19 - 5
aphrodite/modeling/layers/quantization/quip_utils.py

@@ -3,13 +3,29 @@ from pathlib import Path
 
 import scipy
 import torch
-import fast_hadamard_transform
 from safetensors.torch import load_file
 
+try:
+    import aphrodite._hadamard_C as hadamard_C
+except ImportError:
+    pass
+
 HADA_TENSORS = load_file(
     Path(__file__).resolve().parent / "hadamard.safetensors")
 
 
+class HadamardTransformFn(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, x, scale=1.0):
+        ctx._hadamard_transform_scale = scale  # pylint: disable=protected-access
+        return hadamard_C.hadamard_transform(x, scale=scale)
+
+
+def hadamard_transform(x, scale=1.0):
+    return HadamardTransformFn.apply(x, scale)
+
+
 def int2mask(i, int_map):
     return ((i & int_map) > 0).int()
 
@@ -110,14 +126,12 @@ def matmul_hadU_cuda(X, hadK, K, n, scale=None, transpose=False):
     had_scale = 1 / math.sqrt(n // K) if scale is None else scale / math.sqrt(
         n // K)
     if K == 1:
-        return fast_hadamard_transform.hadamard_transform(X.contiguous(),
-                                                          scale=had_scale)
+        return hadamard_transform(X.contiguous(), scale=had_scale)
 
     if transpose:
         hadK = hadK.T.contiguous()
     input = X.view(-1, K, n // K)  # pylint: disable=redefined-builtin
-    input = fast_hadamard_transform.hadamard_transform(input.contiguous(),
-                                                       scale=had_scale)
+    input = hadamard_transform(input.contiguous(), scale=had_scale)
     input = hadK @ input
     return input.reshape(X.shape)
 

+ 247 - 0
kernels/hadamard/fast_hadamard_transform.cpp

@@ -0,0 +1,247 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+#include <torch/extension.h>
+#include <vector>
+
+#include "fast_hadamard_transform.h"
+
+#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
+
+#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...)                    \
+    if (ITYPE == at::ScalarType::Half) {                                            \
+        using input_t = at::Half;                                                   \
+        __VA_ARGS__();                                                              \
+    } else if (ITYPE == at::ScalarType::BFloat16) {                                 \
+        using input_t = at::BFloat16;                                               \
+        __VA_ARGS__();                                                              \
+    } else if (ITYPE == at::ScalarType::Float) {                                    \
+        using input_t = float;                                                      \
+        __VA_ARGS__();                                                              \
+    } else {                                                                        \
+        AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
+    }
+
+template<typename input_t>
+void fast_hadamard_transform_cuda(HadamardParamsBase &params, cudaStream_t stream);
+
+template<typename input_t>
+void fast_hadamard_transform_12N_cuda(HadamardParamsBase &params, cudaStream_t stream);
+
+template<typename input_t>
+void fast_hadamard_transform_20N_cuda(HadamardParamsBase &params, cudaStream_t stream);
+
+template<typename input_t>
+void fast_hadamard_transform_28N_cuda(HadamardParamsBase &params, cudaStream_t stream);
+
+void set_hadamard_params(HadamardParamsBase &params,
+                         // sizes
+                         const size_t batch,
+                         const size_t dim,
+                         const size_t multiple,
+                         // device pointers
+                         const at::Tensor x,
+                         const at::Tensor out,
+                         float scale
+                         ) {
+
+    // Reset the parameters
+    memset(&params, 0, sizeof(params));
+
+    params.batch = batch;
+    params.dim = dim;
+    params.log_N = int(ceil(std::log2(dim / multiple)));
+
+    // Set the pointers and strides.
+    params.x_ptr = x.data_ptr();
+    params.out_ptr = out.data_ptr();
+    // All stride are in elements, not bytes.
+    params.x_batch_stride = x.stride(0);
+    params.out_batch_stride = out.stride(0);
+
+    params.scale = scale;
+}
+
+
+at::Tensor
+fast_hadamard_transform(at::Tensor &x, float scale) {
+    auto input_type = x.scalar_type();
+    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
+
+    TORCH_CHECK(x.is_cuda());
+
+    const auto shapes_og = x.sizes();
+    const int dim_og = x.size(-1);
+    x = x.reshape({-1, dim_og});
+    if (x.stride(-1) != 1) { x = x.contiguous(); }
+    const auto sizes = x.sizes();
+    const int batch_size = sizes[0];
+
+    CHECK_SHAPE(x, batch_size, dim_og);
+    TORCH_CHECK(x.stride(1) == 1);
+
+    if (dim_og % 8 != 0) {
+        x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, 8 - dim_og % 8}));
+    }
+    const int dim = x.size(1);
+
+    TORCH_CHECK(dim % 8 == 0, "fast_hadamard_transform only supports hidden dimension divisible by 8 for now");
+    TORCH_CHECK(dim <= 32768, "fast_hadamard_transform only supports hidden dimension at most 32768 for now");
+
+    at::Tensor out = torch::empty_like(x);
+
+    HadamardParamsBase params;
+    set_hadamard_params(params, batch_size, dim, 1, x, out, scale);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)x.get_device()};
+    auto stream = at::cuda::getCurrentCUDAStream().stream();
+    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
+        fast_hadamard_transform_cuda<input_t>(params, stream);
+    });
+    if (dim_og % 8 != 0) {
+        out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
+    }
+    return out.reshape(shapes_og);
+}
+
+at::Tensor
+fast_hadamard_transform_12N(at::Tensor &x, float scale) {
+    auto input_type = x.scalar_type();
+    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
+
+    TORCH_CHECK(x.is_cuda());
+
+    const auto shapes_og = x.sizes();
+    const int dim_og = x.size(-1);
+    x = x.reshape({-1, dim_og});
+    if (x.stride(-1) != 1) { x = x.contiguous(); }
+    const auto sizes = x.sizes();
+    const int batch_size = sizes[0];
+
+    CHECK_SHAPE(x, batch_size, dim_og);
+    TORCH_CHECK(x.stride(1) == 1);
+
+    if (dim_og % (4 * 12) != 0) {
+        x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 12) - dim_og % (4 * 12)}));
+    }
+    const int dim = x.size(1);
+
+    TORCH_CHECK(dim % (4 * 12) == 0, "fast_hadamard_transform_12N only supports hidden dimension divisible by 48 for now");
+    TORCH_CHECK(dim <= 12 * 1024, "fast_hadamard_transform_12N only supports hidden dimension at most 12288 for now");
+
+    at::Tensor out = torch::empty_like(x);
+
+    HadamardParamsBase params;
+    set_hadamard_params(params, batch_size, dim, 12, x, out, scale);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)x.get_device()};
+    auto stream = at::cuda::getCurrentCUDAStream().stream();
+    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
+        fast_hadamard_transform_12N_cuda<input_t>(params, stream);
+    });
+    if (dim_og % (4 * 12) != 0) {
+        out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
+    }
+    return out.reshape(shapes_og);
+}
+
+at::Tensor
+fast_hadamard_transform_20N(at::Tensor &x, float scale) {
+    auto input_type = x.scalar_type();
+    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
+
+    TORCH_CHECK(x.is_cuda());
+
+    const auto shapes_og = x.sizes();
+    const int dim_og = x.size(-1);
+    x = x.reshape({-1, dim_og});
+    if (x.stride(-1) != 1) { x = x.contiguous(); }
+    const auto sizes = x.sizes();
+    const int batch_size = sizes[0];
+
+    CHECK_SHAPE(x, batch_size, dim_og);
+    TORCH_CHECK(x.stride(1) == 1);
+
+    if (dim_og % (4 * 20) != 0) {
+        x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 20) - dim_og % (4 * 20)}));
+    }
+    const int dim = x.size(1);
+
+    TORCH_CHECK(dim % (4 * 20) == 0, "fast_hadamard_transform_20N only supports hidden dimension divisible by 80 for now");
+    TORCH_CHECK(dim <= 20 * 1024, "fast_hadamard_transform_20N only supports hidden dimension at most 20480 for now");
+
+    at::Tensor out = torch::empty_like(x);
+
+    HadamardParamsBase params;
+    set_hadamard_params(params, batch_size, dim, 20, x, out, scale);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)x.get_device()};
+    auto stream = at::cuda::getCurrentCUDAStream().stream();
+    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
+        fast_hadamard_transform_20N_cuda<input_t>(params, stream);
+    });
+    if (dim_og % (4 * 20) != 0) {
+        out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
+    }
+    return out.reshape(shapes_og);
+}
+
+at::Tensor
+fast_hadamard_transform_28N(at::Tensor &x, float scale) {
+    auto input_type = x.scalar_type();
+    TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
+
+    TORCH_CHECK(x.is_cuda());
+
+    const auto shapes_og = x.sizes();
+    const int dim_og = x.size(-1);
+    x = x.reshape({-1, dim_og});
+    if (x.stride(-1) != 1) { x = x.contiguous(); }
+    const auto sizes = x.sizes();
+    const int batch_size = sizes[0];
+
+    CHECK_SHAPE(x, batch_size, dim_og);
+    TORCH_CHECK(x.stride(1) == 1);
+
+    if (dim_og % (4 * 28) != 0) {
+        x = torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, (4 * 28) - dim_og % (4 * 28)}));
+    }
+    const int dim = x.size(1);
+
+    TORCH_CHECK(dim % (4 * 28) == 0, "fast_hadamard_transform_28N only supports hidden dimension divisible by 112 for now");
+    // TORCH_CHECK(dim <= 28 * 1024, "fast_hadamard_transform_28N only supports hidden dimension at most 28672 for now");
+    TORCH_CHECK(dim <= 28 * 2048, "fast_hadamard_transform_28N only supports hidden dimension at most 28672 for now");
+
+    at::Tensor out = torch::empty_like(x);
+
+    HadamardParamsBase params;
+    set_hadamard_params(params, batch_size, dim, 28, x, out, scale);
+
+    // Otherwise the kernel will be launched from cuda:0 device
+    // Cast to char to avoid compiler warning about narrowing
+    at::cuda::CUDAGuard device_guard{(char)x.get_device()};
+    auto stream = at::cuda::getCurrentCUDAStream().stream();
+    DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "fast_hadamard_transform", [&] {
+        fast_hadamard_transform_28N_cuda<input_t>(params, stream);
+    });
+    if (dim_og % (8 * 28) != 0) {
+        out = out.index({torch::indexing::Slice(), torch::indexing::Slice(torch::indexing::None, dim_og)});
+    }
+    return out.reshape(shapes_og);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("fast_hadamard_transform", &fast_hadamard_transform, "Fast Hadamard transform");
+    m.def("fast_hadamard_transform_12N", &fast_hadamard_transform_20N, "Fast Hadamard transform with dimension = 12 * power of 2");
+    m.def("fast_hadamard_transform_20N", &fast_hadamard_transform_20N, "Fast Hadamard transform with dimension = 20 * power of 2");
+    m.def("fast_hadamard_transform_28N", &fast_hadamard_transform_28N, "Fast Hadamard transform with dimension = 28 * power of 2");
+}

+ 22 - 0
kernels/hadamard/fast_hadamard_transform.h

@@ -0,0 +1,22 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct HadamardParamsBase {
+    using index_t = int64_t;
+
+    int batch, dim, log_N;
+
+    index_t x_batch_stride;
+    index_t out_batch_stride;
+
+    float scale;
+
+    // Common data pointers.
+    void *__restrict__ x_ptr;
+    void *__restrict__ out_ptr;
+};

+ 194 - 0
kernels/hadamard/fast_hadamard_transform_common.h

@@ -0,0 +1,194 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#pragma once
+
+#include <cuda_bf16.h>
+#include <cuda_fp16.h>
+
+#define FULL_MASK 0xffffffff
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct uint8 {
+    uint4 u;
+    uint4 v;
+};
+
+template<int BYTES> struct BytesToType {};
+
+template<>
+struct BytesToType<32> {
+    using Type = uint8;
+    static_assert(sizeof(Type) == 32);
+};
+
+template<> struct BytesToType<16> {
+    using Type = uint4;
+    static_assert(sizeof(Type) == 16);
+};
+
+template<> struct BytesToType<8> {
+    using Type = uint64_t;
+    static_assert(sizeof(Type) == 8);
+};
+
+template<> struct BytesToType<4> {
+    using Type = uint32_t;
+    static_assert(sizeof(Type) == 4);
+};
+
+template<> struct BytesToType<2> {
+    using Type = uint16_t;
+    static_assert(sizeof(Type) == 2);
+};
+
+template<> struct BytesToType<1> {
+    using Type = uint8_t;
+    static_assert(sizeof(Type) == 1);
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<typename T>
+struct SumOp {
+__device__ inline T operator()(T const & x, T const & y) { return x + y; }
+};
+
+template<int THREADS>
+struct Allreduce {
+    static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
+    template<typename T, typename Operator>
+    static __device__ inline T run(T x, Operator &op) {
+        constexpr int OFFSET = THREADS / 2;
+        x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
+        return Allreduce<OFFSET>::run(x, op);
+    }
+};
+
+template<>
+struct Allreduce<2> {
+template<typename T, typename Operator>
+static __device__ inline T run(T x, Operator &op) {
+    x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
+    return x;
+}
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// https://stackoverflow.com/questions/35311711/whats-the-right-way-to-compute-integral-base-2-logarithms-at-compile-time
+constexpr int cilog2(int val) { return val > 0 ? 1 + cilog2(val >> 1) : -1; }
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template<int kLogN, int kNChunks>
+__device__ __forceinline__ void hadamard_mult_thread(float x[kNChunks][1 << kLogN]) {
+    constexpr int N = 1 << kLogN;
+    #pragma unroll
+    for (int i = 0; i < kLogN; ++i) {
+        const int stride = 1 << i;
+        #pragma unroll
+        for (int j = 0; j < N / 2; ++j) {
+            const int lo = j & (stride - 1);
+            const int idx = (j - lo) * 2 + lo;
+            #pragma unroll
+            for (int c = 0; c < kNChunks; ++c) {
+                const float a = x[c][idx];
+                const float b = x[c][idx + stride];
+                x[c][idx] = a + b;
+                x[c][idx + stride] = a - b;
+            }
+        }
+    }
+}
+
+template<int kLogWarpSize, int kStepStart, int kNChunks, int kNItems>
+__device__ __forceinline__ void hadamard_mult_warp(float x[kNChunks][kNItems]) {
+    constexpr int N = 1 << kLogWarpSize;
+    int lane_id = threadIdx.x % N;
+    #pragma unroll
+    for (int step = kStepStart; step < kLogWarpSize; ++step) {
+        const int lane_mask = 1 << step;
+        const float sign = (lane_id & lane_mask) ? -1.f : 1.f;
+        #pragma unroll
+        for (int c = 0; c < kNChunks; ++c) {
+            #pragma unroll
+            for (int i = 0; i < kNItems; ++i) {
+                float x_val_other = __shfl_xor_sync(FULL_MASK, x[c][i], lane_mask);
+                x[c][i] = sign * x[c][i] + x_val_other;
+            }
+        }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <int kNChunks, int kNElts, typename input_t>
+inline __device__ void load_input(input_t *x, float x_vals[kNChunks][kNElts], int dim) {
+    using vec_t = typename BytesToType<sizeof(input_t) * kNElts>::Type;
+    input_t x_vals_load[kNChunks][kNElts] = {0};
+    #pragma unroll
+    for (int c = 0; c < kNChunks; ++c) {
+        if ((c * blockDim.x + threadIdx.x) * kNElts < dim) {
+            reinterpret_cast<vec_t*>(x_vals_load)[c] = reinterpret_cast<const vec_t*>(x)[c * blockDim.x + threadIdx.x];
+        }
+    }
+    #pragma unroll
+    for (int c = 0; c < kNChunks; ++c) {
+        #pragma unroll
+        for (int i = 0; i < kNElts; ++i) { x_vals[c][i] = float(x_vals_load[c][i]); }
+    }
+}
+
+
+template <int kNChunks, int kNElts, typename output_t>
+inline __device__ void store_output(output_t *out, float out_vals[kNChunks][kNElts], int dim, float scale=1.f) {
+    using vec_t = typename BytesToType<sizeof(output_t) * kNElts>::Type;
+    output_t out_vals_store[kNChunks][kNElts];
+    #pragma unroll
+    for (int c = 0; c < kNChunks; ++c) {
+        #pragma unroll
+        for (int i = 0; i < kNElts; ++i) { out_vals_store[c][i] = out_vals[c][i] * scale; }
+    }
+    #pragma unroll
+    for (int c = 0; c < kNChunks; ++c) {
+        if ((c * blockDim.x + threadIdx.x) * kNElts < dim) {
+            reinterpret_cast<vec_t*>(out)[c * blockDim.x + threadIdx.x] = reinterpret_cast<const vec_t*>(out_vals_store)[c];
+        }
+    }
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+// Pre=true means the exchange before the hadamard_mult_warp, Pre=false means after.
+template <int kNChunks, int kChunksPerExchange, int kNElts, int kWarpSize, int kNWarps, bool Pre, typename vec_t>
+inline __device__ void exchange_smem_pre(float x_vals[kNChunks][kNElts], vec_t *smem) {
+    constexpr int kNThreads = kWarpSize * kNWarps;
+    constexpr int kNExchangePerVec = kNElts / (sizeof(vec_t) / sizeof(float));
+    const int warp_id = threadIdx.x / kWarpSize;
+    const int lane_id = threadIdx.x % kWarpSize;
+    const int row_t = threadIdx.x % kNWarps;
+    const int col_t = threadIdx.x / kNWarps;
+    // We use the XOR swizzle trick (new_col = col ^ row) to avoid / reduce smem bank conflicts.
+    #pragma unroll
+    for (int c0 = 0; c0 < kNChunks / kChunksPerExchange; ++c0) {
+        __syncthreads();
+        #pragma unroll
+        for (int c1 = 0; c1 < kChunksPerExchange; ++c1) {
+            #pragma unroll
+            for (int r = 0; r < kNExchangePerVec; ++r) {
+                smem[(c1 * kNExchangePerVec + r) * kNThreads + (Pre ? warp_id * kWarpSize + lane_id ^ warp_id : row_t * kWarpSize + col_t ^ row_t)] = reinterpret_cast<vec_t*>(x_vals[c0 * kChunksPerExchange + c1])[r];
+            }
+        }
+        __syncthreads();
+        #pragma unroll
+        for (int c1 = 0; c1 < kChunksPerExchange; ++c1) {
+            #pragma unroll
+            for (int r = 0; r < kNExchangePerVec; ++r) {
+                reinterpret_cast<vec_t*>(x_vals[c0 * kChunksPerExchange + c1])[r] = smem[(c1 * kNExchangePerVec + r) * kNThreads + (Pre ? row_t * kWarpSize + col_t ^ row_t : warp_id * kWarpSize + lane_id ^ warp_id)];
+            }
+        }
+    }
+}

+ 370 - 0
kernels/hadamard/fast_hadamard_transform_cuda.cu

@@ -0,0 +1,370 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+// #pragma once
+
+#include <c10/util/BFloat16.h>
+#include <c10/util/Half.h>
+#include <c10/cuda/CUDAException.h>  // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
+
+#include "fast_hadamard_transform.h"
+#include "fast_hadamard_transform_common.h"
+#include "fast_hadamard_transform_special.h"
+#include "static_switch.h"
+
+template<int kNThreads_, int kLogN_, typename input_t_>
+struct fast_hadamard_transform_kernel_traits {
+    using input_t = input_t_;
+    static constexpr int kNThreads = kNThreads_;
+    static constexpr int kLogN = kLogN_;
+    static constexpr int N = 1 << kLogN;
+    static constexpr int kNBytes = sizeof(input_t);
+    static_assert(kNBytes == 2 || kNBytes == 4);
+    static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
+    // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
+    // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
+    static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);
+    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
+    static constexpr int kNChunks = N / (kNElts * kNThreads);
+    // We don't want to use more than 32 KB of shared memory.
+    static constexpr int kSmemExchangeSize = std::min(N * 4, 32 * 1024);
+    static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
+    static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
+    static constexpr int kSmemSize = kSmemExchangeSize;
+};
+
+template<int kNThreads_, int kLogN_, typename input_t_>
+struct fast_hadamard_transform_12N_kernel_traits {
+    using input_t = input_t_;
+    static constexpr int kNThreads = kNThreads_;
+    static constexpr int kLogN = kLogN_;
+    static constexpr int N = (1 << kLogN) * 12;
+    static_assert(N <= 12 * 1024, "fast_hadamard_transform_12 only supports dim <= 12288");
+    static constexpr int kNBytes = sizeof(input_t);
+    static_assert(kNBytes == 2 || kNBytes == 4);
+    static constexpr int kNElts = 4;
+    // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
+    // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
+    static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);
+    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
+    static constexpr int kNChunks = N / (kNElts * kNThreads);
+    static_assert(kNChunks == 12);
+    // We don't want to use more than 24 KB of shared memory.
+    static constexpr int kSmemExchangeSize = std::min(N * 4, 24 * 1024);
+    static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
+    static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
+    static constexpr int kSmemSize = kSmemExchangeSize;
+};
+
+template<int kNThreads_, int kLogN_, typename input_t_>
+struct fast_hadamard_transform_20N_kernel_traits {
+    using input_t = input_t_;
+    static constexpr int kNThreads = kNThreads_;
+    static constexpr int kLogN = kLogN_;
+    static constexpr int N = (1 << kLogN) * 20;
+    static_assert(N <= 20 * 1024, "fast_hadamard_transform_20 only supports dim <= 20480");
+    static constexpr int kNBytes = sizeof(input_t);
+    static_assert(kNBytes == 2 || kNBytes == 4);
+    static constexpr int kNElts = 4;
+    // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
+    // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
+    static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);
+    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
+    static constexpr int kNChunks = N / (kNElts * kNThreads);
+    static_assert(kNChunks == 20);
+    // We don't want to use more than 40 KB of shared memory.
+    static constexpr int kSmemExchangeSize = std::min(N * 4, 40 * 1024);
+    static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
+    static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
+    static constexpr int kSmemSize = kSmemExchangeSize;
+};
+
+template<int kNThreads_, int kLogN_, typename input_t_>
+struct fast_hadamard_transform_28N_kernel_traits {
+    using input_t = input_t_;
+    static constexpr int kNThreads = kNThreads_;
+    static constexpr int kLogN = kLogN_;
+    static constexpr int N = (1 << kLogN) * 28;
+    static_assert(N <= 28 * 1024, "fast_hadamard_transform_28 only supports dim <= 28672");
+    static constexpr int kNBytes = sizeof(input_t);
+    static_assert(kNBytes == 2 || kNBytes == 4);
+    static constexpr int kNElts = 4;
+    // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
+    // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
+    static constexpr int kNExchangePerVec = sizeof(float) / sizeof(input_t);
+    using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
+    static constexpr int kNChunks = N / (kNElts * kNThreads);
+    static_assert(kNChunks == 28);
+    // We don't want to use more than 28 KB of shared memory.
+    static constexpr int kSmemExchangeSize = std::min(N * 4, 28 * 1024);
+    static constexpr int kNExchangeRounds = N * 4 / kSmemExchangeSize;
+    static_assert(kNExchangeRounds * kSmemExchangeSize == N * 4);
+    static constexpr int kSmemSize = kSmemExchangeSize;
+};
+
+template <int kNChunks>
+__device__ __forceinline__ void hadamard_mult_thread_chunk_12(float x[kNChunks][12]) {
+    #pragma unroll
+    for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_12(x[c]); }
+}
+
+template <int kNChunks>
+__device__ __forceinline__ void hadamard_mult_thread_chunk_20(float x[kNChunks][20]) {
+    #pragma unroll
+    for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_20(x[c]); }
+}
+
+template <int kNChunks>
+__device__ __forceinline__ void hadamard_mult_thread_chunk_28(float x[kNChunks][28]) {
+    #pragma unroll
+    for (int c = 0; c < kNChunks; ++c) { hadamard_mult_thread_28(x[c]); }
+}
+
+template<typename Ktraits>
+__global__ __launch_bounds__(Ktraits::kNThreads)
+void fast_hadamard_transform_kernel(HadamardParamsBase params) {
+    constexpr int kNThreads = Ktraits::kNThreads;
+    constexpr int kNElts = Ktraits::kNElts;
+    constexpr int kNExchangePerVec = Ktraits::kNExchangePerVec;
+    constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
+    constexpr int kNChunks = Ktraits::kNChunks;
+    using input_t = typename Ktraits::input_t;
+    using vec_t = typename Ktraits::vec_t;
+
+    constexpr int kLogNElts = cilog2(Ktraits::kNElts);
+    static_assert(1 << kLogNElts == kNElts, "kNElts must be a power of 2");
+    constexpr int kWarpSize = std::min(kNThreads, 32);
+    constexpr int kLogWarpSize = cilog2(kWarpSize);
+    static_assert(1 << kLogWarpSize == kWarpSize, "Warp size must be a power of 2");
+    constexpr int kNWarps = kNThreads / kWarpSize;
+    constexpr int kLogNWarps = cilog2(kNWarps);
+    static_assert(1 << kLogNWarps == kNWarps, "kNWarps must be a power of 2");
+    constexpr int kLoadsPerExchange = Ktraits::kSmemExchangeSize / (sizeof(vec_t) * kNThreads);
+    static_assert(kLoadsPerExchange * sizeof(vec_t) * kNThreads == Ktraits::kSmemExchangeSize, "kSmemExchangeSize should be a power of 2");
+    static_assert(kNExchangeRounds * kLoadsPerExchange * sizeof(vec_t) == kNChunks * kNElts * sizeof(float));
+
+    constexpr int kChunksPerExchange = Ktraits::kSmemExchangeSize / (sizeof(vec_t) * kNExchangePerVec * kNThreads);
+    static_assert(kChunksPerExchange * sizeof(vec_t) * kNExchangePerVec * kNThreads == Ktraits::kSmemExchangeSize);
+    constexpr int kNExchanges = kNChunks / kChunksPerExchange;
+    static_assert(kNExchanges * kChunksPerExchange == kNChunks);
+
+    // Shared memory.
+    extern __shared__ char smem_[];
+    vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_);
+
+    const int batch_id = blockIdx.x;
+    input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride;
+    input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride;
+
+    float x_vals[kNChunks][kNElts];
+    load_input<kNChunks, kNElts, input_t>(x, x_vals, params.dim);
+
+    hadamard_mult_thread<kLogNElts, kNChunks>(x_vals);
+    hadamard_mult_warp<kLogWarpSize, 0, kNChunks, kNElts>(x_vals);
+
+    if constexpr (kNWarps > 1) {
+        exchange_smem_pre<kNChunks, kChunksPerExchange, kNElts, kWarpSize, kNWarps, true, vec_t>(x_vals, smem_exchange);
+        hadamard_mult_warp<kLogNWarps, 0, kNChunks, kNElts>(x_vals);
+        exchange_smem_pre<kNChunks, kChunksPerExchange, kNElts, kWarpSize, kNWarps, false, vec_t>(x_vals, smem_exchange);
+    }
+
+    if constexpr (kNChunks > 1) {
+        float x_vals_transposed[kNElts][kNChunks];
+        #pragma unroll
+        for (int c = 0; c < kNChunks; ++c) {
+            #pragma unroll
+            for (int i = 0; i < kNElts; ++i) { x_vals_transposed[i][c] = x_vals[c][i]; }
+        }
+        if constexpr (kNChunks == 12) {
+            hadamard_mult_thread_chunk_12<kNElts>(x_vals_transposed);
+        } else if constexpr (kNChunks == 20) {
+            hadamard_mult_thread_chunk_20<kNElts>(x_vals_transposed);
+        } else if constexpr (kNChunks == 28) {
+            hadamard_mult_thread_chunk_28<kNElts>(x_vals_transposed);
+        } else {
+            constexpr int kLogNChunks = cilog2(kNChunks);
+            static_assert(1 << kLogNChunks == kNChunks, "kNChunks must be a power of 2");
+            hadamard_mult_thread<kLogNChunks, kNElts>(x_vals_transposed);
+        }
+        #pragma unroll
+        for (int c = 0; c < kNChunks; ++c) {
+            #pragma unroll
+            for (int i = 0; i < kNElts; ++i) { x_vals[c][i] = x_vals_transposed[i][c]; }
+        }
+    }
+
+    store_output<kNChunks, kNElts, input_t>(out, x_vals, params.dim, params.scale);
+}
+
+template<int kNThreads, int kLogN, typename input_t>
+void fast_hadamard_transform_launch(HadamardParamsBase &params, cudaStream_t stream) {
+    using Ktraits = fast_hadamard_transform_kernel_traits<kNThreads, kLogN, input_t>;
+    constexpr int kSmemSize = Ktraits::kSmemSize;
+    dim3 grid(params.batch);
+    auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
+    if (kSmemSize >= 48 * 1024) {
+        C10_CUDA_CHECK(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
+        }
+    kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template<typename input_t>
+void fast_hadamard_transform_cuda(HadamardParamsBase &params, cudaStream_t stream) {
+    if (params.log_N == 3) {
+        fast_hadamard_transform_launch<1, 3, input_t>(params, stream);
+    } else if (params.log_N == 4) {
+        fast_hadamard_transform_launch<2, 4, input_t>(params, stream);
+    } else if (params.log_N == 5) {
+        fast_hadamard_transform_launch<4, 5, input_t>(params, stream);
+    } else if (params.log_N == 6) {
+        fast_hadamard_transform_launch<8, 6, input_t>(params, stream);
+    } else if (params.log_N == 7) {
+        fast_hadamard_transform_launch<16, 7, input_t>(params, stream);
+    } else if (params.log_N == 8) {
+        fast_hadamard_transform_launch<32, 8, input_t>(params, stream);
+    } else if (params.log_N == 9) {
+        fast_hadamard_transform_launch<32, 9, input_t>(params, stream);
+    } else if (params.log_N == 10) {
+        fast_hadamard_transform_launch<128, 10, input_t>(params, stream);
+    } else if (params.log_N == 11) {
+        fast_hadamard_transform_launch<256, 11, input_t>(params, stream);
+    } else if (params.log_N == 12) {
+        fast_hadamard_transform_launch<256, 12, input_t>(params, stream);
+    } else if (params.log_N == 13) {
+        fast_hadamard_transform_launch<256, 13, input_t>(params, stream);
+    } else if (params.log_N == 14) {
+        fast_hadamard_transform_launch<256, 14, input_t>(params, stream);
+    } else if (params.log_N == 15) {
+        fast_hadamard_transform_launch<256, 15, input_t>(params, stream);
+    }
+}
+
+template<int kNThreads, int kLogN, typename input_t>
+void fast_hadamard_transform_12N_launch(HadamardParamsBase &params, cudaStream_t stream) {
+    using Ktraits = fast_hadamard_transform_20N_kernel_traits<kNThreads, kLogN, input_t>;
+    constexpr int kSmemSize = Ktraits::kSmemSize;
+    dim3 grid(params.batch);
+    auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
+    if (kSmemSize >= 48 * 1024) {
+        C10_CUDA_CHECK(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
+        }
+    kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template<typename input_t>
+void fast_hadamard_transform_12N_cuda(HadamardParamsBase &params, cudaStream_t stream) {
+    if (params.log_N == 2) {
+        fast_hadamard_transform_12N_launch<1, 2, input_t>(params, stream);
+    } else if (params.log_N == 2) {
+        fast_hadamard_transform_12N_launch<2, 3, input_t>(params, stream);
+    } else if (params.log_N == 4) {
+        fast_hadamard_transform_12N_launch<4, 4, input_t>(params, stream);
+    } else if (params.log_N == 5) {
+        fast_hadamard_transform_12N_launch<8, 5, input_t>(params, stream);
+    } else if (params.log_N == 6) {
+        fast_hadamard_transform_12N_launch<16, 6, input_t>(params, stream);
+    } else if (params.log_N == 7) {
+        fast_hadamard_transform_12N_launch<32, 7, input_t>(params, stream);
+    } else if (params.log_N == 8) {
+        fast_hadamard_transform_12N_launch<64, 8, input_t>(params, stream);
+    } else if (params.log_N == 9) {
+        fast_hadamard_transform_12N_launch<128, 9, input_t>(params, stream);
+    } else if (params.log_N == 10) {
+        fast_hadamard_transform_12N_launch<256, 10, input_t>(params, stream);
+    }
+}
+
+template<int kNThreads, int kLogN, typename input_t>
+void fast_hadamard_transform_20N_launch(HadamardParamsBase &params, cudaStream_t stream) {
+    using Ktraits = fast_hadamard_transform_20N_kernel_traits<kNThreads, kLogN, input_t>;
+    constexpr int kSmemSize = Ktraits::kSmemSize;
+    dim3 grid(params.batch);
+    auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
+    if (kSmemSize >= 48 * 1024) {
+        C10_CUDA_CHECK(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
+        }
+    kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template<typename input_t>
+void fast_hadamard_transform_20N_cuda(HadamardParamsBase &params, cudaStream_t stream) {
+    if (params.log_N == 2) {
+        fast_hadamard_transform_20N_launch<1, 2, input_t>(params, stream);
+    } else if (params.log_N == 2) {
+        fast_hadamard_transform_20N_launch<2, 3, input_t>(params, stream);
+    } else if (params.log_N == 4) {
+        fast_hadamard_transform_20N_launch<4, 4, input_t>(params, stream);
+    } else if (params.log_N == 5) {
+        fast_hadamard_transform_20N_launch<8, 5, input_t>(params, stream);
+    } else if (params.log_N == 6) {
+        fast_hadamard_transform_20N_launch<16, 6, input_t>(params, stream);
+    } else if (params.log_N == 7) {
+        fast_hadamard_transform_20N_launch<32, 7, input_t>(params, stream);
+    } else if (params.log_N == 8) {
+        fast_hadamard_transform_20N_launch<64, 8, input_t>(params, stream);
+    } else if (params.log_N == 9) {
+        fast_hadamard_transform_20N_launch<128, 9, input_t>(params, stream);
+    } else if (params.log_N == 10) {
+        fast_hadamard_transform_20N_launch<256, 10, input_t>(params, stream);
+    }
+}
+
+template<int kNThreads, int kLogN, typename input_t>
+void fast_hadamard_transform_28N_launch(HadamardParamsBase &params, cudaStream_t stream) {
+    using Ktraits = fast_hadamard_transform_28N_kernel_traits<kNThreads, kLogN, input_t>;
+    constexpr int kSmemSize = Ktraits::kSmemSize;
+    dim3 grid(params.batch);
+    auto kernel = &fast_hadamard_transform_kernel<Ktraits>;
+    if (kSmemSize >= 48 * 1024) {
+        C10_CUDA_CHECK(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
+        }
+    kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
+    C10_CUDA_KERNEL_LAUNCH_CHECK();
+}
+
+template<typename input_t>
+void fast_hadamard_transform_28N_cuda(HadamardParamsBase &params, cudaStream_t stream) {
+    if (params.log_N == 2) {
+        fast_hadamard_transform_28N_launch<1, 2, input_t>(params, stream);
+    } else if (params.log_N == 2) {
+        fast_hadamard_transform_28N_launch<2, 3, input_t>(params, stream);
+    } else if (params.log_N == 4) {
+        fast_hadamard_transform_28N_launch<4, 4, input_t>(params, stream);
+    } else if (params.log_N == 5) {
+        fast_hadamard_transform_28N_launch<8, 5, input_t>(params, stream);
+    } else if (params.log_N == 6) {
+        fast_hadamard_transform_28N_launch<16, 6, input_t>(params, stream);
+    } else if (params.log_N == 7) {
+        fast_hadamard_transform_28N_launch<32, 7, input_t>(params, stream);
+    } else if (params.log_N == 8) {
+        fast_hadamard_transform_28N_launch<64, 8, input_t>(params, stream);
+    } else if (params.log_N == 9) {
+        fast_hadamard_transform_28N_launch<128, 9, input_t>(params, stream);
+    } else if (params.log_N == 10) {
+        fast_hadamard_transform_28N_launch<256, 10, input_t>(params, stream);
+    }
+}
+
+template void fast_hadamard_transform_cuda<float>(HadamardParamsBase &params, cudaStream_t stream);
+template void fast_hadamard_transform_cuda<at::Half>(HadamardParamsBase &params, cudaStream_t stream);
+template void fast_hadamard_transform_cuda<at::BFloat16>(HadamardParamsBase &params, cudaStream_t stream);
+
+template void fast_hadamard_transform_12N_cuda<float>(HadamardParamsBase &params, cudaStream_t stream);
+template void fast_hadamard_transform_12N_cuda<at::Half>(HadamardParamsBase &params, cudaStream_t stream);
+template void fast_hadamard_transform_12N_cuda<at::BFloat16>(HadamardParamsBase &params, cudaStream_t stream);
+
+template void fast_hadamard_transform_20N_cuda<float>(HadamardParamsBase &params, cudaStream_t stream);
+template void fast_hadamard_transform_20N_cuda<at::Half>(HadamardParamsBase &params, cudaStream_t stream);
+template void fast_hadamard_transform_20N_cuda<at::BFloat16>(HadamardParamsBase &params, cudaStream_t stream);
+
+template void fast_hadamard_transform_28N_cuda<float>(HadamardParamsBase &params, cudaStream_t stream);
+template void fast_hadamard_transform_28N_cuda<at::Half>(HadamardParamsBase &params, cudaStream_t stream);
+template void fast_hadamard_transform_28N_cuda<at::BFloat16>(HadamardParamsBase &params, cudaStream_t stream);

+ 90 - 0
kernels/hadamard/fast_hadamard_transform_special.h

@@ -0,0 +1,90 @@
+
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+// This file is auto-generated. See "generator.py"
+
+
+#pragma once
+
+
+__device__ __forceinline__ void hadamard_mult_thread_12(float x[12]) {
+    float out[12];
+    out[0] = + x[0] - x[1] + x[2] + x[3] + x[4] + x[5] + x[6] + x[7] + x[8] + x[9] + x[10] + x[11];
+    out[1] = - x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] - x[9] + x[10] - x[11];
+    out[2] = + x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11];
+    out[3] = + x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11];
+    out[4] = + x[0] + x[1] + x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11];
+    out[5] = + x[0] - x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] + x[9] - x[10] + x[11];
+    out[6] = + x[0] + x[1] - x[2] - x[3] + x[4] + x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11];
+    out[7] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] + x[11];
+    out[8] = + x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] + x[8] - x[9] + x[10] + x[11];
+    out[9] = + x[0] - x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11];
+    out[10] = + x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] + x[10] - x[11];
+    out[11] = + x[0] - x[1] + x[2] - x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11];
+    #pragma unroll
+    for (int i = 0; i < 12; i++) { x[i] = out[i]; }
+}
+
+
+__device__ __forceinline__ void hadamard_mult_thread_20(float x[20]) {
+    float out[20];
+    out[0] = + x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19];
+    out[1] = - x[0] + x[1] - x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] + x[18] + x[19];
+    out[2] = - x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] + x[12] + x[13] - x[14] + x[15] - x[16] + x[17] - x[18] + x[19];
+    out[3] = - x[0] - x[1] - x[2] + x[3] - x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] + x[13] + x[14] + x[15] + x[16] - x[17] + x[18] - x[19];
+    out[4] = - x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] + x[14] - x[15] + x[16] + x[17] - x[18] + x[19];
+    out[5] = - x[0] + x[1] + x[2] + x[3] + x[4] + x[5] - x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] - x[13] + x[14] + x[15] + x[16] - x[17] - x[18] + x[19];
+    out[6] = + x[0] - x[1] + x[2] + x[3] + x[4] - x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] + x[15] + x[16] + x[17] - x[18] - x[19];
+    out[7] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] + x[7] - x[8] - x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] + x[16] + x[17] + x[18] - x[19];
+    out[8] = + x[0] + x[1] + x[2] - x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] - x[15] - x[16] + x[17] + x[18] + x[19];
+    out[9] = + x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] - x[14] + x[15] - x[16] - x[17] + x[18] + x[19];
+    out[10] = - x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] - x[14] - x[15] + x[16] + x[17] + x[18] + x[19];
+    out[11] = - x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] + x[18] + x[19];
+    out[12] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] - x[11] + x[12] - x[13] - x[14] + x[15] + x[16] - x[17] + x[18] + x[19];
+    out[13] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] + x[15] + x[16] + x[17] - x[18] + x[19];
+    out[14] = - x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + x[14] + x[15] + x[16] + x[17] + x[18] - x[19];
+    out[15] = - x[0] + x[1] - x[2] - x[3] + x[4] - x[5] - x[6] + x[7] + x[8] - x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] - x[17] - x[18] - x[19];
+    out[16] = + x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] - x[13] - x[14] - x[15] + x[16] - x[17] - x[18] - x[19];
+    out[17] = - x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] - x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19];
+    out[18] = - x[0] - x[1] + x[2] - x[3] + x[4] + x[5] + x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] + x[13] - x[14] - x[15] - x[16] - x[17] + x[18] - x[19];
+    out[19] = + x[0] - x[1] - x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] + x[14] - x[15] - x[16] - x[17] - x[18] + x[19];
+    #pragma unroll
+    for (int i = 0; i < 20; i++) { x[i] = out[i]; }
+}
+
+
+__device__ __forceinline__ void hadamard_mult_thread_28(float x[28]) {
+    float out[28];
+    out[0] = + x[0] - x[1] - x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] - x[9] - x[10] - x[11] - x[12] + x[13] + x[14] - x[15] + x[16] - x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] + x[25] - x[26] - x[27];
+    out[1] = - x[0] + x[1] - x[2] - x[3] - x[4] - x[5] - x[6] + x[7] + x[8] + x[9] - x[10] - x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] - x[24] + x[25] + x[26] - x[27];
+    out[2] = - x[0] - x[1] + x[2] - x[3] - x[4] - x[5] - x[6] - x[7] + x[8] + x[9] + x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] - x[17] + x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] + x[26] + x[27];
+    out[3] = - x[0] - x[1] - x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] + x[9] + x[10] + x[11] - x[12] - x[13] - x[14] + x[15] - x[16] + x[17] - x[18] + x[19] - x[20] + x[21] - x[22] - x[23] + x[24] - x[25] - x[26] + x[27];
+    out[4] = - x[0] - x[1] - x[2] - x[3] + x[4] - x[5] - x[6] - x[7] - x[8] - x[9] + x[10] + x[11] + x[12] - x[13] - x[14] - x[15] + x[16] - x[17] + x[18] - x[19] + x[20] + x[21] + x[22] - x[23] - x[24] + x[25] - x[26] - x[27];
+    out[5] = - x[0] - x[1] - x[2] - x[3] - x[4] + x[5] - x[6] - x[7] - x[8] - x[9] - x[10] + x[11] + x[12] + x[13] + x[14] - x[15] - x[16] + x[17] - x[18] + x[19] - x[20] - x[21] + x[22] + x[23] - x[24] - x[25] + x[26] - x[27];
+    out[6] = - x[0] - x[1] - x[2] - x[3] - x[4] - x[5] + x[6] + x[7] - x[8] - x[9] - x[10] - x[11] + x[12] + x[13] - x[14] + x[15] - x[16] - x[17] + x[18] - x[19] + x[20] - x[21] - x[22] + x[23] + x[24] - x[25] - x[26] + x[27];
+    out[7] = - x[0] - x[1] + x[2] + x[3] + x[4] + x[5] - x[6] + x[7] - x[8] - x[9] - x[10] - x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] - x[18] + x[19] + x[20] + x[21] - x[22] + x[23] - x[24] - x[25] + x[26] - x[27];
+    out[8] = - x[0] - x[1] - x[2] + x[3] + x[4] + x[5] + x[6] - x[7] + x[8] - x[9] - x[10] - x[11] - x[12] - x[13] + x[14] - x[15] + x[16] + x[17] - x[18] - x[19] + x[20] - x[21] + x[22] - x[23] + x[24] - x[25] - x[26] + x[27];
+    out[9] = + x[0] - x[1] - x[2] - x[3] + x[4] + x[5] + x[6] - x[7] - x[8] + x[9] - x[10] - x[11] - x[12] - x[13] + x[14] + x[15] - x[16] + x[17] + x[18] - x[19] - x[20] + x[21] - x[22] + x[23] - x[24] + x[25] - x[26] - x[27];
+    out[10] = + x[0] + x[1] - x[2] - x[3] - x[4] + x[5] + x[6] - x[7] - x[8] - x[9] + x[10] - x[11] - x[12] - x[13] - x[14] + x[15] + x[16] - x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] + x[24] - x[25] + x[26] - x[27];
+    out[11] = + x[0] + x[1] + x[2] - x[3] - x[4] - x[5] + x[6] - x[7] - x[8] - x[9] - x[10] + x[11] - x[12] - x[13] - x[14] - x[15] + x[16] + x[17] - x[18] + x[19] + x[20] - x[21] - x[22] + x[23] - x[24] + x[25] - x[26] + x[27];
+    out[12] = + x[0] + x[1] + x[2] + x[3] - x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] + x[12] - x[13] + x[14] - x[15] - x[16] + x[17] + x[18] - x[19] + x[20] + x[21] - x[22] - x[23] + x[24] - x[25] + x[26] - x[27];
+    out[13] = - x[0] + x[1] + x[2] + x[3] + x[4] - x[5] - x[6] - x[7] - x[8] - x[9] - x[10] - x[11] - x[12] + x[13] + x[14] + x[15] - x[16] - x[17] + x[18] + x[19] - x[20] - x[21] + x[22] - x[23] - x[24] + x[25] - x[26] + x[27];
+    out[14] = - x[0] + x[1] - x[2] + x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] + x[10] + x[11] - x[12] - x[13] + x[14] - x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] + x[24] + x[25] + x[26] - x[27];
+    out[15] = + x[0] - x[1] + x[2] - x[3] + x[4] + x[5] - x[6] - x[7] + x[8] - x[9] - x[10] + x[11] + x[12] - x[13] - x[14] + x[15] - x[16] - x[17] - x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] + x[25] + x[26] + x[27];
+    out[16] = - x[0] + x[1] - x[2] + x[3] - x[4] + x[5] + x[6] - x[7] - x[8] + x[9] - x[10] - x[11] + x[12] + x[13] - x[14] - x[15] + x[16] - x[17] - x[18] - x[19] - x[20] + x[21] - x[22] - x[23] - x[24] + x[25] + x[26] + x[27];
+    out[17] = + x[0] - x[1] + x[2] - x[3] + x[4] - x[5] + x[6] + x[7] - x[8] - x[9] + x[10] - x[11] - x[12] + x[13] - x[14] - x[15] - x[16] + x[17] - x[18] - x[19] - x[20] + x[21] + x[22] - x[23] - x[24] - x[25] + x[26] + x[27];
+    out[18] = + x[0] + x[1] - x[2] + x[3] - x[4] + x[5] - x[6] + x[7] + x[8] - x[9] - x[10] + x[11] - x[12] - x[13] - x[14] - x[15] - x[16] - x[17] + x[18] - x[19] - x[20] + x[21] + x[22] + x[23] - x[24] - x[25] - x[26] + x[27];
+    out[19] = - x[0] + x[1] + x[2] - x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] - x[17] - x[18] + x[19] - x[20] + x[21] + x[22] + x[23] + x[24] - x[25] - x[26] - x[27];
+    out[20] = + x[0] - x[1] + x[2] + x[3] - x[4] + x[5] - x[6] - x[7] - x[8] + x[9] + x[10] - x[11] - x[12] + x[13] - x[14] - x[15] - x[16] - x[17] - x[18] - x[19] + x[20] - x[21] + x[22] + x[23] + x[24] + x[25] - x[26] - x[27];
+    out[21] = - x[0] + x[1] + x[2] - x[3] - x[4] + x[5] + x[6] - x[7] + x[8] - x[9] + x[10] + x[11] - x[12] + x[13] + x[14] + x[15] - x[16] - x[17] - x[18] - x[19] + x[20] + x[21] - x[22] - x[23] - x[24] - x[25] - x[26] - x[27];
+    out[22] = + x[0] - x[1] + x[2] + x[3] - x[4] - x[5] + x[6] + x[7] - x[8] + x[9] - x[10] + x[11] + x[12] - x[13] + x[14] + x[15] + x[16] - x[17] - x[18] - x[19] - x[20] - x[21] + x[22] - x[23] - x[24] - x[25] - x[26] - x[27];
+    out[23] = + x[0] + x[1] - x[2] + x[3] + x[4] - x[5] - x[6] - x[7] + x[8] - x[9] + x[10] - x[11] + x[12] + x[13] - x[14] + x[15] + x[16] + x[17] - x[18] - x[19] - x[20] - x[21] - x[22] + x[23] - x[24] - x[25] - x[26] - x[27];
+    out[24] = - x[0] + x[1] + x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] + x[16] + x[17] + x[18] - x[19] - x[20] - x[21] - x[22] - x[23] + x[24] - x[25] - x[26] - x[27];
+    out[25] = - x[0] - x[1] + x[2] + x[3] - x[4] + x[5] + x[6] + x[7] + x[8] - x[9] + x[10] - x[11] + x[12] - x[13] - x[14] - x[15] - x[16] + x[17] + x[18] + x[19] - x[20] - x[21] - x[22] - x[23] - x[24] + x[25] - x[26] - x[27];
+    out[26] = + x[0] - x[1] - x[2] + x[3] + x[4] - x[5] + x[6] - x[7] + x[8] + x[9] - x[10] + x[11] - x[12] + x[13] - x[14] - x[15] - x[16] - x[17] + x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] + x[26] - x[27];
+    out[27] = + x[0] + x[1] - x[2] - x[3] + x[4] + x[5] - x[6] + x[7] - x[8] + x[9] + x[10] - x[11] + x[12] - x[13] + x[14] - x[15] - x[16] - x[17] - x[18] + x[19] + x[20] - x[21] - x[22] - x[23] - x[24] - x[25] - x[26] + x[27];
+    #pragma unroll
+    for (int i = 0; i < 28; i++) { x[i] = out[i]; }
+}

+ 124 - 0
kernels/hadamard/generator.py

@@ -0,0 +1,124 @@
+import math
+import re
+from pathlib import Path
+
+import numpy as np
+
+# From https://en.wikipedia.org/wiki/Paley_construction (construction II for q = 5)
+
+had_12_paley = """
++-++++++++++
+--+-+-+-+-+-
++++-++----++
++---+--+-++-
++++++-++----
++-+---+--+-+
+++--+++-++--
++--++---+--+
+++----+++-++
++--+-++---+-
+++++----+++-
++-+--+-++---
+""" 
+
+# From http://neilsloane.com/hadamard/
+
+had_20_will = """
++----+----++--++-++-
+-+----+---+++---+-++
+--+----+---+++-+-+-+
+---+----+---+++++-+-
+----+----++--++-++-+
+-+++++-----+--+++--+
++-+++-+---+-+--+++--
+++-++--+---+-+--+++-
++++-+---+---+-+--+++
+++++-----++--+-+--++
+--++-+-++-+-----++++
+---++-+-++-+---+-+++
++---++-+-+--+--++-++
+++---++-+----+-+++-+
+-++---++-+----+++++-
+-+--+--++-+----+----
++-+-----++-+----+---
+-+-+-+---+--+----+--
+--+-+++------+----+-
++--+--++------+----+
+"""
+
+
+had_28_will = """
++------++----++-+--+-+--++--
+-+-----+++-----+-+--+-+--++-
+--+-----+++---+-+-+----+--++
+---+-----+++---+-+-+-+--+--+
+----+-----+++---+-+-+++--+--
+-----+-----++++--+-+--++--+-
+------++----++-+--+-+--++--+
+--++++-+-------++--+++-+--+-
+---++++-+-----+-++--+-+-+--+
++---+++--+----++-++--+-+-+--
+++---++---+----++-++--+-+-+-
++++---+----+----++-++--+-+-+
+++++--------+-+--++-++--+-+-
+-++++--------+++--++--+--+-+
+-+-++-++--++--+--------++++-
++-+-++--+--++--+--------++++
+-+-+-++--+--++--+----+---+++
++-+-+-++--+--+---+---++---++
+++-+-+-++--+------+--+++---+
+-++-+-+-++--+------+-++++---
++-++-+---++--+------+-++++--
+-++--++-+-++-+++----++------
++-++--++-+-++-+++-----+-----
+++-++---+-+-++-+++-----+----
+-++-++-+-+-+-+--+++-----+---
+--++-++++-+-+----+++-----+--
++--++-+-++-+-+----+++-----+-
+++--++-+-++-+-+----++------+
+"""
+
+header = """
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+// This file is auto-generated. See "generator.py"\n
+
+#pragma once
+
+"""
+
+template = """
+__device__ __forceinline__ void hadamard_mult_thread_{N}(float x[{N}]) {
+    float out[{N}];
+    {code}
+    #pragma unroll
+    for (int i = 0; i < {N}; i++) { x[i] = out[i]; }
+}
+
+"""
+
+
+def string_to_array(string):
+    # Convert strings of + and - to bool arrays
+    string = string.strip().replace('+', '1').replace('-', '-1').split()
+    return np.stack([np.fromstring(" ".join(string[i]), dtype=np.int32, sep=' ') for i in range(len(string))])
+
+
+def array_code_gen(arr):
+    N = arr.shape[0]
+    assert arr.shape[0] == arr.shape[1]
+    out = []
+    for i in range(N):
+        out.append(f"out[{i}] = " + " ".join([f"{'+' if arr[i, j] == 1 else '-'} x[{j}]" for j in range(N)]) + ";")
+    return template.replace("{N}", str(N)).replace("{code}", '\n    '.join(out))
+
+
+
+def main():
+    output_dir = Path(__file__).parent / "fast_hadamard_transform_special.h"
+    output_dir.write_text(header + array_code_gen(string_to_array(had_12_paley)) + array_code_gen(string_to_array(had_20_will)) + array_code_gen(string_to_array(had_28_will)))
+
+if __name__ == '__main__':
+    main()

+ 25 - 0
kernels/hadamard/static_switch.h

@@ -0,0 +1,25 @@
+// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
+// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
+
+#pragma once
+
+/// @param COND       - a boolean expression to switch by
+/// @param CONST_NAME - a name given for the constexpr bool variable.
+/// @param ...       - code to execute for true and false
+///
+/// Usage:
+/// ```
+/// BOOL_SWITCH(flag, BoolConst, [&] {
+///     some_function<BoolConst>(...);
+/// });
+/// ```
+#define BOOL_SWITCH(COND, CONST_NAME, ...)                                           \
+    [&] {                                                                            \
+        if (COND) {                                                                  \
+            static constexpr bool CONST_NAME = true;                                 \
+            return __VA_ARGS__();                                                    \
+        } else {                                                                     \
+            static constexpr bool CONST_NAME = false;                                \
+            return __VA_ARGS__();                                                    \
+        }                                                                            \
+    }()

+ 19 - 0
setup.py

@@ -256,6 +256,25 @@ if _is_cuda():
                     "nvcc": NVCC_FLAGS_PUNICA,
                 },
             ))
+    
+    install_hadamard = bool(int(os.getenv("APHRODITE_INSTALL_HADAMARD_KERNELS", "1")))
+    device_count = torch.cuda.device_count()
+    for i in range(device_count):
+        major, minor = torch.cuda.get_device_capability(i)
+        if major < 7:
+            install_hadamard = False
+            break
+    if install_hadamard:
+        ext_modules.append(
+            CUDAExtension(
+                name="aphrodite._hadamard_C",
+                sources=["kernels/hadamard/fast_hadamard_transform.cpp",
+                         "kernels/hadamard/fast_hadamard_transform_cuda.cu"],
+                extra_compile_args={
+                    "cxx": CXX_FLAGS,
+                    "nvcc": NVCC_FLAGS,
+                },
+            ))
 
 elif _is_hip():
     amd_arch = get_amdgpu_offload_arch()