utils.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import base64
  2. import os
  3. from io import BytesIO
  4. from typing import Union
  5. from PIL import Image
  6. from aphrodite.common.connections import global_http_connection
  7. from aphrodite.multimodal.base import MultiModalDataDict
  8. APHRODITE_IMAGE_FETCH_TIMEOUT = int(
  9. os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT", 10))
  10. def _load_image_from_bytes(b: bytes):
  11. image = Image.open(BytesIO(b))
  12. image.load()
  13. return image
  14. def _load_image_from_data_url(image_url: str):
  15. # Only split once and assume the second part is the base64 encoded image
  16. _, image_base64 = image_url.split(",", 1)
  17. return load_image_from_base64(image_base64)
  18. def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
  19. """
  20. Load a PIL image from a HTTP or base64 data URL.
  21. By default, the image is converted into RGB format.
  22. """
  23. if image_url.startswith('http'):
  24. image_raw = global_http_connection.get_bytes(
  25. image_url, timeout=APHRODITE_IMAGE_FETCH_TIMEOUT)
  26. image = _load_image_from_bytes(image_raw)
  27. elif image_url.startswith('data:image'):
  28. image = _load_image_from_data_url(image_url)
  29. else:
  30. raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
  31. "with either 'data:image' or 'http'.")
  32. return image.convert(image_mode)
  33. async def async_fetch_image(image_url: str,
  34. *,
  35. image_mode: str = "RGB") -> Image.Image:
  36. """
  37. Asynchronously load a PIL image from a HTTP or base64 data URL.
  38. By default, the image is converted into RGB format.
  39. """
  40. if image_url.startswith('http'):
  41. image_raw = await global_http_connection.async_get_bytes(
  42. image_url, timeout=APHRODITE_IMAGE_FETCH_TIMEOUT)
  43. image = _load_image_from_bytes(image_raw)
  44. elif image_url.startswith('data:image'):
  45. image = _load_image_from_data_url(image_url)
  46. else:
  47. raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
  48. "with either 'data:image' or 'http'.")
  49. return image.convert(image_mode)
  50. async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
  51. image = await async_fetch_image(image_url)
  52. return {"image": image}
  53. def encode_image_base64(
  54. image: Image.Image,
  55. *,
  56. image_mode: str = "RGB",
  57. format: str = "JPEG",
  58. ) -> str:
  59. """
  60. Encode a pillow image to base64 format.
  61. By default, the image is converted into RGB format before being encoded.
  62. """
  63. buffered = BytesIO()
  64. image = image.convert(image_mode)
  65. image.save(buffered, format)
  66. return base64.b64encode(buffered.getvalue()).decode('utf-8')
  67. def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
  68. """Load image from base64 format."""
  69. return _load_image_from_bytes(base64.b64decode(image))
  70. def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
  71. """Rescale the dimensions of an image by a constant factor."""
  72. new_width = int(image.width * size_factor)
  73. new_height = int(image.height * size_factor)
  74. return image.resize((new_width, new_height))