gguf_to_torch.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import torch
  2. from loguru import logger
  3. from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
  4. from aphrodite.modeling.hf_downloader import convert_gguf_to_state_dict
  5. from aphrodite.transformers_utils.config import extract_gguf_config
  6. from aphrodite.transformers_utils.tokenizer import convert_gguf_to_tokenizer
  7. def convert_save_model(checkpoint, unquantized_path, save_dir, max_shard_size):
  8. if unquantized_path is not None:
  9. config = AutoConfig.from_pretrained(unquantized_path)
  10. else:
  11. config = extract_gguf_config(checkpoint)
  12. with torch.device("meta"):
  13. model = AutoModelForCausalLM.from_config(config)
  14. state_dict = convert_gguf_to_state_dict(checkpoint, config)
  15. logger.info(f"Saving model to {save_dir}...")
  16. model.save_pretrained(save_dir,
  17. state_dict=state_dict,
  18. max_shard_size=max_shard_size)
  19. def convert_save_tokenizer(checkpoint, unquantized_path, save_dir):
  20. if unquantized_path is not None:
  21. tokenizer = AutoTokenizer.from_pretrained(unquantized_path)
  22. else:
  23. tokenizer = convert_gguf_to_tokenizer(checkpoint)
  24. tokenizer.save_pretrained(save_dir)
  25. if __name__ == '__main__':
  26. import argparse
  27. parser = argparse.ArgumentParser(
  28. description='Convert GGUF checkpoints to torch')
  29. parser.add_argument('--input', type=str, help='The path to GGUF file')
  30. parser.add_argument('--output',
  31. type=str,
  32. help='The path to output directory')
  33. parser.add_argument(
  34. '--unquantized-path',
  35. default=None,
  36. type=str,
  37. help='The path to the unquantized model to copy config and tokenizer')
  38. parser.add_argument('--no-tokenizer',
  39. action='store_true',
  40. help='Do not try to copy or extract the tokenizer')
  41. parser.add_argument(
  42. '--max-shard-size',
  43. default="5GB",
  44. type=str,
  45. help='Shard the model in specified shard size, e.g. 5GB')
  46. args = parser.parse_args()
  47. convert_save_model(args.input, args.unquantized_path, args.output,
  48. args.max_shard_size)
  49. if not args.no_tokenizer:
  50. convert_save_tokenizer(args.input, args.unquantized_path, args.output)