gguf_to_torch.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import torch
  2. from loguru import logger
  3. from transformers import AutoConfig, AutoModelForCausalLM
  4. from aphrodite.modeling.hf_downloader import convert_gguf_to_state_dict
  5. from aphrodite.transformers_utils.config import extract_gguf_config
  6. from aphrodite.transformers_utils.tokenizer import convert_gguf_to_tokenizer
  7. def convert_save_model(checkpoint, config_path, save_dir, max_shard_size):
  8. if config_path is not None:
  9. config = AutoConfig.from_pretrained(config_path)
  10. else:
  11. logger.info("Extracting config from GGUF")
  12. config = extract_gguf_config(checkpoint)
  13. with torch.device("meta"):
  14. model = AutoModelForCausalLM.from_config(config)
  15. state_dict = convert_gguf_to_state_dict(checkpoint, config)
  16. logger.info(f"Saving model to {save_dir}...")
  17. model.save_pretrained(save_dir,
  18. state_dict=state_dict,
  19. max_shard_size=max_shard_size)
  20. def convert_save_tokenizer(checkpoint, save_dir):
  21. logger.info("Converting tokenizer...")
  22. tokenizer = convert_gguf_to_tokenizer(checkpoint)
  23. tokenizer.save_pretrained(save_dir)
  24. if __name__ == '__main__':
  25. import argparse
  26. parser = argparse.ArgumentParser(
  27. description='Convert GGUF checkpoints to torch')
  28. parser.add_argument('--input', type=str, help='The path to GGUF file')
  29. parser.add_argument('--output',
  30. type=str,
  31. help='The path to output directory')
  32. parser.add_argument(
  33. '--config-path',
  34. default=None,
  35. type=str,
  36. help='The path to model config. This should point to the unquantized'
  37. 'original repo of the model (not the gguf file or repo).')
  38. parser.add_argument(
  39. '--tokenizer',
  40. action='store_true',
  41. help='Extract the tokenizer from GGUF file. Only llama is supported')
  42. parser.add_argument(
  43. '--max-shard-size',
  44. default="5GB",
  45. type=str,
  46. help='Shard the model in specified shard size, e.g. 5GB')
  47. args = parser.parse_args()
  48. convert_save_model(args.input, args.config_path, args.output,
  49. args.max_shard_size)
  50. if args.tokenizer:
  51. convert_save_tokenizer(args.input, args.output)