Procházet zdrojové kódy

feat: AMD ROCm support (#95)

* feat: add ROCM kernels

* WIP integration into aphrodite

* fix compatibility layer

* fix kernels

* add xformers rocm patch

* update setup.py and add requirements-rocm.txt

* update engine args and ray utilities

* modify attention and quantization code

* formatting pt1

* manual merge and format fixes

* patch script

* rocm dockerfile

* simplify operations

* add custom DevFuncAttribute

* use custom function instead of torch.version

* it's 74 for amd

* no fp32 and fp16 for rocm

* clean up args

* get rid of torch

* partially support mistral for now

* fix formatting

* finishing touches

* revert back to rocm FA2

* why was that cpp
AlpinDale před 1 rokem
rodič
revize
1334a833a4

+ 5 - 1
.gitignore

@@ -192,4 +192,8 @@ _build/
 
 # vim swap files
 *.swo
-*.swp
+*.swp
+
+# HIP files generated by PyTorch
+*.hip
+*_hip*

+ 34 - 4
aphrodite/common/config.py

@@ -5,7 +5,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.common.logger import init_logger
 from aphrodite.transformers_utils.config import get_config
-from aphrodite.common.utils import get_cpu_memory
+from aphrodite.common.utils import get_cpu_memory, is_hip
 
 logger = init_logger(__name__)
 
@@ -81,12 +81,26 @@ class ModelConfig:
 
     def _verify_load_format(self) -> None:
         load_format = self.load_format.lower()
-        if load_format not in [
-                "auto", "pt", "safetensors", "npcache", "dummy"
-        ]:
+        supported_load_format = [
+            "auto", "pt", "safetensors", "npcache", "dummy"
+        ]
+        rocm_not_supported_load_format = ["safetensors"]
+        if load_format not in supported_load_format:
             raise ValueError(
                 f"Unknown load format: {self.load_format}. Must be one of "
                 "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
+        if is_hip():
+            if load_format in ["safetensors"]:
+                rocm_supported_load_format = [
+                    f for f in supported_load_format
+                    if (f not in rocm_not_supported_load_format)
+                ]
+                raise ValueError(
+                    f"load format {load_format} is not supported on ROCm. "
+                    f"Must be one of {rocm_supported_load_format}.")
+            # force ROCm to load from pt weights if nothing is set
+            if load_format == "auto":
+                load_format = "pt"
         self.load_format = load_format
 
     def _verify_tokenizer_mode(self) -> None:
@@ -99,6 +113,7 @@ class ModelConfig:
 
     def _verify_quantization(self) -> None:
         supported_quantization = ["awq", "squeezellm", "gptq"]
+        rocm_not_supported_quantization = ["awq"]
         if self.quantization is not None:
             self.quantization = self.quantization.lower()
 
@@ -117,6 +132,11 @@ class ModelConfig:
                 raise ValueError(
                     f"Unknown quantization method: {self.quantization}. "
                     f"Must be one of {supported_quantization}.")
+            if is_hip(
+            ) and self.quantization in rocm_not_supported_quantization:
+                raise ValueError(
+                    f"{self.quantization} quantization method is currently "
+                    "not supported in ROCm.")
         if self.quantization is not None:
             logger.warning(f"{self.quantization} quantization is not fully "
                            "optimized yet. The speed can be slower than "
@@ -333,6 +353,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
     "bfloat16": torch.bfloat16,
 }
 
+_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
+
 
 def _get_and_verify_dtype(
     config: PretrainedConfig,
@@ -363,6 +385,14 @@ def _get_and_verify_dtype(
             f"Unknown dtype: {dtype}. Must be either a string or a torch "
             "dtype.")
 
+    if is_hip() and torch_dtype == torch.float32:
+        rocm_supported_dtypes = [
+            k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
+            if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
+        ]
+        raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
+                         f"Supported dtypes are {rocm_supported_dtypes}")
+
     # Verify the dtype.
     if torch_dtype != config_dtype:
         if torch_dtype == torch.float32:

+ 6 - 1
aphrodite/common/utils.py

@@ -29,10 +29,15 @@ class Counter:
         self.counter = 0
 
 
+def is_hip() -> bool:
+    return torch.version.hip is not None
+
+
 def get_max_shared_memory_bytes(gpu: int = 0) -> int:
     """Returns the maximum shared memory per thread block in bytes."""
     # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
-    cudaDevAttrMaxSharedMemoryPerBlockOptin = 97  # pylint: disable=invalid-name
+    # pylint: disable=invalid-name
+    cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
     max_shared_mem = cuda_utils.get_device_attribute(
         cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
     return int(max_shared_mem)

+ 3 - 3
aphrodite/engine/args_tools.py

@@ -81,11 +81,11 @@ class EngineArgs:
             default=EngineArgs.load_format,
             choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
             help='The format of the model weights to load. '
-            '"auto" will try to load the weights in the safetensors format '
-            'and fall back to the pytorch bin format if safetensors format '
+            '"auto" will try to load the weights in the safetensors '
+            'and fall back to the pytorch bin format if safetensors '
             'is not available. '
             '"pt" will load the weights in the pytorch bin format. '
-            '"safetensors" will load the weights in the safetensors format. '
+            '"safetensors" will load the weights in the safetensors. '
             '"npcache" will load the weights in pytorch format and store '
             'a numpy cache to speed up the loading. '
             '"dummy" will initialize the weights with random values, '

+ 7 - 1
aphrodite/engine/ray_tools.py

@@ -5,6 +5,7 @@ from typing import Optional, Tuple, TYPE_CHECKING
 
 from aphrodite.common.config import ParallelConfig
 from aphrodite.common.logger import init_logger
+from aphrodite.common.utils import is_hip
 
 logger = init_logger(__name__)
 
@@ -77,7 +78,12 @@ def initialize_cluster(
                 "Ray is not installed. Please install Ray to use distributed "
                 "serving.")
         # Connect to a ray cluster.
-        ray.init(address=ray_address, ignore_reinit_error=True)
+        if is_hip():
+            ray.init(address=ray_address,
+                     ignore_reinit_error=True,
+                     num_gpus=parallel_config.world_size)
+        else:
+            ray.init(address=ray_address, ignore_reinit_error=True)
 
     if not parallel_config.worker_use_ray:
         # Initialize cluster locally.

+ 3 - 0
aphrodite/modeling/layers/attention.py

@@ -13,6 +13,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
 from aphrodite._C import ops as attention_ops
 from aphrodite._C import cache_ops
 from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.common.utils import is_hip
 
 _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
 # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@@ -171,6 +172,8 @@ class PagedAttention(nn.Module):
                 attn_bias=input_metadata.attn_bias,
                 p=0.0,
                 scale=self.scale,
+                op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
+                (is_hip()) else None,
             )
             output = out.view_as(query)
         else:

+ 5 - 2
aphrodite/modeling/layers/quantization/__init__.py

@@ -1,17 +1,20 @@
 from typing import Type
 
-from aphrodite.modeling.layers.quantization.awq import AWQConfig
 from aphrodite.modeling.layers.quantization.squeezellm import SqueezeLLMConfig
 from aphrodite.modeling.layers.quantization.gptq import GPTQConfig
 from aphrodite.modeling.layers.quantization.base_config import (
     QuantizationConfig)
+from aphrodite.common.utils import is_hip
 
 _QUANTIZATION_CONFIG_REGISTRY = {
-    "awq": AWQConfig,
     "squeezellm": SqueezeLLMConfig,
     "gptq": GPTQConfig,
 }
 
+if not is_hip():
+    from aphrodite.modeling.layers.quantization.awq import AWQConfig
+    _QUANTIZATION_CONFIG_REGISTRY["awq"] = AWQConfig
+
 
 def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
     if quantization not in _QUANTIZATION_CONFIG_REGISTRY:

+ 9 - 2
aphrodite/modeling/layers/quantization/awq.py

@@ -2,11 +2,18 @@ from typing import Any, Dict, List, Optional
 
 import torch
 from torch.nn.parameter import Parameter
-
-from aphrodite._C import ops as quantization_ops
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
 from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
+from aphrodite.common.logger import init_logger
+from aphrodite.common.utils import is_hip
+
+logger = init_logger(__name__)
+
+if is_hip():
+    logger.warning("AWQ is not supported on ROCm.")
+else:
+    from aphrodite._C import ops as quantization_ops
 
 
 class AWQConfig(QuantizationConfig):

+ 12 - 4
aphrodite/modeling/layers/quantization/squeezellm.py

@@ -7,6 +7,7 @@ from aphrodite._C import ops as quantization_ops
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
 from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
+from aphrodite.common.utils import is_hip
 
 
 class SqueezeLLMConfig(QuantizationConfig):
@@ -117,10 +118,17 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
         lookup_table = weights["lookup_table"]
         out_shape = x.shape[:-1] + (qweight.shape[-1], )
         reshaped_x = x.reshape(-1, x.shape[-1])
-        # NOTE: The output tensor should be zero-initialized.
-        out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
-        quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
-                                         lookup_table)
+        if is_hip():
+            out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
+            quantization_ops.squeezellm_gemm(reshaped_x, qweight, out_f,
+                                             lookup_table)
+            out = out_f.to(dtype=torch.float16)
+            # do something specific for HIP
+        else:
+            # NOTE: The output tensor should be zero-initialized.
+            out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
+            quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
+                                             lookup_table)
 
         if bias is not None:
             out = out + bias

+ 24 - 0
aphrodite/modeling/loader.py

@@ -10,6 +10,10 @@ from aphrodite.common.config import ModelConfig
 from aphrodite.modeling.models import *  # pylint: disable=wildcard-import
 from aphrodite.modeling.hf_downloader import (get_quant_config,
                                               initialize_dummy_weights)
+from aphrodite.common.utils import is_hip
+from aphrodite.common.logger import init_logger
+
+logger = init_logger(__name__)
 
 # TODO: Lazy-load the model classes.
 _MODEL_REGISTRY = {
@@ -22,6 +26,18 @@ _MODEL_REGISTRY = {
     "PhiForCausalLM": PhiForCausalLM,
 }
 
+# Models to be disabled in ROCm
+_ROCM_UNSUPPORTED_MODELS = []
+if is_hip():
+    for rocm_model in _ROCM_UNSUPPORTED_MODELS:
+        del _MODEL_REGISTRY[rocm_model]
+
+# Models partially supported in ROCm
+_ROCM_PARTIALLY_SUPPORTED_MODELS = {
+    "MistralForCausalLM":
+    "Sliding window attention is not supported in ROCm's flash attention",
+}
+
 
 @contextlib.contextmanager
 def _set_default_torch_dtype(dtype: torch.dtype):
@@ -36,7 +52,15 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
     architectures = getattr(config, "architectures", [])
     for arch in architectures:
         if arch in _MODEL_REGISTRY:
+            if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
+                logger.warning(
+                    f"{arch} is not fully supported in ROCm. Reason: "
+                    f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
             return _MODEL_REGISTRY[arch]
+        elif arch in _ROCM_UNSUPPORTED_MODELS:
+            raise ValueError(
+                f"Model architecture {arch} is not supported by ROCm for now."
+                f"\nSupported architectures {list(_MODEL_REGISTRY.keys())}")
     raise ValueError(
         f"Model architectures {architectures} are not supported for now. "
         f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")

+ 62 - 0
docker/Dockerfile.rocm

@@ -0,0 +1,62 @@
+FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
+
+# Install some basic utilities
+RUN apt-get update && apt-get install python3 python3-pip -y
+
+# Install some basic utilities
+RUN apt-get update && apt-get install -y \
+    curl \
+    ca-certificates \
+    sudo \
+    git \
+    bzip2 \
+    libx11-6 \
+    build-essential \
+    wget \
+    unzip \
+    nvidia-cuda-toolkit \
+    tmux \
+ && rm -rf /var/lib/apt/lists/*
+
+### Mount Point ###
+# When launching the container, mount the code directory to /app
+ARG APP_MOUNT=/app
+VOLUME [ ${APP_MOUNT} ]
+WORKDIR ${APP_MOUNT}
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
+
+ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
+ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
+ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
+ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
+
+# Install ROCm flash-attention
+RUN mkdir libs \
+    && cd libs \
+    && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
+    && cd flash-attention \
+    && git checkout 3d2b6f5 \
+    && git submodule update --init \
+    && export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
+    && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
+    && python3 setup.py install \
+    && cd ..
+
+COPY ./ /app/aphrodite-engine
+
+RUN python3 -m pip install --upgrade pip
+RUN pip install xformers==0.0.22.post7 --no-deps
+
+RUN cd /app \
+    && cd aphrodite-engine \
+    && pip install -U -r requirements-rocm.txt \
+    && bash patch_xformers-0.0.22.post7.rocm.sh \
+    && python3 setup.py install \
+    && cd ..
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir ray[all]
+
+CMD ["/bin/bash"]

+ 4 - 3
kernels/activation_kernels.cu

@@ -1,6 +1,7 @@
 #include <torch/extension.h>
 #include <ATen/cuda/CUDAContext.h>
 
+#include "cuda_compat.h"
 #include "dispatch_utils.h"
 
 namespace aphrodite {
@@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel(
   const int d) {
   const int64_t token_idx = blockIdx.x;
   for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
-    const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
-    const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
+    const scalar_t x = APHRODITE_LDG(&input[token_idx * 2 * d + idx]);
+    const scalar_t y = APHRODITE_LDG(&input[token_idx * 2 * d + d + idx]);
     out[token_idx * d + idx] = silu(x) * y;
   }
 }
@@ -57,7 +58,7 @@ __global__ void activation_kernel(
   const int d) {
   const int64_t token_idx = blockIdx.x;
   for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
-    const scalar_t x = __ldg(&input[token_idx * d + idx]);
+    const scalar_t x = APHRODITE_LDG(&input[token_idx * d + idx]);
     out[token_idx * d + idx] = ACT_FN(x);
   }
 }

+ 24 - 14
kernels/attention/attention_kernels.cu

@@ -16,6 +16,10 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+#ifdef USE_ROCM
+#include <hip/hip_runtime.h>
+#endif
+
 #include <torch/extension.h>
 #include <ATen/cuda/CUDAContext.h>
 
@@ -24,7 +28,11 @@
 
 #include <algorithm>
 
+#ifndef USE_ROCM
 #define WARP_SIZE 32
+#else
+#define WARP_SIZE warpSize
+#endif
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
@@ -41,7 +49,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
   // Compute the sum per warp.
 #pragma unroll
   for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
-    sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
   }
 
   // Warp leaders store the data to shared memory.
@@ -60,11 +68,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
   // Parallel reduction inside the warp.
 #pragma unroll
   for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
-    sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
+    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
   }
 
   // Broadcast to other threads.
-  return __shfl_sync(uint32_t(-1), sum, 0);
+  return APHRODITE_SHFL_SYNC(sum, 0);
 }
 
 // TODO: Merge the last two dimensions of the grid.
@@ -221,7 +229,7 @@ __device__ void paged_attention_kernel(
   // The 0-th thread of each thread group already has its max qk value.
 #pragma unroll
   for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
-    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
   }
   if (lane == 0) {
     red_smem[warp_idx] = qk_max;
@@ -233,10 +241,10 @@ __device__ void paged_attention_kernel(
   qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
 #pragma unroll
   for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
-    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
+    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
   }
   // Broadcast the max qk value to all threads.
-  qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
+  qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
 
   // Get the sum of the exp values.
   float exp_sum = 0.f;
@@ -320,7 +328,7 @@ __device__ void paged_attention_kernel(
     float acc = accs[i];
 #pragma unroll
     for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
-      acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
+      acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
     }
     accs[i] = acc;
   }
@@ -486,7 +494,7 @@ __global__ void paged_attention_v2_reduce_kernel(
   // Reduce within the warp.
 #pragma unroll
   for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
-    max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
+    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
   }
   if (lane == 0) {
     red_smem[warp_idx] = max_logit;
@@ -496,10 +504,10 @@ __global__ void paged_attention_v2_reduce_kernel(
   max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
 #pragma unroll
   for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
-    max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
+    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
   }
   // Broadcast the max value to all threads.
-  max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
+  max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
 
   // Load rescaled exp sums to shared memory.
   float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
@@ -532,11 +540,12 @@ __global__ void paged_attention_v2_reduce_kernel(
 
 } // namespace aphrodite
 
+
 #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \
-  cudaFuncSetAttribute(                                                                       \
-    aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>,                   \
-    cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size);                            \
-  aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>                      \
+  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                  \
+    ((void*)aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>),     \
+    shared_mem_size);                                                                         \
+  aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>                 \
   <<<grid, block, shared_mem_size, stream>>>(                                                 \
     out_ptr,                                                                                  \
     query_ptr,                                                                                \
@@ -552,6 +561,7 @@ __global__ void paged_attention_v2_reduce_kernel(
     kv_block_stride,                                                                          \
     kv_head_stride);
 
+
 // TODO: Tune NUM_THREADS.
 template<
   typename T,

+ 2 - 1
kernels/attention/attention_utils.cuh

@@ -18,6 +18,7 @@
  */
 #pragma once
 
+#include "../cuda_compat.h"
 #include "attention_dtypes.h"
 
 #include <float.h>
@@ -40,7 +41,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
   float qk = sum(qk_vec);
 #pragma unroll
   for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
-    qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
+    qk += APHRODITE_SHFL_XOR_SYNC(qk, mask);
   }
   return qk;
 }

+ 16 - 3
kernels/attention/dtype_bfloat16.cuh

@@ -22,8 +22,17 @@
 #include "attention_generic.cuh"
 #include "dtype_float32.cuh"
 
-#include <cuda_bf16.h>
-#include <cuda_fp16.h>
+#ifndef USE_ROCM
+  #include <cuda_bf16.h>
+  #include <cuda_fp16.h>
+#else
+  #include <hip/hip_bf16.h>
+  #include <hip/hip_fp16.h>
+
+  typedef __hip_bfloat162 __nv_bfloat162;
+  typedef __hip_bfloat16 __nv_bfloat16;
+#endif
+
 #include <stdint.h>
 
 namespace aphrodite {
@@ -99,7 +108,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
   assert(false);
 #else
-  return a + b;
+  #ifndef USE_ROCM
+    return a + b;
+  #else
+    return __hadd(a, b);
+  #endif
 #endif
 }
 

+ 63 - 5
kernels/attention/dtype_float16.cuh

@@ -22,6 +22,10 @@
 #include "attention_generic.cuh"
 #include "dtype_float32.cuh"
 
+#ifdef USE_ROCM
+  #include <hip/hip_fp16.h>
+#endif
+
 #include <stdint.h>
 
 namespace aphrodite {
@@ -64,21 +68,47 @@ struct FloatVec<uint4> {
 
 // Utility functions for type conversions.
 inline __device__ uint32_t h0_h0(uint16_t a) {
+#ifndef USE_ROCM
   uint32_t b;
   asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
   return b;
+#else
+  union {
+   uint32_t u32;
+   uint16_t u16[2];
+  } tmp;
+  tmp.u16[0] = a;
+  tmp.u16[1] = a;
+  return tmp.u32;
+#endif
 }
 
 inline __device__ float half_to_float(uint16_t h) {
   float f;
+#ifndef USE_ROCM
   asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
+#else
+  asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
+#endif
   return f;
 }
 
 inline __device__ float2 half2_to_float2(uint32_t v) {
+#ifndef USE_ROCM
   uint16_t lo, hi;
   asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
   return make_float2(half_to_float(lo), half_to_float(hi));
+#else
+  union {
+    uint32_t u32;
+    uint16_t u16[2];
+  } tmp;
+  tmp.u32 = v;
+  float2 ret;
+  ret.x = half_to_float(tmp.u16[0]);
+  ret.y = half_to_float(tmp.u16[1]);
+  return ret;
+#endif
 }
 
 inline __device__ uint16_t float_to_half(float f) {
@@ -86,7 +116,11 @@ inline __device__ uint16_t float_to_half(float f) {
     uint32_t u32;
     uint16_t u16[2];
   } tmp;
+#ifndef USE_ROCM
   asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
+#else
+  asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
+#endif
   return tmp.u16[0];
 }
 
@@ -95,12 +129,16 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
     uint32_t u32;
     uint16_t u16[2];
   } tmp;
-
-#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
-  asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
+#ifndef USE_ROCM
+  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+    asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
+  #else
+    asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
+    asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
+  #endif
 #else
-  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
-  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
+  tmp.u16[0] = float_to_half(f.x);
+  tmp.u16[1] = float_to_half(f.y);
 #endif
   return tmp.u32;
 }
@@ -108,13 +146,21 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
 // Vector addition.
 inline __device__ uint16_t add(uint16_t a, uint16_t b) {
   uint16_t c;
+#ifndef USE_ROCM
   asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+#else
+  asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
   return c;
 }
 
 inline __device__ uint32_t add(uint32_t a, uint32_t b) {
   uint32_t c;
+#ifndef USE_ROCM
   asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+#else
+  asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
   return c;
 }
 
@@ -159,14 +205,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
 template<>
 inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
   uint16_t c;
+#ifndef USE_ROCM
   asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+#else
+  asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
   return c;
 }
 
 template<>
 inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
   uint32_t c;
+#ifndef USE_ROCM
   asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+#else
+  asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
   return c;
 }
 
@@ -273,7 +327,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
 // Vector fused multiply-add.
 inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
   uint32_t d;
+#ifndef USE_ROCM
   asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
+#else
+  asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
+#endif
   return d;
 }
 

+ 7 - 6
kernels/cache_kernels.cu

@@ -1,6 +1,7 @@
 #include <torch/extension.h>
 #include <ATen/cuda/CUDAContext.h>
 
+#include "cuda_compat.h"
 #include "dispatch_utils.h"
 
 #include <algorithm>
@@ -28,8 +29,8 @@ void swap_blocks(
     TORCH_CHECK(false, "Invalid device combination");
   }
 
-  void *src_ptr = src.data_ptr();
-  void *dst_ptr = dst.data_ptr();
+  char *src_ptr = static_cast<char*>(src.data_ptr());
+  char *dst_ptr = static_cast<char*>(dst.data_ptr());
 
   const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@@ -266,8 +267,8 @@ __global__ void gather_cached_kv_kernel(
                                 + head_offset * block_size
                                 + block_offset;
 
-      key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
-      value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
+      key[tgt_key_idx] = APHRODITE_LDG(&key_cache[src_key_idx]);
+      value[tgt_value_idx] = APHRODITE_LDG(&value_cache[src_value_idx]);
     }
 }
 
@@ -332,8 +333,8 @@ __global__ void gather_cached_kv_kernel_optimized(
             src_key_indices[j] = src_key_idx;
             src_value_indices[j] = src_value_idx;
 
-            keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
-            values_to_store[j] = __ldg(&value_cache[src_value_idx]);
+            keys_to_store[j] = APHRODITE_LDG(&key_cache[src_key_idx]);
+            values_to_store[j] = APHRODITE_LDG(&value_cache[src_value_idx]);
         }
 
         #pragma unroll

+ 27 - 0
kernels/cuda_compat.h

@@ -0,0 +1,27 @@
+#pragma once
+
+#ifndef USE_ROCM
+  #define APHRODITE_LDG(arg) __ldg(arg)
+#else
+  #define APHRODITE_LDG(arg) *(arg)
+#endif
+
+#ifndef USE_ROCM
+  #define APHRODITE_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
+#else
+  #define APHRODITE_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
+#endif
+
+#ifndef USE_ROCM
+  #define APHRODITE_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
+#else
+  #define APHRODITE_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
+#endif
+
+#ifndef USE_ROCM
+  #define APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
+    cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
+#else
+  #define APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
+    hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
+#endif

+ 4 - 0
kernels/cuda_utils_kernels.cu

@@ -1,3 +1,7 @@
+#ifdef USE_ROCM
+    #include <hip/hip_runtime.h>
+#endif
+
 int get_device_attribute(
     int attribute,
     int device_id)

+ 3 - 0
kernels/ops.h

@@ -62,12 +62,15 @@ void gelu_fast(
   torch::Tensor& out,
   torch::Tensor& input);
 
+// The AWQ kernels are only available on CUDA
+#ifndef USE_ROCM
 torch::Tensor awq_gemm(
   torch::Tensor _in_feats,
   torch::Tensor _kernel,
   torch::Tensor _scaling_factors,
   torch::Tensor _zeros,
   int split_k_iters);
+#endif
 
 void squeezellm_gemm(
   torch::Tensor vec,

+ 5 - 4
kernels/pos_encoding_kernels.cu

@@ -1,6 +1,7 @@
 #include <torch/extension.h>
 #include <ATen/cuda/CUDAContext.h>
 
+#include "cuda_compat.h"
 #include "dispatch_utils.h"
 
 namespace aphrodite {
@@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding(
     // GPT-NeoX style rotary embedding.
     x_index = rot_offset;
     y_index = embed_dim + rot_offset;
-    cos = __ldg(cos_ptr + x_index);
-    sin = __ldg(sin_ptr + x_index);
+    cos = APHRODITE_LDG(cos_ptr + x_index);
+    sin = APHRODITE_LDG(sin_ptr + x_index);
   } else {
     // GPT-J style rotary embedding.
     x_index = 2 * rot_offset;
     y_index = 2 * rot_offset + 1;
-    cos = __ldg(cos_ptr + x_index / 2);
-    sin = __ldg(sin_ptr + x_index / 2);
+    cos = APHRODITE_LDG(cos_ptr + x_index / 2);
+    sin = APHRODITE_LDG(sin_ptr + x_index / 2);
   }
 
   const scalar_t x = arr[x_index];

+ 2 - 0
kernels/pybind.cpp

@@ -49,7 +49,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
 
   // Quantization ops
+  #ifndef USE_ROCM
   ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
+  #endif
   ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
   ops.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
   ops.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");

+ 74 - 0
kernels/quantization/squeezellm/quant_cuda_kernel.cu

@@ -20,9 +20,17 @@ __device__ inline unsigned int as_unsigned(int i) {
 
 // 4-bit matvec kernel (LUT-based)
 __global__ void NUQ4MatMulKernel(
+#ifndef USE_ROCM
     const  half2* __restrict__ vec,
+#else
+    const  __half2* __restrict__ vec,
+#endif
     const    int* __restrict__ mat,
+#ifndef USE_ROCM
            half2* __restrict__ mul,
+#else
+          float2* __restrict__ mul,
+#endif
     const  __half* __restrict__ lookup_table,
     int height,
     int width,
@@ -35,7 +43,11 @@ __global__ void NUQ4MatMulKernel(
   int row = BLOCKHEIGHT4 * blockIdx.x;
   int col =  BLOCKWIDTH * blockIdx.y + threadIdx.x;
 
+#ifndef USE_ROCM
   __shared__ half2 blockvec[blockwidth2];
+#else
+  __shared__ __half2 blockvec[blockwidth2];
+#endif
 
   __shared__ __half deq2[16][BLOCKWIDTH];
   int off = threadIdx.x;
@@ -46,8 +58,13 @@ __global__ void NUQ4MatMulKernel(
   }
 
   __half res;
+#ifndef USE_ROCM
   half2 res2;
   half2 tmp2;
+#else
+  __half2 res2;
+  __half2 tmp2;
+#endif
 
   int i;
   int k;
@@ -68,48 +85,96 @@ __global__ void NUQ4MatMulKernel(
     while (k < blockwidth2) {
       tmp1 = as_unsigned(mat[i]);
 
+#ifndef USE_ROCM
       res2 = {};
       tmp2 = {};
+#else
+      res2.x = __half_as_ushort(__float2half(0));
+      res2.y = __half_as_ushort(__float2half(0));
+      tmp2.x = __half_as_ushort(__float2half(0));
+      tmp2.y = __half_as_ushort(__float2half(0));
+#endif
 
       lut_index1 = tmp1 & 0xF;
       lut_index2 = (tmp1 >> 4) & 0xF;
+#ifndef USE_ROCM
       tmp2.x = deq2[lut_index1][off];
       tmp2.y = deq2[lut_index2][off];
+#else
+      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
+      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
+#endif
       res2 = __hfma2(tmp2, blockvec[k + 0], res2);
 
       lut_index1 = (tmp1 >> 8) & 0xF;
       lut_index2 = (tmp1 >> 12) & 0xF;
+#ifndef USE_ROCM
       tmp2.x = deq2[lut_index1][off];
       tmp2.y = deq2[lut_index2][off];
+#else
+      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
+      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
+#endif
       res2 = __hfma2(tmp2, blockvec[k + 1], res2);
 
       lut_index1 = (tmp1 >> 16) & 0xF;
       lut_index2 = (tmp1 >> 20) & 0xF;
+#ifndef USE_ROCM
       tmp2.x = deq2[lut_index1][off];
       tmp2.y = deq2[lut_index2][off];
+#else
+      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
+      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
+#endif
       res2 = __hfma2(tmp2, blockvec[k + 2], res2);
 
       lut_index1 = (tmp1 >> 24) & 0xF;
       lut_index2 = (tmp1 >> 28) & 0xF;
+#ifndef USE_ROCM
       tmp2.x = deq2[lut_index1][off];
       tmp2.y = deq2[lut_index2][off];
+#else
+      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
+      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
+#endif
       res2 = __hfma2(tmp2, blockvec[k + 3], res2);
 
+#ifndef USE_ROCM
       res = __hadd(__hadd(res2.x, res2.y), res);
+#else
+      res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
+#endif
 
       i += width;
       k += 4;
     }
 
     // col%2 -> only set one of the two values
+#ifndef USE_ROCM
     half2 res3 = {};
     if (col % 2 == 0) {
       res3.x = res;
     } else {
       res3.y = res;
     }
+#else
+    __half2 res3;
+    res3.x = __half_as_ushort(__float2half(0));
+    res3.y = __half_as_ushort(__float2half(0));
+    if (col % 2 == 0) {
+      res3.x = __half_as_ushort(res);
+    } else {
+      res3.y = __half_as_ushort(res);
+    }
+#endif
 
+#ifndef USE_ROCM
     atomicAdd(&mul[b * width / 2 + col / 2], res3);
+#else
+    int tmp_addr = b * width / 2 + col / 2;
+    atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
+    atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
+#endif
   }
 }
 
@@ -136,10 +201,19 @@ void squeezellm_gemm(
   dim3 threads(BLOCKWIDTH);
 
   aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
+#ifndef USE_ROCM
     (half2*) vec.data<at::Half>(),
+#else
+    (__half2*) vec.data_ptr<at::Half>(),
+#endif
     mat.data_ptr<int>(),
+#ifndef USE_ROCM
     (half2*) mul.data<at::Half>(),
     (__half*) lookup_table.data<at::Half>(),
+#else
+    (float2*) mul.data_ptr<float>(),
+    (__half*) lookup_table.data_ptr<at::Half>(),
+#endif
     height, width, batch, vec_height
   );
 }

+ 3 - 1
kernels/reduction.cuh

@@ -18,13 +18,15 @@
  */
 #pragma once
 
+#include "cuda_compat.h"
+
 namespace aphrodite {
 
 template<typename T>
 __inline__ __device__ T warpReduceSum(T val) {
 #pragma unroll
   for (int mask = 16; mask > 0; mask >>= 1)
-    val += __shfl_xor_sync(0xffffffff, val, mask, 32);
+    val += APHRODITE_SHFL_XOR_SYNC(val, mask);
   return val;
 }
 

+ 22 - 0
patch_xformers-0.0.22.post7.rocm.sh

@@ -0,0 +1,22 @@
+#!/bin/bash
+export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
+export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
+
+echo $XFORMERS_FMHA_FLASH_PATH
+echo $XFORMERS_FMHA_COMMON_PATH
+
+if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then
+    echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
+    patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"
+    echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
+else
+    echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
+fi
+
+if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then
+    echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
+    patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"
+    echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
+else
+    echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
+fi

+ 15 - 0
requirements-rocm.txt

@@ -0,0 +1,15 @@
+ninja  # For faster builds.
+typing-extensions>=4.8.0
+psutil
+ray >= 2.5.1
+pandas  # Required for Ray data.
+pyarrow  # Required for Ray data.
+sentencepiece  # Required for LLaMA tokenizer.
+numpy
+tokenizers>=0.15.0
+huggingface_hub<0.18,>=0.16.4
+einops  # Required for phi-1_5
+transformers >= 4.34.0  # Required for Mistral.
+fastapi
+uvicorn[standard]
+pydantic == 1.10.13  # Required for OpenAI server.

+ 1 - 0
requirements.txt

@@ -8,6 +8,7 @@ transformers >= 4.34.0
 uvicorn
 openai # for fastapi's openai proxy emulation
 xformers >= 0.0.22
+einops  # Required for phi-1_5
 fschat >= 0.2.23
 pydantic == 1.10.13
 pyarrow # needed for ray

+ 13 - 0
rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch

@@ -0,0 +1,13 @@
+--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py	2023-11-29 03:17:03.930103539 +0000
++++ common.py	2023-11-28 16:14:19.846233146 +0000
+@@ -298,8 +298,8 @@
+         dtype = d.query.dtype
+         if device_type not in cls.SUPPORTED_DEVICES:
+             reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
+-        if device_type == "cuda" and not _built_with_cuda:
+-            reasons.append("xFormers wasn't build with CUDA support")
++        #if device_type == "cuda" and not _built_with_cuda:
++        #    reasons.append("xFormers wasn't build with CUDA support")
+         if device_type == "cuda":
+             device_capability = torch.cuda.get_device_capability(d.device)
+             if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:

+ 134 - 0
rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch

@@ -0,0 +1,134 @@
+--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py	2023-11-29 03:17:03.930103539 +0000
++++ flash.py	2023-11-28 16:14:25.206128903 +0000
+@@ -31,39 +31,39 @@
+ 
+ FLASH_VERSION = "0.0.0"
+ try:
+-    try:
+-        from ... import _C_flashattention  # type: ignore[attr-defined]
+-        from ..._cpp_lib import _build_metadata
+-
+-        if _build_metadata is not None:
+-            FLASH_VERSION = _build_metadata.flash_version
+-    except ImportError:
+-        import flash_attn
+-        from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+-
+-        FLASH_VERSION = flash_attn.__version__
+-        flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
+-        if flash_ver_parsed < (2, 3):
+-            raise ImportError("Requires 2.3 for sliding window support")
++    #try:
++    #    from ... import _C_flashattention  # type: ignore[attr-defined]
++    #    from ..._cpp_lib import _build_metadata
++
++    #    if _build_metadata is not None:
++    #        FLASH_VERSION = _build_metadata.flash_version
++    #except ImportError:
++    import flash_attn
++    from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
++
++    FLASH_VERSION = flash_attn.__version__
++    #    flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
++    #    if flash_ver_parsed < (2, 3):
++    #        raise ImportError("Requires 2.3 for sliding window support")
+ 
+     # create library so that flash-attn goes through the PyTorch Dispatcher
+-    _flash_lib = torch.library.Library("xformers_flash", "DEF")
++    #_flash_lib = torch.library.Library("xformers_flash", "DEF")
+ 
+-    _flash_lib.define(
+-        "flash_fwd(Tensor query, Tensor key, Tensor value, "
+-        "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
+-        "int max_seqlen_q, int max_seqlen_k, "
+-        "float p, float softmax_scale, "
+-        "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+-    )
+-
+-    _flash_lib.define(
+-        "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+-        "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+-        "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+-        "int max_seqlen_q, int max_seqlen_k, "
+-        "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+-    )
++    #_flash_lib.define(
++    #    "flash_fwd(Tensor query, Tensor key, Tensor value, "
++    #    "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
++    #    "int max_seqlen_q, int max_seqlen_k, "
++    #    "float p, float softmax_scale, "
++    #    "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
++    #)
++
++    #_flash_lib.define(
++    #    "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
++    #    "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
++    #    "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
++    #    "int max_seqlen_q, int max_seqlen_k, "
++    #    "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
++    #)
+ 
+     def _flash_fwd(
+         query,
+@@ -98,8 +98,8 @@
+                 p,
+                 softmax_scale,
+                 is_causal,
+-                window_size - 1,  # window_size_left
+-                -1,  # window_size_right
++        #        window_size - 1,  # window_size_left
++        #        -1,  # window_size_right
+                 return_softmax,
+                 None,  # rng
+             )
+@@ -127,8 +127,8 @@
+                 softmax_scale,
+                 False,
+                 is_causal,
+-                window_size - 1,  # window_size_left
+-                -1,  # window_size_right
++         #       window_size - 1,  # window_size_left
++         #       -1,  # window_size_right
+                 return_softmax,
+                 None,
+             )
+@@ -169,8 +169,8 @@
+                 p,
+                 softmax_scale,
+                 is_causal,
+-                window_size - 1,  # window_size_left
+-                -1,  # window_size_right
++        #        window_size - 1,  # window_size_left
++        #        -1,  # window_size_right
+                 None,
+                 rng_state,
+             )
+@@ -193,15 +193,15 @@
+                 softmax_scale,
+                 False,  # zero_tensors
+                 is_causal,
+-                window_size - 1,  # window_size_left
+-                -1,  # window_size_right
++        #        window_size - 1,  # window_size_left
++        #        -1,  # window_size_right
+                 None,
+                 rng_state,
+             )
+         return dq, dk, dv
+ 
+-    _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+-    _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
++    #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
++    #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+ except ImportError:
+     pass
+ 
+@@ -348,7 +348,7 @@
+         implementation.
+     """
+ 
+-    OPERATOR = get_operator("xformers_flash", "flash_fwd")
++    OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
+     SUPPORTED_DEVICES: Set[str] = {"cuda"}
+     CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
+     SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}

+ 172 - 88
setup.py

@@ -8,30 +8,85 @@ import warnings
 from packaging.version import parse, Version
 import setuptools
 import torch
-from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
+from torch.utils.cpp_extension import (
+    BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME)
 
 ROOT_DIR = os.path.dirname(__file__)
 
 MAIN_CUDA_VERSION = "11.8"
 
 # Supported NVIDIA GPU architectures.
-SUPPORTED_ARCHS = {
+NVIDIA_SUPPORTED_ARCHS = {
     "6.0", "6.1", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"
 }
+ROCM_SUPPORTED_ARCHS = {
+    "gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"
+}
+
+def _is_hip() -> bool:
+    return torch.version.hip is not None
+
+def _is_cuda() -> bool:
+    return torch.version.cuda is not None
+
 
 # Compiler flags.
 CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
 # TODO: Should we use -O3?
 NVCC_FLAGS = ["-O2", "-std=c++17"]
 
-ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
-CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
-NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
 
-if CUDA_HOME is None:
+if _is_hip():
+    if ROCM_HOME is None:
+        raise RuntimeError(
+            "Cannot find ROCM_HOME. ROCm must be available to build the "
+            "package.")
+    NVCC_FLAGS += ["-DUSE_ROCM"]
+
+if _is_cuda() and CUDA_HOME is None:
     raise RuntimeError(
         "Cannot find CUDA_HOME. CUDA must be available to build the package.")
 
+ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
+CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
+NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
+
+def get_amdgpu_offload_arch():
+    command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
+    try:
+        output = subprocess.check_output([command])
+        return output.decode('utf-8').strip()
+    except subprocess.CalledProcessError as e:
+        error_message = f"Error: {e}"
+        raise RuntimeError(error_message) from e
+    except FileNotFoundError as e:
+        # If the command is not found, print an error message
+        error_message = f"The command {command} was not found."
+        raise RuntimeError(error_message) from e
+
+    return None
+
+
+def get_hipcc_rocm_version():
+    # Run the hipcc --version command
+    result = subprocess.run(['hipcc', '--version'],
+                            stdout=subprocess.PIPE,
+                            stderr=subprocess.STDOUT,
+                            text=True)
+
+    # Check if the command was executed successfully
+    if result.returncode != 0:
+        print("Error running 'hipcc --version'")
+        return None
+
+    # Extract the version using a regular expression
+    match = re.search(r'HIP version: (\S+)', result.stdout)
+    if match:
+        # Return the version string
+        return match.group(1)
+    else:
+        print("Could not find HIP version in the output")
+        return None
 
 def get_nvcc_cuda_version(cuda_dir: str) -> Version:
     """Get the CUDA version from nvcc.
@@ -63,20 +118,22 @@ def get_torch_arch_list() -> Set[str]:
         return set()
 
     # Filter out the invalid architectures and print a warning.
-    valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
+    valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
+        {s + "+PTX"
+         for s in NVIDIA_SUPPORTED_ARCHS})
     arch_list = torch_arch_list.intersection(valid_archs)
     # If none of the specified architectures are valid, raise an error.
     if not arch_list:
         raise RuntimeError(
-            "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
-            f"variable ({env_arch_list}) is supported. "
+            "None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` "
+            f"env variable ({env_arch_list}) is supported. "
             f"Supported CUDA architectures are: {valid_archs}.")
     invalid_arch_list = torch_arch_list - valid_archs
     if invalid_arch_list:
         warnings.warn(
-            f"Unsupported CUDA architectures ({invalid_arch_list}) are "
+            f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
             "excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
-            f"({env_arch_list}). Supported CUDA architectures are: "
+            f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
             f"{valid_archs}.",
             stacklevel=2)
     return arch_list
@@ -84,7 +141,7 @@ def get_torch_arch_list() -> Set[str]:
 
 # First, check the TORCH_CUDA_ARCH_LIST environment variable.
 compute_capabilities = get_torch_arch_list()
-if not compute_capabilities:
+if _is_cuda() and not compute_capabilities:
     # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
     # GPUs on the current machine.
     device_count = torch.cuda.device_count()
@@ -95,72 +152,87 @@ if not compute_capabilities:
                 "GPUs with compute capability below 6.0 are not supported.")
         compute_capabilities.add(f"{major}.{minor}")
 
-nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
-if not compute_capabilities:
-    # If no GPU is specified nor available, add all supported architectures
-    # based on the NVCC CUDA version.
-    compute_capabilities = SUPPORTED_ARCHS.copy()
-    if nvcc_cuda_version < Version("11.1"):
-        compute_capabilities.remove("8.6")
+if _is_cuda():
+    nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
+    if not compute_capabilities:
+        # If no GPU is specified nor available, add all supported architectures
+        # based on the NVCC CUDA version.
+        compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
+        if nvcc_cuda_version < Version("11.1"):
+            compute_capabilities.remove("8.6")
+        if nvcc_cuda_version < Version("11.8"):
+            compute_capabilities.remove("8.9")
+            compute_capabilities.remove("9.0")
+    # Validate the NVCC CUDA version.
+    if nvcc_cuda_version < Version("11.0"):
+        raise RuntimeError(
+            "CUDA 11.0 or higher is required to build the package.")
+    if (nvcc_cuda_version < Version("11.1")
+            and any(cc.startswith("8.6") for cc in compute_capabilities)):
+        raise RuntimeError(
+            "CUDA 11.1 or higher is required for compute capability 8.6.")
     if nvcc_cuda_version < Version("11.8"):
-        compute_capabilities.remove("8.9")
-        compute_capabilities.remove("9.0")
-
-# Validate the NVCC CUDA version.
-if nvcc_cuda_version < Version("11.0"):
-    raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
-if (nvcc_cuda_version < Version("11.1")
-        and any(cc.startswith("8.6") for cc in compute_capabilities)):
-    raise RuntimeError(
-        "CUDA 11.1 or higher is required for compute capability 8.6.")
-if nvcc_cuda_version < Version("11.8"):
-    if any(cc.startswith("8.9") for cc in compute_capabilities):
-        # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
-        # However, GPUs with compute capability 8.9 can also run the code generated by
-        # the previous versions of CUDA 11 and targeting compute capability 8.0.
-        # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
-        # instead of 8.9.
-        warnings.warn(
-            "CUDA 11.8 or higher is required for compute capability 8.9. "
-            "Targeting compute capability 8.0 instead.",
-            stacklevel=2)
-        compute_capabilities = set(cc for cc in compute_capabilities
-                                   if not cc.startswith("8.9"))
-        compute_capabilities.add("8.0+PTX")
-    if any(cc.startswith("9.0") for cc in compute_capabilities):
+        if any(cc.startswith("8.9") for cc in compute_capabilities):
+            # CUDA 11.8 is required to generate the code targeting compute capability 8.9.
+            # However, GPUs with compute capability 8.9 can also run the code generated by
+            # the previous versions of CUDA 11 and targeting compute capability 8.0.
+            # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
+            # instead of 8.9.
+            warnings.warn(
+                "CUDA 11.8 or higher is required for compute capability 8.9. "
+                "Targeting compute capability 8.0 instead.",
+                stacklevel=2)
+            compute_capabilities = set(cc for cc in compute_capabilities
+                                       if not cc.startswith("8.9"))
+            compute_capabilities.add("8.0+PTX")
+        if any(cc.startswith("9.0") for cc in compute_capabilities):
+            raise RuntimeError(
+                "CUDA 11.8 or higher is required for compute capability 9.0.")
+
+    # Add target compute capabilities to NVCC flags.
+    for capability in compute_capabilities:
+        num = capability[0] + capability[2]
+        NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
+        if capability.endswith("+PTX"):
+            NVCC_FLAGS += [
+                "-gencode", f"arch=compute_{num},code=compute_{num}"
+            ]
+
+    # Use NVCC threads to parallelize the build.
+    if nvcc_cuda_version >= Version("11.2"):
+        num_threads = min(os.cpu_count(), 8)
+        NVCC_FLAGS += ["--threads", str(num_threads)]
+
+elif _is_hip():
+    amd_arch = get_amdgpu_offload_arch()
+    if amd_arch not in ROCM_SUPPORTED_ARCHS:
         raise RuntimeError(
-            "CUDA 11.8 or higher is required for compute capability 9.0.")
+            f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
+            f"amdgpu_arch_found: {amd_arch}")
 
-# Add target compute capabilities to NVCC flags.
-for capability in compute_capabilities:
-    num = capability[0] + capability[2]
-    NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
-    if capability.endswith("+PTX"):
-        NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
+ext_modules = []
 
-# Use NVCC threads to parallelize the build.
-if nvcc_cuda_version >= Version("11.2"):
-    num_threads = min(os.cpu_count(), 8)
-    NVCC_FLAGS += ["--threads", str(num_threads)]
+aphrodite_extension_sources = [
+    "kernels/cache_kernels.cu",
+    "kernels/attention/attention_kernels.cu",
+    "kernels/pos_encoding_kernels.cu",
+    "kernels/activation_kernels.cu",
+    "kernels/layernorm_kernels.cu",
+    "kernels/quantization/squeezellm/quant_cuda_kernel.cu",
+    "kernels/quantization/gptq/exllama_ext.cpp",
+    "kernels/quantization/gptq/q_matrix.cu",
+    "kernels/quantization/gptq/q_gemm.cu",
+    "kernels/quantization/gptq/old_matmul_kernel.cu",
+    "kernels/cuda_utils_kernels.cu",
+    "kernels/pybind.cpp",
+]
+
+if _is_cuda():
+    aphrodite_extension_sources.append("kernels/quantization/awq/gemm_kernels.cu")
 
-ext_modules = []
 aphrodite_extension = CUDAExtension(
     name="aphrodite._C",
-    sources=[
-        "kernels/cache_kernels.cu",
-        "kernels/attention/attention_kernels.cu",
-        "kernels/pos_encoding_kernels.cu",
-        "kernels/activation_kernels.cu",
-        "kernels/layernorm_kernels.cu",
-        "kernels/quantization/awq/gemm_kernels.cu",
-        "kernels/quantization/squeezellm/quant_cuda_kernel.cu",
-        "kernels/quantization/gptq/exllama_ext.cpp",
-        "kernels/quantization/gptq/q_matrix.cu",
-        "kernels/quantization/gptq/q_gemm.cu",
-        "kernels/quantization/gptq/old_matmul_kernel.cu",
-        "kernels/cuda_utils_kernels.cu",
-        "kernels/pybind.cpp",
-    ],
+    sources=aphrodite_extension_sources,
     extra_compile_args={
         "cxx": CXX_FLAGS,
         "nvcc": NVCC_FLAGS,
@@ -188,21 +260,29 @@ def find_version(filepath: str) -> str:
 
 def get_aphrodite_version() -> str:
     version = find_version(get_path("aphrodite-engine", "__init__.py"))
-    cuda_version = str(nvcc_cuda_version)
     
-    # Split the version into numerical and suffix parts
-    version_parts = version.split('-')
-    version_num = version_parts[0]
-    version_suffix = version_parts[1] if len(version_parts) > 1 else ''
-    
-    if cuda_version != MAIN_CUDA_VERSION:
-        cuda_version_str = cuda_version.replace(".", "")[:3]
-        version_num += f"+cu{cuda_version_str}"
-    
-    # Reassemble the version string with the suffix, if any
-    version = version_num + ('-' + version_suffix if version_suffix else '')
-    
-    return version
+    if _is_hip():
+        # get the HIP version
+
+        hipcc_version = get_hipcc_rocm_version()
+        if hipcc_version != MAIN_CUDA_VERSION:
+            rocm_version_str = hipcc_version.replace(".", "")[:3]
+            version += f"+rocm{rocm_version_str}"
+    else:
+        cuda_version = str(nvcc_cuda_version)
+        # Split the version into numerical and suffix parts
+        version_parts = version.split('-')
+        version_num = version_parts[0]
+        version_suffix = version_parts[1] if len(version_parts) > 1 else ''
+        
+        if cuda_version != MAIN_CUDA_VERSION:
+            cuda_version_str = cuda_version.replace(".", "")[:3]
+            version_num += f"+cu{cuda_version_str}"
+        
+        # Reassemble the version string with the suffix, if any
+        version = version_num + ('-' + version_suffix if version_suffix else '')
+        
+        return version
 
 
 def read_readme() -> str:
@@ -216,8 +296,12 @@ def read_readme() -> str:
 
 def get_requirements() -> List[str]:
     """Get Python package dependencies from requirements.txt."""
-    with open(get_path("requirements.txt")) as f:
-        requirements = f.read().strip().split("\n")
+    if _is_hip():
+        with open(get_path("requirements-rocm.txt")) as f:
+            requirements = f.read().strip().split("\n")
+    else:
+        with open(get_path("requirements.txt")) as f:
+            requirements = f.read().strip().split("\n")
     return requirements
 
 
@@ -251,4 +335,4 @@ setuptools.setup(
     ext_modules=ext_modules,
     cmdclass={"build_ext": BuildExtension},
     package_data={"aphrodite-engine": ["py.typed"]},
-)
+)