conftest.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import contextlib
  2. import gc
  3. import os
  4. from typing import List, Optional, Tuple
  5. import pytest
  6. import torch
  7. from PIL import Image
  8. from transformers import (AutoModelForCausalLM, AutoProcessor,
  9. LlavaForConditionalGeneration)
  10. from aphrodite import LLM, SamplingParams
  11. from aphrodite.common.config import TokenizerPoolConfig, VisionLanguageConfig
  12. from aphrodite.common.sequence import MultiModalData
  13. from aphrodite.distributed import destroy_model_parallel
  14. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  15. _TEST_DIR = os.path.dirname(__file__)
  16. _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
  17. _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")]
  18. # Multi modal related
  19. _PIXEL_VALUES_FILES = [
  20. os.path.join(_TEST_DIR, "images", filename) for filename in
  21. ["stop_sign_pixel_values.pt", "cherry_blossom_pixel_values.pt"]
  22. ]
  23. _IMAGE_FEATURES_FILES = [
  24. os.path.join(_TEST_DIR, "images", filename) for filename in
  25. ["stop_sign_image_features.pt", "cherry_blossom_image_features.pt"]
  26. ]
  27. _IMAGE_FILES = [
  28. os.path.join(_TEST_DIR, "images", filename)
  29. for filename in ["stop_sign.jpg", "cherry_blossom.jpg"]
  30. ]
  31. _IMAGE_PROMPTS = [
  32. "<image>\nUSER: What's the content of the image?\nASSISTANT:",
  33. "<image>\nUSER: What is the season?\nASSISTANT:"
  34. ]
  35. assert len(_PIXEL_VALUES_FILES) == len(_IMAGE_FEATURES_FILES) == len(
  36. _IMAGE_FILES) == len(_IMAGE_PROMPTS)
  37. def _read_prompts(filename: str) -> List[str]:
  38. with open(filename, "r") as f:
  39. prompts = f.readlines()
  40. return prompts
  41. def cleanup():
  42. destroy_model_parallel()
  43. with contextlib.suppress(AssertionError):
  44. torch.distributed.destroy_process_group()
  45. gc.collect()
  46. torch.cuda.empty_cache()
  47. @pytest.fixture()
  48. def should_do_global_cleanup_after_test(request) -> bool:
  49. """Allow subdirectories to skip global cleanup by overriding this fixture.
  50. This can provide a ~10x speedup for non-GPU unit tests since they don't need
  51. to initialize torch.
  52. """
  53. if request.node.get_closest_marker("skip_global_cleanup"):
  54. return False
  55. return True
  56. @pytest.fixture(autouse=True)
  57. def cleanup_fixture(should_do_global_cleanup_after_test: bool):
  58. yield
  59. if should_do_global_cleanup_after_test:
  60. cleanup()
  61. @pytest.fixture(scope="session")
  62. def hf_image_prompts() -> List[str]:
  63. return _IMAGE_PROMPTS
  64. @pytest.fixture(scope="session")
  65. def hf_images() -> List[Image.Image]:
  66. return [Image.open(filename) for filename in _IMAGE_FILES]
  67. @pytest.fixture()
  68. def aphrodite_images(request) -> "torch.Tensor":
  69. vision_language_config = request.getfixturevalue("model_and_config")[1]
  70. all_images = []
  71. if vision_language_config.image_input_type == (
  72. VisionLanguageConfig.ImageInputType.IMAGE_FEATURES):
  73. filenames = _IMAGE_FEATURES_FILES
  74. else:
  75. filenames = _PIXEL_VALUES_FILES
  76. for filename in filenames:
  77. all_images.append(torch.load(filename))
  78. return torch.concat(all_images, dim=0)
  79. @pytest.fixture()
  80. def aphrodite_image_prompts(request) -> List[str]:
  81. vision_language_config = request.getfixturevalue("model_and_config")[1]
  82. return [
  83. "<image>" * (vision_language_config.image_feature_size - 1) + p
  84. for p in _IMAGE_PROMPTS
  85. ]
  86. @pytest.fixture
  87. def example_prompts() -> List[str]:
  88. prompts = []
  89. for filename in _TEST_PROMPTS:
  90. prompts += _read_prompts(filename)
  91. return prompts
  92. @pytest.fixture
  93. def example_long_prompts() -> List[str]:
  94. prompts = []
  95. for filename in _LONG_PROMPTS:
  96. prompts += _read_prompts(filename)
  97. return prompts
  98. _STR_DTYPE_TO_TORCH_DTYPE = {
  99. "half": torch.half,
  100. "bfloat16": torch.bfloat16,
  101. "float": torch.float,
  102. }
  103. _VISION_LANGUAGE_MODELS = {
  104. "llava-hf/llava-1.5-7b-hf": LlavaForConditionalGeneration,
  105. }
  106. class HfRunner:
  107. def __init__(
  108. self,
  109. model_name: str,
  110. tokenizer_name: Optional[str] = None,
  111. dtype: str = "half",
  112. ) -> None:
  113. assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
  114. torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
  115. self.model_name = model_name
  116. if model_name not in _VISION_LANGUAGE_MODELS:
  117. self.model = AutoModelForCausalLM.from_pretrained(
  118. model_name,
  119. torch_dtype=torch_dtype,
  120. trust_remote_code=True,
  121. ).cuda()
  122. self.processor = None
  123. else:
  124. self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained(
  125. model_name,
  126. torch_dtype=torch_dtype,
  127. trust_remote_code=True,
  128. ).cuda()
  129. self.processor = AutoProcessor.from_pretrained(
  130. model_name,
  131. torch_dtype=torch_dtype,
  132. )
  133. if tokenizer_name is None:
  134. tokenizer_name = model_name
  135. self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)
  136. def generate(
  137. self,
  138. prompts: List[str],
  139. images: Optional[List[Image.Image]] = None,
  140. **kwargs,
  141. ) -> List[Tuple[List[int], str]]:
  142. outputs: List[Tuple[List[int], str]] = []
  143. if images:
  144. assert len(prompts) == len(images)
  145. for i, prompt in enumerate(prompts):
  146. if self.model_name not in _VISION_LANGUAGE_MODELS:
  147. input_ids = self.tokenizer(prompt,
  148. return_tensors="pt").input_ids
  149. inputs = {"input_ids": input_ids.cuda()}
  150. else:
  151. image = images[i] if images else None
  152. inputs = self.processor(text=prompt,
  153. images=image,
  154. return_tensors="pt")
  155. inputs = {
  156. key: value.cuda() if value is not None else None
  157. for key, value in inputs.items()
  158. }
  159. output_ids = self.model.generate(
  160. **inputs,
  161. use_cache=True,
  162. **kwargs,
  163. )
  164. output_str = self.tokenizer.batch_decode(
  165. output_ids,
  166. skip_special_tokens=True,
  167. clean_up_tokenization_spaces=False,
  168. )
  169. output_ids = output_ids.cpu().tolist()
  170. outputs.append((output_ids, output_str))
  171. return outputs
  172. def generate_greedy(
  173. self,
  174. prompts: List[str],
  175. max_tokens: int,
  176. images: Optional["torch.Tensor"] = None,
  177. ) -> List[Tuple[List[int], str]]:
  178. outputs = self.generate(prompts,
  179. do_sample=False,
  180. max_new_tokens=max_tokens,
  181. images=images)
  182. for i in range(len(outputs)):
  183. output_ids, output_str = outputs[i]
  184. outputs[i] = (output_ids[0], output_str[0])
  185. return outputs
  186. def generate_beam_search(
  187. self,
  188. prompts: List[str],
  189. beam_width: int,
  190. max_tokens: int,
  191. ) -> List[Tuple[List[int], str]]:
  192. outputs = self.generate(prompts,
  193. do_sample=False,
  194. max_new_tokens=max_tokens,
  195. num_beams=beam_width,
  196. num_return_sequences=beam_width)
  197. for i in range(len(outputs)):
  198. output_ids, output_str = outputs[i]
  199. for j in range(len(output_ids)):
  200. output_ids[j] = [
  201. x for x in output_ids[j]
  202. if x != self.tokenizer.pad_token_id
  203. ]
  204. outputs[i] = (output_ids, output_str)
  205. return outputs
  206. def generate_greedy_logprobs(
  207. self,
  208. prompts: List[str],
  209. max_tokens: int,
  210. ) -> List[List[torch.Tensor]]:
  211. all_logprobs = []
  212. for prompt in prompts:
  213. input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
  214. output = self.model.generate(
  215. input_ids.cuda(),
  216. use_cache=True,
  217. do_sample=False,
  218. max_new_tokens=max_tokens,
  219. output_hidden_states=True,
  220. return_dict_in_generate=True,
  221. )
  222. seq_logprobs = []
  223. for hidden_states in output.hidden_states:
  224. last_hidden_states = hidden_states[-1][0]
  225. logits = torch.matmul(
  226. last_hidden_states,
  227. self.model.get_output_embeddings().weight.t(),
  228. )
  229. if self.model.get_output_embeddings().bias is not None:
  230. logits += self.model.get_output_embeddings(
  231. ).bias.unsqueeze(0)
  232. logprobs = torch.nn.functional.log_softmax(logits,
  233. dim=-1,
  234. dtype=torch.float32)
  235. seq_logprobs.append(logprobs)
  236. all_logprobs.append(seq_logprobs)
  237. return all_logprobs
  238. def __del__(self):
  239. del self.model
  240. cleanup()
  241. @pytest.fixture
  242. def hf_runner():
  243. return HfRunner
  244. class AphroditeRunner:
  245. def __init__(
  246. self,
  247. model_name: str,
  248. tokenizer_name: Optional[str] = None,
  249. # Use smaller max model length, otherwise bigger model cannot run due
  250. # to kv cache size limit.
  251. max_model_len=1024,
  252. dtype: str = "half",
  253. disable_log_stats: bool = True,
  254. tensor_parallel_size: int = 1,
  255. block_size: int = 16,
  256. enable_chunked_prefill: bool = False,
  257. **kwargs,
  258. ) -> None:
  259. self.model = LLM(
  260. model=model_name,
  261. tokenizer=tokenizer_name,
  262. trust_remote_code=True,
  263. dtype=dtype,
  264. swap_space=0,
  265. disable_log_stats=disable_log_stats,
  266. tensor_parallel_size=tensor_parallel_size,
  267. max_model_len=max_model_len,
  268. block_size=block_size,
  269. enable_chunked_prefill=enable_chunked_prefill,
  270. **kwargs,
  271. )
  272. def generate(
  273. self,
  274. prompts: List[str],
  275. sampling_params: SamplingParams,
  276. images: Optional["torch.Tensor"] = None,
  277. ) -> List[Tuple[List[int], str]]:
  278. if images is not None:
  279. assert len(prompts) == images.shape[0]
  280. req_outputs = self.model.generate(
  281. prompts,
  282. sampling_params=sampling_params,
  283. multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
  284. data=images)
  285. if images is not None else None)
  286. outputs = []
  287. for req_output in req_outputs:
  288. prompt_str = req_output.prompt
  289. prompt_ids = req_output.prompt_token_ids
  290. req_sample_output_ids = []
  291. req_sample_output_strs = []
  292. for sample in req_output.outputs:
  293. output_str = sample.text
  294. output_ids = sample.token_ids
  295. req_sample_output_ids.append(prompt_ids + output_ids)
  296. req_sample_output_strs.append(prompt_str + output_str)
  297. outputs.append((req_sample_output_ids, req_sample_output_strs))
  298. return outputs
  299. def generate_w_logprobs(
  300. self,
  301. prompts: List[str],
  302. sampling_params: SamplingParams,
  303. ) -> List[Tuple[List[int], str]]:
  304. assert sampling_params.logprobs is not None
  305. req_outputs = self.model.generate(prompts,
  306. sampling_params=sampling_params)
  307. outputs = []
  308. for req_output in req_outputs:
  309. for sample in req_output.outputs:
  310. output_str = sample.text
  311. output_ids = sample.token_ids
  312. output_logprobs = sample.logprobs
  313. outputs.append((output_ids, output_str, output_logprobs))
  314. return outputs
  315. def generate_greedy(
  316. self,
  317. prompts: List[str],
  318. max_tokens: int,
  319. images: Optional[torch.Tensor] = None,
  320. ) -> List[Tuple[List[int], str]]:
  321. greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
  322. outputs = self.generate(prompts, greedy_params, images=images)
  323. return [(output_ids[0], output_str[0])
  324. for output_ids, output_str in outputs]
  325. def generate_greedy_logprobs(
  326. self,
  327. prompts: List[str],
  328. max_tokens: int,
  329. num_logprobs: int,
  330. ) -> List[Tuple[List[int], str]]:
  331. greedy_logprobs_params = SamplingParams(temperature=0.0,
  332. max_tokens=max_tokens,
  333. logprobs=num_logprobs)
  334. outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
  335. return [(output_ids, output_str, output_logprobs)
  336. for output_ids, output_str, output_logprobs in outputs]
  337. def generate_beam_search(
  338. self,
  339. prompts: List[str],
  340. beam_width: int,
  341. max_tokens: int,
  342. ) -> List[Tuple[List[int], str]]:
  343. beam_search_params = SamplingParams(n=beam_width,
  344. use_beam_search=True,
  345. temperature=0.0,
  346. max_tokens=max_tokens)
  347. outputs = self.generate(prompts, beam_search_params)
  348. return outputs
  349. def __del__(self):
  350. del self.model
  351. cleanup()
  352. @pytest.fixture(scope="session")
  353. def aphrodite_runner():
  354. return AphroditeRunner
  355. def get_tokenizer_pool_config(tokenizer_group_type):
  356. if tokenizer_group_type is None:
  357. return None
  358. if tokenizer_group_type == "ray":
  359. return TokenizerPoolConfig(pool_size=1,
  360. pool_type="ray",
  361. extra_config={})
  362. raise ValueError(f"Unknown tokenizer_group_type: {tokenizer_group_type}")