pretrained.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import os
  2. from functools import partial
  3. import torch
  4. from safetensors.torch import load_file as safe_load_file
  5. from transformers.utils import (
  6. SAFE_WEIGHTS_INDEX_NAME,
  7. SAFE_WEIGHTS_NAME,
  8. WEIGHTS_INDEX_NAME,
  9. WEIGHTS_NAME,
  10. )
  11. from transformers.utils.hub import cached_file, get_checkpoint_shard_files
  12. def state_dict_from_pretrained(model_name, device=None, dtype=None):
  13. # If not fp32, then we don't want to load directly to the GPU
  14. mapped_device = "cpu" if dtype not in [torch.float32, None] else device
  15. is_sharded = False
  16. load_safe = False
  17. resolved_archive_file = None
  18. weights_path = os.path.join(model_name, WEIGHTS_NAME)
  19. weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
  20. safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
  21. safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
  22. if os.path.isfile(weights_path):
  23. resolved_archive_file = cached_file(
  24. model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
  25. )
  26. elif os.path.isfile(weights_index_path):
  27. resolved_archive_file = cached_file(
  28. model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
  29. )
  30. is_sharded = True
  31. elif os.path.isfile(safe_weights_path):
  32. resolved_archive_file = cached_file(
  33. model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False
  34. )
  35. load_safe = True
  36. elif os.path.isfile(safe_weights_index_path):
  37. resolved_archive_file = cached_file(
  38. model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
  39. )
  40. is_sharded = True
  41. load_safe = True
  42. else: # Try loading from HF hub instead of from local files
  43. resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
  44. _raise_exceptions_for_missing_entries=False)
  45. if resolved_archive_file is None:
  46. resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
  47. _raise_exceptions_for_missing_entries=False)
  48. if resolved_archive_file is not None:
  49. is_sharded = True
  50. if resolved_archive_file is None:
  51. raise EnvironmentError(f"Model name {model_name} was not found.")
  52. if load_safe:
  53. loader = partial(safe_load_file, device=mapped_device)
  54. else:
  55. loader = partial(torch.load, map_location=mapped_device)
  56. if is_sharded:
  57. # resolved_archive_file becomes a list of files that point to the different
  58. # checkpoint shards in this case.
  59. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
  60. model_name, resolved_archive_file
  61. )
  62. state_dict = {}
  63. for sharded_file in resolved_archive_file:
  64. state_dict.update(loader(sharded_file))
  65. else:
  66. state_dict = loader(resolved_archive_file)
  67. # Convert dtype before moving to GPU to save memory
  68. if dtype is not None:
  69. state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
  70. state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
  71. return state_dict