123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- from array import array
- from typing import Mapping
- from unittest.mock import patch
- import pytest
- import torch
- from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
- SequenceData)
- from aphrodite.inputs import InputContext, LLMInputs
- from aphrodite.inputs.registry import InputRegistry
- from aphrodite.multimodal import MultiModalRegistry
- from ..models.utils import build_model_context
- # Used for fast tests where the model doesn't matter
- DUMMY_MODEL_ID = "facebook/opt-125m"
- # Used for tests that need a multimodal model
- MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
- # For mm_processor_kwargs - we test overrides by defining mocks for each place
- # it is used, and ensuring that we can pass processor kwargs an override value
- # to receive the intended result for things like sequence length etc.
- DEFAULT_NUM_CROPS = 4
- NUM_CROPS_OVERRIDE = 16
- # Mocks for all of the places that we use the mm_processor_kwargs
- # to override values in different callables
- @pytest.fixture
- def use_processor_mock():
- """Patches the internal model input processor with an override callable."""
- def custom_processor(
- ctx: InputContext, llm_inputs: LLMInputs, *, num_crops=DEFAULT_NUM_CROPS
- ):
- # For testing purposes, we don't worry about the llm inputs / return
- # type validation, and just return the value of the kwarg that we
- # clobber.
- return num_crops
- with patch(
- "aphrodite.inputs.registry.InputRegistry._get_model_input_processor",
- return_value=custom_processor,
- ):
- yield
- @pytest.fixture
- def use_dummy_data_mock():
- """Patches the internal model input processor with an override callable."""
- def custom_dummy_data_factory(
- self,
- ctx: InputContext,
- seq_len: int,
- mm_counts: Mapping[str, int],
- *,
- num_crops=DEFAULT_NUM_CROPS,
- ):
- seq_data = SequenceData(
- array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)
- )
- return seq_data, None
- with patch(
- "aphrodite.inputs.registry.InputRegistry._default_dummy_data_factory",
- custom_dummy_data_factory,
- ):
- yield
- # Lazy import to avoid CUDA reinitialization error
- def mm_model_cls():
- from aphrodite.modeling.models.phi3v import Phi3VForCausalLM
- return Phi3VForCausalLM
- # lambda whose signature matches max token calcs extra & mapper + extra kwargs
- get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
- custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
- "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
- }
- ### Test for default processor logic & mm_processor_kwargs wrapping
- def test_default_processor_is_a_noop():
- """Ensure that by default, there is no processor override."""
- dummy_registry = InputRegistry()
- ctx = build_model_context(DUMMY_MODEL_ID)
- processor = dummy_registry.create_input_processor(ctx.model_config)
- proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
- proc_outputs = processor(inputs=proc_inputs)
- assert proc_inputs is proc_outputs
- @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
- def test_processor_default_kwargs(use_processor_mock, num_crops):
- """Ensure input processors can use processor kwargs."""
- dummy_registry = InputRegistry()
- # If we have a value for num_crops, pass the override value and make
- # sure we get that value as a return-value from out mock processor,
- # otherwise fall back to the default value
- mm_processor_kwargs = (
- None if num_crops is None else {"num_crops": num_crops}
- )
- expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
- ctx = build_model_context(
- DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs
- )
- processor = dummy_registry.create_input_processor(ctx.model_config)
- num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
- assert num_crops_val == expected_num_crops
- @pytest.mark.parametrize(
- "mm_processor_kwargs",
- [
- # Not part of the signature
- {"does_not_exist": 100},
- # Part of the signature, not keyword only
- {"ctx": "something bad"},
- ],
- )
- def test_processor_with_sad_kwarg_overrides(
- use_processor_mock, mm_processor_kwargs
- ):
- """Ensure that input processors filter out invalid mm_processor_kwargs"""
- dummy_registry = InputRegistry()
- ctx = build_model_context(
- DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs
- )
- processor = dummy_registry.create_input_processor(ctx.model_config)
- num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
- assert num_crops_val == DEFAULT_NUM_CROPS
- ### Test overrides for the dummy data
- @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
- def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
- """Ensure dummy data factories can use processor kwargs."""
- mm_processor_kwargs = (
- None if num_crops is None else {"num_crops": num_crops}
- )
- expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
- dummy_registry = InputRegistry()
- ctx = build_model_context(
- DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs
- )
- mm_registry = MultiModalRegistry()
- mm_registry.init_mm_limits_per_prompt(ctx.model_config)
- # NOTE: seq_len is thrown away here since this will leverage the
- # default dummy data factory that we have patched in, whose seq
- # len is solely dependent on the value of the mm_processor_kwargs.
- seq_data, _ = dummy_registry.dummy_data_for_profiling(
- ctx.model_config, seq_len=-1, mm_registry=mm_registry
- )
- assert len(seq_data.prompt_token_ids) == expected_seq_count
- @pytest.mark.parametrize(
- "mm_processor_kwargs",
- [
- # Not part of the signature
- {"does_not_exist": 100},
- # Part of the signature, not keyword only
- {"ctx": "something bad"},
- ],
- )
- def test_dummy_data_with_sad_kwarg_overrides(
- use_dummy_data_mock, mm_processor_kwargs
- ):
- """Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
- dummy_registry = InputRegistry()
- ctx = build_model_context(
- DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs
- )
- mm_registry = MultiModalRegistry()
- mm_registry.init_mm_limits_per_prompt(ctx.model_config)
- # NOTE: seq_len is thrown away here since this will leverage the
- # default dummy data factory that we have patched in, whose seq
- # len is solely dependent on the value of the mm_processor_kwargs.
- seq_data, _ = dummy_registry.dummy_data_for_profiling(
- ctx.model_config, seq_len=-1, mm_registry=mm_registry
- )
- assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
- ### Test overrides for the max token count per multimodal instance
- @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
- def test_max_tokens_kwarg_overrides(num_crops):
- """Ensure max token calcs can use processor kwargs."""
- mm_processor_kwargs = (
- None if num_crops is None else {"num_crops": num_crops}
- )
- expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
- ctx = build_model_context(
- MULTIMODAL_MODEL_ID,
- trust_remote_code=True,
- mm_processor_kwargs=mm_processor_kwargs,
- limit_mm_per_prompt={"image": 1},
- )
- mm_registry = MultiModalRegistry()
- mm_registry.init_mm_limits_per_prompt(ctx.model_config)
- # Patch the image registry for phi3v with our lambda that is compatible
- # with overrides, then ensure that calling the method correctly echos
- # our num_crops value back from the mm_processor_kwargs.
- with patch.object(
- mm_registry._get_plugin("image"),
- "_max_mm_tokens",
- {mm_model_cls(): get_num_crops},
- ):
- max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
- ctx.model_config
- )
- assert expected_seq_count == max_multimodal_tokens
- @pytest.mark.parametrize(
- "mm_processor_kwargs",
- [
- # Not part of the signature
- {"does_not_exist": 100},
- # Part of the signature, not keyword only
- {"ctx": "something bad"},
- ],
- )
- def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
- """Ensure that max token calcs filters out invalid mm_processor_kwargs"""
- ctx = build_model_context(
- MULTIMODAL_MODEL_ID,
- trust_remote_code=True,
- mm_processor_kwargs=mm_processor_kwargs,
- limit_mm_per_prompt={"image": 1},
- )
- mm_registry = MultiModalRegistry()
- mm_registry.init_mm_limits_per_prompt(ctx.model_config)
- # Similar before, but since these kwargs get filtered,
- # we always get our default value back.
- with patch.object(
- mm_registry._get_plugin("image"),
- "_max_mm_tokens",
- {mm_model_cls(): get_num_crops},
- ):
- max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
- ctx.model_config
- )
- assert max_multimodal_tokens == DEFAULT_NUM_CROPS
- ### Test overrides for the mapper
- @pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
- def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
- """Ensure that the mapper processor kwargs can fall back to HF models."""
- # NOTE - we don't validate bad inputs for the default mapper, because it's
- # through the automodel interface in transformers, so we can't easily
- # inspect what kwargs are or are not allowed.
- ctx = build_model_context(
- MULTIMODAL_MODEL_ID,
- trust_remote_code=True,
- mm_processor_kwargs={"num_crops": num_crops},
- limit_mm_per_prompt={"image": 1},
- )
- mm_registry = MultiModalRegistry()
- mm_registry.init_mm_limits_per_prompt(ctx.model_config)
- image = image_assets[0].pil_image
- mm_inputs = {"image": image}
- mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
- # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
- assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
- @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
- def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
- """Ensure custom mappers can use processor kwargs."""
- mm_processor_kwargs = (
- None if num_crops is None else {"num_crops": num_crops}
- )
- expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
- ctx = build_model_context(
- MULTIMODAL_MODEL_ID,
- trust_remote_code=True,
- mm_processor_kwargs=mm_processor_kwargs,
- limit_mm_per_prompt={"image": 1},
- )
- mm_registry = MultiModalRegistry()
- mm_registry.init_mm_limits_per_prompt(ctx.model_config)
- # Patch the image registry for phi3v with our lambda that is compatible
- # with overrides, then ensure that calling the method correctly echos
- # our num_crops value back from the mm_processor_kwargs.
- image = image_assets[0].pil_image
- mm_inputs = {"image": image}
- with patch.object(
- mm_registry._get_plugin("image"),
- "_default_input_mapper",
- {mm_model_cls(): custom_mapper},
- ):
- mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
- assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
- @pytest.mark.parametrize(
- "mm_processor_kwargs",
- [
- # Not part of the signature
- {"does_not_exist": 100},
- # Part of the signature, not keyword only
- {"ctx": "something bad"},
- ],
- )
- def test_custom_mapper_with_sad_kwarg_overrides(
- image_assets, mm_processor_kwargs
- ):
- """Ensure that custom mappers filters out invalid mm_processor_kwargs"""
- ctx = build_model_context(
- MULTIMODAL_MODEL_ID,
- trust_remote_code=True,
- mm_processor_kwargs=mm_processor_kwargs,
- limit_mm_per_prompt={"image": 1},
- )
- mm_registry = MultiModalRegistry()
- mm_registry.init_mm_limits_per_prompt(ctx.model_config)
- # Patch the image registry for phi3v with our lambda that is compatible
- # with overrides, then ensure that calling the method correctly echos
- # our num_crops value back from the mm_processor_kwargs.
- image = image_assets[0].pil_image
- mm_inputs = {"image": image}
- with patch.object(
- mm_registry._get_plugin("image"),
- "_default_input_mapper",
- {mm_model_cls(): custom_mapper},
- ):
- mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
- assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1
|