Browse Source

fix: navi support (#283)

* fix: arch checks for rocm

* use flash attention for rocm

* update patch and dockerfile

* formatting

Co-Authored-By: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>

---------

Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
AlpinDale 1 year ago
parent
commit
13d850334e

+ 43 - 0
aphrodite/modeling/layers/attention.py

@@ -1,6 +1,7 @@
 """Multi-head attention."""
 from typing import List, Optional
 
+import importlib
 import torch
 import torch.nn as nn
 from xformers import ops as xops
@@ -58,6 +59,40 @@ class PagedAttention(nn.Module):
             raise ValueError(f"head_size ({self.head_size}) is not supported. "
                              f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
 
+        self.use_ref_attention = self.check_use_ref_attention()
+
+    def check_use_ref_attention(self) -> bool:
+        if not is_hip():
+            return False
+        # For ROCm, check whether flash attention is installed or not.
+        # if not, use_ref_attention needs to be True
+        return importlib.util.find_spec("flash_attn") is None
+
+    def ref_masked_attention(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+    ) -> torch.Tensor:
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        seq_len, _, _ = query.shape
+        attn_mask = torch.triu(torch.ones(seq_len,
+                                          seq_len,
+                                          dtype=query.dtype,
+                                          device=query.device),
+                               diagonal=1)
+        attn_mask = attn_mask * torch.finfo(query.dtype).min
+
+        attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
+                                                 key).float()
+        attn_weights = attn_weights + attn_mask.float()
+        attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
+        out = torch.einsum("hqk,khd->qhd", attn_weights, value)
+        return out
+
     def forward(
         self,
         query: torch.Tensor,
@@ -137,6 +172,14 @@ class PagedAttention(nn.Module):
                             self.alibi_slopes, self.num_kv_heads, batch_size,
                             seq_len, query.dtype)
 
+                if self.use_ref_attention:
+                    output = self.ref_masked_attention(
+                        query,
+                        key,
+                        value,
+                    )
+                    return output.reshape(batch_size, seq_len, hidden_size)
+
                 # TODO: Too many view operations. Let's try to reduce
                 # them in the future for code readability.
                 if self.alibi_slopes is None:

+ 39 - 9
docker/Dockerfile.rocm

@@ -1,4 +1,24 @@
-FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
+# default base image
+ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
+
+FROM $BASE_IMAGE
+
+RUN echo "Base image is $BASE_IMAGE"
+
+# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
+# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
+
+
+ARG FA_GFX_ARCHS="gfx90a;gfx942;gfx1100"
+RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
+
+ARG FA_BRANCH="3d2b6f5"
+RUN echo "FA_BRANCH is $FA_BRANCH"
+
+# whether to build flash-attention
+# if 0, will not build flash attention
+# this is useful for gfx target where flash-attention is not supported
+ARG BUILD_FA="1"
 
 # Install some basic utilities
 RUN apt-get update && apt-get install python3 python3-pip -y
@@ -33,26 +53,36 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
 ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
 
 # Install ROCm flash-attention
-RUN mkdir libs \
+RUN if [ "$BUILD_FA" = "1" ]; then \
+    mkdir libs \
     && cd libs \
-    && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
+    && git clone https://github.com/ROCm/flash-attention.git \
     && cd flash-attention \
-    && git checkout 3d2b6f5 \
+    && git checkout ${FA_BRANCH} \
     && git submodule update --init \
-    && export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
-    && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
+    && export GPU_ARCHS=${FA_GFX_ARCHS} \
+    && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
+        patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
     && python3 setup.py install \
-    && cd ..
+    && cd ..; \
+    fi
 
 COPY ./ /app/aphrodite-engine
 
 RUN python3 -m pip install --upgrade pip
-RUN pip install xformers==0.0.22.post7 --no-deps
+RUN python3 -m pip install xformers==0.0.23 --no-deps
+
+# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
+# Manually removed it so that later steps of numpy upgrade can continue
+RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
+    rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
 
 RUN cd /app \
     && cd aphrodite-engine \
     && pip install -U -r requirements-rocm.txt \
-    && bash patch_xformers-0.0.22.post7.rocm.sh \
+    && if [ "$BUILD_FA" = "1" ]; then \
+       bash patch_xformers.rocm.sh; fi \
+    && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/aphrodite-engine/rocm_patch/rocm_bf16.patch \
     && python3 setup.py install \
     && cd ..
 

+ 17 - 6
patch_xformers-0.0.22.post7.rocm.sh → patch_xformers.rocm.sh

@@ -1,21 +1,32 @@
 #!/bin/bash
+set -e
+
+XFORMERS_VERSION="0.0.23"
+
+export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
+
+if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
+    echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
+    exit 1
+fi
+
 export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
 export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
 
-echo $XFORMERS_FMHA_FLASH_PATH
-echo $XFORMERS_FMHA_COMMON_PATH
+echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
+echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
 
-if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then
+if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
     echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
-    patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"
+    patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
     echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
 else
     echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
 fi
 
-if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then
+if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
     echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
-    patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"
+    patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
     echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
 else
     echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"

+ 0 - 0
rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch → rocm_patch/commonpy_xformers-0.0.23.rocm.patch


+ 57 - 39
rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch → rocm_patch/flashpy_xformers-0.0.23.rocm.patch

@@ -1,6 +1,6 @@
---- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py	2023-11-29 03:17:03.930103539 +0000
-+++ flash.py	2023-11-28 16:14:25.206128903 +0000
-@@ -31,39 +31,39 @@
+--- flash_ori.py	2023-12-13 05:43:31.530752623 +0000
++++ flash_patch.py	2023-12-13 06:00:45.962403104 +0000
+@@ -36,44 +36,44 @@
  
  FLASH_VERSION = "0.0.0"
  try:
@@ -15,9 +15,12 @@
 -        from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
 -
 -        FLASH_VERSION = flash_attn.__version__
--        flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
--        if flash_ver_parsed < (2, 3):
--            raise ImportError("Requires 2.3 for sliding window support")
+-        flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
+-        if (
+-            flash_ver_parsed != (2, 3, 6)
+-            and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
+-        ):
+-            raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
 +    #try:
 +    #    from ... import _C_flashattention  # type: ignore[attr-defined]
 +    #    from ..._cpp_lib import _build_metadata
@@ -29,35 +32,41 @@
 +    from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
 +
 +    FLASH_VERSION = flash_attn.__version__
-+    #    flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
-+    #    if flash_ver_parsed < (2, 3):
-+    #        raise ImportError("Requires 2.3 for sliding window support")
++    #    flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
++    #    if (
++    #        flash_ver_parsed != (2, 3, 6)
++    #        and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
++    #    ):
++    #        raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
  
      # create library so that flash-attn goes through the PyTorch Dispatcher
 -    _flash_lib = torch.library.Library("xformers_flash", "DEF")
-+    #_flash_lib = torch.library.Library("xformers_flash", "DEF")
- 
+-
 -    _flash_lib.define(
 -        "flash_fwd(Tensor query, Tensor key, Tensor value, "
--        "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
+-        "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
 -        "int max_seqlen_q, int max_seqlen_k, "
 -        "float p, float softmax_scale, "
--        "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+-        "bool is_causal, int window_left, "
+-        "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
 -    )
--
++    #_flash_lib = torch.library.Library("xformers_flash", "DEF")
+ 
 -    _flash_lib.define(
 -        "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
 -        "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
 -        "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
 -        "int max_seqlen_q, int max_seqlen_k, "
--        "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+-        "float p, float softmax_scale, bool is_causal, "
+-        "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
 -    )
 +    #_flash_lib.define(
 +    #    "flash_fwd(Tensor query, Tensor key, Tensor value, "
-+    #    "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
++    #    "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
 +    #    "int max_seqlen_q, int max_seqlen_k, "
 +    #    "float p, float softmax_scale, "
-+    #    "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
++    #    "bool is_causal, int window_left, "
++    #    "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
 +    #)
 +
 +    #_flash_lib.define(
@@ -65,52 +74,61 @@
 +    #    "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
 +    #    "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
 +    #    "int max_seqlen_q, int max_seqlen_k, "
-+    #    "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
++    #    "float p, float softmax_scale, bool is_causal, "
++    #    "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
 +    #)
  
      def _flash_fwd(
          query,
-@@ -98,8 +98,8 @@
+@@ -111,8 +111,8 @@
                  p,
                  softmax_scale,
                  is_causal,
--                window_size - 1,  # window_size_left
--                -1,  # window_size_right
-+        #        window_size - 1,  # window_size_left
-+        #        -1,  # window_size_right
+-                window_left,  # window_size_left
+-                window_right,  # window_size_right
++        #        window_left,  # window_size_left
++        #        window_right,  # window_size_right
                  return_softmax,
                  None,  # rng
              )
-@@ -127,8 +127,8 @@
+@@ -134,15 +134,15 @@
+                 out,
+                 cu_seq_lens_q,
+                 cu_seq_lens_k,
+-                seqused_k,
++         #       seqused_k,
+                 max_seq_len_q,
+                 max_seq_len_k,
+                 p,
                  softmax_scale,
                  False,
                  is_causal,
--                window_size - 1,  # window_size_left
--                -1,  # window_size_right
-+         #       window_size - 1,  # window_size_left
-+         #       -1,  # window_size_right
+-                window_left,
+-                window_right,
++         #       window_left,
++         #       window_right,
                  return_softmax,
                  None,
              )
-@@ -169,8 +169,8 @@
+@@ -184,8 +184,8 @@
                  p,
                  softmax_scale,
                  is_causal,
--                window_size - 1,  # window_size_left
--                -1,  # window_size_right
-+        #        window_size - 1,  # window_size_left
-+        #        -1,  # window_size_right
+-                window_left,
+-                window_right,
++        #        window_left,
++        #        window_right,
                  None,
                  rng_state,
              )
-@@ -193,15 +193,15 @@
+@@ -208,15 +208,15 @@
                  softmax_scale,
                  False,  # zero_tensors
                  is_causal,
--                window_size - 1,  # window_size_left
--                -1,  # window_size_right
-+        #        window_size - 1,  # window_size_left
-+        #        -1,  # window_size_right
+-                window_left,
+-                window_right,
++        #        window_left,
++        #        window_right,
                  None,
                  rng_state,
              )
@@ -123,7 +141,7 @@
  except ImportError:
      pass
  
-@@ -348,7 +348,7 @@
+@@ -400,7 +400,7 @@
          implementation.
      """
  

+ 15 - 0
rocm_patch/rocm_bf16.patch

@@ -0,0 +1,15 @@
+--- amd_hip_bf16.h	2024-02-06 18:28:58.268699142 +0000
++++ amd_hip_bf16.h.new	2024-02-06 18:28:31.988647133 +0000
+@@ -90,10 +90,10 @@
+ #include "math_fwd.h"              // ocml device functions
+ 
+ #if defined(__HIPCC_RTC__)
+-#define __HOST_DEVICE__ __device__
++#define __HOST_DEVICE__ __device__ static
+ #else
+ #include <climits>
+-#define __HOST_DEVICE__ __host__ __device__
++#define __HOST_DEVICE__ __host__ __device__ static inline
+ #endif
+ 
+ // Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on

+ 40 - 28
setup.py

@@ -22,9 +22,7 @@ MAIN_CUDA_VERSION = "12.1"
 NVIDIA_SUPPORTED_ARCHS = {
     "6.1", "7.0", "7.5", "8.0", "8.6", "8.9", "9.0"
 }
-ROCM_SUPPORTED_ARCHS = {
-    "gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"
-}
+ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942", "gfx1100"}
 
 def _is_hip() -> bool:
     return torch.version.hip is not None
@@ -54,21 +52,6 @@ ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
 CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
 NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
 
-def get_amdgpu_offload_arch():
-    command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
-    try:
-        output = subprocess.check_output([command])
-        return output.decode('utf-8').strip()
-    except subprocess.CalledProcessError as e:
-        error_message = f"Error: {e}"
-        raise RuntimeError(error_message) from e
-    except FileNotFoundError as e:
-        # If the command is not found, print an error message
-        error_message = f"The command {command} was not found."
-        raise RuntimeError(error_message) from e
-
-    return None
-
 
 def get_hipcc_rocm_version():
     # Run the hipcc --version command
@@ -107,6 +90,39 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version:
     nvcc_cuda_version = parse(output[release_idx].split(",")[0])
     return nvcc_cuda_version
 
+def get_pytorch_rocm_arch() -> Set[str]:
+    env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)
+
+    # If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
+    if env_arch_list is None:
+        command = "rocm_agent_enumerator"
+        env_arch_list = subprocess.check_output([command]).decode('utf-8')\
+                        .strip().replace("\n", ";")
+        arch_source_str = "rocm_agent_enumerator"
+    else:
+        arch_source_str = "PYTORCH_ROCM_ARCH env variable"
+
+    # List are separated by ; or space.
+    pytorch_rocm_arch = set(env_arch_list.replace(" ", ";").split(";"))
+
+    # Filter out the invalid architectures and print a warning.
+    arch_list = pytorch_rocm_arch.intersection(ROCM_SUPPORTED_ARCHS)
+
+    # If none of the specified architectures are valid, raise an error.
+    if not arch_list:
+        raise RuntimeError(
+            f"None of the ROCM architectures in {arch_source_str} "
+            f"({env_arch_list}) is supported. "
+            f"Supported ROCM architectures are: {ROCM_SUPPORTED_ARCHS}.")
+    invalid_arch_list = pytorch_rocm_arch - ROCM_SUPPORTED_ARCHS
+    if invalid_arch_list:
+        warnings.warn(
+            f"Unsupported ROCM architectures ({invalid_arch_list}) are "
+            f"excluded from the {arch_source_str} output "
+            f"({env_arch_list}). Supported ROCM architectures are: "
+            f"{ROCM_SUPPORTED_ARCHS}.",
+            stacklevel=2)
+    return arch_list
 
 def get_torch_arch_list() -> Set[str]:
     # TORCH_CUDA_ARCH_LIST can have one or more architectures,
@@ -146,8 +162,12 @@ def get_torch_arch_list() -> Set[str]:
     return arch_list
 
 
-# First, check the TORCH_CUDA_ARCH_LIST environment variable.
-compute_capabilities = get_torch_arch_list()
+if _is_hip():
+    rocm_arches = get_pytorch_rocm_arch()
+    NVCC_FLAGS += ["--offload-arch=" + arch for arch in rocm_arches]
+else:
+    # First, check the TORCH_CUDA_ARCH_LIST environment variable.
+    compute_capabilities = get_torch_arch_list()
 if _is_cuda() and not compute_capabilities:
     # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
     # GPUs on the current machine.
@@ -276,14 +296,6 @@ if _is_cuda():
                 },
             ))
 
-elif _is_hip():
-    amd_arch = get_amdgpu_offload_arch()
-    if amd_arch not in ROCM_SUPPORTED_ARCHS:
-        raise RuntimeError(
-            f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
-            f"amdgpu_arch_found: {amd_arch}")
-
-
 aphrodite_extension_sources = [
     "kernels/cache_kernels.cu",
     "kernels/attention/attention_kernels.cu",