1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768 |
- from typing import Callable, Iterable, Optional
- import pytest
- from aphrodite import LLM
- from aphrodite.modeling.utils import set_random_seed
- from ....conftest import cleanup
- @pytest.fixture
- def baseline_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
- baseline_llm_kwargs, seed):
- return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
- baseline_llm_kwargs, seed)
- @pytest.fixture
- def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
- test_llm_kwargs, seed):
- return create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
- test_llm_kwargs, seed)
- def create_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
- distinct_llm_kwargs, seed):
- kwargs = {
- **common_llm_kwargs,
- **per_test_common_llm_kwargs,
- **distinct_llm_kwargs,
- }
- def generator_inner():
- llm = LLM(**kwargs)
- set_random_seed(seed)
- yield llm
- del llm
- cleanup()
- for llm in generator_inner():
- yield llm
- del llm
- def get_text_from_llm_generator(llm_generator: Iterable[LLM],
- prompts,
- sampling_params,
- llm_cb: Optional[Callable[[LLM],
- None]] = None):
- for llm in llm_generator:
- if llm_cb:
- llm_cb(llm)
- outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
- text = [output.outputs[0].text for output in outputs]
- del llm
- return text
- def get_token_ids_from_llm_generator(llm_generator, prompts, sampling_params):
- for llm in llm_generator:
- outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
- token_ids = [output.outputs[0].token_ids for output in outputs]
- del llm
- return token_ids
|