conftest.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. import asyncio
  2. import os
  3. from itertools import cycle
  4. from typing import Dict, List, Optional, Sequence, Tuple, Union
  5. import pytest
  6. import ray
  7. import torch
  8. from aphrodite import LLM
  9. from aphrodite.common.outputs import RequestOutput
  10. from aphrodite.common.sampling_params import SamplingParams
  11. from aphrodite.common.sequence import Logprob
  12. from aphrodite.common.utils import Counter, random_uuid
  13. from aphrodite.engine.args_tools import AsyncEngineArgs
  14. from aphrodite.engine.async_aphrodite import AsyncAphrodite
  15. from aphrodite.lora.request import LoRARequest
  16. from aphrodite.modeling.utils import set_random_seed
  17. from aphrodite.multimodal import MultiModalDataDict
  18. from aphrodite.prompt_adapter.request import PromptAdapterRequest
  19. from ...conftest import cleanup
  20. from ...utils import wait_for_gpu_memory_to_clear
  21. class AsyncLLM:
  22. """AsyncLLM
  23. Note: Current LLM class in aphrodite don't support async mode, for test
  24. purpose, we implement async one in here. Maybe we could move to
  25. aphrodite/endpoints/llm.py in future.
  26. Below AsyncLLM is directly borrow from aphrodite/endpoints/llm.py with
  27. changes to make to work in async mode.
  28. """
  29. def __init__(
  30. self,
  31. model: str,
  32. tokenizer: Optional[str] = None,
  33. tokenizer_mode: str = "auto",
  34. skip_tokenizer_init: bool = False,
  35. trust_remote_code: bool = False,
  36. tensor_parallel_size: int = 1,
  37. dtype: str = "auto",
  38. quantization: Optional[str] = None,
  39. revision: Optional[str] = None,
  40. tokenizer_revision: Optional[str] = None,
  41. seed: int = 0,
  42. gpu_memory_utilization: float = 0.9,
  43. swap_space: int = 4,
  44. enforce_eager: bool = False,
  45. max_seq_len_to_capture: int = 8192,
  46. disable_custom_all_reduce: bool = False,
  47. **kwargs,
  48. ) -> None:
  49. if "disable_log_stats" not in kwargs:
  50. kwargs["disable_log_stats"] = True
  51. # Needed to engine_use_ray works as a deprecated feature,
  52. # otherwise the following constructor will raise an exception
  53. os.environ["APHRODITE_ALLOW_ENGINE_USE_RAY"] = "1"
  54. engine_args = AsyncEngineArgs(
  55. model=model,
  56. tokenizer=tokenizer,
  57. tokenizer_mode=tokenizer_mode,
  58. skip_tokenizer_init=skip_tokenizer_init,
  59. trust_remote_code=trust_remote_code,
  60. tensor_parallel_size=tensor_parallel_size,
  61. dtype=dtype,
  62. quantization=quantization,
  63. revision=revision,
  64. tokenizer_revision=tokenizer_revision,
  65. seed=seed,
  66. gpu_memory_utilization=gpu_memory_utilization,
  67. swap_space=swap_space,
  68. enforce_eager=enforce_eager,
  69. max_seq_len_to_capture=max_seq_len_to_capture,
  70. # For now use ray for the distributed back-end, since
  71. # we rely on the use of engine_use_ray=True to avoid
  72. # reinitializing CUDA in the same process (driver worker)
  73. engine_use_ray=True,
  74. distributed_executor_backend="ray",
  75. disable_custom_all_reduce=disable_custom_all_reduce,
  76. **kwargs,
  77. )
  78. self.request_counter = Counter()
  79. self.llm_engine = AsyncAphrodite.from_engine_args(engine_args)
  80. def generate(
  81. self,
  82. prompts: Optional[Union[str, List[str]]] = None,
  83. sampling_params: Optional[Union[SamplingParams,
  84. List[SamplingParams]]] = None,
  85. prompt_token_ids: Optional[List[List[int]]] = None,
  86. use_tqdm: bool = True,
  87. lora_request: Optional[LoRARequest] = None,
  88. multi_modal_data: Optional[MultiModalDataDict] = None,
  89. prompt_adapter_request: Optional[PromptAdapterRequest] = None
  90. ) -> List[RequestOutput]:
  91. if prompts is None:
  92. raise ValueError("prompts must be provided.")
  93. if isinstance(prompts, str):
  94. # Convert a single prompt to a list.
  95. prompts = [prompts]
  96. if prompts is not None:
  97. num_requests = len(prompts)
  98. if sampling_params is None:
  99. # Use default sampling params.
  100. sampling_params = SamplingParams()
  101. elif isinstance(sampling_params,
  102. list) and len(sampling_params) != num_requests:
  103. raise ValueError("The lengths of prompts and "
  104. "sampling_params must be the same.")
  105. async def get_output(prompt, sampling_param) -> RequestOutput:
  106. request_id = random_uuid()
  107. results_generator = self.llm_engine.generate(
  108. prompt, sampling_param, request_id)
  109. final_output = None
  110. async for request_output in results_generator:
  111. final_output = request_output
  112. assert final_output is not None
  113. return final_output
  114. outputs: List[RequestOutput] = []
  115. try:
  116. for i in range(num_requests):
  117. prompt = prompts[i] if prompts is not None else None
  118. params = sampling_params[i] if isinstance(
  119. sampling_params, Sequence) else sampling_params
  120. res = asyncio.run(get_output(prompt, params))
  121. outputs.append(res)
  122. finally:
  123. ray.shutdown()
  124. return outputs
  125. @pytest.fixture
  126. def baseline_llm_generator(request, common_llm_kwargs,
  127. per_test_common_llm_kwargs, baseline_llm_kwargs,
  128. seed):
  129. return create_llm_generator("baseline", request, common_llm_kwargs,
  130. per_test_common_llm_kwargs,
  131. baseline_llm_kwargs, seed)
  132. @pytest.fixture
  133. def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
  134. test_llm_kwargs, seed):
  135. return create_llm_generator("test", request, common_llm_kwargs,
  136. per_test_common_llm_kwargs, test_llm_kwargs,
  137. seed)
  138. def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
  139. per_test_common_llm_kwargs, distinct_llm_kwargs,
  140. seed):
  141. kwargs = {
  142. **common_llm_kwargs,
  143. **per_test_common_llm_kwargs,
  144. **distinct_llm_kwargs,
  145. }
  146. test_name = request.node.name
  147. model = kwargs["model"]
  148. draft_model = kwargs.get("speculative_model", None)
  149. same_draft_target_model = (draft_model is not None
  150. and draft_model == model)
  151. def generator_inner():
  152. wait_for_gpu_memory_to_clear(
  153. devices=list(range(torch.cuda.device_count())),
  154. threshold_bytes=2 * 2**30,
  155. timeout_s=60,
  156. )
  157. use_async = False
  158. if "use_async" in kwargs:
  159. use_async = kwargs.pop("use_async")
  160. print(f'{use_async=}')
  161. print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
  162. llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
  163. # Override logging interval to 0 for spec decode test run to
  164. # log all metrics in time.
  165. if (baseline_or_test == "test" and not use_async
  166. and llm.llm_engine.log_stats):
  167. for sate_logger in llm.llm_engine.stat_loggers.values():
  168. sate_logger.local_interval = 0
  169. if seed is not None:
  170. set_random_seed(seed)
  171. yield llm
  172. del llm
  173. cleanup()
  174. def generator_outer():
  175. for llm in generator_inner():
  176. yield llm
  177. del llm
  178. # Set an attribute to the generator_outer function to allow us to
  179. # determine whether to further check the acceptance rate in tests.
  180. generator_outer.same_draft_target_model = same_draft_target_model # type: ignore
  181. return generator_outer
  182. def maybe_assert_ngram_worker(llm):
  183. # Verify the proposer worker is ngram if ngram is specified.
  184. if (not isinstance(llm, AsyncLLM)
  185. and llm.llm_engine.speculative_config is not None
  186. and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
  187. from aphrodite.spec_decode.ngram_worker import NGramWorker
  188. assert isinstance(
  189. llm.llm_engine.model_executor.driver_worker.proposer_worker,
  190. NGramWorker)
  191. def get_output_from_llm_generator(
  192. llm_generator, prompts,
  193. sampling_params) -> Tuple[List[str], List[List[int]], float]:
  194. tokens: List[str] = []
  195. token_ids: List[List[int]] = []
  196. acceptance_rate: float = -1.0
  197. for llm in llm_generator():
  198. maybe_assert_ngram_worker(llm)
  199. outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
  200. token_ids = [output.outputs[0].token_ids for output in outputs]
  201. tokens = [output.outputs[0].text for output in outputs]
  202. # Fetch acceptance rate if logging is enabled.
  203. if stat_loggers := getattr(llm.llm_engine, "stat_loggers", None):
  204. stat_logger = stat_loggers["prometheus"]
  205. acceptance_rate = (stat_logger.metrics.
  206. gauge_spec_decode_draft_acceptance_rate.labels(
  207. **stat_logger.labels)._value.get())
  208. del llm
  209. return tokens, token_ids, acceptance_rate
  210. def get_logprobs_from_llm_generator(
  211. llm_generator, prompts,
  212. sampling_params) -> List[List[Dict[int, Logprob]]]:
  213. """Returns a dict of (token_id: Logprob) for each generated position, for
  214. each sequence in the batch.
  215. """
  216. for llm in llm_generator():
  217. outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
  218. logprobs = [output.outputs[0].logprobs[:] for output in outputs]
  219. del llm
  220. return logprobs
  221. def run_greedy_equality_correctness_test(baseline_llm_generator,
  222. test_llm_generator,
  223. batch_size,
  224. max_output_len,
  225. force_output_len: bool,
  226. print_tokens: bool = False,
  227. ensure_all_accepted: bool = False):
  228. """Helper method that compares the outputs of both the baseline LLM and
  229. the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
  230. the same when temperature is zero.
  231. """
  232. run_equality_correctness_test(baseline_llm_generator,
  233. test_llm_generator,
  234. batch_size,
  235. max_output_len,
  236. force_output_len,
  237. temperature=0.0,
  238. seeded=False,
  239. print_tokens=print_tokens,
  240. ensure_all_accepted=ensure_all_accepted)
  241. def run_equality_correctness_test(
  242. baseline_llm_generator,
  243. test_llm_generator,
  244. batch_size,
  245. max_output_len,
  246. force_output_len: bool,
  247. temperature: float,
  248. seeded: bool,
  249. print_tokens: bool = False,
  250. ensure_all_accepted: bool = False,
  251. expected_acceptance_rate: Optional[float] = None):
  252. """Helper method that compares the outputs of both the baseline LLM and
  253. the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
  254. the same when temperature is zero (or when temperature is > 0 and seeded).
  255. """
  256. prompts = [
  257. "Hello, my name is",
  258. "The president of the United States is",
  259. "The capital of France is",
  260. "The future of AI is",
  261. "San Francisco is know for its",
  262. "Facebook was created in 2004 by",
  263. "Curious George is a",
  264. "Python 3.11 brings improvements to its",
  265. ]
  266. prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
  267. # If the test requires that we generated max_output_len tokens, then set the
  268. # sampling params to ignore eos token.
  269. ignore_eos = force_output_len
  270. if seeded:
  271. sampling_params = [
  272. SamplingParams(
  273. max_tokens=max_output_len,
  274. ignore_eos=ignore_eos,
  275. temperature=temperature,
  276. seed=i,
  277. ) for i in range(len(prompts))
  278. ]
  279. else:
  280. sampling_params = SamplingParams(
  281. max_tokens=max_output_len,
  282. ignore_eos=ignore_eos,
  283. temperature=temperature,
  284. )
  285. (spec_batch_tokens, spec_batch_token_ids,
  286. acceptance_rate) = get_output_from_llm_generator(test_llm_generator,
  287. prompts, sampling_params)
  288. (baseline_batch_tokens, baseline_batch_token_ids,
  289. _) = get_output_from_llm_generator(baseline_llm_generator, prompts,
  290. sampling_params)
  291. assert len(baseline_batch_token_ids) == len(prompts)
  292. assert len(spec_batch_token_ids) == len(prompts)
  293. for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
  294. spec_tokens) in enumerate(
  295. zip(baseline_batch_token_ids, baseline_batch_tokens,
  296. spec_batch_token_ids, spec_batch_tokens)):
  297. if print_tokens:
  298. print(f'{i=} {baseline_tokens=}')
  299. print(f'{i=} {spec_tokens=}')
  300. print(f'{i=} {baseline_token_ids=}')
  301. print(f'{i=} {spec_token_ids=}')
  302. assert baseline_token_ids == spec_token_ids
  303. print(f'{acceptance_rate=}')
  304. if ensure_all_accepted:
  305. assert acceptance_rate == 1.0
  306. if expected_acceptance_rate is not None:
  307. assert acceptance_rate >= expected_acceptance_rate - 1e-2