test_utils.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import base64
  2. import mimetypes
  3. from tempfile import NamedTemporaryFile
  4. from typing import Dict, Tuple
  5. import numpy as np
  6. import pytest
  7. from PIL import Image
  8. from transformers import AutoConfig, AutoTokenizer
  9. from aphrodite.multimodal.utils import (async_fetch_image, fetch_image,
  10. repeat_and_pad_placeholder_tokens)
  11. # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
  12. TEST_IMAGE_URLS = [
  13. "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
  14. "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
  15. "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
  16. "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
  17. ]
  18. @pytest.fixture(scope="module")
  19. def url_images() -> Dict[str, Image.Image]:
  20. return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
  21. def get_supported_suffixes() -> Tuple[str, ...]:
  22. # We should at least test the file types mentioned in GPT-4 with Vision
  23. OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif')
  24. # Additional file types that are supported by us
  25. EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff')
  26. return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES
  27. def _image_equals(a: Image.Image, b: Image.Image) -> bool:
  28. return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
  29. @pytest.mark.asyncio
  30. @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
  31. async def test_fetch_image_http(image_url: str):
  32. image_sync = fetch_image(image_url)
  33. image_async = await async_fetch_image(image_url)
  34. assert _image_equals(image_sync, image_async)
  35. @pytest.mark.asyncio
  36. @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
  37. @pytest.mark.parametrize("suffix", get_supported_suffixes())
  38. async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
  39. image_url: str, suffix: str):
  40. url_image = url_images[image_url]
  41. try:
  42. mime_type = Image.MIME[Image.registered_extensions()[suffix]]
  43. except KeyError:
  44. try:
  45. mime_type = mimetypes.types_map[suffix]
  46. except KeyError:
  47. pytest.skip('No MIME type')
  48. with NamedTemporaryFile(suffix=suffix) as f:
  49. try:
  50. url_image.save(f.name)
  51. except Exception as e:
  52. if e.args[0] == 'cannot write mode RGBA as JPEG':
  53. pytest.skip('Conversion not supported')
  54. raise
  55. base64_image = base64.b64encode(f.read()).decode("utf-8")
  56. data_url = f"data:{mime_type};base64,{base64_image}"
  57. data_image_sync = fetch_image(data_url)
  58. if _image_equals(url_image, Image.open(f)):
  59. assert _image_equals(url_image, data_image_sync)
  60. else:
  61. pass # Lossy format; only check that image can be opened
  62. data_image_async = await async_fetch_image(data_url)
  63. assert _image_equals(data_image_sync, data_image_async)
  64. @pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
  65. def test_repeat_and_pad_placeholder_tokens(model):
  66. config = AutoConfig.from_pretrained(model)
  67. image_token_id = config.image_token_index
  68. tokenizer = AutoTokenizer.from_pretrained(model)
  69. test_cases = [
  70. ("<image>", 2, "<image><image>", [32000, 32000]),
  71. ("<image><image>", 2, "<image><image><image>", [32000, 32000, 32000]),
  72. ("<image><image>", [3, 2], "<image><image><image><image><image>",
  73. [32000, 32000, 32000, 32000, 32000]),
  74. ("Image:<image>Image:<image>!", [3, 2],
  75. "Image:<image><image><image>Image:<image><image>!",
  76. [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]),
  77. ("<image>", [3, 2], "<image><image><image>", [32000, 32000, 32000]),
  78. ]
  79. for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases:
  80. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  81. tokenizer=tokenizer,
  82. prompt=prompt,
  83. prompt_token_ids=tokenizer.encode(prompt,
  84. add_special_tokens=False),
  85. placeholder_token_id=image_token_id,
  86. repeat_count=repeat_count,
  87. )
  88. assert new_prompt == expected_prompt
  89. assert new_token_ids == expected_token_ids