weight_utils.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import os
  2. import tempfile
  3. import huggingface_hub.constants
  4. import pytest
  5. from huggingface_hub.utils import LocalEntryNotFoundError
  6. from aphrodite.modeling.model_loader.weight_utils import (
  7. download_weights_from_hf, enable_hf_transfer)
  8. def test_hf_transfer_auto_activation():
  9. if "HF_HUB_ENABLE_HF_TRANSFER" in os.environ:
  10. # in case it is already set, we can't test the auto activation
  11. pytest.skip(
  12. "HF_HUB_ENABLE_HF_TRANSFER is set, can't test auto activation")
  13. enable_hf_transfer()
  14. try:
  15. # enable hf hub transfer if available
  16. import hf_transfer # type: ignore # noqa
  17. HF_TRANFER_ACTIVE = True
  18. except ImportError:
  19. HF_TRANFER_ACTIVE = False
  20. assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER ==
  21. HF_TRANFER_ACTIVE)
  22. def test_download_weights_from_hf():
  23. with tempfile.TemporaryDirectory() as tmpdir:
  24. # assert LocalEntryNotFoundError error is thrown
  25. # if offline is set and model is not cached
  26. huggingface_hub.constants.HF_HUB_OFFLINE = True
  27. with pytest.raises(LocalEntryNotFoundError):
  28. download_weights_from_hf("facebook/opt-125m",
  29. allow_patterns=["*.safetensors", "*.bin"],
  30. cache_dir=tmpdir)
  31. # download the model
  32. huggingface_hub.constants.HF_HUB_OFFLINE = False
  33. download_weights_from_hf("facebook/opt-125m",
  34. allow_patterns=["*.safetensors", "*.bin"],
  35. cache_dir=tmpdir)
  36. # now it should work offline
  37. huggingface_hub.constants.HF_HUB_OFFLINE = True
  38. assert download_weights_from_hf(
  39. "facebook/opt-125m",
  40. allow_patterns=["*.safetensors", "*.bin"],
  41. cache_dir=tmpdir) is not None
  42. if __name__ == "__main__":
  43. test_hf_transfer_auto_activation()
  44. test_download_weights_from_hf()