test_qwen.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import pathlib
  2. from typing import Dict, List, Optional, Tuple, Type, Union
  3. import pytest
  4. import torch
  5. from PIL.Image import Image
  6. from aphrodite.common.config import ModelConfig
  7. from aphrodite.inputs import InputContext, LLMInputs
  8. from aphrodite.multimodal.base import MultiModalInputs
  9. from aphrodite.multimodal.utils import cached_get_tokenizer, rescale_image_size
  10. from ..conftest import (IMAGE_ASSETS, AphroditeRunner, HfRunner, ImageAsset,
  11. PromptImageInput, _ImageAssets)
  12. from .utils import check_logprobs_close
  13. pytestmark = pytest.mark.vlm
  14. text_only_models = [
  15. "Qwen/Qwen-7B-Chat" # Has no visual component
  16. ]
  17. multimodal_models = ["Qwen/Qwen-VL"]
  18. HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
  19. "stop_sign":
  20. "Picture 1: <img></img>\nWhat's the content of the image?: ",
  21. "cherry_blossom":
  22. "Picture 1: <img></img>\nWhat is the season?: ",
  23. })
  24. HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: <img></img>\nPicture 2: <img></img>\nCan you compare these images?\n" # noqa: E501
  25. HF_MULTIIMAGE_IMAGE_PROMPT = "Picture 1: <img></img>\nPicture 2: <img></img>\nDescribe the two images in detail.\n" # noqa: E501
  26. ### Multimodal preprocessing tests
  27. SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
  28. # These values are specific to Qwen-VL/Chat; we can get these from the model
  29. # config also, but they are hardcoded here to keep the parameterize/fixtures
  30. # easy to read.
  31. IMG_START_ID = 151857
  32. IMG_END_ID = 151858
  33. IMG_PAD_ID = 151859
  34. TOKS_PER_IMG = 256
  35. VIS_ENC_DIM = 4096
  36. IMG_SIZE = 448
  37. def build_model_context(model_name: str,
  38. tokenizer_name: Optional[str] = None,
  39. trust_remote_code: bool = False):
  40. """Creates an InputContext for a given model.
  41. Args:
  42. model_name: Name of the model being considered.
  43. tokenizer_name: Name of the tokenizer being considered.
  44. trust_remote_code: Whether or not to allow loading remote code.
  45. Returns:
  46. InputContext for the model being considered.
  47. """
  48. if tokenizer_name is None:
  49. tokenizer_name = model_name
  50. model_config = ModelConfig(
  51. model_name,
  52. tokenizer_name,
  53. tokenizer_mode="auto",
  54. trust_remote_code=trust_remote_code,
  55. dtype="float32",
  56. seed=0,
  57. )
  58. return InputContext(model_config)
  59. @pytest.fixture()
  60. def input_mapper_for_qwen():
  61. # Lazy import to avoid initializing CUDA during test collection
  62. from aphrodite.modeling.models.qwen import input_mapper_for_qwen
  63. return input_mapper_for_qwen
  64. @pytest.fixture()
  65. def input_processor_for_qwen():
  66. # Lazy import to avoid initializing CUDA during test collection
  67. from aphrodite.modeling.models.qwen import input_processor_for_qwen
  68. return input_processor_for_qwen
  69. @pytest.fixture()
  70. def qwen_vl_context() -> InputContext:
  71. """Get an InputContext for Qwen-VL."""
  72. return build_model_context(model_name="Qwen/Qwen-VL",
  73. trust_remote_code=True)
  74. # Happy path tests for single/multi-image scenarios for the multimodal
  75. # input processor and mapper, respectively
  76. @pytest.mark.parametrize("num_images", [1, 2])
  77. def test_input_processor_valid_mm_data(input_processor_for_qwen,
  78. qwen_vl_context: InputContext,
  79. num_images: int):
  80. """Happy cases for image inputs to Qwen's multimodal input processor."""
  81. prompt = "".join(
  82. [f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
  83. inputs = LLMInputs(
  84. prompt=prompt,
  85. # When processing multimodal data for a multimodal model, the qwen
  86. # input processor will overwrite the provided prompt_token_ids with
  87. # the image prompts
  88. prompt_token_ids=None,
  89. multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
  90. )
  91. proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
  92. assert isinstance(proc_inputs, dict)
  93. # Each image should have one start / stop and a fixed context of 256
  94. proc_tokens = proc_inputs["prompt_token_ids"]
  95. assert proc_tokens.count(IMG_START_ID) == num_images
  96. assert proc_tokens.count(IMG_END_ID) == num_images
  97. assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG
  98. @pytest.mark.parametrize(
  99. "img_data,expected_shape",
  100. [
  101. # single / multi-image
  102. (SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)),
  103. (2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)),
  104. # single / multi-image embeddings
  105. (torch.rand(
  106. (TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
  107. (torch.rand(
  108. (1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
  109. (torch.rand(
  110. (2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)),
  111. ])
  112. def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
  113. qwen_vl_context: InputContext,
  114. img_data: Union[torch.Tensor, List[Image],
  115. Image],
  116. expected_shape: List[int]):
  117. """Happy cases for image inputs to Qwen's multimodal input mapper."""
  118. mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
  119. # Ensure that we get the appropriately shaped pixel_values
  120. # for images and image embeddings, respectively.
  121. assert isinstance(mapped_img_data, MultiModalInputs)
  122. assert "pixel_values" in mapped_img_data
  123. assert mapped_img_data["pixel_values"].shape == expected_shape
  124. # Sad path tests for the multimodal input processor and mapper, respectively
  125. @pytest.mark.parametrize("mm_data", [
  126. {
  127. "image": torch.rand((5))
  128. },
  129. {
  130. "image": torch.rand((5, 5, 5, 5, 5))
  131. },
  132. ])
  133. def test_input_processor_invalid_mm_data(input_processor_for_qwen,
  134. qwen_vl_context: InputContext,
  135. mm_data: Dict[str, torch.Tensor]):
  136. """Test sad cases validated in Qwen's multimodal input processor."""
  137. tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer,
  138. trust_remote_code=True)
  139. prompt = "Picture 1: <img></img>\n"
  140. prompt_token_ids = tokenizer.encode(prompt)
  141. inputs = LLMInputs(prompt=prompt,
  142. prompt_token_ids=prompt_token_ids,
  143. multi_modal_data=mm_data)
  144. # Should fail since we have too many or too few dimensions for embeddings
  145. with pytest.raises(ValueError):
  146. input_processor_for_qwen(qwen_vl_context, inputs)
  147. @pytest.mark.parametrize(
  148. "img_data",
  149. [
  150. # Wrong context length
  151. torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)),
  152. # Wrong visual encoder output size
  153. torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)),
  154. ])
  155. def test_input_mapper_invalid_mm_data(
  156. input_mapper_for_qwen,
  157. qwen_vl_context: InputContext,
  158. img_data: Union[torch.Tensor, List[Image], Image],
  159. ):
  160. """Sad cases validated in Qwen VL's multimodal input mapper."""
  161. with pytest.raises(ValueError):
  162. input_mapper_for_qwen(qwen_vl_context, img_data)
  163. ### End-to-end generation tests
  164. def get_prompt_with_path(tmp_path: pathlib.PosixPath, prompt: str,
  165. assets: Union[_ImageAssets, List[ImageAsset]]) -> str:
  166. """Given a temporary dir path, export one or more image assets into the
  167. tempdir & replace its contents with the local path to the string so that
  168. the HF version of Qwen-VL can resolve the path and load the image ni its
  169. forward() call.
  170. Args:
  171. tmp_path: Tempdir for test under consideration.
  172. prompt: Prompt with image placeholders.
  173. assets: List of image assets whose len equals the num placeholders.
  174. """
  175. # Ensure that the number of placeholders matches the number of assets;
  176. # If this is not true, the test is probably written incorrectly.
  177. assert prompt.count("<img></img>") == len(assets)
  178. # Replace the placeholders with local paths to the exported assets
  179. for asset in assets:
  180. image_tmp_path = tmp_path / f"{asset.name}.jpg"
  181. asset.pil_image.save(image_tmp_path)
  182. prompt = prompt.replace(
  183. "<img></img>",
  184. f"<img>{image_tmp_path}</img>",
  185. 1,
  186. )
  187. return prompt
  188. def run_test(
  189. hf_runner: Type[HfRunner],
  190. aphrodite_runner: Type[AphroditeRunner],
  191. inputs: List[Tuple[List[str], PromptImageInput]],
  192. model: str,
  193. *,
  194. dtype: str,
  195. max_tokens: int,
  196. num_logprobs: int,
  197. mm_limit: int,
  198. tensor_parallel_size: int,
  199. distributed_executor_backend: Optional[str] = None,
  200. ):
  201. """Inference result should be the same between hf and aphrodite.
  202. All the image fixtures for the test is under tests/images.
  203. For huggingface runner, we provide the PIL images as input.
  204. For aphrodite runner, we provide MultiModalDataDict objects
  205. and corresponding MultiModalConfig as input.
  206. Note, the text input is also adjusted to abide by aphrodite contract.
  207. The text output is sanitized to be able to compare with hf.
  208. """
  209. # NOTE: take care of the order. run Aphrodite first, and then run HF.
  210. # Aphrodite needs a fresh new process without cuda initialization.
  211. # if we run HF first, the cuda initialization will be done and it
  212. # will hurt multiprocessing backend with fork method (the default method).
  213. # max_model_len should be greater than image_feature_size
  214. # Qwen encodes each image into a fixed content size of 256
  215. with aphrodite_runner(model,
  216. max_model_len=1024,
  217. max_num_seqs=1,
  218. dtype=dtype,
  219. limit_mm_per_prompt={"image": mm_limit},
  220. tensor_parallel_size=tensor_parallel_size,
  221. distributed_executor_backend=distributed_executor_backend,
  222. enforce_eager=True) as aphrodite_model:
  223. aphrodite_outputs_per_image = [
  224. aphrodite_model.generate_greedy_logprobs(prompts,
  225. max_tokens,
  226. num_logprobs=num_logprobs,
  227. images=images)
  228. for prompts, images in inputs
  229. ]
  230. with hf_runner(model, dtype=dtype) as hf_model:
  231. hf_outputs_per_image = [
  232. hf_model.generate_greedy_logprobs_limit(prompts,
  233. max_tokens,
  234. num_logprobs=num_logprobs,
  235. images=images)
  236. for prompts, images in inputs
  237. ]
  238. for hf_outputs, aphrodite_outputs in zip(hf_outputs_per_image,
  239. aphrodite_outputs_per_image):
  240. check_logprobs_close(
  241. outputs_0_lst=hf_outputs,
  242. outputs_1_lst=aphrodite_outputs,
  243. name_0="hf",
  244. name_1="aphrodite",
  245. )
  246. @pytest.mark.parametrize("model", multimodal_models)
  247. @pytest.mark.parametrize(
  248. "size_factors",
  249. [
  250. # No image
  251. [],
  252. # Single-scale
  253. [1.0],
  254. # Single-scale, batched
  255. [1.0, 1.0, 1.0],
  256. # Multi-scale
  257. [0.25, 0.5, 1.0],
  258. ],
  259. )
  260. @pytest.mark.parametrize("dtype", ["bfloat16"])
  261. @pytest.mark.parametrize("max_tokens", [8])
  262. @pytest.mark.parametrize("num_logprobs", [5])
  263. def test_multimodal_models_single_image(tmp_path: pathlib.PosixPath,
  264. hf_runner: Type[HfRunner],
  265. aphrodite_runner: Type[AphroditeRunner],
  266. image_assets: _ImageAssets, model: str,
  267. size_factors: List[float], dtype: str,
  268. max_tokens: int,
  269. num_logprobs: int) -> None:
  270. """Tests multimodal models with single image prompts."""
  271. images = [asset.pil_image for asset in image_assets]
  272. prompts = [
  273. get_prompt_with_path(tmp_path, prompt, [asset])
  274. for prompt, asset in zip(HF_IMAGE_PROMPTS, image_assets)
  275. ]
  276. inputs = [(
  277. [prompt for _ in size_factors],
  278. [rescale_image_size(image, factor) for factor in size_factors],
  279. ) for image, prompt in zip(images, prompts)]
  280. run_test(
  281. hf_runner,
  282. aphrodite_runner,
  283. inputs,
  284. model,
  285. dtype=dtype,
  286. max_tokens=max_tokens,
  287. num_logprobs=num_logprobs,
  288. mm_limit=1,
  289. tensor_parallel_size=1,
  290. )
  291. @pytest.mark.parametrize("model", multimodal_models)
  292. @pytest.mark.parametrize(
  293. "size_factors",
  294. [
  295. # No image
  296. [],
  297. # Single-scale
  298. [1.0],
  299. # Single-scale, batched
  300. [1.0, 1.0, 1.0],
  301. # Multi-scale
  302. [0.25, 0.5, 1.0],
  303. ],
  304. )
  305. @pytest.mark.parametrize("dtype", ["bfloat16"])
  306. @pytest.mark.parametrize("max_tokens", [128])
  307. @pytest.mark.parametrize("num_logprobs", [5])
  308. def test_multimodal_models_multi_image(tmp_path: pathlib.PosixPath,
  309. hf_runner: Type[HfRunner],
  310. aphrodite_runner: Type[AphroditeRunner],
  311. image_assets: _ImageAssets, model: str,
  312. size_factors: List[float], dtype: str,
  313. max_tokens: int,
  314. num_logprobs: int) -> None:
  315. """Tests multimodal models with multi-image prompts."""
  316. images = [asset.pil_image for asset in image_assets]
  317. # Put all of the images into one prompt.
  318. prompt = get_prompt_with_path(tmp_path, HF_MULTIIMAGE_IMAGE_PROMPT,
  319. image_assets)
  320. inputs = [([prompt for _ in size_factors],
  321. [[rescale_image_size(image, factor) for image in images]
  322. for factor in size_factors])]
  323. run_test(
  324. hf_runner,
  325. aphrodite_runner,
  326. inputs,
  327. model,
  328. dtype=dtype,
  329. max_tokens=max_tokens,
  330. num_logprobs=num_logprobs,
  331. mm_limit=2,
  332. tensor_parallel_size=1,
  333. )
  334. # Ensure that a text-only Qwen model can still be loaded and
  335. # used for inference in Aphrodite without throwing.
  336. @pytest.mark.parametrize("model", text_only_models)
  337. @pytest.mark.parametrize("dtype", ["bfloat16"])
  338. @pytest.mark.parametrize("max_tokens", [32])
  339. @pytest.mark.parametrize("num_logprobs", [5])
  340. def test_text_only_qwen_model_can_be_loaded_and_run(
  341. aphrodite_runner: Type[AphroditeRunner],
  342. example_prompts: List[str],
  343. model: str,
  344. *,
  345. dtype: str,
  346. max_tokens: int,
  347. num_logprobs: int,
  348. ):
  349. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  350. aphrodite_model.generate_greedy_logprobs(
  351. example_prompts,
  352. max_tokens,
  353. num_logprobs=num_logprobs,
  354. )