test_weight_loading.py 621 B

12345678910111213141516171819202122
  1. import os
  2. MAX_MODEL_LEN = 1024
  3. MODEL_NAME = os.environ.get(
  4. "MODEL_NAME", "robertgshaw2/zephyr-7b-beta-channelwise-gptq"
  5. )
  6. REVISION = os.environ.get("REVISION", "main")
  7. QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
  8. def test_weight_loading(aphrodite_runner):
  9. with aphrodite_runner(
  10. model_name=MODEL_NAME,
  11. revision=REVISION,
  12. dtype="auto",
  13. quantization=QUANTIZATION,
  14. max_model_len=MAX_MODEL_LEN,
  15. tensor_parallel_size=2,
  16. ) as model:
  17. output = model.generate_greedy("Hello world!", max_tokens=20)
  18. print(output)
  19. assert output