patch_xformers.rocm.sh 1.5 KB

123456789101112131415161718192021222324252627282930313233
  1. #!/bin/bash
  2. set -e
  3. XFORMERS_VERSION="0.0.23"
  4. export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
  5. if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
  6. echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
  7. exit 1
  8. fi
  9. export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
  10. export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
  11. echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
  12. echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
  13. if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
  14. echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
  15. patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
  16. echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
  17. else
  18. echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
  19. fi
  20. if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
  21. echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
  22. patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
  23. echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
  24. else
  25. echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
  26. fi