utils.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import warnings
  2. from typing import Dict, List, Optional, Sequence, Tuple, Union
  3. from aphrodite.common.sequence import Logprob, SampleLogprobs
  4. TokensText = Tuple[List[int], str]
  5. def check_outputs_equal(
  6. *,
  7. outputs_0_lst: Sequence[TokensText],
  8. outputs_1_lst: Sequence[TokensText],
  9. name_0: str,
  10. name_1: str,
  11. ):
  12. """
  13. Compare the two sequences generated by different models,
  14. which should be equal.
  15. """
  16. assert len(outputs_0_lst) == len(outputs_1_lst)
  17. for prompt_idx, (outputs_0,
  18. outputs_1) in enumerate(zip(outputs_0_lst,
  19. outputs_1_lst)):
  20. output_ids_0, output_str_0 = outputs_0
  21. output_ids_1, output_str_1 = outputs_1
  22. # The text and token outputs should exactly match
  23. fail_msg = (f"Test{prompt_idx}:"
  24. f"\n{name_0}:\t{output_str_0!r}"
  25. f"\n{name_1}:\t{output_str_1!r}")
  26. assert output_str_0 == output_str_1, fail_msg
  27. assert output_ids_0 == output_ids_1, fail_msg
  28. TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
  29. float]],
  30. SampleLogprobs]]]
  31. # Allow for tokens to be represented as str's rather than IDs
  32. TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
  33. List[Dict[str,
  34. Logprob]]]]]
  35. def check_logprobs_close(
  36. *,
  37. outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
  38. outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
  39. name_0: str,
  40. name_1: str,
  41. num_outputs_0_skip_tokens: int = 0,
  42. warn_on_mismatch: bool = True,
  43. always_check_logprobs: bool = False,
  44. ) -> None:
  45. """Compare the logprobs of two sequences generated by different models,
  46. which should be similar but not necessarily equal.
  47. Args:
  48. outputs_0_lst: First sequence to compare
  49. outputs_0_lst: Second sequence to compare
  50. name_0: sequence #0 name
  51. name_1: sequence #1 name
  52. num_outputs_0_skip_tokens: If > 0, specifies the number of initial
  53. sequence #0 tokens & logprobs to discard
  54. before comparison, i.e. all
  55. of sequence #1 will be compared to
  56. sequence #0 beginning at index
  57. num_outputs_0_skip_tokens
  58. warn_on_mismatch: Issue a warning if there is token-wise or text-wise
  59. mismatch between the two sequences
  60. always_check_logprobs: If true, check logprobs even when tokens match
  61. """
  62. assert len(outputs_0_lst) == len(outputs_1_lst)
  63. # Loop through responses to each prompt.
  64. for prompt_idx, (outputs_0,
  65. outputs_1) in enumerate(zip(outputs_0_lst,
  66. outputs_1_lst)):
  67. output_ids_0, output_str_0, logprobs_0 = outputs_0
  68. output_ids_1, output_str_1, logprobs_1 = outputs_1
  69. if logprobs_0 is None:
  70. logprobs_0 = [None] * len(output_ids_0)
  71. if logprobs_1 is None:
  72. logprobs_1 = [None] * len(output_ids_1)
  73. # Skip specified number of initial sequence #0 tokens
  74. # & logprobs, leaving output text as-is for simplicity
  75. # (text mismatches may generate warnings but do not
  76. # cause the test to fail.)
  77. if num_outputs_0_skip_tokens < 0:
  78. raise ValueError("num_outputs_0_skip_tokens must be non-negative")
  79. output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
  80. logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
  81. # Loop through generated tokens.
  82. for idx, (output_id_0,
  83. output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
  84. is_tok_mismatch = output_id_0 != output_id_1
  85. # If generated tokens don't match
  86. # or it is desired to always check logprobs,
  87. # then
  88. if is_tok_mismatch or always_check_logprobs:
  89. logprobs_elem_0 = logprobs_0[idx]
  90. logprobs_elem_1 = logprobs_1[idx]
  91. # Each predicted token must be in top N logprobs of the other
  92. fail_msg = (
  93. f"Test{prompt_idx}:"
  94. f"\nMatched tokens:\t{output_ids_0[:idx]}"
  95. f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
  96. f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
  97. assert logprobs_elem_0 is not None, fail_msg
  98. assert logprobs_elem_1 is not None, fail_msg
  99. assert output_id_0 in logprobs_elem_1, fail_msg
  100. assert output_id_1 in logprobs_elem_0, fail_msg
  101. if warn_on_mismatch and is_tok_mismatch:
  102. with warnings.catch_warnings():
  103. # This ensures that repeated warnings are shown
  104. # in the output, not just the first occurrence
  105. warnings.simplefilter("always")
  106. warnings.warn(fail_msg, stacklevel=2)
  107. # Break out since sequences will now diverge.
  108. break
  109. else:
  110. if output_str_0 != output_str_1 and warn_on_mismatch:
  111. # The token outputs exactly match,
  112. # so the text outputs should exactly match as well
  113. fail_msg = (f"Test{prompt_idx}:"
  114. f"\n{name_0}:\t{output_str_0!r}"
  115. f"\n{name_1}:\t{output_str_1!r}")
  116. with warnings.catch_warnings():
  117. # This ensures that repeated warnings are shown
  118. # in the output, not just the first occurrence
  119. warnings.simplefilter("always")
  120. warnings.warn(fail_msg, stacklevel=2)