123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185 |
- # Default ROCm 6.1 base image
- ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
- # Default ROCm ARCHes to build Aphrodite for.
- ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
- # Whether to install CK-based flash-attention
- # If 0, will not install flash-attention
- ARG BUILD_FA="1"
- # If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
- # If this succeeds, we use the downloaded wheel and skip building flash-attention.
- # Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
- # architectures specified in `FA_GFX_ARCHS`
- ARG TRY_FA_WHEEL="1"
- ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
- ARG FA_GFX_ARCHS="gfx90a;gfx942"
- ARG FA_BRANCH="23a2b1c2"
- # Whether to build triton on rocm
- ARG BUILD_TRITON="1"
- ARG TRITON_BRANCH="e0fc12c"
- ### Base image build stage
- FROM $BASE_IMAGE AS base
- # Import arg(s) defined before this build stage
- ARG PYTORCH_ROCM_ARCH
- # Install some basic utilities
- RUN apt-get update && apt-get install python3 python3-pip -y
- RUN apt-get update && apt-get install -y \
- curl \
- ca-certificates \
- sudo \
- git \
- bzip2 \
- libx11-6 \
- build-essential \
- wget \
- unzip \
- tmux \
- ccache \
- && rm -rf /var/lib/apt/lists/*
- # When launching the container, mount the code directory to /aphrodite-workspace
- ARG APP_MOUNT=/aphrodite-workspace
- WORKDIR ${APP_MOUNT}
- RUN python3 -m pip install --upgrade pip
- # Remove sccache so it doesn't interfere with ccache
- # TODO: implement sccache support across components
- RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
- # Install torch == 2.5.0 on ROCm
- RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
- *"rocm-6.1"*) \
- python3 -m pip uninstall -y torch torchvision \
- && python3 -m pip install --no-cache-dir --pre \
- torch==2.5.0.dev20240726 \
- torchvision==0.20.0.dev20240726 \
- --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
- *) ;; esac
- ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
- ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
- ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
- ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
- ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
- ENV CCACHE_DIR=/root/.cache/ccache
- ### AMD-SMI build stage
- FROM base AS build_amdsmi
- # Build amdsmi wheel always
- RUN cd /opt/rocm/share/amd_smi \
- && python3 -m pip wheel . --wheel-dir=/install
- ### Flash-Attention wheel build stage
- FROM base AS build_fa
- ARG BUILD_FA
- ARG TRY_FA_WHEEL
- ARG FA_WHEEL_URL
- ARG FA_GFX_ARCHS
- ARG FA_BRANCH
- # Build ROCm flash-attention wheel if `BUILD_FA = 1`
- RUN --mount=type=cache,target=${CCACHE_DIR} \
- if [ "$BUILD_FA" = "1" ]; then \
- if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
- # If a suitable wheel exists, we download it instead of building FA
- mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
- else \
- mkdir -p libs \
- && cd libs \
- && git clone https://github.com/ROCm/flash-attention.git \
- && cd flash-attention \
- && git checkout "${FA_BRANCH}" \
- && git submodule update --init \
- && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
- fi; \
- # Create an empty directory otherwise as later build stages expect one
- else mkdir -p /install; \
- fi
- ### Triton wheel build stage
- FROM base AS build_triton
- ARG BUILD_TRITON
- ARG TRITON_BRANCH
- # Build triton wheel if `BUILD_TRITON = 1`
- RUN --mount=type=cache,target=${CCACHE_DIR} \
- if [ "$BUILD_TRITON" = "1" ]; then \
- mkdir -p libs \
- && cd libs \
- && git clone https://github.com/OpenAI/triton.git \
- && cd triton \
- && git checkout "${TRITON_BRANCH}" \
- && cd python \
- && python3 setup.py bdist_wheel --dist-dir=/install; \
- # Create an empty directory otherwise as later build stages expect one
- else mkdir -p /install; \
- fi
- ### Final Aphrodite build stage
- FROM base AS final
- # Import the Aphrodite development directory from the build context
- COPY . .
- # Package upgrades for useful functionality or to avoid dependency issues
- RUN --mount=type=cache,target=/root/.cache/pip \
- python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
- # Clang needs to be patched for ROCm
- RUN chmod +x patches/amd.patch
- RUN patch -p1 < patches/amd.patch
- # Workaround for ray >= 2.10.0
- ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
- # Silences the HF Tokenizers warning
- ENV TOKENIZERS_PARALLELISM=false
- RUN --mount=type=cache,target=${CCACHE_DIR} \
- --mount=type=cache,target=/root/.cache/pip \
- python3 -m pip install -Ur requirements-rocm.txt \
- && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
- *"rocm-6.1"*) \
- # Bring in upgrades to HIP graph earlier than ROCm 6.2 for Aphrodite
- wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \
- # Prevent interference if torch bundles its own HIP runtime
- && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
- *) ;; esac \
- && python3 setup.py clean --all \
- && python3 setup.py develop
- # Copy amdsmi wheel into final image
- RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
- mkdir -p libs \
- && cp /install/*.whl libs \
- # Preemptively uninstall to avoid same-version no-installs
- && python3 -m pip uninstall -y amdsmi;
- # Copy triton wheel(s) into final image if they were built
- RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
- mkdir -p libs \
- && if ls /install/*.whl; then \
- cp /install/*.whl libs \
- # Preemptively uninstall to avoid same-version no-installs
- && python3 -m pip uninstall -y triton; fi
- # Copy flash-attn wheel(s) into final image if they were built
- RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
- mkdir -p libs \
- && if ls /install/*.whl; then \
- cp /install/*.whl libs \
- # Preemptively uninstall to avoid same-version no-installs
- && python3 -m pip uninstall -y flash-attn; fi
- # Install wheels that were built to the final image
- RUN --mount=type=cache,target=/root/.cache/pip \
- if ls libs/*.whl; then \
- python3 -m pip install libs/*.whl; fi
- CMD ["/bin/bash"]
|