conftest.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. import contextlib
  2. import gc
  3. import json
  4. import os
  5. import sys
  6. import tempfile
  7. from collections import UserList
  8. from enum import Enum
  9. from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
  10. TypeVar, Union)
  11. import numpy as np
  12. import pytest
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. from huggingface_hub import snapshot_download
  17. from loguru import logger
  18. from PIL import Image
  19. from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
  20. BatchFeature)
  21. from aphrodite import LLM, SamplingParams
  22. from aphrodite.assets.image import ImageAsset
  23. from aphrodite.assets.video import VideoAsset
  24. from aphrodite.common.config import TokenizerPoolConfig
  25. from aphrodite.common.outputs import RequestOutput
  26. from aphrodite.common.utils import (STR_DTYPE_TO_TORCH_DTYPE,
  27. cuda_device_count_stateless, identity,
  28. is_cpu)
  29. from aphrodite.connections import global_http_connection
  30. from aphrodite.distributed import (destroy_distributed_environment,
  31. destroy_model_parallel,
  32. init_distributed_environment,
  33. initialize_model_parallel)
  34. from aphrodite.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
  35. to_enc_dec_tuple_list, zip_enc_dec_prompts)
  36. from tests.models.utils import (TokensTextLogprobs,
  37. TokensTextLogprobsPromptLogprobs)
  38. _TEST_DIR = os.path.dirname(__file__)
  39. _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
  40. _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
  41. PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]]
  42. PromptAudioInput = Union[List[Tuple[np.ndarray, int]],
  43. List[List[Tuple[np.ndarray, int]]]]
  44. PromptVideoInput = Union[List[np.ndarray], List[List[np.ndarray]]]
  45. def _read_prompts(filename: str) -> List[str]:
  46. with open(filename, "r") as f:
  47. prompts = f.readlines()
  48. return prompts
  49. class _ImageAssetPrompts(TypedDict):
  50. stop_sign: str
  51. cherry_blossom: str
  52. if sys.version_info < (3, 9):
  53. # UserList cannot be subscripted
  54. class _ImageAssetsBase(UserList):
  55. pass
  56. else:
  57. class _ImageAssetsBase(UserList[ImageAsset]):
  58. pass
  59. class _ImageAssets(_ImageAssetsBase):
  60. def __init__(self) -> None:
  61. super().__init__([
  62. ImageAsset("stop_sign"),
  63. ImageAsset("cherry_blossom"),
  64. ])
  65. def prompts(self, prompts: _ImageAssetPrompts) -> List[str]:
  66. """
  67. Convenience method to define the prompt for each test image.
  68. The order of the returned prompts matches the order of the
  69. assets when iterating through this object.
  70. """
  71. return [prompts["stop_sign"], prompts["cherry_blossom"]]
  72. class _VideoAssetPrompts(TypedDict):
  73. sample_demo_1: str
  74. if sys.version_info < (3, 9):
  75. # UserList cannot be subscripted
  76. class _VideoAssetsBase(UserList):
  77. pass
  78. else:
  79. class _VideoAssetsBase(UserList[VideoAsset]):
  80. pass
  81. class _VideoAssets(_VideoAssetsBase):
  82. def __init__(self) -> None:
  83. super().__init__([
  84. VideoAsset("sample_demo_1.mp4"),
  85. ])
  86. def prompts(self, prompts: _VideoAssetPrompts) -> List[str]:
  87. return [prompts["sample_demo_1"]]
  88. IMAGE_ASSETS = _ImageAssets()
  89. """Singleton instance of :class:`_ImageAssets`."""
  90. VIDEO_ASSETS = _VideoAssets()
  91. """Singleton instance of :class:`_VideoAssets`."""
  92. @pytest.fixture(autouse=True)
  93. def init_test_http_connection():
  94. # pytest_asyncio may use a different event loop per test
  95. # so we need to make sure the async client is created anew
  96. global_http_connection.reuse_client = False
  97. @pytest.fixture
  98. def dist_init():
  99. temp_file = tempfile.mkstemp()[1]
  100. init_distributed_environment(
  101. world_size=1,
  102. rank=0,
  103. distributed_init_method=f"file://{temp_file}",
  104. local_rank=0,
  105. backend="nccl",
  106. )
  107. initialize_model_parallel(1, 1)
  108. yield
  109. cleanup()
  110. def cleanup():
  111. destroy_model_parallel()
  112. destroy_distributed_environment()
  113. with contextlib.suppress(AssertionError):
  114. torch.distributed.destroy_process_group()
  115. gc.collect()
  116. if not is_cpu():
  117. torch.cuda.empty_cache()
  118. @pytest.fixture()
  119. def should_do_global_cleanup_after_test(request) -> bool:
  120. """Allow subdirectories to skip global cleanup by overriding this fixture.
  121. This can provide a ~10x speedup for non-GPU unit tests since they don't need
  122. to initialize torch.
  123. """
  124. if request.node.get_closest_marker("skip_global_cleanup"):
  125. return False
  126. return True
  127. @pytest.fixture(autouse=True)
  128. def cleanup_fixture(should_do_global_cleanup_after_test: bool):
  129. yield
  130. if should_do_global_cleanup_after_test:
  131. cleanup()
  132. @pytest.fixture
  133. def example_prompts() -> List[str]:
  134. prompts = []
  135. for filename in _TEST_PROMPTS:
  136. prompts += _read_prompts(filename)
  137. return prompts
  138. class DecoderPromptType(Enum):
  139. """For encoder/decoder models only."""
  140. CUSTOM = 1
  141. NONE = 2
  142. EMPTY_STR = 3
  143. @pytest.fixture
  144. def example_encoder_decoder_prompts(
  145. ) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
  146. '''
  147. Returns an encoder prompt list and a decoder prompt list, wherein each pair
  148. of same-index entries in both lists corresponds to an (encoder prompt,
  149. decoder prompt) tuple.
  150. Returns:
  151. * Encoder prompt list
  152. * Decoder prompt list (reverse of encoder prompt list)
  153. '''
  154. encoder_prompts = []
  155. for filename in _TEST_PROMPTS:
  156. encoder_prompts += _read_prompts(filename)
  157. custom_decoder_prompts = encoder_prompts[::-1]
  158. empty_str_decoder_prompts = [""] * len(encoder_prompts)
  159. none_decoder_prompts = [None] * len(encoder_prompts)
  160. # NONE decoder prompt type
  161. return {
  162. DecoderPromptType.NONE:
  163. zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
  164. DecoderPromptType.EMPTY_STR:
  165. zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
  166. DecoderPromptType.CUSTOM:
  167. zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
  168. }
  169. @pytest.fixture
  170. def example_long_prompts() -> List[str]:
  171. prompts = []
  172. for filename in _LONG_PROMPTS:
  173. prompts += _read_prompts(filename)
  174. return prompts
  175. @pytest.fixture(scope="session")
  176. def image_assets() -> _ImageAssets:
  177. return IMAGE_ASSETS
  178. @pytest.fixture(scope="session")
  179. def video_assets() -> _VideoAssets:
  180. return VIDEO_ASSETS
  181. _T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature)
  182. class HfRunner:
  183. def wrap_device(self, input: _T) -> _T:
  184. if not is_cpu():
  185. # Check if the input is already on the GPU
  186. if hasattr(input, 'device') and input.device.type == "cuda":
  187. return input # Already on GPU, no need to move
  188. return input.to("cuda")
  189. else:
  190. # Check if the input is already on the CPU
  191. if hasattr(input, 'device') and input.device.type == "cpu":
  192. return input # Already on CPU, no need to move
  193. return input.to("cpu")
  194. def __init__(
  195. self,
  196. model_name: str,
  197. dtype: str = "half",
  198. *,
  199. model_kwargs: Optional[Dict[str, Any]] = None,
  200. is_embedding_model: bool = False,
  201. auto_cls=AutoModelForCausalLM,
  202. postprocess_inputs: Callable[[BatchEncoding],
  203. BatchEncoding] = identity,
  204. ) -> None:
  205. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
  206. self.model_name = model_name
  207. if is_embedding_model:
  208. # Lazy init required for AMD CI
  209. from sentence_transformers import SentenceTransformer
  210. self.model = self.wrap_device(
  211. SentenceTransformer(
  212. model_name,
  213. device="cpu",
  214. ).to(dtype=torch_dtype))
  215. else:
  216. model_kwargs = model_kwargs if model_kwargs is not None else {}
  217. self.model = self.wrap_device(
  218. auto_cls.from_pretrained(
  219. model_name,
  220. torch_dtype=torch_dtype,
  221. trust_remote_code=True,
  222. **model_kwargs,
  223. ))
  224. self.tokenizer = AutoTokenizer.from_pretrained(
  225. model_name,
  226. torch_dtype=torch_dtype,
  227. trust_remote_code=True,
  228. )
  229. try:
  230. # don't put this import at the top level
  231. # it will call torch.cuda.device_count()
  232. from transformers import AutoProcessor # noqa: F401
  233. self.processor = AutoProcessor.from_pretrained(
  234. model_name,
  235. torch_dtype=torch_dtype,
  236. trust_remote_code=True,
  237. )
  238. except Exception as exc:
  239. logger.warning(
  240. f"Unable to auto-load HuggingFace processor for model "
  241. f"({model_name}). Using tokenizer instead. Reason: {exc}")
  242. self.processor = self.tokenizer
  243. self.postprocess_inputs = postprocess_inputs
  244. def generate(
  245. self,
  246. prompts: List[str],
  247. images: Optional[PromptImageInput] = None,
  248. videos: Optional[List[np.ndarray]] = None,
  249. **kwargs: Any,
  250. ) -> List[Tuple[List[List[int]], List[str]]]:
  251. if images:
  252. assert len(prompts) == len(images)
  253. outputs: List[Tuple[List[List[int]], List[str]]] = []
  254. for i, prompt in enumerate(prompts):
  255. processor_kwargs: Dict[str, Any] = {
  256. "text": prompt,
  257. "return_tensors": "pt",
  258. }
  259. if images is not None and images[i] is not None:
  260. processor_kwargs["images"] = images[i]
  261. if videos is not None and videos[i] is not None:
  262. processor_kwargs["videos"] = videos[i]
  263. inputs = self.processor(**processor_kwargs)
  264. inputs = self.postprocess_inputs(inputs)
  265. output_ids = self.model.generate(
  266. **self.wrap_device(inputs),
  267. use_cache=True,
  268. **kwargs,
  269. )
  270. output_str = self.processor.batch_decode(
  271. output_ids,
  272. skip_special_tokens=True,
  273. clean_up_tokenization_spaces=False,
  274. )
  275. output_ids = output_ids.cpu().tolist()
  276. outputs.append((output_ids, output_str))
  277. return outputs
  278. def generate_greedy(
  279. self,
  280. prompts: List[str],
  281. max_tokens: int,
  282. images: Optional[PromptImageInput] = None,
  283. **kwargs: Any,
  284. ) -> List[Tuple[List[int], str]]:
  285. outputs = self.generate(prompts,
  286. do_sample=False,
  287. max_new_tokens=max_tokens,
  288. images=images,
  289. **kwargs)
  290. return [(output_ids[0], output_str[0])
  291. for output_ids, output_str in outputs]
  292. def generate_beam_search(
  293. self,
  294. prompts: List[str],
  295. beam_width: int,
  296. max_tokens: int,
  297. ) -> List[Tuple[List[List[int]], List[str]]]:
  298. outputs = self.generate(prompts,
  299. do_sample=False,
  300. max_new_tokens=max_tokens,
  301. num_beams=beam_width,
  302. num_return_sequences=beam_width)
  303. for i in range(len(outputs)):
  304. output_ids, output_str = outputs[i]
  305. for j in range(len(output_ids)):
  306. output_ids[j] = [
  307. x for x in output_ids[j]
  308. if x != self.tokenizer.pad_token_id
  309. ]
  310. outputs[i] = (output_ids, output_str)
  311. return outputs
  312. def generate_greedy_logprobs(
  313. self,
  314. prompts: List[str],
  315. max_tokens: int,
  316. images: Optional[PromptImageInput] = None,
  317. videos: Optional[List[np.ndarray]] = None,
  318. **kwargs: Any,
  319. ) -> List[List[torch.Tensor]]:
  320. all_logprobs: List[List[torch.Tensor]] = []
  321. for i, prompt in enumerate(prompts):
  322. processor_kwargs: Dict[str, Any] = {
  323. "text": prompt,
  324. "return_tensors": "pt",
  325. }
  326. if images is not None and images[i] is not None:
  327. processor_kwargs["images"] = images[i]
  328. if videos is not None and videos[i] is not None:
  329. processor_kwargs["videos"] = videos[i]
  330. inputs = self.processor(**processor_kwargs)
  331. inputs = self.postprocess_inputs(inputs)
  332. output = self.model.generate(
  333. **self.wrap_device(inputs),
  334. use_cache=True,
  335. do_sample=False,
  336. max_new_tokens=max_tokens,
  337. output_hidden_states=True,
  338. return_dict_in_generate=True,
  339. **kwargs,
  340. )
  341. seq_logprobs: List[torch.Tensor] = []
  342. for hidden_states in output.hidden_states:
  343. last_hidden_states = hidden_states[-1][0]
  344. logits = torch.matmul(
  345. last_hidden_states,
  346. self.model.get_output_embeddings().weight.t(),
  347. )
  348. if self.model.get_output_embeddings().bias is not None:
  349. logits += self.model.get_output_embeddings(
  350. ).bias.unsqueeze(0)
  351. logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
  352. seq_logprobs.append(logprobs)
  353. all_logprobs.append(seq_logprobs)
  354. return all_logprobs
  355. def _hidden_states_to_logprobs(
  356. self,
  357. hidden_states,
  358. num_logprobs,
  359. ) -> Tuple[List[Dict[int, float]], int]:
  360. seq_logprobs: List[torch.Tensor] = []
  361. output_len = len(hidden_states)
  362. for _, hidden_state in enumerate(hidden_states):
  363. last_hidden_states = hidden_state[-1][0]
  364. logits = torch.matmul(
  365. last_hidden_states,
  366. self.model.get_output_embeddings().weight.t(),
  367. )
  368. if getattr(self.model.get_output_embeddings(), "bias",
  369. None) is not None:
  370. logits += self.model.get_output_embeddings().bias.unsqueeze(0)
  371. logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
  372. seq_logprobs.append(logprobs)
  373. # convert to dict
  374. seq_logprobs_lst: List[Dict[int, float]] = []
  375. for tok_idx, tok_logprobs in enumerate(seq_logprobs):
  376. # drop prompt logprobs
  377. if tok_idx == 0:
  378. tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
  379. topk = tok_logprobs.topk(num_logprobs)
  380. tok_logprobs_dct = {}
  381. for token_id, logprob in zip(topk.indices[0], topk.values[0]):
  382. tok_logprobs_dct[token_id.item()] = logprob.item()
  383. seq_logprobs_lst.append(tok_logprobs_dct)
  384. return (
  385. seq_logprobs_lst,
  386. output_len,
  387. )
  388. def generate_greedy_logprobs_limit(
  389. self,
  390. prompts: List[str],
  391. max_tokens: int,
  392. num_logprobs: int,
  393. images: Optional[PromptImageInput] = None,
  394. audios: Optional[PromptAudioInput] = None,
  395. videos: Optional[List[np.ndarray]] = None,
  396. **kwargs: Any,
  397. ) -> List[TokensTextLogprobs]:
  398. all_logprobs: List[List[Dict[int, float]]] = []
  399. all_output_ids: List[List[int]] = []
  400. all_output_strs: List[str] = []
  401. for i, prompt in enumerate(prompts):
  402. processor_kwargs: Dict[str, Any] = {
  403. "text": prompt,
  404. "return_tensors": "pt",
  405. }
  406. if images is not None and images[i] is not None:
  407. processor_kwargs["images"] = images[i]
  408. if audios is not None:
  409. audio, sr = audios[i]
  410. processor_kwargs["audio"] = audio
  411. processor_kwargs["sampling_rate"] = sr
  412. if videos is not None:
  413. processor_kwargs["videos"] = videos[i]
  414. inputs = self.processor(**processor_kwargs)
  415. inputs = self.postprocess_inputs(inputs)
  416. output = self.model.generate(
  417. **self.wrap_device(inputs),
  418. use_cache=True,
  419. do_sample=False,
  420. max_new_tokens=max_tokens,
  421. output_hidden_states=True,
  422. return_dict_in_generate=True,
  423. **kwargs,
  424. )
  425. (
  426. seq_logprobs_lst,
  427. output_len,
  428. ) = self._hidden_states_to_logprobs(output.hidden_states,
  429. num_logprobs)
  430. all_logprobs.append(seq_logprobs_lst)
  431. seq_ids = output.sequences[0]
  432. output_len = len(seq_logprobs_lst)
  433. output_ids = seq_ids[-output_len:]
  434. all_output_ids.append(output_ids.tolist())
  435. all_output_strs.append(self.tokenizer.decode(output_ids))
  436. outputs = zip(all_output_ids, all_output_strs, all_logprobs)
  437. return [(output_ids, output_str, output_logprobs)
  438. for output_ids, output_str, output_logprobs in outputs]
  439. def generate_encoder_decoder_greedy_logprobs_limit(
  440. self,
  441. encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
  442. max_tokens: int,
  443. num_logprobs: int,
  444. **kwargs: Any,
  445. ) -> List[TokensTextLogprobs]:
  446. '''
  447. Greedy logprobs generation for Aphrodite encoder/decoder models
  448. '''
  449. all_logprobs: List[List[Dict[int, float]]] = []
  450. all_output_ids: List[List[int]] = []
  451. all_output_strs: List[str] = []
  452. for (encoder_prompt,
  453. decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts):
  454. encoder_input_ids = self.wrap_device(
  455. self.tokenizer(encoder_prompt, return_tensors="pt").input_ids)
  456. decoder_input_ids = (
  457. None if decoder_prompt is None else self.wrap_device(
  458. self.tokenizer(decoder_prompt,
  459. return_tensors="pt").input_ids))
  460. output = self.model.generate(
  461. encoder_input_ids,
  462. decoder_input_ids=decoder_input_ids,
  463. use_cache=True,
  464. do_sample=False,
  465. max_new_tokens=max_tokens,
  466. output_hidden_states=True,
  467. return_dict_in_generate=True,
  468. **kwargs,
  469. )
  470. (
  471. seq_logprobs_lst,
  472. output_len,
  473. ) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
  474. num_logprobs)
  475. all_logprobs.append(seq_logprobs_lst)
  476. seq_ids = output.sequences[0]
  477. output_ids = seq_ids[-output_len:]
  478. all_output_ids.append(output_ids.tolist())
  479. all_output_strs.append(self.tokenizer.decode(output_ids))
  480. outputs = zip(all_output_ids, all_output_strs, all_logprobs)
  481. return [(output_ids, output_str, output_logprobs)
  482. for output_ids, output_str, output_logprobs in outputs]
  483. def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]:
  484. return self.model.encode(prompts)
  485. def __enter__(self):
  486. return self
  487. def __exit__(self, exc_type, exc_value, traceback):
  488. del self.model
  489. cleanup()
  490. @pytest.fixture(scope="session")
  491. def hf_runner():
  492. return HfRunner
  493. class AphroditeRunner:
  494. def __init__(
  495. self,
  496. model_name: str,
  497. tokenizer_name: Optional[str] = None,
  498. # Use smaller max model length, otherwise bigger model cannot run due
  499. # to kv cache size limit.
  500. max_model_len: int = 1024,
  501. dtype: str = "half",
  502. disable_log_stats: bool = True,
  503. tensor_parallel_size: int = 1,
  504. block_size: int = 16,
  505. enable_chunked_prefill: bool = False,
  506. swap_space: int = 4,
  507. enforce_eager: Optional[bool] = False,
  508. **kwargs,
  509. ) -> None:
  510. self.model = LLM(
  511. model=model_name,
  512. tokenizer=tokenizer_name,
  513. trust_remote_code=True,
  514. dtype=dtype,
  515. swap_space=swap_space,
  516. enforce_eager=enforce_eager,
  517. disable_log_stats=disable_log_stats,
  518. tensor_parallel_size=tensor_parallel_size,
  519. max_model_len=max_model_len,
  520. block_size=block_size,
  521. enable_chunked_prefill=enable_chunked_prefill,
  522. **kwargs,
  523. )
  524. def generate(
  525. self,
  526. prompts: List[str],
  527. sampling_params: SamplingParams,
  528. images: Optional[PromptImageInput] = None,
  529. ) -> List[Tuple[List[List[int]], List[str]]]:
  530. if images is not None:
  531. assert len(prompts) == len(images)
  532. inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
  533. if images is not None:
  534. for i, image in enumerate(images):
  535. inputs[i]["multi_modal_data"] = {"image": image}
  536. req_outputs = self.model.generate(inputs,
  537. sampling_params=sampling_params)
  538. outputs: List[Tuple[List[List[int]], List[str]]] = []
  539. for req_output in req_outputs:
  540. prompt_str = req_output.prompt
  541. prompt_ids = req_output.prompt_token_ids
  542. req_sample_output_ids: List[List[int]] = []
  543. req_sample_output_strs: List[str] = []
  544. for sample in req_output.outputs:
  545. output_str = sample.text
  546. output_ids = list(sample.token_ids)
  547. req_sample_output_ids.append(prompt_ids + output_ids)
  548. req_sample_output_strs.append(prompt_str + output_str)
  549. outputs.append((req_sample_output_ids, req_sample_output_strs))
  550. return outputs
  551. @staticmethod
  552. def _final_steps_generate_w_logprobs(
  553. req_outputs: List[RequestOutput],
  554. ) -> List[TokensTextLogprobsPromptLogprobs]:
  555. outputs: List[TokensTextLogprobsPromptLogprobs] = []
  556. for req_output in req_outputs:
  557. assert len(req_output.outputs) > 0
  558. for sample in req_output.outputs:
  559. output_str = sample.text
  560. output_ids = list(sample.token_ids)
  561. output_logprobs = sample.logprobs
  562. outputs.append((output_ids, output_str, output_logprobs,
  563. req_output.prompt_logprobs))
  564. return outputs
  565. def generate_w_logprobs(
  566. self,
  567. prompts: List[str],
  568. sampling_params: SamplingParams,
  569. images: Optional[PromptImageInput] = None,
  570. audios: Optional[PromptAudioInput] = None,
  571. videos: Optional[PromptVideoInput] = None,
  572. ) -> Union[List[TokensTextLogprobs],
  573. List[TokensTextLogprobsPromptLogprobs]]:
  574. assert sampling_params.logprobs is not None
  575. if images is not None:
  576. assert len(prompts) == len(images)
  577. if videos is not None:
  578. assert len(prompts) == len(videos)
  579. inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
  580. if images is not None:
  581. for i, image in enumerate(images):
  582. inputs[i]["multi_modal_data"] = {"image": image}
  583. if audios is not None:
  584. for i, audio in enumerate(audios):
  585. inputs[i]["multi_modal_data"] = {"audio": audio}
  586. if videos is not None:
  587. for i, video in enumerate(videos):
  588. inputs[i]["multi_modal_data"] = {"video": video}
  589. print(f"[INPUTS!!!!]: {inputs}, {sampling_params}")
  590. req_outputs = self.model.generate(inputs,
  591. sampling_params=sampling_params)
  592. toks_str_logsprobs_prompt_logprobs = (
  593. self._final_steps_generate_w_logprobs(req_outputs))
  594. # Omit prompt logprobs if not required by sampling params
  595. return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
  596. if sampling_params.prompt_logprobs is None else
  597. toks_str_logsprobs_prompt_logprobs)
  598. def generate_encoder_decoder_w_logprobs(
  599. self,
  600. encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
  601. sampling_params: SamplingParams,
  602. ) -> Union[List[TokensTextLogprobs],
  603. List[TokensTextLogprobsPromptLogprobs]]:
  604. '''
  605. Logprobs generation for Aphrodite encoder/decoder models
  606. '''
  607. assert sampling_params.logprobs is not None
  608. req_outputs = self.model.generate(encoder_decoder_prompts,
  609. sampling_params=sampling_params)
  610. toks_str_logsprobs_prompt_logprobs = (
  611. self._final_steps_generate_w_logprobs(req_outputs))
  612. # Omit prompt logprobs if not required by sampling params
  613. return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
  614. if sampling_params.prompt_logprobs is None else
  615. toks_str_logsprobs_prompt_logprobs)
  616. def generate_greedy(
  617. self,
  618. prompts: List[str],
  619. max_tokens: int,
  620. images: Optional[PromptImageInput] = None,
  621. ) -> List[Tuple[List[int], str]]:
  622. greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
  623. outputs = self.generate(prompts, greedy_params, images=images)
  624. return [(output_ids[0], output_str[0])
  625. for output_ids, output_str in outputs]
  626. def generate_greedy_logprobs(
  627. self,
  628. prompts: List[str],
  629. max_tokens: int,
  630. num_logprobs: int,
  631. num_prompt_logprobs: Optional[int] = None,
  632. images: Optional[PromptImageInput] = None,
  633. audios: Optional[PromptAudioInput] = None,
  634. videos: Optional[PromptVideoInput] = None,
  635. stop_token_ids: Optional[List[int]] = None,
  636. ) -> Union[List[TokensTextLogprobs],
  637. List[TokensTextLogprobsPromptLogprobs]]:
  638. greedy_logprobs_params = SamplingParams(
  639. temperature=0.0,
  640. max_tokens=max_tokens,
  641. logprobs=num_logprobs,
  642. prompt_logprobs=(num_prompt_logprobs),
  643. stop_token_ids=stop_token_ids)
  644. return self.generate_w_logprobs(prompts,
  645. greedy_logprobs_params,
  646. images=images,
  647. audios=audios,
  648. videos=videos)
  649. def generate_encoder_decoder_greedy_logprobs(
  650. self,
  651. encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
  652. max_tokens: int,
  653. num_logprobs: int,
  654. num_prompt_logprobs: Optional[int] = None,
  655. ) -> Union[List[TokensTextLogprobs],
  656. List[TokensTextLogprobsPromptLogprobs]]:
  657. greedy_logprobs_params = SamplingParams(
  658. temperature=0.0,
  659. use_beam_search=False,
  660. max_tokens=max_tokens,
  661. logprobs=num_logprobs,
  662. prompt_logprobs=(num_prompt_logprobs),
  663. )
  664. '''
  665. Greedy logprobs generation for Aphrodite encoder/decoder models
  666. '''
  667. return self.generate_encoder_decoder_w_logprobs(
  668. encoder_decoder_prompts, greedy_logprobs_params)
  669. def generate_beam_search(
  670. self,
  671. prompts: List[str],
  672. beam_width: int,
  673. max_tokens: int,
  674. ) -> List[Tuple[List[List[int]], List[str]]]:
  675. beam_search_params = SamplingParams(n=beam_width,
  676. use_beam_search=True,
  677. temperature=0.0,
  678. max_tokens=max_tokens)
  679. outputs = self.generate(prompts, beam_search_params)
  680. return outputs
  681. def encode(self, prompts: List[str]) -> List[List[float]]:
  682. req_outputs = self.model.encode(prompts)
  683. outputs = []
  684. for req_output in req_outputs:
  685. embedding = req_output.outputs.embedding
  686. outputs.append(embedding)
  687. return outputs
  688. def __enter__(self):
  689. return self
  690. def __exit__(self, exc_type, exc_value, traceback):
  691. del self.model
  692. cleanup()
  693. @pytest.fixture(scope="session")
  694. def aphrodite_runner():
  695. return AphroditeRunner
  696. def get_tokenizer_pool_config(tokenizer_group_type):
  697. if tokenizer_group_type is None:
  698. return None
  699. if tokenizer_group_type == "ray":
  700. return TokenizerPoolConfig(pool_size=1,
  701. pool_type="ray",
  702. extra_config={})
  703. if isinstance(tokenizer_group_type, type):
  704. return TokenizerPoolConfig(pool_size=1,
  705. pool_type=tokenizer_group_type,
  706. extra_config={})
  707. raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")
  708. @pytest.fixture()
  709. def temporary_enable_log_propagate():
  710. import logging
  711. logger = logging.getLogger("aphrodite")
  712. logger.propagate = True
  713. yield
  714. logger.propagate = False
  715. @pytest.fixture()
  716. def caplog_aphrodite(temporary_enable_log_propagate, caplog):
  717. # To capture aphrodite log, we should enable propagate=True temporarily
  718. # because caplog depends on logs propagated to the root logger.
  719. yield caplog
  720. @pytest.fixture(scope="session")
  721. def num_gpus_available():
  722. """Get number of GPUs without initializing the CUDA context
  723. in current process."""
  724. return cuda_device_count_stateless()
  725. temp_dir = tempfile.gettempdir()
  726. _dummy_path = os.path.join(temp_dir, "dummy_opt")
  727. @pytest.fixture
  728. def dummy_opt_path():
  729. json_path = os.path.join(_dummy_path, "config.json")
  730. if not os.path.exists(_dummy_path):
  731. snapshot_download(repo_id="facebook/opt-125m",
  732. local_dir=_dummy_path,
  733. ignore_patterns=[
  734. "*.bin", "*.bin.index.json", "*.pt", "*.h5",
  735. "*.msgpack"
  736. ])
  737. assert os.path.exists(json_path)
  738. with open(json_path, "r") as f:
  739. config = json.load(f)
  740. config["architectures"] = ["MyOPTForCausalLM"]
  741. with open(json_path, "w") as f:
  742. json.dump(config, f)
  743. return _dummy_path