test_internvl.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import types
  2. from typing import List, Optional, Tuple, Type, Union
  3. import pytest
  4. import torch
  5. from PIL.Image import Image
  6. from transformers import AutoConfig
  7. from aphrodite.common.utils import is_cpu
  8. from aphrodite.multimodal.utils import rescale_image_size
  9. from ..conftest import (IMAGE_ASSETS, AphroditeRunner, HfRunner,
  10. PromptImageInput, _ImageAssets)
  11. from .utils import check_logprobs_close
  12. pytestmark = pytest.mark.vlm
  13. HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
  14. "stop_sign":
  15. "<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
  16. "cherry_blossom":
  17. "<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
  18. })
  19. HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
  20. models = [
  21. "OpenGVLab/InternVL2-1B",
  22. "OpenGVLab/InternVL2-2B",
  23. # Broken due to outdated implementation of Phi-3
  24. # See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
  25. # "OpenGVLab/InternVL2-4B",
  26. ]
  27. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
  28. def generate(
  29. self,
  30. pixel_values: torch.FloatTensor,
  31. input_ids: torch.FloatTensor,
  32. attention_mask: Optional[torch.LongTensor] = None,
  33. **generate_kwargs,
  34. ) -> torch.LongTensor:
  35. """Generate method for InternVL2 model without fixed use_cache."""
  36. assert self.img_context_token_id is not None
  37. vit_embeds = self.extract_feature(pixel_values)
  38. input_embeds = self.language_model.get_input_embeddings()(input_ids)
  39. B, N, C = input_embeds.shape
  40. input_embeds = input_embeds.reshape(B * N, C)
  41. input_ids = input_ids.reshape(B * N)
  42. selected = (input_ids == self.img_context_token_id)
  43. assert selected.sum() != 0
  44. input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
  45. input_embeds = input_embeds.reshape(B, N, C)
  46. outputs = self.language_model.generate(
  47. inputs_embeds=input_embeds,
  48. attention_mask=attention_mask,
  49. **generate_kwargs,
  50. )
  51. return outputs
  52. def run_test(
  53. hf_runner: Type[HfRunner],
  54. aphrodite_runner: Type[AphroditeRunner],
  55. inputs: List[Tuple[List[str], PromptImageInput]],
  56. model: str,
  57. *,
  58. dtype: str,
  59. max_tokens: int,
  60. num_logprobs: int,
  61. mm_limit: int,
  62. tensor_parallel_size: int,
  63. distributed_executor_backend: Optional[str] = None,
  64. ):
  65. """Inference result should be the same between hf and aphrodite.
  66. All the image fixtures for the test is under tests/images.
  67. For huggingface runner, we provide the PIL images as input.
  68. For aphrodite runner, we provide MultiModalDataDict objects
  69. and corresponding MultiModalConfig as input.
  70. Note, the text input is also adjusted to abide by aphrodite contract.
  71. The text output is sanitized to be able to compare with hf.
  72. """
  73. # NOTE: take care of the order. run Aphrodite first, and then run HF.
  74. # Aphrodite needs a fresh new process without cuda initialization.
  75. # if we run HF first, the cuda initialization will be done and it
  76. # will hurt multiprocessing backend with fork method (the default method).
  77. class InternVLProcessor:
  78. """A simple processor for InternVL2 which misses a processor."""
  79. def __init__(self, hf_runner: HfRunner):
  80. self.num_image_token = hf_runner.model.num_image_token
  81. self.tokenizer = hf_runner.tokenizer
  82. self.dtype = hf_runner.model.dtype
  83. self.config = AutoConfig.from_pretrained(hf_runner.model_name)
  84. self.vision_config = self.config.vision_config
  85. self.use_thumbnail = self.config.use_thumbnail
  86. self.min_num = self.config.min_dynamic_patch
  87. self.max_num = self.config.max_dynamic_patch
  88. self.image_size = self.vision_config.image_size
  89. def __call__(self, text: str, images: Union[Image, List[Image]],
  90. **kwargs):
  91. from aphrodite.modeling.models.internvl import (
  92. IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
  93. images = [images] if isinstance(images, Image) else images
  94. pixel_values = [
  95. image_to_pixel_values(image, self.image_size, self.min_num,
  96. self.max_num,
  97. self.use_thumbnail).to(self.dtype)
  98. for image in images
  99. ]
  100. num_patches_list = [
  101. pixel_value.shape[0] for pixel_value in pixel_values
  102. ]
  103. pixel_values = torch.cat(pixel_values, dim=0)
  104. for num_patches in num_patches_list:
  105. context_tokens = IMG_CONTEXT * self.num_image_token \
  106. * num_patches
  107. image_tokens = IMG_START + context_tokens + IMG_END
  108. text = text.replace('<image>', image_tokens, 1)
  109. prompt = self.tokenizer(text, return_tensors="pt")
  110. prompt.update({"pixel_values": pixel_values})
  111. return prompt
  112. # max_model_len should be greater than image_feature_size
  113. with aphrodite_runner(model,
  114. max_model_len=4096,
  115. dtype=dtype,
  116. limit_mm_per_prompt={"image": mm_limit},
  117. tensor_parallel_size=tensor_parallel_size,
  118. distributed_executor_backend=distributed_executor_backend,
  119. enforce_eager=True) as aphrodite_model:
  120. aphrodite_outputs_per_image = [
  121. aphrodite_model.generate_greedy_logprobs(prompts,
  122. max_tokens,
  123. num_logprobs=num_logprobs,
  124. images=images)
  125. for prompts, images in inputs
  126. ]
  127. with hf_runner(model, dtype=dtype) as hf_model:
  128. img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
  129. "<IMG_CONTEXT>")
  130. hf_model.model.img_context_token_id = img_context_token_id
  131. hf_model.processor = InternVLProcessor(hf_model)
  132. hf_model.model.get_output_embeddings = lambda: \
  133. hf_model.model.language_model.get_output_embeddings()
  134. hf_model.model.generate = types.MethodType(generate, hf_model.model)
  135. eos_token_id = hf_model.tokenizer.eos_token_id
  136. hf_outputs_per_image = [
  137. hf_model.generate_greedy_logprobs_limit(prompts,
  138. max_tokens,
  139. num_logprobs=num_logprobs,
  140. images=hf_images,
  141. eos_token_id=eos_token_id)
  142. for prompts, hf_images in inputs
  143. ]
  144. for hf_outputs, aphrodite_outputs in zip(hf_outputs_per_image,
  145. aphrodite_outputs_per_image):
  146. # TODO: Check whether using original CLIPVisionModel can improve
  147. # consistency against HF
  148. check_logprobs_close(
  149. outputs_0_lst=hf_outputs,
  150. outputs_1_lst=aphrodite_outputs,
  151. name_0="hf",
  152. name_1="aphrodite",
  153. )
  154. def run_awq_test(
  155. aphrodite_runner: Type[AphroditeRunner],
  156. image_assets: _ImageAssets,
  157. models: Tuple[str, str],
  158. *,
  159. size_factors: List[float],
  160. dtype: str,
  161. max_tokens: int,
  162. num_logprobs: int,
  163. tensor_parallel_size: int,
  164. distributed_executor_backend: Optional[str] = None,
  165. ):
  166. source_model, quant_model = models
  167. images = [asset.pil_image for asset in image_assets]
  168. inputs_per_image = [(
  169. [prompt for _ in size_factors],
  170. [rescale_image_size(image, factor) for factor in size_factors],
  171. ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
  172. # NOTE: take care of the order. run Aphrodite first, and then run HF.
  173. # Aphrodite needs a fresh new process without cuda initialization.
  174. # if we run HF first, the cuda initialization will be done and it
  175. # will hurt multiprocessing backend with fork method (the default method).
  176. # max_model_len should be greater than image_feature_size
  177. with aphrodite_runner(source_model,
  178. max_model_len=4096,
  179. dtype=dtype,
  180. tensor_parallel_size=tensor_parallel_size,
  181. distributed_executor_backend=distributed_executor_backend,
  182. enforce_eager=True) as aphrodite_model:
  183. source_outputs_per_image = [
  184. aphrodite_model.generate_greedy_logprobs(prompts,
  185. max_tokens,
  186. num_logprobs=num_logprobs,
  187. images=images)
  188. for prompts, images in inputs_per_image
  189. ]
  190. with aphrodite_runner(quant_model,
  191. quantization="awq",
  192. max_model_len=4096,
  193. dtype=dtype,
  194. tensor_parallel_size=tensor_parallel_size,
  195. distributed_executor_backend=distributed_executor_backend,
  196. enforce_eager=True) as aphrodite_model:
  197. quant_outputs_per_image = [
  198. aphrodite_model.generate_greedy_logprobs(prompts,
  199. max_tokens,
  200. num_logprobs=num_logprobs,
  201. images=images)
  202. for prompts, images in inputs_per_image
  203. ]
  204. for source_outputs, quant_outputs in zip(source_outputs_per_image,
  205. quant_outputs_per_image):
  206. # TODO: Check whether using original CLIPVisionModel can improve
  207. # consistency against HF
  208. check_logprobs_close(
  209. outputs_0_lst=source_outputs,
  210. outputs_1_lst=quant_outputs,
  211. name_0="source",
  212. name_1="awq",
  213. )
  214. target_dtype = "half"
  215. if is_cpu():
  216. target_dtype = "bfloat16"
  217. @pytest.mark.parametrize("model", models)
  218. @pytest.mark.parametrize(
  219. "size_factors",
  220. [
  221. # No image
  222. [],
  223. # Single-scale
  224. [1.0],
  225. # Single-scale, batched
  226. [1.0, 1.0, 1.0],
  227. # Multi-scale
  228. [0.25, 0.5, 1.0],
  229. ],
  230. )
  231. @pytest.mark.parametrize("dtype", [target_dtype])
  232. @pytest.mark.parametrize("max_tokens", [128])
  233. @pytest.mark.parametrize("num_logprobs", [5])
  234. @torch.inference_mode()
  235. def test_models(hf_runner, aphrodite_runner, image_assets, model, size_factors,
  236. dtype: str, max_tokens: int, num_logprobs: int) -> None:
  237. images = [asset.pil_image for asset in image_assets]
  238. inputs_per_image = [(
  239. [prompt for _ in size_factors],
  240. [rescale_image_size(image, factor) for factor in size_factors],
  241. ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
  242. run_test(
  243. hf_runner,
  244. aphrodite_runner,
  245. inputs_per_image,
  246. model,
  247. dtype=dtype,
  248. max_tokens=max_tokens,
  249. num_logprobs=num_logprobs,
  250. mm_limit=1,
  251. tensor_parallel_size=1,
  252. )
  253. @pytest.mark.parametrize("model", models)
  254. @pytest.mark.parametrize(
  255. "size_factors",
  256. [
  257. # No image
  258. [],
  259. # Single-scale
  260. [1.0],
  261. # Single-scale, batched
  262. [1.0, 1.0, 1.0],
  263. # Multi-scale
  264. [0.5, 0.75, 1.0],
  265. ],
  266. )
  267. @pytest.mark.parametrize("dtype", [target_dtype])
  268. @pytest.mark.parametrize("max_tokens", [128])
  269. @pytest.mark.parametrize("num_logprobs", [5])
  270. @torch.inference_mode()
  271. def test_multi_images_models(hf_runner, aphrodite_runner, image_assets, model,
  272. size_factors, dtype: str, max_tokens: int,
  273. num_logprobs: int) -> None:
  274. images = [asset.pil_image for asset in image_assets]
  275. inputs_per_case = [
  276. ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
  277. [[rescale_image_size(image, factor) for image in images]
  278. for factor in size_factors])
  279. ]
  280. run_test(
  281. hf_runner,
  282. aphrodite_runner,
  283. inputs_per_case,
  284. model,
  285. dtype=dtype,
  286. max_tokens=max_tokens,
  287. num_logprobs=num_logprobs,
  288. mm_limit=2,
  289. tensor_parallel_size=1,
  290. )
  291. @pytest.mark.parametrize("model", ["OpenGVLab/InternVL2-2B"])
  292. @pytest.mark.parametrize("size_factors", [[0.5, 1.0]])
  293. @pytest.mark.parametrize("dtype", [target_dtype])
  294. @pytest.mark.parametrize("max_tokens", [128])
  295. @pytest.mark.parametrize("num_logprobs", [5])
  296. @torch.inference_mode()
  297. def test_different_num_patches(hf_runner, aphrodite_runner, image_assets, model,
  298. size_factors, dtype: str, max_tokens: int,
  299. num_logprobs: int) -> None:
  300. images = [asset.pil_image.resize((896, 896)) for asset in image_assets]
  301. inputs_batching = [(
  302. [prompt for _ in size_factors],
  303. [rescale_image_size(image, factor) for factor in size_factors],
  304. ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
  305. inputs_multi_images = [
  306. ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
  307. [[rescale_image_size(image, factor) for image in images]
  308. for factor in size_factors])
  309. ]
  310. for inputs in [inputs_batching, inputs_multi_images]:
  311. run_test(
  312. hf_runner,
  313. aphrodite_runner,
  314. inputs,
  315. model,
  316. dtype=dtype,
  317. max_tokens=max_tokens,
  318. num_logprobs=num_logprobs,
  319. mm_limit=2,
  320. tensor_parallel_size=1,
  321. )
  322. @pytest.mark.parametrize(
  323. "models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")])
  324. @pytest.mark.parametrize(
  325. "size_factors",
  326. [
  327. # No image
  328. [],
  329. # Single-scale
  330. [1.0],
  331. # Single-scale, batched
  332. [1.0, 1.0, 1.0],
  333. # Multi-scale
  334. [0.25, 0.5, 1.0],
  335. ],
  336. )
  337. @pytest.mark.parametrize("dtype", ["half"])
  338. @pytest.mark.parametrize("max_tokens", [128])
  339. @pytest.mark.parametrize("num_logprobs", [5])
  340. @torch.inference_mode()
  341. def test_awq_models(aphrodite_runner, image_assets, models, size_factors,
  342. dtype: str, max_tokens: int, num_logprobs: int) -> None:
  343. run_awq_test(
  344. aphrodite_runner,
  345. image_assets,
  346. models,
  347. size_factors=size_factors,
  348. dtype=dtype,
  349. max_tokens=max_tokens,
  350. num_logprobs=num_logprobs,
  351. tensor_parallel_size=1,
  352. )