utils.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. # code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420
  2. import os
  3. from typing import Optional
  4. import torch
  5. from huggingface_hub import file_exists, hf_hub_download
  6. from huggingface_hub.utils import EntryNotFoundError
  7. from safetensors.torch import load_file as safe_load_file
  8. WEIGHTS_NAME = "adapter_model.bin"
  9. SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors"
  10. # Get current device name based on available devices
  11. def infer_device() -> str:
  12. if torch.cuda.is_available():
  13. return "cuda"
  14. return "cpu"
  15. def load_peft_weights(
  16. model_id: str, device: Optional[str] = None, **hf_hub_download_kwargs
  17. ) -> dict:
  18. r"""
  19. A helper method to load the PEFT weights from the HuggingFace Hub or locally
  20. Args:
  21. model_id (`str`):
  22. The local path to the adapter weights or the name of the adapter to
  23. load from the HuggingFace Hub.
  24. device (`str`):
  25. The device to load the weights onto.
  26. hf_hub_download_kwargs (`dict`):
  27. Additional arguments to pass to the `hf_hub_download` method when
  28. loading from the HuggingFace Hub.
  29. """
  30. path = (
  31. os.path.join(model_id, hf_hub_download_kwargs["subfolder"])
  32. if hf_hub_download_kwargs.get("subfolder", None) is not None
  33. else model_id
  34. )
  35. if device is None:
  36. device = infer_device()
  37. if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)):
  38. filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
  39. use_safetensors = True
  40. elif os.path.exists(os.path.join(path, WEIGHTS_NAME)):
  41. filename = os.path.join(path, WEIGHTS_NAME)
  42. use_safetensors = False
  43. else:
  44. token = hf_hub_download_kwargs.get("token", None)
  45. if token is None:
  46. token = hf_hub_download_kwargs.get("use_auth_token", None)
  47. hub_filename = (
  48. os.path.join(
  49. hf_hub_download_kwargs["subfolder"], SAFETENSORS_WEIGHTS_NAME
  50. )
  51. if hf_hub_download_kwargs.get("subfolder", None) is not None
  52. else SAFETENSORS_WEIGHTS_NAME
  53. )
  54. has_remote_safetensors_file = file_exists(
  55. repo_id=model_id,
  56. filename=hub_filename,
  57. revision=hf_hub_download_kwargs.get("revision", None),
  58. repo_type=hf_hub_download_kwargs.get("repo_type", None),
  59. token=token,
  60. )
  61. use_safetensors = has_remote_safetensors_file
  62. if has_remote_safetensors_file:
  63. # Priority 1: load safetensors weights
  64. filename = hf_hub_download(
  65. model_id,
  66. SAFETENSORS_WEIGHTS_NAME,
  67. **hf_hub_download_kwargs,
  68. )
  69. else:
  70. try:
  71. filename = hf_hub_download(
  72. model_id, WEIGHTS_NAME, **hf_hub_download_kwargs
  73. )
  74. except EntryNotFoundError:
  75. raise ValueError( # noqa: B904
  76. f"Can't find weights for {model_id} in {model_id} or \
  77. in the Hugging Face Hub. "
  78. f"Please check that the file {WEIGHTS_NAME} or \
  79. {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}."
  80. )
  81. if use_safetensors:
  82. adapters_weights = safe_load_file(filename, device=device)
  83. else:
  84. adapters_weights = torch.load(
  85. filename, map_location=torch.device(device)
  86. )
  87. return adapters_weights