123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- import pytest
- import torch
- from aphrodite import SamplingParams
- MODELS = ["facebook/opt-125m"]
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.parametrize("dtype", ["half"])
- def test_logits_processor_force_generate(
- aphrodite_runner,
- example_prompts,
- model: str,
- dtype: str,
- ) -> None:
- with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
- tokenizer = aphrodite_model.model.get_tokenizer()
- repeat_times = 2
- enforced_answers = " aphrodite"
- aphrodite_token_ids = tokenizer.encode(enforced_answers,
- add_special_tokens=False)
- max_tokens = len(aphrodite_token_ids) * repeat_times
- def pick_aphrodite(token_ids, logits):
- token_id = aphrodite_token_ids[len(token_ids) %
- len(aphrodite_token_ids)]
- logits[token_id] = torch.finfo(logits.dtype).max
- return logits
- params_with_logprobs = SamplingParams(
- logits_processors=[pick_aphrodite],
- prompt_logprobs=3,
- max_tokens=max_tokens,
- )
- # test logits_processors when prompt_logprobs is not None
- aphrodite_model.model._add_request(
- example_prompts[0],
- params=params_with_logprobs,
- )
- # test prompt_logprobs is not None
- aphrodite_model.model._add_request(
- example_prompts[1],
- params=SamplingParams(
- prompt_logprobs=3,
- max_tokens=max_tokens,
- ),
- )
- # test grouped requests
- aphrodite_model.model._add_request(
- example_prompts[2],
- params=SamplingParams(max_tokens=max_tokens),
- )
- outputs = aphrodite_model.model._run_engine(use_tqdm=False)
- assert outputs[0].outputs[0].text == enforced_answers * repeat_times
|