test_internvl.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import types
  2. from typing import List, Optional, Tuple, Type
  3. import pytest
  4. import torch
  5. from huggingface_hub import snapshot_download
  6. from PIL.Image import Image
  7. from transformers import AutoConfig
  8. from aphrodite.common.utils import is_cpu
  9. from aphrodite.modeling.models.internvl import (IMG_CONTEXT, IMG_END,
  10. IMG_START,
  11. image_to_pixel_values)
  12. from aphrodite.multimodal.utils import rescale_image_size
  13. from ..conftest import IMAGE_ASSETS, AphroditeRunner, HfRunner, _ImageAssets
  14. from .utils import check_logprobs_close
  15. pytestmark = pytest.mark.vlm
  16. HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
  17. "stop_sign":
  18. "<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
  19. "cherry_blossom":
  20. "<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
  21. })
  22. # we use snapshot_download to prevent conflicts between
  23. # dynamic_module and trust_remote_code for hf_runner
  24. DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
  25. models = [
  26. snapshot_download("OpenGVLab/InternVL2-1B",
  27. allow_patterns=DOWNLOAD_PATTERN),
  28. snapshot_download("OpenGVLab/InternVL2-2B",
  29. allow_patterns=DOWNLOAD_PATTERN),
  30. # Broken due to outdated implementation of Phi-3
  31. # See: https://huggingface.co/OpenGVLab/InternVL2-4B/discussions/3
  32. # snapshot_download("OpenGVLab/InternVL2-4B"),
  33. ]
  34. class InternVLProcessor:
  35. """A simple processor for InternVL2 HF model which misses a processor."""
  36. def __init__(self, hf_runner: HfRunner):
  37. self.num_image_token = hf_runner.model.num_image_token
  38. self.tokenizer = hf_runner.tokenizer
  39. self.dtype = hf_runner.model.dtype
  40. self.config = AutoConfig.from_pretrained(hf_runner.model_name)
  41. self.vision_config = self.config.vision_config
  42. self.use_thumbnail = self.config.use_thumbnail
  43. self.min_num = self.config.min_dynamic_patch
  44. self.max_num = self.config.max_dynamic_patch
  45. self.image_size = self.vision_config.image_size
  46. def __call__(self, text: str, images: Image, **kwargs):
  47. pixel_values = image_to_pixel_values(images, self.image_size,
  48. self.min_num, self.max_num,
  49. self.use_thumbnail).to(self.dtype)
  50. num_patches_list = [pixel_values.shape[0]]
  51. for num_patches in num_patches_list:
  52. context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
  53. image_tokens = IMG_START + context_tokens + IMG_END
  54. text = text.replace('<image>', image_tokens, 1)
  55. prompt = self.tokenizer(text, return_tensors="pt")
  56. prompt.update({"pixel_values": pixel_values})
  57. return prompt
  58. # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
  59. def generate(
  60. self,
  61. pixel_values: torch.FloatTensor,
  62. input_ids: torch.FloatTensor,
  63. attention_mask: Optional[torch.LongTensor] = None,
  64. **generate_kwargs,
  65. ) -> torch.LongTensor:
  66. """Generate method for InternVL2 model without fixed use_cache."""
  67. assert self.img_context_token_id is not None
  68. vit_embeds = self.extract_feature(pixel_values)
  69. input_embeds = self.language_model.get_input_embeddings()(input_ids)
  70. B, N, C = input_embeds.shape
  71. input_embeds = input_embeds.reshape(B * N, C)
  72. input_ids = input_ids.reshape(B * N)
  73. selected = (input_ids == self.img_context_token_id)
  74. assert selected.sum() != 0
  75. input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
  76. input_embeds = input_embeds.reshape(B, N, C)
  77. outputs = self.language_model.generate(
  78. inputs_embeds=input_embeds,
  79. attention_mask=attention_mask,
  80. **generate_kwargs,
  81. )
  82. return outputs
  83. def run_test(
  84. hf_runner: Type[HfRunner],
  85. aphrodite_runner: Type[AphroditeRunner],
  86. image_assets: _ImageAssets,
  87. model: str,
  88. *,
  89. size_factors: List[float],
  90. dtype: str,
  91. max_tokens: int,
  92. num_logprobs: int,
  93. tensor_parallel_size: int,
  94. distributed_executor_backend: Optional[str] = None,
  95. ):
  96. """Inference result should be the same between hf and aphrodite.
  97. All the image fixtures for the test is under tests/images.
  98. For huggingface runner, we provide the PIL images as input.
  99. For aphrodite runner, we provide MultiModalDataDict objects
  100. and corresponding MultiModalConfig as input.
  101. Note, the text input is also adjusted to abide by aphrodite contract.
  102. The text output is sanitized to be able to compare with hf.
  103. """
  104. images = [asset.pil_image for asset in image_assets]
  105. inputs_per_image = [(
  106. [prompt for _ in size_factors],
  107. [rescale_image_size(image, factor) for factor in size_factors],
  108. ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
  109. # NOTE: take care of the order. run Aphrodite first, and then run HF.
  110. # Aphrodite needs a fresh new process without cuda initialization.
  111. # if we run HF first, the cuda initialization will be done and it
  112. # will hurt multiprocessing backend with fork method (the default method).
  113. # max_model_len should be greater than image_feature_size
  114. with aphrodite_runner(model,
  115. max_model_len=4096,
  116. dtype=dtype,
  117. tensor_parallel_size=tensor_parallel_size,
  118. distributed_executor_backend=distributed_executor_backend,
  119. enforce_eager=True) as aphrodite_model:
  120. aphrodite_outputs_per_image = [
  121. aphrodite_model.generate_greedy_logprobs(prompts,
  122. max_tokens,
  123. num_logprobs=num_logprobs,
  124. images=images)
  125. for prompts, images in inputs_per_image
  126. ]
  127. with hf_runner(model, dtype=dtype) as hf_model:
  128. img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
  129. "<IMG_CONTEXT>")
  130. hf_model.model.img_context_token_id = img_context_token_id
  131. hf_model.processor = InternVLProcessor(hf_model)
  132. hf_model.model.get_output_embeddings = lambda: \
  133. hf_model.model.language_model.get_output_embeddings()
  134. hf_model.model.generate = types.MethodType(generate, hf_model.model)
  135. eos_token_id = hf_model.tokenizer.eos_token_id
  136. hf_outputs_per_image = [
  137. hf_model.generate_greedy_logprobs_limit(prompts,
  138. max_tokens,
  139. num_logprobs=num_logprobs,
  140. images=hf_images,
  141. eos_token_id=eos_token_id)
  142. for prompts, hf_images in inputs_per_image
  143. ]
  144. for hf_outputs, aphrodite_outputs in zip(hf_outputs_per_image,
  145. aphrodite_outputs_per_image):
  146. # TODO: Check whether using original CLIPVisionModel can improve
  147. # consistency against HF
  148. check_logprobs_close(
  149. outputs_0_lst=hf_outputs,
  150. outputs_1_lst=aphrodite_outputs,
  151. name_0="hf",
  152. name_1="aphrodite",
  153. )
  154. def run_awq_test(
  155. aphrodite_runner: Type[AphroditeRunner],
  156. image_assets: _ImageAssets,
  157. models: Tuple[str, str],
  158. *,
  159. size_factors: List[float],
  160. dtype: str,
  161. max_tokens: int,
  162. num_logprobs: int,
  163. tensor_parallel_size: int,
  164. distributed_executor_backend: Optional[str] = None,
  165. ):
  166. source_model, quant_model = models
  167. images = [asset.pil_image for asset in image_assets]
  168. inputs_per_image = [(
  169. [prompt for _ in size_factors],
  170. [rescale_image_size(image, factor) for factor in size_factors],
  171. ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
  172. # NOTE: take care of the order. run Aphrodite first, and then run HF.
  173. # Aphrodite needs a fresh new process without cuda initialization.
  174. # if we run HF first, the cuda initialization will be done and it
  175. # will hurt multiprocessing backend with fork method (the default method).
  176. # max_model_len should be greater than image_feature_size
  177. with aphrodite_runner(source_model,
  178. max_model_len=4096,
  179. dtype=dtype,
  180. tensor_parallel_size=tensor_parallel_size,
  181. distributed_executor_backend=distributed_executor_backend,
  182. enforce_eager=True) as aphrodite_model:
  183. source_outputs_per_image = [
  184. aphrodite_model.generate_greedy_logprobs(prompts,
  185. max_tokens,
  186. num_logprobs=num_logprobs,
  187. images=images)
  188. for prompts, images in inputs_per_image
  189. ]
  190. with aphrodite_runner(quant_model,
  191. quantization="awq",
  192. max_model_len=4096,
  193. dtype=dtype,
  194. tensor_parallel_size=tensor_parallel_size,
  195. distributed_executor_backend=distributed_executor_backend,
  196. enforce_eager=True) as aphrodite_model:
  197. quant_outputs_per_image = [
  198. aphrodite_model.generate_greedy_logprobs(prompts,
  199. max_tokens,
  200. num_logprobs=num_logprobs,
  201. images=images)
  202. for prompts, images in inputs_per_image
  203. ]
  204. for source_outputs, quant_outputs in zip(source_outputs_per_image,
  205. quant_outputs_per_image):
  206. # TODO: Check whether using original CLIPVisionModel can improve
  207. # consistency against HF
  208. check_logprobs_close(
  209. outputs_0_lst=source_outputs,
  210. outputs_1_lst=quant_outputs,
  211. name_0="source",
  212. name_1="awq",
  213. )
  214. target_dtype = "half"
  215. if is_cpu():
  216. target_dtype = "bfloat16"
  217. @pytest.mark.parametrize("model", models)
  218. @pytest.mark.parametrize(
  219. "size_factors",
  220. [
  221. # No image
  222. [],
  223. # Single-scale
  224. [1.0],
  225. # Single-scale, batched
  226. [1.0, 1.0, 1.0],
  227. # Multi-scale
  228. [0.25, 0.5, 1.0],
  229. ],
  230. )
  231. @pytest.mark.parametrize("dtype", [target_dtype])
  232. @pytest.mark.parametrize("max_tokens", [128])
  233. @pytest.mark.parametrize("num_logprobs", [5])
  234. @torch.inference_mode()
  235. def test_models(hf_runner, aphrodite_runner, image_assets, model, size_factors,
  236. dtype: str, max_tokens: int, num_logprobs: int) -> None:
  237. run_test(
  238. hf_runner,
  239. aphrodite_runner,
  240. image_assets,
  241. model,
  242. size_factors=size_factors,
  243. dtype=dtype,
  244. max_tokens=max_tokens,
  245. num_logprobs=num_logprobs,
  246. tensor_parallel_size=1,
  247. )
  248. @pytest.mark.parametrize(
  249. "models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")])
  250. @pytest.mark.parametrize(
  251. "size_factors",
  252. [
  253. # No image
  254. [],
  255. # Single-scale
  256. [1.0],
  257. # Single-scale, batched
  258. [1.0, 1.0, 1.0],
  259. # Multi-scale
  260. [0.25, 0.5, 1.0],
  261. ],
  262. )
  263. @pytest.mark.parametrize("dtype", ["half"])
  264. @pytest.mark.parametrize("max_tokens", [128])
  265. @pytest.mark.parametrize("num_logprobs", [5])
  266. @torch.inference_mode()
  267. def test_awq_models(aphrodite_runner, image_assets, models, size_factors,
  268. dtype: str, max_tokens: int, num_logprobs: int) -> None:
  269. run_awq_test(
  270. aphrodite_runner,
  271. image_assets,
  272. models,
  273. size_factors=size_factors,
  274. dtype=dtype,
  275. max_tokens=max_tokens,
  276. num_logprobs=num_logprobs,
  277. tensor_parallel_size=1,
  278. )