test_processor_kwargs.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. from array import array
  2. from typing import Mapping
  3. from unittest.mock import patch
  4. import pytest
  5. import torch
  6. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  7. SequenceData)
  8. from aphrodite.inputs import InputContext, LLMInputs
  9. from aphrodite.inputs.registry import InputRegistry
  10. from aphrodite.multimodal import MultiModalRegistry
  11. from ..models.utils import build_model_context
  12. # Used for fast tests where the model doesn't matter
  13. DUMMY_MODEL_ID = "facebook/opt-125m"
  14. # Used for tests that need a multimodal model
  15. MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
  16. # For mm_processor_kwargs - we test overrides by defining mocks for each place
  17. # it is used, and ensuring that we can pass processor kwargs an override value
  18. # to receive the intended result for things like sequence length etc.
  19. DEFAULT_NUM_CROPS = 4
  20. NUM_CROPS_OVERRIDE = 16
  21. # Mocks for all of the places that we use the mm_processor_kwargs
  22. # to override values in different callables
  23. @pytest.fixture
  24. def use_processor_mock():
  25. """Patches the internal model input processor with an override callable."""
  26. def custom_processor(
  27. ctx: InputContext, llm_inputs: LLMInputs, *, num_crops=DEFAULT_NUM_CROPS
  28. ):
  29. # For testing purposes, we don't worry about the llm inputs / return
  30. # type validation, and just return the value of the kwarg that we
  31. # clobber.
  32. return num_crops
  33. with patch(
  34. "aphrodite.inputs.registry.InputRegistry._get_model_input_processor",
  35. return_value=custom_processor,
  36. ):
  37. yield
  38. @pytest.fixture
  39. def use_dummy_data_mock():
  40. """Patches the internal model input processor with an override callable."""
  41. def custom_dummy_data_factory(
  42. self,
  43. ctx: InputContext,
  44. seq_len: int,
  45. mm_counts: Mapping[str, int],
  46. *,
  47. num_crops=DEFAULT_NUM_CROPS,
  48. ):
  49. seq_data = SequenceData(
  50. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)
  51. )
  52. return seq_data, None
  53. with patch(
  54. "aphrodite.inputs.registry.InputRegistry._default_dummy_data_factory",
  55. custom_dummy_data_factory,
  56. ):
  57. yield
  58. # Lazy import to avoid CUDA reinitialization error
  59. def mm_model_cls():
  60. from aphrodite.modeling.models.phi3v import Phi3VForCausalLM
  61. return Phi3VForCausalLM
  62. # lambda whose signature matches max token calcs extra & mapper + extra kwargs
  63. get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops
  64. custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: {
  65. "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336))
  66. }
  67. ### Test for default processor logic & mm_processor_kwargs wrapping
  68. def test_default_processor_is_a_noop():
  69. """Ensure that by default, there is no processor override."""
  70. dummy_registry = InputRegistry()
  71. ctx = build_model_context(DUMMY_MODEL_ID)
  72. processor = dummy_registry.create_input_processor(ctx.model_config)
  73. proc_inputs = LLMInputs(prompt_token_ids=[], prompt="")
  74. proc_outputs = processor(inputs=proc_inputs)
  75. assert proc_inputs is proc_outputs
  76. @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
  77. def test_processor_default_kwargs(use_processor_mock, num_crops):
  78. """Ensure input processors can use processor kwargs."""
  79. dummy_registry = InputRegistry()
  80. # If we have a value for num_crops, pass the override value and make
  81. # sure we get that value as a return-value from out mock processor,
  82. # otherwise fall back to the default value
  83. mm_processor_kwargs = (
  84. None if num_crops is None else {"num_crops": num_crops}
  85. )
  86. expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops
  87. ctx = build_model_context(
  88. DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs
  89. )
  90. processor = dummy_registry.create_input_processor(ctx.model_config)
  91. num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
  92. assert num_crops_val == expected_num_crops
  93. @pytest.mark.parametrize(
  94. "mm_processor_kwargs",
  95. [
  96. # Not part of the signature
  97. {"does_not_exist": 100},
  98. # Part of the signature, not keyword only
  99. {"ctx": "something bad"},
  100. ],
  101. )
  102. def test_processor_with_sad_kwarg_overrides(
  103. use_processor_mock, mm_processor_kwargs
  104. ):
  105. """Ensure that input processors filter out invalid mm_processor_kwargs"""
  106. dummy_registry = InputRegistry()
  107. ctx = build_model_context(
  108. DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs
  109. )
  110. processor = dummy_registry.create_input_processor(ctx.model_config)
  111. num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt=""))
  112. assert num_crops_val == DEFAULT_NUM_CROPS
  113. ### Test overrides for the dummy data
  114. @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
  115. def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops):
  116. """Ensure dummy data factories can use processor kwargs."""
  117. mm_processor_kwargs = (
  118. None if num_crops is None else {"num_crops": num_crops}
  119. )
  120. expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
  121. dummy_registry = InputRegistry()
  122. ctx = build_model_context(
  123. DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs
  124. )
  125. mm_registry = MultiModalRegistry()
  126. mm_registry.init_mm_limits_per_prompt(ctx.model_config)
  127. # NOTE: seq_len is thrown away here since this will leverage the
  128. # default dummy data factory that we have patched in, whose seq
  129. # len is solely dependent on the value of the mm_processor_kwargs.
  130. seq_data, _ = dummy_registry.dummy_data_for_profiling(
  131. ctx.model_config, seq_len=-1, mm_registry=mm_registry
  132. )
  133. assert len(seq_data.prompt_token_ids) == expected_seq_count
  134. @pytest.mark.parametrize(
  135. "mm_processor_kwargs",
  136. [
  137. # Not part of the signature
  138. {"does_not_exist": 100},
  139. # Part of the signature, not keyword only
  140. {"ctx": "something bad"},
  141. ],
  142. )
  143. def test_dummy_data_with_sad_kwarg_overrides(
  144. use_dummy_data_mock, mm_processor_kwargs
  145. ):
  146. """Ensure the dummy data factory filters out invalid mm_processor_kwargs"""
  147. dummy_registry = InputRegistry()
  148. ctx = build_model_context(
  149. DUMMY_MODEL_ID, mm_processor_kwargs=mm_processor_kwargs
  150. )
  151. mm_registry = MultiModalRegistry()
  152. mm_registry.init_mm_limits_per_prompt(ctx.model_config)
  153. # NOTE: seq_len is thrown away here since this will leverage the
  154. # default dummy data factory that we have patched in, whose seq
  155. # len is solely dependent on the value of the mm_processor_kwargs.
  156. seq_data, _ = dummy_registry.dummy_data_for_profiling(
  157. ctx.model_config, seq_len=-1, mm_registry=mm_registry
  158. )
  159. assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS
  160. ### Test overrides for the max token count per multimodal instance
  161. @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
  162. def test_max_tokens_kwarg_overrides(num_crops):
  163. """Ensure max token calcs can use processor kwargs."""
  164. mm_processor_kwargs = (
  165. None if num_crops is None else {"num_crops": num_crops}
  166. )
  167. expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
  168. ctx = build_model_context(
  169. MULTIMODAL_MODEL_ID,
  170. trust_remote_code=True,
  171. mm_processor_kwargs=mm_processor_kwargs,
  172. limit_mm_per_prompt={"image": 1},
  173. )
  174. mm_registry = MultiModalRegistry()
  175. mm_registry.init_mm_limits_per_prompt(ctx.model_config)
  176. # Patch the image registry for phi3v with our lambda that is compatible
  177. # with overrides, then ensure that calling the method correctly echos
  178. # our num_crops value back from the mm_processor_kwargs.
  179. with patch.object(
  180. mm_registry._get_plugin("image"),
  181. "_max_mm_tokens",
  182. {mm_model_cls(): get_num_crops},
  183. ):
  184. max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
  185. ctx.model_config
  186. )
  187. assert expected_seq_count == max_multimodal_tokens
  188. @pytest.mark.parametrize(
  189. "mm_processor_kwargs",
  190. [
  191. # Not part of the signature
  192. {"does_not_exist": 100},
  193. # Part of the signature, not keyword only
  194. {"ctx": "something bad"},
  195. ],
  196. )
  197. def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
  198. """Ensure that max token calcs filters out invalid mm_processor_kwargs"""
  199. ctx = build_model_context(
  200. MULTIMODAL_MODEL_ID,
  201. trust_remote_code=True,
  202. mm_processor_kwargs=mm_processor_kwargs,
  203. limit_mm_per_prompt={"image": 1},
  204. )
  205. mm_registry = MultiModalRegistry()
  206. mm_registry.init_mm_limits_per_prompt(ctx.model_config)
  207. # Similar before, but since these kwargs get filtered,
  208. # we always get our default value back.
  209. with patch.object(
  210. mm_registry._get_plugin("image"),
  211. "_max_mm_tokens",
  212. {mm_model_cls(): get_num_crops},
  213. ):
  214. max_multimodal_tokens = mm_registry.get_max_multimodal_tokens(
  215. ctx.model_config
  216. )
  217. assert max_multimodal_tokens == DEFAULT_NUM_CROPS
  218. ### Test overrides for the mapper
  219. @pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
  220. def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
  221. """Ensure that the mapper processor kwargs can fall back to HF models."""
  222. # NOTE - we don't validate bad inputs for the default mapper, because it's
  223. # through the automodel interface in transformers, so we can't easily
  224. # inspect what kwargs are or are not allowed.
  225. ctx = build_model_context(
  226. MULTIMODAL_MODEL_ID,
  227. trust_remote_code=True,
  228. mm_processor_kwargs={"num_crops": num_crops},
  229. limit_mm_per_prompt={"image": 1},
  230. )
  231. mm_registry = MultiModalRegistry()
  232. mm_registry.init_mm_limits_per_prompt(ctx.model_config)
  233. image = image_assets[0].pil_image
  234. mm_inputs = {"image": image}
  235. mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
  236. # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336]
  237. assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1
  238. @pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE])
  239. def test_custom_mapper_kwarg_overrides(image_assets, num_crops):
  240. """Ensure custom mappers can use processor kwargs."""
  241. mm_processor_kwargs = (
  242. None if num_crops is None else {"num_crops": num_crops}
  243. )
  244. expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
  245. ctx = build_model_context(
  246. MULTIMODAL_MODEL_ID,
  247. trust_remote_code=True,
  248. mm_processor_kwargs=mm_processor_kwargs,
  249. limit_mm_per_prompt={"image": 1},
  250. )
  251. mm_registry = MultiModalRegistry()
  252. mm_registry.init_mm_limits_per_prompt(ctx.model_config)
  253. # Patch the image registry for phi3v with our lambda that is compatible
  254. # with overrides, then ensure that calling the method correctly echos
  255. # our num_crops value back from the mm_processor_kwargs.
  256. image = image_assets[0].pil_image
  257. mm_inputs = {"image": image}
  258. with patch.object(
  259. mm_registry._get_plugin("image"),
  260. "_default_input_mapper",
  261. {mm_model_cls(): custom_mapper},
  262. ):
  263. mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
  264. assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1
  265. @pytest.mark.parametrize(
  266. "mm_processor_kwargs",
  267. [
  268. # Not part of the signature
  269. {"does_not_exist": 100},
  270. # Part of the signature, not keyword only
  271. {"ctx": "something bad"},
  272. ],
  273. )
  274. def test_custom_mapper_with_sad_kwarg_overrides(
  275. image_assets, mm_processor_kwargs
  276. ):
  277. """Ensure that custom mappers filters out invalid mm_processor_kwargs"""
  278. ctx = build_model_context(
  279. MULTIMODAL_MODEL_ID,
  280. trust_remote_code=True,
  281. mm_processor_kwargs=mm_processor_kwargs,
  282. limit_mm_per_prompt={"image": 1},
  283. )
  284. mm_registry = MultiModalRegistry()
  285. mm_registry.init_mm_limits_per_prompt(ctx.model_config)
  286. # Patch the image registry for phi3v with our lambda that is compatible
  287. # with overrides, then ensure that calling the method correctly echos
  288. # our num_crops value back from the mm_processor_kwargs.
  289. image = image_assets[0].pil_image
  290. mm_inputs = {"image": image}
  291. with patch.object(
  292. mm_registry._get_plugin("image"),
  293. "_default_input_mapper",
  294. {mm_model_cls(): custom_mapper},
  295. ):
  296. mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs)
  297. assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1