Dockerfile.rocm 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # Default ROCm 6.1 base image
  2. ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
  3. # Default ROCm ARCHes to build Aphrodite for.
  4. ARG PYTORCH_ROCM_ARCH="gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100"
  5. # Whether to install CK-based flash-attention
  6. # If 0, will not install flash-attention
  7. ARG BUILD_FA="1"
  8. # If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
  9. # If this succeeds, we use the downloaded wheel and skip building flash-attention.
  10. # Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
  11. # architectures specified in `FA_GFX_ARCHS`
  12. ARG TRY_FA_WHEEL="1"
  13. ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
  14. ARG FA_GFX_ARCHS="gfx90a;gfx942"
  15. ARG FA_BRANCH="23a2b1c2"
  16. # Whether to build triton on rocm
  17. ARG BUILD_TRITON="1"
  18. ARG TRITON_BRANCH="e0fc12c"
  19. ### Base image build stage
  20. FROM $BASE_IMAGE AS base
  21. # Import arg(s) defined before this build stage
  22. ARG PYTORCH_ROCM_ARCH
  23. # Install some basic utilities
  24. RUN apt-get update && apt-get install python3 python3-pip -y
  25. RUN apt-get update && apt-get install -y \
  26. curl \
  27. ca-certificates \
  28. sudo \
  29. git \
  30. bzip2 \
  31. libx11-6 \
  32. build-essential \
  33. wget \
  34. unzip \
  35. tmux \
  36. ccache \
  37. && rm -rf /var/lib/apt/lists/*
  38. # When launching the container, mount the code directory to /aphrodite-workspace
  39. ARG APP_MOUNT=/aphrodite-workspace
  40. WORKDIR ${APP_MOUNT}
  41. RUN python3 -m pip install --upgrade pip
  42. # Remove sccache so it doesn't interfere with ccache
  43. # TODO: implement sccache support across components
  44. RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
  45. # Install torch == 2.5.0 on ROCm
  46. RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
  47. *"rocm-6.1"*) \
  48. python3 -m pip uninstall -y torch torchaudio torchvision \
  49. && python3 -m pip install --no-cache-dir --pre \
  50. torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
  51. torchvision==0.20.0.dev20240710 \
  52. --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
  53. *) ;; esac
  54. ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
  55. ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
  56. ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
  57. ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
  58. ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
  59. ENV CCACHE_DIR=/root/.cache/ccache
  60. ### AMD-SMI build stage
  61. FROM base AS build_amdsmi
  62. # Build amdsmi wheel always
  63. RUN cd /opt/rocm/share/amd_smi \
  64. && python3 -m pip wheel . --wheel-dir=/install
  65. ### Flash-Attention wheel build stage
  66. FROM base AS build_fa
  67. ARG BUILD_FA
  68. ARG TRY_FA_WHEEL
  69. ARG FA_WHEEL_URL
  70. ARG FA_GFX_ARCHS
  71. ARG FA_BRANCH
  72. # Build ROCm flash-attention wheel if `BUILD_FA = 1`
  73. RUN --mount=type=cache,target=${CCACHE_DIR} \
  74. if [ "$BUILD_FA" = "1" ]; then \
  75. if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
  76. # If a suitable wheel exists, we download it instead of building FA
  77. mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
  78. else \
  79. mkdir -p libs \
  80. && cd libs \
  81. && git clone https://github.com/ROCm/flash-attention.git \
  82. && cd flash-attention \
  83. && git checkout "${FA_BRANCH}" \
  84. && git submodule update --init \
  85. && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
  86. fi; \
  87. # Create an empty directory otherwise as later build stages expect one
  88. else mkdir -p /install; \
  89. fi
  90. ### Triton wheel build stage
  91. FROM base AS build_triton
  92. ARG BUILD_TRITON
  93. ARG TRITON_BRANCH
  94. # Build triton wheel if `BUILD_TRITON = 1`
  95. RUN --mount=type=cache,target=${CCACHE_DIR} \
  96. if [ "$BUILD_TRITON" = "1" ]; then \
  97. mkdir -p libs \
  98. && cd libs \
  99. && git clone https://github.com/OpenAI/triton.git \
  100. && cd triton \
  101. && git checkout "${TRITON_BRANCH}" \
  102. && cd python \
  103. && python3 setup.py bdist_wheel --dist-dir=/install; \
  104. # Create an empty directory otherwise as later build stages expect one
  105. else mkdir -p /install; \
  106. fi
  107. ### Final Aphrodite build stage
  108. FROM base AS final
  109. # Import the Aphrodite development directory from the build context
  110. COPY . .
  111. # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
  112. # Manually remove it so that later steps of numpy upgrade can continue
  113. RUN case "$(which python3)" in \
  114. *"/opt/conda/envs/py_3.9"*) \
  115. rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
  116. *) ;; esac
  117. # Package upgrades for useful functionality or to avoid dependency issues
  118. RUN --mount=type=cache,target=/root/.cache/pip \
  119. python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
  120. # TODO: Figure out what's wrong with punica kernels on ROCm.
  121. ENV APHRODITE_INSTALL_PUNICA_KERNELS=0
  122. # Workaround for ray >= 2.10.0
  123. ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
  124. # Silences the HF Tokenizers warning
  125. ENV TOKENIZERS_PARALLELISM=false
  126. RUN --mount=type=cache,target=${CCACHE_DIR} \
  127. --mount=type=cache,target=/root/.cache/pip \
  128. python3 -m pip install -Ur requirements-rocm.txt \
  129. && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
  130. *"rocm-6.1"*) \
  131. # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
  132. wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \
  133. # Prevent interference if torch bundles its own HIP runtime
  134. && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
  135. *) ;; esac \
  136. && python3 setup.py clean --all \
  137. && python3 setup.py develop
  138. # Copy amdsmi wheel into final image
  139. RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
  140. mkdir -p libs \
  141. && cp /install/*.whl libs \
  142. # Preemptively uninstall to avoid same-version no-installs
  143. && python3 -m pip uninstall -y amdsmi;
  144. # Copy triton wheel(s) into final image if they were built
  145. RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
  146. mkdir -p libs \
  147. && if ls /install/*.whl; then \
  148. cp /install/*.whl libs \
  149. # Preemptively uninstall to avoid same-version no-installs
  150. && python3 -m pip uninstall -y triton; fi
  151. # Copy flash-attn wheel(s) into final image if they were built
  152. RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
  153. mkdir -p libs \
  154. && if ls /install/*.whl; then \
  155. cp /install/*.whl libs \
  156. # Preemptively uninstall to avoid same-version no-installs
  157. && python3 -m pip uninstall -y flash-attn; fi
  158. # Install wheels that were built to the final image
  159. RUN --mount=type=cache,target=/root/.cache/pip \
  160. if ls libs/*.whl; then \
  161. python3 -m pip install libs/*.whl; fi
  162. CMD ["/bin/bash"]