conftest.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from typing import Callable, Iterable, Optional
  2. import pytest
  3. from aphrodite import LLM
  4. from aphrodite.modeling.utils import set_random_seed
  5. from ....conftest import cleanup
  6. @pytest.fixture
  7. def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
  8. baseline_llm_kwargs, seed):
  9. return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
  10. baseline_llm_kwargs, seed)
  11. @pytest.fixture
  12. def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
  13. test_llm_kwargs, seed):
  14. return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
  15. test_llm_kwargs, seed)
  16. def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
  17. distinct_llm_kwargs, seed):
  18. kwargs = {
  19. **common_llm_kwargs,
  20. **per_test_common_llm_kwargs,
  21. **distinct_llm_kwargs,
  22. }
  23. def generator_inner():
  24. llm = LLM(**kwargs)
  25. set_random_seed(seed)
  26. yield llm
  27. del llm
  28. cleanup()
  29. for llm in generator_inner():
  30. yield llm
  31. del llm
  32. def get_text_from_llm_generator(llm_generator: Iterable[LLM],
  33. prompts,
  34. sampling_params,
  35. llm_cb: Optional[Callable[[LLM],
  36. None]] = None):
  37. for llm in llm_generator:
  38. if llm_cb:
  39. llm_cb(llm)
  40. outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
  41. text = [output.outputs[0].text for output in outputs]
  42. del llm
  43. return text
  44. def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
  45. for llm in llm_generator:
  46. outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
  47. token_ids = [output.outputs[0].token_ids for output in outputs]
  48. del llm
  49. return token_ids