test_paligemma.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import os
  2. from typing import List, Optional, Tuple, Type
  3. import pytest
  4. from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer
  5. from aphrodite.common.sequence import SampleLogprobs
  6. from aphrodite.common.utils import is_hip
  7. from aphrodite.multimodal.utils import rescale_image_size
  8. from ..conftest import IMAGE_ASSETS, AphroditeRunner, HfRunner, _ImageAssets
  9. from .utils import check_logprobs_close
  10. pytestmark = pytest.mark.vlm
  11. HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
  12. "stop_sign":
  13. "caption es",
  14. "cherry_blossom":
  15. "What is in the picture?",
  16. })
  17. models = ["google/paligemma-3b-mix-224"]
  18. # ROCm Triton FA can run into compilation issues with these models due to,
  19. # excessive use of shared memory. Use other backends in the meantime.
  20. # FIXME (mattwong, gshtrasb, hongxiayan)
  21. if is_hip():
  22. os.environ["APHRODITE_USE_TRITON_FLASH_ATTN"] = "0"
  23. def aphrodite_to_hf_output(aphrodite_output: Tuple[List[int], str,
  24. Optional[SampleLogprobs]],
  25. model: str):
  26. """Sanitize aphrodite output to be comparable with hf output."""
  27. output_ids, output_str, out_logprobs = aphrodite_output
  28. config = AutoConfig.from_pretrained(model)
  29. image_token_id = config.image_token_index
  30. tokenizer = AutoTokenizer.from_pretrained(model)
  31. eos_token_id = tokenizer.eos_token_id
  32. hf_output_ids = [
  33. token_id for idx, token_id in enumerate(output_ids)
  34. if token_id != image_token_id or output_ids[idx - 1] != image_token_id
  35. ]
  36. hf_output_str = output_str
  37. if hf_output_ids[-1] == eos_token_id:
  38. hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
  39. return hf_output_ids, hf_output_str, out_logprobs
  40. def run_test(
  41. hf_runner: Type[HfRunner],
  42. aphrodite_runner: Type[AphroditeRunner],
  43. image_assets: _ImageAssets,
  44. model: str,
  45. *,
  46. size_factors: List[float],
  47. dtype: str,
  48. max_tokens: int,
  49. num_logprobs: int,
  50. tensor_parallel_size: int,
  51. distributed_executor_backend: Optional[str] = None,
  52. ):
  53. """Inference result should be the same between hf and aphrodite.
  54. All the image fixtures for the test is under tests/images.
  55. For huggingface runner, we provide the PIL images as input.
  56. For aphrodite runner, we provide MultiModalDataDict objects
  57. and corresponding MultiModalConfig as input.
  58. Note, the text input is also adjusted to abide by aphrodite contract.
  59. The text output is sanitized to be able to compare with hf.
  60. """
  61. images = [asset.pil_image for asset in image_assets]
  62. inputs_per_image = [(
  63. [prompt for _ in size_factors],
  64. [rescale_image_size(image, factor) for factor in size_factors],
  65. ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
  66. # NOTE: take care of the order. run Aphrodite first, and then run HF.
  67. # Aphrodite needs a fresh new process without cuda initialization.
  68. # if we run HF first, the cuda initialization will be done and it
  69. # will hurt multiprocessing backend with fork method (the default method).
  70. # max_model_len should be greater than image_feature_size
  71. with aphrodite_runner(model,
  72. dtype=dtype,
  73. tensor_parallel_size=tensor_parallel_size,
  74. distributed_executor_backend=distributed_executor_backend,
  75. enforce_eager=True) as aphrodite_model:
  76. aphrodite_outputs_per_image = [
  77. aphrodite_model.generate_greedy_logprobs(prompts,
  78. max_tokens,
  79. num_logprobs=num_logprobs,
  80. images=images)
  81. for prompts, images in inputs_per_image
  82. ]
  83. with hf_runner(model, dtype=dtype,
  84. auto_cls=AutoModelForVision2Seq) as hf_model:
  85. hf_outputs_per_image = [
  86. hf_model.generate_greedy_logprobs_limit(prompts,
  87. max_tokens,
  88. num_logprobs=num_logprobs,
  89. images=images)
  90. for prompts, images in inputs_per_image
  91. ]
  92. for hf_outputs, aphrodite_outputs in zip(hf_outputs_per_image,
  93. aphrodite_outputs_per_image):
  94. check_logprobs_close(
  95. outputs_0_lst=hf_outputs,
  96. outputs_1_lst=[
  97. aphrodite_to_hf_output(aphrodite_output, model)
  98. for aphrodite_output in aphrodite_outputs
  99. ],
  100. name_0="hf",
  101. name_1="aphrodite",
  102. )
  103. @pytest.mark.parametrize("model", models)
  104. @pytest.mark.parametrize(
  105. "size_factors",
  106. [
  107. # No image
  108. [],
  109. # Single-scale
  110. [1.0],
  111. # Single-scale, batched
  112. [1.0, 1.0, 1.0],
  113. # Multi-scale
  114. [0.25, 0.5, 1.0],
  115. ],
  116. )
  117. @pytest.mark.parametrize("dtype", [
  118. pytest.param(
  119. "float",
  120. marks=pytest.mark.skipif(
  121. is_hip(),
  122. reason=
  123. "ROCm FA does not yet fully support 32-bit precision on PaliGemma")
  124. ), "half"
  125. ])
  126. @pytest.mark.parametrize("max_tokens", [128])
  127. @pytest.mark.parametrize("num_logprobs", [5])
  128. def test_models(hf_runner, aphrodite_runner, image_assets, model, size_factors,
  129. dtype: str, max_tokens: int, num_logprobs: int) -> None:
  130. run_test(
  131. hf_runner,
  132. aphrodite_runner,
  133. image_assets,
  134. model,
  135. size_factors=size_factors,
  136. dtype=dtype,
  137. max_tokens=max_tokens,
  138. num_logprobs=num_logprobs,
  139. tensor_parallel_size=1,
  140. )