test_internvl.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import types
  2. from typing import List, Optional, Tuple, Type, Union
  3. import pytest
  4. import torch
  5. from PIL.Image import Image
  6. from transformers import AutoConfig
  7. from aphrodite.common.utils import is_cpu
  8. from aphrodite.multimodal.utils import rescale_image_size
  9. from ....conftest import (IMAGE_ASSETS, AphroditeRunner, HfRunner,
  10. PromptImageInput, _ImageAssets)
  11. from ...utils import check_logprobs_close
  12. HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
  13. "stop_sign":
  14. "<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
  15. "cherry_blossom":
  16. "<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
  17. })
  18. HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: <image>\nImage-2: <image>\nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501
  19. models = [
  20. "OpenGVLab/InternVL2-1B",
  21. "OpenGVLab/InternVL2-2B",
  22. # Broken due to outdated implementation of Phi-3
  23. # See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
  24. # "OpenGVLab/InternVL2-4B",
  25. ]
  26. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
  27. def generate(
  28. self,
  29. pixel_values: torch.FloatTensor,
  30. input_ids: torch.FloatTensor,
  31. attention_mask: Optional[torch.LongTensor] = None,
  32. **generate_kwargs,
  33. ) -> torch.LongTensor:
  34. """Generate method for InternVL2 model without fixed use_cache."""
  35. assert self.img_context_token_id is not None
  36. vit_embeds = self.extract_feature(pixel_values)
  37. input_embeds = self.language_model.get_input_embeddings()(input_ids)
  38. B, N, C = input_embeds.shape
  39. input_embeds = input_embeds.reshape(B * N, C)
  40. input_ids = input_ids.reshape(B * N)
  41. selected = (input_ids == self.img_context_token_id)
  42. assert selected.sum() != 0
  43. input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
  44. input_embeds = input_embeds.reshape(B, N, C)
  45. outputs = self.language_model.generate(
  46. inputs_embeds=input_embeds,
  47. attention_mask=attention_mask,
  48. **generate_kwargs,
  49. )
  50. return outputs
  51. def run_test(
  52. hf_runner: Type[HfRunner],
  53. aphrodite_runner: Type[AphroditeRunner],
  54. inputs: List[Tuple[List[str], PromptImageInput]],
  55. model: str,
  56. *,
  57. dtype: str,
  58. max_tokens: int,
  59. num_logprobs: int,
  60. mm_limit: int,
  61. tensor_parallel_size: int,
  62. distributed_executor_backend: Optional[str] = None,
  63. ):
  64. """Inference result should be the same between hf and aphrodite.
  65. All the image fixtures for the test are from IMAGE_ASSETS.
  66. For huggingface runner, we provide the PIL images as input.
  67. For aphrodite runner, we provide MultiModalDataDict objects
  68. and corresponding MultiModalConfig as input.
  69. Note, the text input is also adjusted to abide by aphrodite contract.
  70. The text output is sanitized to be able to compare with hf.
  71. """
  72. # NOTE: take care of the order. run Aphrodite first, and then run HF.
  73. # Aphrodite needs a fresh new process without cuda initialization.
  74. # if we run HF first, the cuda initialization will be done and it
  75. # will hurt multiprocessing backend with fork method (the default method).
  76. class InternVLProcessor:
  77. """A simple processor for InternVL2 which misses a processor."""
  78. def __init__(self, hf_runner: HfRunner):
  79. self.num_image_token = hf_runner.model.num_image_token
  80. self.tokenizer = hf_runner.tokenizer
  81. self.dtype = hf_runner.model.dtype
  82. self.config = AutoConfig.from_pretrained(hf_runner.model_name)
  83. self.vision_config = self.config.vision_config
  84. self.use_thumbnail = self.config.use_thumbnail
  85. self.min_num = self.config.min_dynamic_patch
  86. self.max_num = self.config.max_dynamic_patch
  87. self.image_size = self.vision_config.image_size
  88. def __call__(self, text: str, images: Union[Image, List[Image]],
  89. **kwargs):
  90. from aphrodite.modeling.models.internvl import (
  91. IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
  92. images = [images] if isinstance(images, Image) else images
  93. pixel_values = [
  94. image_to_pixel_values(image, self.image_size, self.min_num,
  95. self.max_num,
  96. self.use_thumbnail).to(self.dtype)
  97. for image in images
  98. ]
  99. num_patches_list = [
  100. pixel_value.shape[0] for pixel_value in pixel_values
  101. ]
  102. pixel_values = torch.cat(pixel_values, dim=0)
  103. for num_patches in num_patches_list:
  104. context_tokens = IMG_CONTEXT * self.num_image_token \
  105. * num_patches
  106. image_tokens = IMG_START + context_tokens + IMG_END
  107. text = text.replace('<image>', image_tokens, 1)
  108. prompt = self.tokenizer(text, return_tensors="pt")
  109. prompt.update({"pixel_values": pixel_values})
  110. return prompt
  111. # max_model_len should be greater than image_feature_size
  112. with aphrodite_runner(model,
  113. max_model_len=4096,
  114. dtype=dtype,
  115. limit_mm_per_prompt={"image": mm_limit},
  116. tensor_parallel_size=tensor_parallel_size,
  117. distributed_executor_backend=distributed_executor_backend,
  118. enforce_eager=True) as aphrodite_model:
  119. aphrodite_outputs_per_image = [
  120. aphrodite_model.generate_greedy_logprobs(prompts,
  121. max_tokens,
  122. num_logprobs=num_logprobs,
  123. images=images)
  124. for prompts, images in inputs
  125. ]
  126. with hf_runner(model, dtype=dtype) as hf_model:
  127. img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
  128. "<IMG_CONTEXT>")
  129. hf_model.model.img_context_token_id = img_context_token_id
  130. hf_model.processor = InternVLProcessor(hf_model)
  131. hf_model.model.get_output_embeddings = lambda: \
  132. hf_model.model.language_model.get_output_embeddings()
  133. hf_model.model.generate = types.MethodType(generate, hf_model.model)
  134. eos_token_id = hf_model.tokenizer.eos_token_id
  135. hf_outputs_per_image = [
  136. hf_model.generate_greedy_logprobs_limit(prompts,
  137. max_tokens,
  138. num_logprobs=num_logprobs,
  139. images=hf_images,
  140. eos_token_id=eos_token_id)
  141. for prompts, hf_images in inputs
  142. ]
  143. for hf_outputs, aphrodite_outputs in zip(hf_outputs_per_image,
  144. aphrodite_outputs_per_image):
  145. # TODO: Check whether using original CLIPVisionModel can improve
  146. # consistency against HF
  147. check_logprobs_close(
  148. outputs_0_lst=hf_outputs,
  149. outputs_1_lst=aphrodite_outputs,
  150. name_0="hf",
  151. name_1="aphrodite",
  152. )
  153. def run_awq_test(
  154. aphrodite_runner: Type[AphroditeRunner],
  155. image_assets: _ImageAssets,
  156. models: Tuple[str, str],
  157. *,
  158. size_factors: List[float],
  159. dtype: str,
  160. max_tokens: int,
  161. num_logprobs: int,
  162. tensor_parallel_size: int,
  163. distributed_executor_backend: Optional[str] = None,
  164. ):
  165. source_model, quant_model = models
  166. images = [asset.pil_image for asset in image_assets]
  167. inputs_per_image = [(
  168. [prompt for _ in size_factors],
  169. [rescale_image_size(image, factor) for factor in size_factors],
  170. ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
  171. # NOTE: take care of the order. run Aphrodite first, and then run HF.
  172. # Aphrodite needs a fresh new process without cuda initialization.
  173. # if we run HF first, the cuda initialization will be done and it
  174. # will hurt multiprocessing backend with fork method (the default method).
  175. # max_model_len should be greater than image_feature_size
  176. with aphrodite_runner(source_model,
  177. max_model_len=4096,
  178. dtype=dtype,
  179. tensor_parallel_size=tensor_parallel_size,
  180. distributed_executor_backend=distributed_executor_backend,
  181. enforce_eager=True) as aphrodite_model:
  182. source_outputs_per_image = [
  183. aphrodite_model.generate_greedy_logprobs(prompts,
  184. max_tokens,
  185. num_logprobs=num_logprobs,
  186. images=images)
  187. for prompts, images in inputs_per_image
  188. ]
  189. with aphrodite_runner(quant_model,
  190. quantization="awq",
  191. max_model_len=4096,
  192. dtype=dtype,
  193. tensor_parallel_size=tensor_parallel_size,
  194. distributed_executor_backend=distributed_executor_backend,
  195. enforce_eager=True) as aphrodite_model:
  196. quant_outputs_per_image = [
  197. aphrodite_model.generate_greedy_logprobs(prompts,
  198. max_tokens,
  199. num_logprobs=num_logprobs,
  200. images=images)
  201. for prompts, images in inputs_per_image
  202. ]
  203. for source_outputs, quant_outputs in zip(source_outputs_per_image,
  204. quant_outputs_per_image):
  205. # TODO: Check whether using original CLIPVisionModel can improve
  206. # consistency against HF
  207. check_logprobs_close(
  208. outputs_0_lst=source_outputs,
  209. outputs_1_lst=quant_outputs,
  210. name_0="source",
  211. name_1="awq",
  212. )
  213. target_dtype = "half"
  214. if is_cpu():
  215. target_dtype = "bfloat16"
  216. @pytest.mark.parametrize("model", models)
  217. @pytest.mark.parametrize(
  218. "size_factors",
  219. [
  220. # No image
  221. [],
  222. # Single-scale
  223. [1.0],
  224. # Single-scale, batched
  225. [1.0, 1.0, 1.0],
  226. # Multi-scale
  227. [0.25, 0.5, 1.0],
  228. ],
  229. )
  230. @pytest.mark.parametrize("dtype", [target_dtype])
  231. @pytest.mark.parametrize("max_tokens", [128])
  232. @pytest.mark.parametrize("num_logprobs", [5])
  233. @torch.inference_mode()
  234. def test_models(hf_runner, aphrodite_runner, image_assets, model, size_factors,
  235. dtype: str, max_tokens: int, num_logprobs: int) -> None:
  236. images = [asset.pil_image for asset in image_assets]
  237. inputs_per_image = [(
  238. [prompt for _ in size_factors],
  239. [rescale_image_size(image, factor) for factor in size_factors],
  240. ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
  241. run_test(
  242. hf_runner,
  243. aphrodite_runner,
  244. inputs_per_image,
  245. model,
  246. dtype=dtype,
  247. max_tokens=max_tokens,
  248. num_logprobs=num_logprobs,
  249. mm_limit=1,
  250. tensor_parallel_size=1,
  251. )
  252. @pytest.mark.parametrize("model", models)
  253. @pytest.mark.parametrize(
  254. "size_factors",
  255. [
  256. # No image
  257. [],
  258. # Single-scale
  259. [1.0],
  260. # Single-scale, batched
  261. [1.0, 1.0, 1.0],
  262. # Multi-scale
  263. [0.5, 0.75, 1.0],
  264. ],
  265. )
  266. @pytest.mark.parametrize("dtype", [target_dtype])
  267. @pytest.mark.parametrize("max_tokens", [128])
  268. @pytest.mark.parametrize("num_logprobs", [5])
  269. @torch.inference_mode()
  270. def test_multi_images_models(hf_runner, aphrodite_runner, image_assets, model,
  271. size_factors, dtype: str, max_tokens: int,
  272. num_logprobs: int) -> None:
  273. images = [asset.pil_image for asset in image_assets]
  274. inputs_per_case = [
  275. ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
  276. [[rescale_image_size(image, factor) for image in images]
  277. for factor in size_factors])
  278. ]
  279. run_test(
  280. hf_runner,
  281. aphrodite_runner,
  282. inputs_per_case,
  283. model,
  284. dtype=dtype,
  285. max_tokens=max_tokens,
  286. num_logprobs=num_logprobs,
  287. mm_limit=2,
  288. tensor_parallel_size=1,
  289. )
  290. @pytest.mark.parametrize("model", ["OpenGVLab/InternVL2-2B"])
  291. @pytest.mark.parametrize("size_factors", [[0.5, 1.0]])
  292. @pytest.mark.parametrize("dtype", [target_dtype])
  293. @pytest.mark.parametrize("max_tokens", [128])
  294. @pytest.mark.parametrize("num_logprobs", [5])
  295. @torch.inference_mode()
  296. def test_different_num_patches(hf_runner, aphrodite_runner, image_assets, model,
  297. size_factors, dtype: str, max_tokens: int,
  298. num_logprobs: int) -> None:
  299. images = [asset.pil_image.resize((896, 896)) for asset in image_assets]
  300. inputs_batching = [(
  301. [prompt for _ in size_factors],
  302. [rescale_image_size(image, factor) for factor in size_factors],
  303. ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
  304. inputs_multi_images = [
  305. ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
  306. [[rescale_image_size(image, factor) for image in images]
  307. for factor in size_factors])
  308. ]
  309. for inputs in [inputs_batching, inputs_multi_images]:
  310. run_test(
  311. hf_runner,
  312. aphrodite_runner,
  313. inputs,
  314. model,
  315. dtype=dtype,
  316. max_tokens=max_tokens,
  317. num_logprobs=num_logprobs,
  318. mm_limit=2,
  319. tensor_parallel_size=1,
  320. )
  321. @pytest.mark.parametrize(
  322. "models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")])
  323. @pytest.mark.parametrize(
  324. "size_factors",
  325. [
  326. # No image
  327. [],
  328. # Single-scale
  329. [1.0],
  330. # Single-scale, batched
  331. [1.0, 1.0, 1.0],
  332. # Multi-scale
  333. [0.25, 0.5, 1.0],
  334. ],
  335. )
  336. @pytest.mark.parametrize("dtype", ["half"])
  337. @pytest.mark.parametrize("max_tokens", [128])
  338. @pytest.mark.parametrize("num_logprobs", [5])
  339. @torch.inference_mode()
  340. def test_awq_models(aphrodite_runner, image_assets, models, size_factors,
  341. dtype: str, max_tokens: int, num_logprobs: int) -> None:
  342. run_awq_test(
  343. aphrodite_runner,
  344. image_assets,
  345. models,
  346. size_factors=size_factors,
  347. dtype=dtype,
  348. max_tokens=max_tokens,
  349. num_logprobs=num_logprobs,
  350. tensor_parallel_size=1,
  351. )