test_llm_generate.py 1016 B

123456789101112131415161718192021222324252627282930
  1. import pytest
  2. from aphrodite import LLM, SamplingParams
  3. def test_multiple_sampling_params():
  4. llm = LLM(model='gpt2', max_num_batched_tokens=1024)
  5. prompts = [
  6. "Once upon a time",
  7. "In a galaxy far far away",
  8. "The quick brown fox jumps over the lazy dog",
  9. ]
  10. sampling_params = [
  11. SamplingParams(temperature=0.7, min_p=0.06),
  12. SamplingParams(temperature=0.8, min_p=0.07),
  13. SamplingParams(temperature=0.9, min_p=0.08),
  14. ]
  15. outputs = llm.generate(prompts, sampling_params=sampling_params)
  16. assert len(prompts) == len(outputs)
  17. with pytest.raises(ValueError):
  18. outputs = llm.generate(prompts, sampling_params=sampling_params[:2])
  19. single_sampling_params = SamplingParams(temperature=0.7, min_p=0.06)
  20. outputs = llm.generate(prompts, sampling_params=single_sampling_params)
  21. assert len(prompts) == len(outputs)
  22. outputs = llm.generate(prompts, sampling_params=None)
  23. assert len(prompts) == len(outputs)