123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315 |
- from typing import List, Optional, Tuple, Type, overload
- import pytest
- from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
- BatchEncoding)
- from aphrodite.common.sequence import SampleLogprobs
- from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
- from aphrodite.multimodal.utils import rescale_image_size
- from ..conftest import (IMAGE_ASSETS, AphroditeRunner, HfRunner,
- PromptImageInput, _ImageAssets)
- from .utils import check_logprobs_close
- pytestmark = pytest.mark.vlm
- _LIMIT_IMAGE_PER_PROMPT = 4
- HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
- "stop_sign":
- "USER: <image>\nWhat's the content of the image?\nASSISTANT:",
- "cherry_blossom":
- "USER: <image>\nWhat is the season?\nASSISTANT:",
- })
- models = [
- "llava-hf/llava-1.5-7b-hf",
- # TODO: Get this model to produce meaningful output in Aphrodite
- # "TIGER-Lab/Mantis-8B-siglip-llama3",
- ]
- def aphrodite_to_hf_output(aphrodite_output: Tuple[List[int], str,
- Optional[SampleLogprobs]],
- model: str):
- """Sanitize aphrodite output to be comparable with hf output."""
- output_ids, output_str, out_logprobs = aphrodite_output
- config = AutoConfig.from_pretrained(model)
- image_token_id = config.image_token_index
- tokenizer = AutoTokenizer.from_pretrained(model)
- eos_token_id = tokenizer.eos_token_id
- hf_output_ids = [
- token_id for idx, token_id in enumerate(output_ids)
- if token_id != image_token_id or output_ids[idx - 1] != image_token_id
- ]
- assert output_str[0] == " "
- hf_output_str = output_str[1:]
- if hf_output_ids[-1] == eos_token_id:
- hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
- return hf_output_ids, hf_output_str, out_logprobs
- @overload
- def run_test(
- hf_runner: Type[HfRunner],
- aphrodite_runner: Type[AphroditeRunner],
- image_assets: _ImageAssets,
- model: str,
- *,
- size_factors: List[float],
- dtype: str,
- max_tokens: int,
- num_logprobs: int,
- tensor_parallel_size: int,
- distributed_executor_backend: Optional[str] = None,
- ):
- ...
- @overload
- def run_test(
- hf_runner: Type[HfRunner],
- aphrodite_runner: Type[AphroditeRunner],
- image_assets: _ImageAssets,
- model: str,
- *,
- sizes: List[Tuple[int, int]],
- dtype: str,
- max_tokens: int,
- num_logprobs: int,
- tensor_parallel_size: int,
- distributed_executor_backend: Optional[str] = None,
- ):
- ...
- def run_test(
- hf_runner: Type[HfRunner],
- aphrodite_runner: Type[AphroditeRunner],
- image_assets: _ImageAssets,
- model: str,
- *,
- size_factors: Optional[List[float]] = None,
- sizes: Optional[List[Tuple[int, int]]] = None,
- dtype: str,
- max_tokens: int,
- num_logprobs: int,
- tensor_parallel_size: int,
- distributed_executor_backend: Optional[str] = None,
- ):
- images = [asset.pil_image for asset in image_assets]
- if size_factors is not None:
- inputs_per_image = [(
- [prompt for _ in size_factors],
- [rescale_image_size(image, factor) for factor in size_factors],
- ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
- elif sizes is not None:
- inputs_per_image = [(
- [prompt for _ in sizes],
- [image.resize(size) for size in sizes],
- ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
- else:
- raise ValueError("You must provide either `size_factors` or `sizes`")
- _run_test(hf_runner,
- aphrodite_runner,
- inputs_per_image,
- model,
- dtype=dtype,
- max_tokens=max_tokens,
- num_logprobs=num_logprobs,
- tensor_parallel_size=tensor_parallel_size,
- distributed_executor_backend=distributed_executor_backend)
- def _run_test(
- hf_runner: Type[HfRunner],
- aphrodite_runner: Type[AphroditeRunner],
- inputs: List[Tuple[List[str], PromptImageInput]],
- model: str,
- *,
- dtype: str,
- max_tokens: int,
- num_logprobs: int,
- tensor_parallel_size: int,
- distributed_executor_backend: Optional[str] = None,
- ):
- """Inference result should be the same between hf and aphrodite.
- All the image fixtures for the test is under tests/images.
- For huggingface runner, we provide the PIL images as input.
- For aphrodite runner, we provide MultiModalDataDict objects
- and corresponding MultiModalConfig as input.
- Note, the text input is also adjusted to abide by aphrodite contract.
- The text output is sanitized to be able to compare with hf.
- """
- # NOTE: For local use; this isn't tested in CI yet (see TODO above)
- if model.startswith("TIGER-Lab/Mantis"):
- from mantis.models.mllava import MLlavaProcessor
- torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
- mantis_processor = MLlavaProcessor.from_pretrained(
- model, torch_dtype=torch_dtype)
- assert isinstance(mantis_processor, MLlavaProcessor)
- else:
- mantis_processor = None
- # NOTE: take care of the order. run Aphrodite first, and then run HF.
- # Aphrodite needs a fresh new process without cuda initialization.
- # if we run HF first, the cuda initialization will be done and it
- # will hurt multiprocessing backend with fork method (the default method).
- # max_model_len should be greater than image_feature_size
- with aphrodite_runner(model,
- dtype=dtype,
- max_model_len=4096,
- tensor_parallel_size=tensor_parallel_size,
- distributed_executor_backend=distributed_executor_backend,
- enforce_eager=True,
- limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
- }) as aphrodite_model:
- aphrodite_outputs_per_image = [
- aphrodite_model.generate_greedy_logprobs(prompts,
- max_tokens,
- num_logprobs=num_logprobs,
- images=images)
- for prompts, images in inputs
- ]
- if mantis_processor is not None:
- def process(hf_inputs: BatchEncoding):
- hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \
- .to(torch_dtype) # type: ignore
- return hf_inputs
- else:
- def process(hf_inputs: BatchEncoding):
- return hf_inputs
- with hf_runner(model,
- dtype=dtype,
- postprocess_inputs=process,
- auto_cls=AutoModelForVision2Seq) as hf_model:
- hf_outputs_per_image = [
- hf_model.generate_greedy_logprobs_limit(prompts,
- max_tokens,
- num_logprobs=num_logprobs,
- images=images)
- for prompts, images in inputs
- ]
- for hf_outputs, aphrodite_outputs in zip(hf_outputs_per_image,
- aphrodite_outputs_per_image):
- # TODO: Check whether using original CLIPVisionModel can improve
- # consistency against HF
- check_logprobs_close(
- outputs_0_lst=hf_outputs,
- outputs_1_lst=[
- aphrodite_to_hf_output(aphrodite_output, model)
- for aphrodite_output in aphrodite_outputs
- ],
- name_0="hf",
- name_1="aphrodite",
- )
- @pytest.mark.parametrize("model", models)
- @pytest.mark.parametrize(
- "size_factors",
- [
- # No image
- [],
- # Single-scale
- [1.0],
- # Single-scale, batched
- [1.0, 1.0, 1.0],
- # Multi-scale
- [0.25, 0.5, 1.0],
- ],
- )
- @pytest.mark.parametrize("dtype", ["half"])
- @pytest.mark.parametrize("max_tokens", [128])
- @pytest.mark.parametrize("num_logprobs", [5])
- def test_models(hf_runner, aphrodite_runner, image_assets, model, size_factors,
- dtype: str, max_tokens: int, num_logprobs: int) -> None:
- run_test(
- hf_runner,
- aphrodite_runner,
- image_assets,
- model,
- size_factors=size_factors,
- dtype=dtype,
- max_tokens=max_tokens,
- num_logprobs=num_logprobs,
- tensor_parallel_size=1,
- )
- @pytest.mark.parametrize("model", models)
- @pytest.mark.parametrize("dtype", ["half"])
- @pytest.mark.parametrize("max_tokens", [128])
- @pytest.mark.parametrize("num_logprobs", [5])
- def test_models_multiple_image_inputs(hf_runner, aphrodite_runner, image_assets,
- model, dtype, max_tokens,
- num_logprobs) -> None:
- stop_sign = image_assets[0].pil_image
- cherry_blossom = image_assets[1].pil_image
- inputs = [(
- [
- "USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
- "USER: <image><image>\nDescribe 2 images.\nASSISTANT:",
- "USER: <image><image><image><image>\nDescribe 4 images.\nASSISTANT:", # noqa: E501
- "USER: <image>\nWhat is the season?\nASSISTANT:",
- ],
- [
- [stop_sign, cherry_blossom],
- # Images with different sizes and aspect-ratios
- [
- rescale_image_size(stop_sign, 0.1),
- stop_sign,
- ],
- [
- stop_sign,
- rescale_image_size(stop_sign, 0.25),
- cherry_blossom.resize((183, 488)),
- cherry_blossom.resize((488, 183))
- ],
- cherry_blossom,
- ])]
- _run_test(
- hf_runner,
- aphrodite_runner,
- inputs,
- model,
- dtype=dtype,
- max_tokens=max_tokens,
- num_logprobs=num_logprobs,
- tensor_parallel_size=1,
- )
- @pytest.mark.parametrize("model", models)
- def test_context_length_too_short(aphrodite_runner, image_assets, model):
- images = [asset.pil_image for asset in image_assets]
- with pytest.raises(ValueError, match="too long to fit into the model"):
- aphrodite_model = aphrodite_runner(
- model,
- max_model_len=128, # LLaVA has a feature size of 576
- enforce_eager=True,
- )
- with aphrodite_model:
- aphrodite_model.generate_greedy([HF_IMAGE_PROMPTS[0]],
- max_tokens=1,
- images=[images[0]])
|