12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import pytest
- from aphrodite import SamplingParams
- MODELS = ["facebook/opt-125m"]
- @pytest.mark.parametrize("model", MODELS)
- @pytest.mark.parametrize("dtype", ["half"])
- def test_ranks(
- aphrodite_runner,
- model,
- dtype,
- example_prompts,
- ):
- max_tokens = 5
- num_top_logprobs = 5
- num_prompt_logprobs = 5
- with aphrodite_runner(model, dtype=dtype,
- max_logprobs=num_top_logprobs) as aphrodite_model:
- ## Test greedy logprobs ranks
- aphrodite_sampling_params = SamplingParams(
- temperature=0.0,
- top_p=1.0,
- max_tokens=max_tokens,
- logprobs=num_top_logprobs,
- prompt_logprobs=num_prompt_logprobs)
- aphrodite_results = aphrodite_model.generate_w_logprobs(example_prompts,
- aphrodite_sampling_params)
- ## Test non-greedy logprobs ranks
- sampling_params = SamplingParams(temperature=1.0,
- top_p=1.0,
- max_tokens=max_tokens,
- logprobs=num_top_logprobs,
- prompt_logprobs=num_prompt_logprobs)
- res = aphrodite_model.generate_w_logprobs(example_prompts,
- sampling_params)
- for result in aphrodite_results:
- assert result[2] is not None
- assert len(result[2]) == len(result[0])
- # check whether all chosen tokens have ranks = 1
- for token, logprobs in zip(result[0], result[2]):
- assert token in logprobs
- assert logprobs[token].rank == 1
- for result in res:
- assert result[2] is not None
- assert len(result[2]) == len(result[0])
- # check whether all chosen tokens have ranks
- for token, logprobs in zip(result[0], result[2]):
- assert logprobs[token].rank >= 1
|