1
0

Dockerfile.rocm 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Default ROCm 6.1 base image
  2. ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
  3. # Tested and supported base rocm/pytorch images
  4. ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
  5. ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
  6. ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
  7. # Default ROCm ARCHes to build Aphrodite for.
  8. ARG PYTORCH_ROCM_ARCH="gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100"
  9. # Whether to build CK-based flash-attention
  10. # If 0, will not build flash attention
  11. # This is useful for gfx target where flash-attention is not supported
  12. # (i.e. those that do not appear in `FA_GFX_ARCHS`)
  13. # Triton FA is used by default on ROCm now so this is unnecessary.
  14. ARG BUILD_FA="1"
  15. ARG FA_GFX_ARCHS="gfx90a;gfx942"
  16. ARG FA_BRANCH="ae7928c"
  17. # Whether to build triton on rocm
  18. ARG BUILD_TRITON="1"
  19. ARG TRITON_BRANCH="0ef1848"
  20. ### Base image build stage
  21. FROM $BASE_IMAGE AS base
  22. # Import arg(s) defined before this build stage
  23. ARG PYTORCH_ROCM_ARCH
  24. # Install some basic utilities
  25. RUN apt-get update && apt-get install python3 python3-pip -y
  26. RUN apt-get update && apt-get install -y \
  27. curl \
  28. ca-certificates \
  29. sudo \
  30. git \
  31. bzip2 \
  32. libx11-6 \
  33. build-essential \
  34. wget \
  35. unzip \
  36. tmux \
  37. ccache \
  38. && rm -rf /var/lib/apt/lists/*
  39. # When launching the container, mount the code directory to /aphrodite-workspace
  40. ARG APP_MOUNT=/aphrodite-workspace
  41. WORKDIR ${APP_MOUNT}
  42. RUN pip install --upgrade pip
  43. # Remove sccache so it doesn't interfere with ccache
  44. # TODO: implement sccache support across components
  45. RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
  46. # Install torch == 2.4.0 on ROCm
  47. RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
  48. *"rocm-5.7"*) \
  49. pip uninstall -y torch torchaudio torchvision \
  50. && pip install --no-cache-dir --pre \
  51. torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \
  52. torchvision==0.19.0.dev20240612 \
  53. --index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
  54. *"rocm-6.0"*) \
  55. pip uninstall -y torch torchaudio torchvision \
  56. && pip install --no-cache-dir --pre \
  57. torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \
  58. torchvision==0.19.0.dev20240612 \
  59. --index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
  60. *"rocm-6.1"*) \
  61. pip uninstall -y torch torchaudio torchvision \
  62. && pip install --no-cache-dir --pre \
  63. torch==2.4.0.dev20240612 torchaudio==2.4.0.dev20240612 \
  64. torchvision==0.19.0.dev20240612 \
  65. --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
  66. *) ;; esac
  67. ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
  68. ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
  69. ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
  70. ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
  71. ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH}
  72. ENV CCACHE_DIR=/root/.cache/ccache
  73. ### AMD-SMI build stage
  74. FROM base AS build_amdsmi
  75. # Build amdsmi wheel always
  76. RUN cd /opt/rocm/share/amd_smi \
  77. && pip wheel . --wheel-dir=/install
  78. ### Flash-Attention wheel build stage
  79. FROM base AS build_fa
  80. ARG BUILD_FA
  81. ARG FA_GFX_ARCHS
  82. ARG FA_BRANCH
  83. # Build ROCm flash-attention wheel if `BUILD_FA = 1`
  84. RUN --mount=type=cache,target=${CCACHE_DIR} \
  85. if [ "$BUILD_FA" = "1" ]; then \
  86. mkdir -p libs \
  87. && cd libs \
  88. && git clone https://github.com/ROCm/flash-attention.git \
  89. && cd flash-attention \
  90. && git checkout "${FA_BRANCH}" \
  91. && git submodule update --init \
  92. && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
  93. *"rocm-5.7"*) \
  94. export APHRODITE_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
  95. && patch "${APHRODITE_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
  96. *) ;; esac \
  97. && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
  98. # Create an empty directory otherwise as later build stages expect one
  99. else mkdir -p /install; \
  100. fi
  101. ### Triton wheel build stage
  102. FROM base AS build_triton
  103. ARG BUILD_TRITON
  104. ARG TRITON_BRANCH
  105. # Build triton wheel if `BUILD_TRITON = 1`
  106. RUN --mount=type=cache,target=${CCACHE_DIR} \
  107. if [ "$BUILD_TRITON" = "1" ]; then \
  108. mkdir -p libs \
  109. && cd libs \
  110. && git clone https://github.com/OpenAI/triton.git \
  111. && cd triton \
  112. && git checkout "${TRITON_BRANCH}" \
  113. && cd python \
  114. && python3 setup.py bdist_wheel --dist-dir=/install; \
  115. # Create an empty directory otherwise as later build stages expect one
  116. else mkdir -p /install; \
  117. fi
  118. ### Final Aphrodite build stage
  119. FROM base AS final
  120. # Import the Aphrodite development directory from the build context
  121. COPY . .
  122. # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
  123. # Manually remove it so that later steps of numpy upgrade can continue
  124. RUN case "$(which python3)" in \
  125. *"/opt/conda/envs/py_3.9"*) \
  126. rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
  127. *) ;; esac
  128. # Package upgrades for useful functionality or to avoid dependency issues
  129. RUN --mount=type=cache,target=/root/.cache/pip \
  130. pip install --upgrade numba scipy huggingface-hub[cli]
  131. # Make sure punica kernels are built (for LoRA)
  132. ENV APHRODITE_INSTALL_PUNICA_KERNELS=1
  133. # Workaround for ray >= 2.10.0
  134. ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
  135. # Silences the HF Tokenizers warning
  136. ENV TOKENIZERS_PARALLELISM=false
  137. RUN --mount=type=cache,target=${CCACHE_DIR} \
  138. --mount=type=cache,target=/root/.cache/pip \
  139. pip install -U -r requirements-rocm.txt \
  140. && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
  141. *"rocm-6.0"*) \
  142. patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
  143. *"rocm-6.1"*) \
  144. # Bring in upgrades to HIP graph earlier than ROCm 6.2 for Aphrodite
  145. wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
  146. && cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
  147. # Prevent interference if torch bundles its own HIP runtime
  148. && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
  149. *) ;; esac \
  150. && python3 setup.py clean --all \
  151. && python3 setup.py develop
  152. # Copy amdsmi wheel into final image
  153. RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
  154. mkdir -p libs \
  155. && cp /install/*.whl libs \
  156. # Preemptively uninstall to avoid same-version no-installs
  157. && pip uninstall -y amdsmi;
  158. # Copy triton wheel(s) into final image if they were built
  159. RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
  160. mkdir -p libs \
  161. && if ls /install/*.whl; then \
  162. cp /install/*.whl libs \
  163. # Preemptively uninstall to avoid same-version no-installs
  164. && pip uninstall -y triton; fi
  165. # Copy flash-attn wheel(s) into final image if they were built
  166. RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
  167. mkdir -p libs \
  168. && if ls /install/*.whl; then \
  169. cp /install/*.whl libs \
  170. # Preemptively uninstall to avoid same-version no-installs
  171. && pip uninstall -y flash-attn; fi
  172. # Install wheels that were built to the final image
  173. RUN --mount=type=cache,target=/root/.cache/pip \
  174. if ls libs/*.whl; then \
  175. pip install libs/*.whl; fi
  176. CMD ["/bin/bash"]