pretrained.py 1.5 KB

123456789101112131415161718192021222324252627282930313233
  1. import torch
  2. from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
  3. from transformers.utils import is_remote_url
  4. from transformers.modeling_utils import load_state_dict
  5. from transformers.utils.hub import cached_file, get_checkpoint_shard_files
  6. def state_dict_from_pretrained(model_name, device=None, dtype=None):
  7. is_sharded = False
  8. resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
  9. _raise_exceptions_for_missing_entries=False)
  10. if resolved_archive_file is None:
  11. resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME,
  12. _raise_exceptions_for_missing_entries=False)
  13. if resolved_archive_file is not None:
  14. is_sharded = True
  15. if resolved_archive_file is None:
  16. raise EnvironmentError(f"Model name {model_name} was not found.")
  17. if is_sharded:
  18. # resolved_archive_file becomes a list of files that point to the different
  19. # checkpoint shards in this case.
  20. resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
  21. model_name, resolved_archive_file
  22. )
  23. state_dict = {}
  24. for sharded_file in resolved_archive_file:
  25. state_dict.update(torch.load(sharded_file, map_location=device))
  26. else:
  27. state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device)
  28. if dtype is not None:
  29. state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
  30. return state_dict