save_sharded_state.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. """
  2. Saves each worker's model state dict directly to a checkpoint, which enables a
  3. fast load path for large tensor-parallel models where each worker only needs to
  4. read its own shard rather than the entire checkpoint.
  5. Example usage:
  6. python save_sharded_state.py \
  7. --model /path/to/load \
  8. --quantization deepspeedfp \
  9. --tensor-parallel-size 8 \
  10. --output /path/to/save
  11. Then, the model can be loaded with
  12. llm = LLM(
  13. model="/path/to/save",
  14. load_format="sharded_state",
  15. quantization="deepspeedfp",
  16. tensor_parallel_size=8,
  17. )
  18. """
  19. import argparse
  20. import dataclasses
  21. import os
  22. import shutil
  23. from pathlib import Path
  24. from aphrodite import LLM, EngineArgs
  25. parser = argparse.ArgumentParser()
  26. EngineArgs.add_cli_args(parser)
  27. parser.add_argument("--output",
  28. "-o",
  29. required=True,
  30. type=str,
  31. help="path to output checkpoint")
  32. parser.add_argument("--file-pattern",
  33. type=str,
  34. help="string pattern of saved filenames")
  35. parser.add_argument("--max-file-size",
  36. type=str,
  37. default=5 * 1024**3,
  38. help="max size (in bytes) of each safetensors file")
  39. def main(args):
  40. engine_args = EngineArgs.from_cli_args(args)
  41. if engine_args.enable_lora:
  42. raise ValueError("Saving with enable_lora=True is not supported!")
  43. model_path = engine_args.model
  44. if not Path(model_path).is_dir():
  45. raise ValueError("model path must be a local directory")
  46. # Create LLM instance from arguments
  47. llm = LLM(**dataclasses.asdict(engine_args))
  48. # Prepare output directory
  49. Path(args.output).mkdir(exist_ok=True)
  50. # Dump worker states to output directory
  51. model_executor = llm.llm_engine.model_executor
  52. model_executor.save_sharded_state(path=args.output,
  53. pattern=args.file_pattern,
  54. max_size=args.max_file_size)
  55. # Copy metadata files to output directory
  56. for file in os.listdir(model_path):
  57. if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
  58. if os.path.isdir(os.path.join(model_path, file)):
  59. shutil.copytree(os.path.join(model_path, file),
  60. os.path.join(args.output, file))
  61. else:
  62. shutil.copy(os.path.join(model_path, file), args.output)
  63. if __name__ == "__main__":
  64. args = parser.parse_args()
  65. main(args)