test_logits_processor.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import pytest
  2. import torch
  3. from aphrodite import SamplingParams
  4. MODELS = ["facebook/opt-125m"]
  5. @pytest.mark.parametrize("model", MODELS)
  6. @pytest.mark.parametrize("dtype", ["half"])
  7. def test_logits_processor_force_generate(
  8. aphrodite_runner,
  9. example_prompts,
  10. model: str,
  11. dtype: str,
  12. ) -> None:
  13. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  14. tokenizer = aphrodite_model.model.get_tokenizer()
  15. repeat_times = 2
  16. enforced_answers = " aphrodite"
  17. aphrodite_token_ids = tokenizer.encode(enforced_answers,
  18. add_special_tokens=False)
  19. max_tokens = len(aphrodite_token_ids) * repeat_times
  20. def pick_aphrodite(token_ids, logits):
  21. token_id = aphrodite_token_ids[len(token_ids) %
  22. len(aphrodite_token_ids)]
  23. logits[token_id] = torch.finfo(logits.dtype).max
  24. return logits
  25. params_with_logprobs = SamplingParams(
  26. logits_processors=[pick_aphrodite],
  27. prompt_logprobs=3,
  28. max_tokens=max_tokens,
  29. )
  30. # test logits_processors when prompt_logprobs is not None
  31. aphrodite_model.model._add_request(
  32. example_prompts[0],
  33. params=params_with_logprobs,
  34. )
  35. # test prompt_logprobs is not None
  36. aphrodite_model.model._add_request(
  37. example_prompts[1],
  38. params=SamplingParams(
  39. prompt_logprobs=3,
  40. max_tokens=max_tokens,
  41. ),
  42. )
  43. # test grouped requests
  44. aphrodite_model.model._add_request(
  45. example_prompts[2],
  46. params=SamplingParams(max_tokens=max_tokens),
  47. )
  48. outputs = aphrodite_model.model._run_engine(use_tqdm=False)
  49. assert outputs[0].outputs[0].text == enforced_answers * repeat_times