123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242 |
- import warnings
- from typing import Dict, List, Optional, Sequence, Tuple, Union
- from aphrodite.common.sequence import Logprob, PromptLogprobs, SampleLogprobs
- TokensText = Tuple[List[int], str]
- def check_outputs_equal(
- *,
- outputs_0_lst: Sequence[TokensText],
- outputs_1_lst: Sequence[TokensText],
- name_0: str,
- name_1: str,
- ):
- """
- Compare the two sequences generated by different models,
- which should be equal.
- """
- assert len(outputs_0_lst) == len(outputs_1_lst)
- for prompt_idx, (outputs_0,
- outputs_1) in enumerate(zip(outputs_0_lst,
- outputs_1_lst)):
- output_ids_0, output_str_0 = outputs_0
- output_ids_1, output_str_1 = outputs_1
- # The text and token outputs should exactly match
- fail_msg = (f"Test{prompt_idx}:"
- f"\n{name_0}:\t{output_str_0!r}"
- f"\n{name_1}:\t{output_str_1!r}")
- assert output_str_0 == output_str_1, fail_msg
- assert output_ids_0 == output_ids_1, fail_msg
- # Representation of generated sequence as a tuple of
- # * Token ID list
- # * String
- # * List of top sample logprobs for each sampled token
- #
- # Assumes prompt logprobs were not requested.
- TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
- float]],
- SampleLogprobs]]]
- # Allow for tokens to be represented as str's rather than IDs;
- # tuple of
- # * Token string representations list
- # * String
- # * Optional list of top sample logprobs for each sampled token
- #
- # Assumes prompt logprobs were not requested.
- TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
- List[Dict[str,
- Logprob]]]]]
- # Representation of generated sequence as a tuple of
- # * Token ID list
- # * String
- # * Optional list of top sample logprobs for each sampled token
- # * Optional list of top prompt logprobs for each prompt token
- #
- # Allows prompt logprobs to be requested.
- TokensTextLogprobsPromptLogprobs = Tuple[
- List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
- Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
- def check_logprobs_close(
- *,
- outputs_0_lst: Sequence[Union[TokensTextLogprobs,
- TokensTextLogprobsPromptLogprobs,
- TextTextLogprobs]],
- outputs_1_lst: Sequence[Union[TokensTextLogprobs,
- TokensTextLogprobsPromptLogprobs,
- TextTextLogprobs]],
- name_0: str,
- name_1: str,
- num_outputs_0_skip_tokens: int = 0,
- warn_on_mismatch: bool = True,
- always_check_logprobs: bool = False,
- ) -> None:
- """Compare the logprobs of two sequences generated by different models,
- which should be similar but not necessarily equal.
- How sample logprobs are compared:
- * `always_check_logprobs == True`: set of highest-logprob token ids
- must match between seq0 and seq1 at all sampled token offsets
- * `always_check_logprobs == False`: highest-logprob token ids are
- only compared at sampled token offsets for which generated token
- ids don't match
- Prompt logprobs must be provided either for both input sequences, or
- for neither. If prompt logprobs are provided, then highest-logprob
- prompt token ids must match between seq0 and seq1 at all prompt token
- offsets.
- Args:
- outputs_0_lst: First sequence to compare
- outputs_0_lst: Second sequence to compare
- name_0: sequence #0 name
- name_1: sequence #1 name
- num_outputs_0_skip_tokens: If > 0, specifies the number of initial
- sequence #0 tokens & logprobs to discard
- before comparison, i.e. all
- of sequence #1 will be compared to
- sequence #0 beginning at index
- num_outputs_0_skip_tokens
- warn_on_mismatch: Issue a warning if there is token-wise or text-wise
- mismatch between the two sequences
- always_check_logprobs: If true, check logprobs even when tokens match
- """
- assert len(outputs_0_lst) == len(outputs_1_lst)
- # Loop through responses to each prompt.
- for prompt_idx, (outputs_0,
- outputs_1) in enumerate(zip(outputs_0_lst,
- outputs_1_lst)):
- assert len(outputs_0) == len(outputs_1)
- if len(outputs_0) == 3:
- assert len(outputs_1) == 3
- # Break out tokens, text & sample logprobs
- # (prompt logprobs were not provided)
- output_ids_0, output_str_0, logprobs_0 = outputs_0
- output_ids_1, output_str_1, logprobs_1 = outputs_1
- elif len(outputs_0) == 4:
- assert len(outputs_1) == 4
- # Break out tokens, text, sample logprobs & prompt logprobs
- (
- output_ids_0,
- output_str_0,
- logprobs_0,
- prompt_logprobs_0,
- ) = outputs_0
- (
- output_ids_1,
- output_str_1,
- logprobs_1,
- prompt_logprobs_1,
- ) = outputs_1
- # Test prompt logprobs closeness
- if (prompt_logprobs_0 is not None
- and prompt_logprobs_1 is not None):
- # Both sequences' prompt logprobs lists are not `None``
- # (although individual list elements may be `None`);
- # for each token's logprobs:
- for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
- zip(prompt_logprobs_0, prompt_logprobs_1)):
- fail_msg = (
- f"Prompt logprobs test:"
- f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
- f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
- if logprobs_elem_0 is None:
- # If the seq 0 token's logprobs are `None`,
- # the seq 1 token's logprobs must be `None`
- assert logprobs_elem_1 is None, fail_msg
- else:
- # If the seq 0 token's logprobs are not `None`,
- # the seq 1 token's logprobs must not be `None`
- assert logprobs_elem_1 is not None, fail_msg
- # Logprobs check: top-k token choices must be the same
- assert (set(logprobs_elem_0.keys()) == set(
- logprobs_elem_1.keys())), fail_msg
- else:
- # Both sequence logprobs lists must be `None`
- fail_msg = (f"Prompt logprobs test:"
- f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
- f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
- assert (prompt_logprobs_0 is None
- and prompt_logprobs_1 is None), fail_msg
- else:
- raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
- f"{len(outputs_0)} elements were provided: "
- f"{outputs_0}")
- if logprobs_0 is None:
- logprobs_0 = [None] * len(output_ids_0)
- if logprobs_1 is None:
- logprobs_1 = [None] * len(output_ids_1)
- # Skip specified number of initial sequence #0 tokens
- # & logprobs, leaving output text as-is for simplicity
- # (text mismatches may generate warnings but do not
- # cause the test to fail.)
- if num_outputs_0_skip_tokens < 0:
- raise ValueError("num_outputs_0_skip_tokens must be non-negative")
- output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
- logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
- # Loop through generated tokens.
- for idx, (output_id_0,
- output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
- is_tok_mismatch = output_id_0 != output_id_1
- # If generated tokens don't match
- # or it is desired to always check logprobs,
- # then
- if is_tok_mismatch or always_check_logprobs:
- logprobs_elem_0 = logprobs_0[idx]
- logprobs_elem_1 = logprobs_1[idx]
- # Each predicted token must be in top N logprobs of the other
- fail_msg = (
- f"Test{prompt_idx}:"
- f"\nMatched tokens:\t{output_ids_0[:idx]}"
- f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
- f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
- assert logprobs_elem_0 is not None, fail_msg
- assert logprobs_elem_1 is not None, fail_msg
- assert output_id_0 in logprobs_elem_1, fail_msg
- assert output_id_1 in logprobs_elem_0, fail_msg
- if warn_on_mismatch and is_tok_mismatch:
- with warnings.catch_warnings():
- # This ensures that repeated warnings are shown
- # in the output, not just the first occurrence
- warnings.simplefilter("always")
- warnings.warn(fail_msg, stacklevel=2)
- # Break out since sequences will now diverge.
- break
- else:
- if output_str_0 != output_str_1 and warn_on_mismatch:
- # The token outputs exactly match,
- # so the text outputs should exactly match as well
- fail_msg = (f"Test{prompt_idx}:"
- f"\n{name_0}:\t{output_str_0!r}"
- f"\n{name_1}:\t{output_str_1!r}")
- with warnings.catch_warnings():
- # This ensures that repeated warnings are shown
- # in the output, not just the first occurrence
- warnings.simplefilter("always")
- warnings.warn(fail_msg, stacklevel=2)
|