123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import base64
- import os
- from io import BytesIO
- from typing import Optional, Union
- from urllib.parse import urlparse
- import aiohttp
- import requests
- from PIL import Image
- from aphrodite.multimodal.base import MultiModalDataDict
- from aphrodite.version import __version__ as APHRODITE_VERSION
- APHRODITE_IMAGE_FETCH_TIMEOUT = int(
- os.getenv("APHRODITE_IMAGE_FETCH_TIMEOUT", 10))
- def _validate_remote_url(url: str, *, name: str):
- parsed_url = urlparse(url)
- if parsed_url.scheme not in ["http", "https"]:
- raise ValueError(f"Invalid '{name}': A valid '{name}' "
- "must have scheme 'http' or 'https'.")
- def _get_request_headers():
- return {"User-Agent": f"aphrodite/{APHRODITE_VERSION}"}
- def _load_image_from_bytes(b: bytes):
- image = Image.open(BytesIO(b))
- image.load()
- return image
- def _load_image_from_data_url(image_url: str):
- # Only split once and assume the second part is the base64 encoded image
- _, image_base64 = image_url.split(",", 1)
- return load_image_from_base64(image_base64)
- def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
- """
- Load a PIL image from a HTTP or base64 data URL.
- By default, the image is converted into RGB format.
- """
- if image_url.startswith('http'):
- _validate_remote_url(image_url, name="image_url")
- headers = _get_request_headers()
- with requests.get(url=image_url, headers=headers) as response:
- response.raise_for_status()
- image_raw = response.content
- image = _load_image_from_bytes(image_raw)
- elif image_url.startswith('data:image'):
- image = _load_image_from_data_url(image_url)
- else:
- raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
- "with either 'data:image' or 'http'.")
- return image.convert(image_mode)
- class ImageFetchAiohttp:
- aiohttp_client: Optional[aiohttp.ClientSession] = None
- @classmethod
- def get_aiohttp_client(cls) -> aiohttp.ClientSession:
- if cls.aiohttp_client is None:
- timeout = aiohttp.ClientTimeout(
- total=APHRODITE_IMAGE_FETCH_TIMEOUT)
- connector = aiohttp.TCPConnector()
- cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
- connector=connector)
- return cls.aiohttp_client
- @classmethod
- async def fetch_image(
- cls,
- image_url: str,
- *,
- image_mode: str = "RGB",
- ) -> Image.Image:
- """
- Asynchronously load a PIL image from a HTTP or base64 data URL.
- By default, the image is converted into RGB format.
- """
- if image_url.startswith('http'):
- _validate_remote_url(image_url, name="image_url")
- client = cls.get_aiohttp_client()
- headers = _get_request_headers()
- async with client.get(url=image_url, headers=headers) as response:
- response.raise_for_status()
- image_raw = await response.read()
- image = _load_image_from_bytes(image_raw)
- # Only split once and assume the second part is the base64 encoded image
- elif image_url.startswith('data:image'):
- image = _load_image_from_data_url(image_url)
- else:
- raise ValueError(
- "Invalid 'image_url': A valid 'image_url' must start "
- "with either 'data:image' or 'http'.")
- return image.convert(image_mode)
- async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
- image = await ImageFetchAiohttp.fetch_image(image_url)
- return {"image": image}
- def encode_image_base64(
- image: Image.Image,
- *,
- image_mode: str = "RGB",
- format: str = "JPEG",
- ) -> str:
- """
- Encode a pillow image to base64 format.
- By default, the image is converted into RGB format before being encoded.
- """
- buffered = BytesIO()
- image = image.convert(image_mode)
- image.save(buffered, format)
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
- def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
- """Load image from base64 format."""
- return _load_image_from_bytes(base64.b64decode(image))
- def rescale_image_size(image: Image.Image, size_factor: float) -> Image.Image:
- """Rescale the dimensions of an image by a constant factor."""
- new_width = int(image.width * size_factor)
- new_height = int(image.height * size_factor)
- return image.resize((new_width, new_height))
|