1
0

gguf_to_torch.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from accelerate import init_empty_weights
  2. from loguru import logger
  3. from transformers import AutoConfig, AutoModelForCausalLM
  4. from aphrodite.modeling.hf_downloader import convert_gguf_to_state_dict
  5. from aphrodite.transformers_utils.config import extract_gguf_config
  6. from aphrodite.transformers_utils.tokenizer import convert_gguf_to_tokenizer
  7. def convert_save_model(checkpoint, save_dir, max_shard_size):
  8. try:
  9. config = AutoConfig.from_pretrained(save_dir)
  10. except Exception:
  11. logger.warning(
  12. f"Unable to load config from {save_dir}, extracting from GGUF")
  13. config = extract_gguf_config(checkpoint)
  14. with init_empty_weights():
  15. model = AutoModelForCausalLM.from_config(config)
  16. state_dict = convert_gguf_to_state_dict(checkpoint, config)
  17. logger.info(f"Saving model to {save_dir}...")
  18. model.save_pretrained(save_dir,
  19. state_dict=state_dict,
  20. max_shard_size=max_shard_size)
  21. def convert_save_tokenizer(checkpoint, save_dir):
  22. logger.info("Converting tokenizer...")
  23. tokenizer = convert_gguf_to_tokenizer(checkpoint)
  24. tokenizer.save_pretrained(save_dir)
  25. if __name__ == '__main__':
  26. import argparse
  27. parser = argparse.ArgumentParser(
  28. description='Convert GGUF checkpoints to torch')
  29. parser.add_argument('--input', type=str, help='The path to GGUF file')
  30. parser.add_argument('--output',
  31. type=str,
  32. help='The path to output directory')
  33. parser.add_argument(
  34. '--tokenizer',
  35. action='store_true',
  36. help='Extract the tokenizer from GGUF file. Only llama is supported')
  37. parser.add_argument(
  38. '--max-shard-size',
  39. default="5GB",
  40. type=str,
  41. help='Shard the model in specified shard size, e.g. 5GB')
  42. args = parser.parse_args()
  43. convert_save_model(args.input, args.output, args.max_shard_size)
  44. if args.tokenizer:
  45. convert_save_tokenizer(args.input, args.output)