utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. import warnings
  2. from typing import Dict, List, Optional, Sequence, Tuple, Union
  3. from aphrodite.common.config import ModelConfig
  4. from aphrodite.common.sequence import Logprob, PromptLogprobs, SampleLogprobs
  5. from aphrodite.inputs import InputContext
  6. TokensText = Tuple[List[int], str]
  7. def check_outputs_equal(
  8. *,
  9. outputs_0_lst: Sequence[TokensText],
  10. outputs_1_lst: Sequence[TokensText],
  11. name_0: str,
  12. name_1: str,
  13. ):
  14. """
  15. Compare the two sequences generated by different models,
  16. which should be equal.
  17. """
  18. assert len(outputs_0_lst) == len(outputs_1_lst)
  19. for prompt_idx, (outputs_0,
  20. outputs_1) in enumerate(zip(outputs_0_lst,
  21. outputs_1_lst)):
  22. output_ids_0, output_str_0 = outputs_0
  23. output_ids_1, output_str_1 = outputs_1
  24. # The text and token outputs should exactly match
  25. fail_msg = (f"Test{prompt_idx}:"
  26. f"\n{name_0}:\t{output_str_0!r}"
  27. f"\n{name_1}:\t{output_str_1!r}")
  28. assert output_str_0 == output_str_1, fail_msg
  29. assert output_ids_0 == output_ids_1, fail_msg
  30. # Representation of generated sequence as a tuple of
  31. # * Token ID list
  32. # * String
  33. # * List of top sample logprobs for each sampled token
  34. #
  35. # Assumes prompt logprobs were not requested.
  36. TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
  37. float]],
  38. SampleLogprobs]]]
  39. # Allow for tokens to be represented as str's rather than IDs;
  40. # tuple of
  41. # * Token string representations list
  42. # * String
  43. # * Optional list of top sample logprobs for each sampled token
  44. #
  45. # Assumes prompt logprobs were not requested.
  46. TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
  47. List[Dict[str,
  48. Logprob]]]]]
  49. # Representation of generated sequence as a tuple of
  50. # * Token ID list
  51. # * String
  52. # * Optional list of top sample logprobs for each sampled token
  53. # * Optional list of top prompt logprobs for each prompt token
  54. #
  55. # Allows prompt logprobs to be requested.
  56. TokensTextLogprobsPromptLogprobs = Tuple[
  57. List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
  58. Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
  59. def check_logprobs_close(
  60. *,
  61. outputs_0_lst: Sequence[Union[TokensTextLogprobs,
  62. TokensTextLogprobsPromptLogprobs,
  63. TextTextLogprobs]],
  64. outputs_1_lst: Sequence[Union[TokensTextLogprobs,
  65. TokensTextLogprobsPromptLogprobs,
  66. TextTextLogprobs]],
  67. name_0: str,
  68. name_1: str,
  69. num_outputs_0_skip_tokens: int = 0,
  70. warn_on_mismatch: bool = True,
  71. always_check_logprobs: bool = False,
  72. ) -> None:
  73. """Compare the logprobs of two sequences generated by different models,
  74. which should be similar but not necessarily equal.
  75. How sample logprobs are compared:
  76. * `always_check_logprobs == True`: set of highest-logprob token ids
  77. must match between seq0 and seq1 at all sampled token offsets
  78. * `always_check_logprobs == False`: highest-logprob token ids are
  79. only compared at sampled token offsets for which generated token
  80. ids don't match
  81. Prompt logprobs must be provided either for both input sequences, or
  82. for neither. If prompt logprobs are provided, then highest-logprob
  83. prompt token ids must match between seq0 and seq1 at all prompt token
  84. offsets.
  85. Args:
  86. outputs_0_lst: First sequence to compare
  87. outputs_0_lst: Second sequence to compare
  88. name_0: sequence #0 name
  89. name_1: sequence #1 name
  90. num_outputs_0_skip_tokens: If > 0, specifies the number of initial
  91. sequence #0 tokens & logprobs to discard
  92. before comparison, i.e. all
  93. of sequence #1 will be compared to
  94. sequence #0 beginning at index
  95. num_outputs_0_skip_tokens
  96. warn_on_mismatch: Issue a warning if there is token-wise or text-wise
  97. mismatch between the two sequences
  98. always_check_logprobs: If true, check logprobs even when tokens match
  99. """
  100. assert len(outputs_0_lst) == len(outputs_1_lst)
  101. # Loop through responses to each prompt.
  102. for prompt_idx, (outputs_0,
  103. outputs_1) in enumerate(zip(outputs_0_lst,
  104. outputs_1_lst)):
  105. assert len(outputs_0) == len(outputs_1)
  106. if len(outputs_0) == 3:
  107. assert len(outputs_1) == 3
  108. # Break out tokens, text & sample logprobs
  109. # (prompt logprobs were not provided)
  110. output_ids_0, output_str_0, logprobs_0 = outputs_0
  111. output_ids_1, output_str_1, logprobs_1 = outputs_1
  112. elif len(outputs_0) == 4:
  113. assert len(outputs_1) == 4
  114. # Break out tokens, text, sample logprobs & prompt logprobs
  115. (
  116. output_ids_0,
  117. output_str_0,
  118. logprobs_0,
  119. prompt_logprobs_0,
  120. ) = outputs_0
  121. (
  122. output_ids_1,
  123. output_str_1,
  124. logprobs_1,
  125. prompt_logprobs_1,
  126. ) = outputs_1
  127. # Test prompt logprobs closeness
  128. if (prompt_logprobs_0 is not None
  129. and prompt_logprobs_1 is not None):
  130. # Both sequences' prompt logprobs lists are not `None``
  131. # (although individual list elements may be `None`);
  132. # for each token's logprobs:
  133. for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
  134. zip(prompt_logprobs_0, prompt_logprobs_1)):
  135. fail_msg = (
  136. f"Prompt logprobs test:"
  137. f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
  138. f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
  139. if logprobs_elem_0 is None:
  140. # If the seq 0 token's logprobs are `None`,
  141. # the seq 1 token's logprobs must be `None`
  142. assert logprobs_elem_1 is None, fail_msg
  143. else:
  144. # If the seq 0 token's logprobs are not `None`,
  145. # the seq 1 token's logprobs must not be `None`
  146. assert logprobs_elem_1 is not None, fail_msg
  147. # Logprobs check: top-k token choices must be the same
  148. assert (set(logprobs_elem_0.keys()) == set(
  149. logprobs_elem_1.keys())), fail_msg
  150. else:
  151. # Both sequence logprobs lists must be `None`
  152. fail_msg = (f"Prompt logprobs test:"
  153. f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
  154. f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
  155. assert (prompt_logprobs_0 is None
  156. and prompt_logprobs_1 is None), fail_msg
  157. else:
  158. raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
  159. f"{len(outputs_0)} elements were provided: "
  160. f"{outputs_0}")
  161. if logprobs_0 is None:
  162. logprobs_0 = [None] * len(output_ids_0)
  163. if logprobs_1 is None:
  164. logprobs_1 = [None] * len(output_ids_1)
  165. # Skip specified number of initial sequence #0 tokens
  166. # & logprobs, leaving output text as-is for simplicity
  167. # (text mismatches may generate warnings but do not
  168. # cause the test to fail.)
  169. if num_outputs_0_skip_tokens < 0:
  170. raise ValueError("num_outputs_0_skip_tokens must be non-negative")
  171. output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
  172. logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
  173. # Loop through generated tokens.
  174. for idx, (output_id_0,
  175. output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
  176. is_tok_mismatch = output_id_0 != output_id_1
  177. # If generated tokens don't match
  178. # or it is desired to always check logprobs,
  179. # then
  180. if is_tok_mismatch or always_check_logprobs:
  181. logprobs_elem_0 = logprobs_0[idx]
  182. logprobs_elem_1 = logprobs_1[idx]
  183. # Each predicted token must be in top N logprobs of the other
  184. fail_msg = (
  185. f"Test{prompt_idx}:"
  186. f"\nMatched tokens:\t{output_ids_0[:idx]}"
  187. f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
  188. f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
  189. assert logprobs_elem_0 is not None, fail_msg
  190. assert logprobs_elem_1 is not None, fail_msg
  191. assert output_id_0 in logprobs_elem_1, fail_msg
  192. assert output_id_1 in logprobs_elem_0, fail_msg
  193. if warn_on_mismatch and is_tok_mismatch:
  194. with warnings.catch_warnings():
  195. # This ensures that repeated warnings are shown
  196. # in the output, not just the first occurrence
  197. warnings.simplefilter("always")
  198. warnings.warn(fail_msg, stacklevel=2)
  199. # Break out since sequences will now diverge.
  200. break
  201. else:
  202. if output_str_0 != output_str_1 and warn_on_mismatch:
  203. # The token outputs exactly match,
  204. # so the text outputs should exactly match as well
  205. fail_msg = (f"Test{prompt_idx}:"
  206. f"\n{name_0}:\t{output_str_0!r}"
  207. f"\n{name_1}:\t{output_str_1!r}")
  208. with warnings.catch_warnings():
  209. # This ensures that repeated warnings are shown
  210. # in the output, not just the first occurrence
  211. warnings.simplefilter("always")
  212. warnings.warn(fail_msg, stacklevel=2)
  213. def build_model_context(model_name: str,
  214. tokenizer_name: Optional[str] = None,
  215. trust_remote_code: bool = False,
  216. mm_processor_kwargs: Optional[Dict] = None,
  217. limit_mm_per_prompt: Optional[Dict] = None):
  218. """Creates an InputContext for a given model.
  219. Args:
  220. model_name: Name of the model being considered.
  221. tokenizer_name: Name of the tokenizer being considered.
  222. trust_remote_code: Whether or not to allow loading remote code.
  223. mm_processor_kwargs: optional processor kwargs for to be leveraged
  224. in the input processor, mapper, dummy data creation, etc.
  225. limit_mm_per_prompt: Multimodal limits.
  226. Returns:
  227. InputContext for the model being considered.
  228. """
  229. if tokenizer_name is None:
  230. tokenizer_name = model_name
  231. model_config = ModelConfig(
  232. model_name,
  233. tokenizer_name,
  234. tokenizer_mode="auto",
  235. trust_remote_code=trust_remote_code,
  236. dtype="float32",
  237. seed=0,
  238. mm_processor_kwargs=mm_processor_kwargs,
  239. limit_mm_per_prompt=limit_mm_per_prompt,
  240. )
  241. return InputContext(model_config)