test_utils.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import base64
  2. import mimetypes
  3. from tempfile import NamedTemporaryFile
  4. from typing import Dict, Tuple
  5. import numpy as np
  6. import pytest
  7. from PIL import Image
  8. from aphrodite.multimodal.utils import async_fetch_image, fetch_image
  9. # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
  10. TEST_IMAGE_URLS = [
  11. "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
  12. "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png",
  13. "https://upload.wikimedia.org/wikipedia/commons/thumb/9/91/Venn_diagram_rgb.svg/1280px-Venn_diagram_rgb.svg.png",
  14. "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
  15. ]
  16. @pytest.fixture(scope="module")
  17. def url_images() -> Dict[str, Image.Image]:
  18. return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
  19. def get_supported_suffixes() -> Tuple[str, ...]:
  20. # We should at least test the file types mentioned in GPT-4 with Vision
  21. OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif')
  22. # Additional file types that are supported by us
  23. EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff')
  24. return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES
  25. def _image_equals(a: Image.Image, b: Image.Image) -> bool:
  26. return (np.asarray(a) == np.asarray(b.convert(a.mode))).all()
  27. @pytest.mark.asyncio
  28. @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
  29. async def test_fetch_image_http(image_url: str):
  30. image_sync = fetch_image(image_url)
  31. image_async = await async_fetch_image(image_url)
  32. assert _image_equals(image_sync, image_async)
  33. @pytest.mark.asyncio
  34. @pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
  35. @pytest.mark.parametrize("suffix", get_supported_suffixes())
  36. async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
  37. image_url: str, suffix: str):
  38. url_image = url_images[image_url]
  39. try:
  40. mime_type = Image.MIME[Image.registered_extensions()[suffix]]
  41. except KeyError:
  42. try:
  43. mime_type = mimetypes.types_map[suffix]
  44. except KeyError:
  45. pytest.skip('No MIME type')
  46. with NamedTemporaryFile(suffix=suffix) as f:
  47. try:
  48. url_image.save(f.name)
  49. except Exception as e:
  50. if e.args[0] == 'cannot write mode RGBA as JPEG':
  51. pytest.skip('Conversion not supported')
  52. raise
  53. base64_image = base64.b64encode(f.read()).decode("utf-8")
  54. data_url = f"data:{mime_type};base64,{base64_image}"
  55. data_image_sync = fetch_image(data_url)
  56. if _image_equals(url_image, Image.open(f)):
  57. assert _image_equals(url_image, data_image_sync)
  58. else:
  59. pass # Lossy format; only check that image can be opened
  60. data_image_async = await async_fetch_image(data_url)
  61. assert _image_equals(data_image_sync, data_image_async)