test_sharded_state_loader.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import multiprocessing as mp
  2. import os
  3. import shutil
  4. from tempfile import TemporaryDirectory
  5. import pytest
  6. import torch
  7. from huggingface_hub import snapshot_download
  8. from aphrodite import LLM, SamplingParams
  9. from aphrodite.modeling.model_loader.loader import ShardedStateLoader
  10. prompts = [
  11. "Hello, my name is",
  12. "The president of the United States is",
  13. "The capital of France is",
  14. "The future of AI is",
  15. ]
  16. # Create a sampling params object.
  17. sampling_params = SamplingParams(
  18. temperature=0,
  19. max_tokens=256,
  20. ignore_eos=True,
  21. )
  22. def test_filter_subtensors():
  23. state_dict = {
  24. "a": torch.empty(2),
  25. "b": torch.empty((2, 4)),
  26. "c": torch.empty((2, 4, 8)),
  27. }
  28. state_dict.update({
  29. "x": state_dict["b"],
  30. "y": state_dict["c"][1, 2, :],
  31. "z": state_dict["c"][1, :, 4],
  32. })
  33. filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
  34. assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
  35. for key, tensor in filtered_state_dict.items():
  36. # NOTE: don't use `equal` here, as the tensor might contain NaNs
  37. assert tensor is state_dict[key]
  38. @pytest.fixture(scope="module")
  39. def llama_2_7b_files():
  40. with TemporaryDirectory() as cache_dir:
  41. input_dir = snapshot_download("meta-llama/Llama-2-7b-hf",
  42. cache_dir=cache_dir,
  43. ignore_patterns="*.bin*")
  44. yield input_dir
  45. def _run_writer(input_dir, output_dir, weights_patterns, **kwargs):
  46. llm_sharded_writer = LLM(model=input_dir, **kwargs)
  47. # Dump worker states to output directory
  48. llm_sharded_writer.llm_engine.model_executor.save_sharded_state(
  49. path=output_dir)
  50. # Copy metadata files to output directory
  51. for file in os.listdir(input_dir):
  52. if not any(file.endswith(ext) for ext in weights_patterns):
  53. shutil.copy(f"{input_dir}/{file}", output_dir)
  54. def _run_generate(input_dir, queue: mp.Queue, **kwargs):
  55. llm = LLM(model=input_dir, **kwargs)
  56. gen = llm.generate(prompts, sampling_params)
  57. queue.put([g.outputs[0].__dict__ for g in gen])
  58. queue.close()
  59. queue.join_thread()
  60. @pytest.mark.parametrize("enable_lora", [False, True])
  61. @pytest.mark.parametrize("tp_size", [1, 2])
  62. def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
  63. llama_2_7b_files):
  64. if num_gpus_available < tp_size:
  65. pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
  66. weights_patterns = ("*.safetensors", )
  67. gpu_memory_utilization = 0.8
  68. input_dir = llama_2_7b_files
  69. ctx = mp.get_context("spawn")
  70. # Run in separate processes for memory & CUDA isolation
  71. with TemporaryDirectory() as output_dir:
  72. p = ctx.Process(target=_run_writer,
  73. args=(input_dir, output_dir, weights_patterns),
  74. kwargs=dict(
  75. tensor_parallel_size=tp_size,
  76. distributed_executor_backend="mp",
  77. gpu_memory_utilization=gpu_memory_utilization,
  78. enforce_eager=True,
  79. ))
  80. p.start()
  81. p.join()
  82. queue = ctx.Queue()
  83. p = ctx.Process(target=_run_generate,
  84. args=(input_dir, queue),
  85. kwargs=dict(
  86. distributed_executor_backend="mp",
  87. enable_lora=enable_lora,
  88. gpu_memory_utilization=gpu_memory_utilization,
  89. tensor_parallel_size=tp_size,
  90. ))
  91. p.start()
  92. p.join()
  93. out_before = queue.get()
  94. p = ctx.Process(target=_run_generate,
  95. args=(output_dir, queue),
  96. kwargs=dict(
  97. distributed_executor_backend="mp",
  98. enable_lora=enable_lora,
  99. gpu_memory_utilization=gpu_memory_utilization,
  100. tensor_parallel_size=tp_size,
  101. load_format="sharded_state",
  102. ))
  103. p.start()
  104. p.join()
  105. out_after = queue.get()
  106. assert out_before == out_after