conftest.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. from itertools import cycle
  2. from typing import List, Optional, Tuple
  3. import pytest
  4. from aphrodite import LLM, SamplingParams
  5. from aphrodite.modeling.utils import set_random_seed
  6. from ...conftest import cleanup
  7. from ...models.utils import check_logprobs_close, check_outputs_equal
  8. from ...utils import RemoteOpenAIServer
  9. PROMPTS = [
  10. "Hello, my name is",
  11. "The president of the United States is",
  12. "The capital of France is",
  13. "The future of AI is",
  14. "San Francisco is know for its",
  15. "Facebook was created in 2004 by",
  16. "Curious George is a",
  17. "Python 3.11 brings improvements to its",
  18. ]
  19. @pytest.fixture
  20. def test_llm_generator(common_llm_kwargs, per_test_common_llm_kwargs,
  21. test_llm_kwargs, seed):
  22. def generate():
  23. kwargs = {
  24. **common_llm_kwargs,
  25. **per_test_common_llm_kwargs,
  26. **test_llm_kwargs,
  27. }
  28. llm = LLM(**kwargs)
  29. if seed is not None:
  30. set_random_seed(seed)
  31. yield llm
  32. del llm
  33. cleanup()
  34. return generate
  35. def maybe_assert_ngram_worker(llm):
  36. # Verify the proposer worker is ngram if ngram is specified.
  37. if (llm.llm_engine.speculative_config is not None
  38. and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
  39. from aphrodite.spec_decode.ngram_worker import NGramWorker
  40. assert isinstance(
  41. llm.llm_engine.model_executor.driver_worker.proposer_worker,
  42. NGramWorker)
  43. def get_output_from_llm_generator(
  44. llm_generator, prompts,
  45. sampling_params) -> Tuple[List[str], List[List[int]], float]:
  46. tokens: List[str] = []
  47. token_ids: List[List[int]] = []
  48. acceptance_rate: float = -1.0
  49. for llm in llm_generator():
  50. maybe_assert_ngram_worker(llm)
  51. outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
  52. token_ids = [output.outputs[0].token_ids for output in outputs]
  53. tokens = [output.outputs[0].text for output in outputs]
  54. # Fetch acceptance rate if logging is enabled.
  55. if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
  56. stat_logger = stat_loggers["prometheus"]
  57. acceptance_rate = (stat_logger.metrics.
  58. gauge_spec_decode_draft_acceptance_rate.labels(
  59. **stat_logger.labels)._value.get())
  60. del llm
  61. return tokens, token_ids, acceptance_rate
  62. def run_logprob_correctness_test(aphrodite_runner,
  63. common_llm_kwargs,
  64. per_test_common_llm_kwargs,
  65. baseline_llm_kwargs,
  66. test_llm_kwargs,
  67. batch_size: int,
  68. max_output_len: int,
  69. seed: Optional[int] = 0,
  70. temperature: float = 0.0,
  71. logprobs: int = 1):
  72. org_args = {
  73. **common_llm_kwargs,
  74. **per_test_common_llm_kwargs,
  75. **baseline_llm_kwargs,
  76. }
  77. sd_args = {
  78. **common_llm_kwargs,
  79. **per_test_common_llm_kwargs,
  80. **test_llm_kwargs,
  81. }
  82. prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
  83. sampling_params = SamplingParams(temperature=temperature,
  84. max_tokens=max_output_len,
  85. seed=seed,
  86. logprobs=logprobs)
  87. with aphrodite_runner(**org_args) as aphrodite_model:
  88. org_outputs = aphrodite_model.generate_w_logprobs(prompts,
  89. sampling_params)
  90. with aphrodite_runner(**sd_args) as aphrodite_model:
  91. sd_outputs = aphrodite_model.generate_w_logprobs(prompts,
  92. sampling_params)
  93. check_logprobs_close(outputs_0_lst=org_outputs,
  94. outputs_1_lst=sd_outputs,
  95. name_0="org",
  96. name_1="sd")
  97. def run_equality_correctness_test(
  98. aphrodite_runner,
  99. common_llm_kwargs,
  100. per_test_common_llm_kwargs,
  101. baseline_llm_kwargs,
  102. test_llm_kwargs,
  103. batch_size: int,
  104. max_output_len: int,
  105. seed: Optional[int] = 0,
  106. temperature: float = 0.0,
  107. disable_seed: bool = False,
  108. ignore_eos: bool = True,
  109. ensure_all_accepted: bool = False,
  110. expected_acceptance_rate: Optional[float] = None):
  111. org_args = {
  112. **common_llm_kwargs,
  113. **per_test_common_llm_kwargs,
  114. **baseline_llm_kwargs,
  115. }
  116. sd_args = {
  117. **common_llm_kwargs,
  118. **per_test_common_llm_kwargs,
  119. **test_llm_kwargs,
  120. }
  121. prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
  122. if disable_seed:
  123. seed = None
  124. sampling_params = SamplingParams(temperature=temperature,
  125. max_tokens=max_output_len,
  126. seed=seed,
  127. ignore_eos=ignore_eos)
  128. with aphrodite_runner(**org_args) as aphrodite_model:
  129. org_outputs = aphrodite_model.generate(prompts, sampling_params)
  130. with aphrodite_runner(**sd_args) as aphrodite_model:
  131. if ensure_all_accepted or expected_acceptance_rate is not None:
  132. # Force log interval to be 0 to catch all metrics.
  133. stat_logger = aphrodite_model.model.llm_engine.stat_loggers[
  134. 'prometheus']
  135. stat_logger.local_interval = -100
  136. sd_outputs = aphrodite_model.generate(prompts, sampling_params)
  137. if ensure_all_accepted or expected_acceptance_rate is not None:
  138. acceptance_rate = (stat_logger.metrics.
  139. gauge_spec_decode_draft_acceptance_rate.labels(
  140. **stat_logger.labels)._value.get())
  141. if ensure_all_accepted:
  142. assert True
  143. # FIXME: ci fails to log acceptance rate.
  144. # It works locally.
  145. # assert acceptance_rate == 1.0
  146. if expected_acceptance_rate is not None:
  147. assert acceptance_rate >= expected_acceptance_rate - 1e-2
  148. check_outputs_equal(outputs_0_lst=org_outputs,
  149. outputs_1_lst=sd_outputs,
  150. name_0="org",
  151. name_1="sd")
  152. def run_equality_correctness_test_tp(model,
  153. common_llm_kwargs,
  154. per_test_common_llm_kwargs,
  155. baseline_llm_kwargs,
  156. test_llm_kwargs,
  157. batch_size: int,
  158. max_output_len: int,
  159. seed: int = 0,
  160. temperature: float = 0.0):
  161. """Helper method that compares the outputs of both the baseline LLM and
  162. the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
  163. the same when temperature is zero.
  164. """
  165. arg1 = common_llm_kwargs + per_test_common_llm_kwargs + baseline_llm_kwargs
  166. arg2 = common_llm_kwargs + per_test_common_llm_kwargs + test_llm_kwargs
  167. env1 = env2 = None
  168. max_wait_seconds = 240
  169. results = []
  170. prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]
  171. for args, env in ((arg1, env1), (arg2, env2)):
  172. with RemoteOpenAIServer(model,
  173. args,
  174. env_dict=env,
  175. max_wait_seconds=max_wait_seconds) as server:
  176. client = server.get_client()
  177. completion = client.completions.create(model=model,
  178. prompt=prompts,
  179. max_tokens=max_output_len,
  180. seed=seed,
  181. temperature=temperature)
  182. results.append({
  183. "test":
  184. "seeded_sampling",
  185. "text": [choice.text for choice in completion.choices],
  186. "finish_reason":
  187. [choice.finish_reason for choice in completion.choices],
  188. "usage":
  189. completion.usage,
  190. })
  191. n = len(results) // 2
  192. arg1_results = results[:n]
  193. arg2_results = results[n:]
  194. for arg1_result, arg2_result in zip(arg1_results, arg2_results):
  195. assert arg1_result == arg2_result, (
  196. f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
  197. f"{arg1_result=} != {arg2_result=}")