123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- import base64
- from io import BytesIO
- from typing import Tuple, Union
- import librosa
- import numpy as np
- import soundfile
- from PIL import Image
- from aphrodite import envs
- from aphrodite.common.connections import global_http_connection
- from aphrodite.multimodal.base import MultiModalDataDict
- APHRODITE_IMAGE_FETCH_TIMEOUT = envs.APHRODITE_IMAGE_FETCH_TIMEOUT
- APHRODITE_AUDIO_FETCH_TIMEOUT = envs.APHRODITE_AUDIO_FETCH_TIMEOUT
- def _load_image_from_bytes(b: bytes):
- image = Image.open(BytesIO(b))
- image.load()
- return image
- def _load_image_from_data_url(image_url: str):
- # Only split once and assume the second part is the base64 encoded image
- _, image_base64 = image_url.split(",", 1)
- return load_image_from_base64(image_base64)
- def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image:
- """
- Load a PIL image from a HTTP or base64 data URL.
- By default, the image is converted into RGB format.
- """
- if image_url.startswith('http'):
- image_raw = global_http_connection.get_bytes(
- image_url, timeout=APHRODITE_IMAGE_FETCH_TIMEOUT)
- image = _load_image_from_bytes(image_raw)
- elif image_url.startswith('data:image'):
- image = _load_image_from_data_url(image_url)
- else:
- raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
- "with either 'data:image' or 'http'.")
- return image.convert(image_mode)
- async def async_fetch_image(image_url: str,
- *,
- image_mode: str = "RGB") -> Image.Image:
- """
- Asynchronously load a PIL image from a HTTP or base64 data URL.
- By default, the image is converted into RGB format.
- """
- if image_url.startswith('http'):
- image_raw = await global_http_connection.async_get_bytes(
- image_url, timeout=APHRODITE_IMAGE_FETCH_TIMEOUT)
- image = _load_image_from_bytes(image_raw)
- elif image_url.startswith('data:image'):
- image = _load_image_from_data_url(image_url)
- else:
- raise ValueError("Invalid 'image_url': A valid 'image_url' must start "
- "with either 'data:image' or 'http'.")
- return image.convert(image_mode)
- def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
- """
- Load audio from a URL.
- """
- if audio_url.startswith("http"):
- audio_bytes = global_http_connection.get_bytes(
- audio_url, timeout=APHRODITE_AUDIO_FETCH_TIMEOUT)
- elif audio_url.startswith("data:audio"):
- _, audio_base64 = audio_url.split(",", 1)
- audio_bytes = base64.b64decode(audio_base64)
- else:
- raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
- "with either 'data:audio' or 'http'.")
- return librosa.load(BytesIO(audio_bytes), sr=None)
- async def async_fetch_audio(
- audio_url: str) -> Tuple[np.ndarray, Union[int, float]]:
- """
- Asynchronously fetch audio from a URL.
- """
- if audio_url.startswith("http"):
- audio_bytes = await global_http_connection.async_get_bytes(
- audio_url, timeout=APHRODITE_AUDIO_FETCH_TIMEOUT)
- elif audio_url.startswith("data:audio"):
- _, audio_base64 = audio_url.split(",", 1)
- audio_bytes = base64.b64decode(audio_base64)
- else:
- raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start "
- "with either 'data:audio' or 'http'.")
- return librosa.load(BytesIO(audio_bytes), sr=None)
- async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict:
- audio, sr = await async_fetch_audio(audio_url)
- return {"audio": (audio, sr)}
- async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict:
- image = await async_fetch_image(image_url)
- return {"image": image}
- def encode_audio_base64(
- audio: np.ndarray,
- sampling_rate: int,
- ) -> str:
- """Encode audio as base64."""
- buffered = BytesIO()
- soundfile.write(buffered, audio, sampling_rate, format="WAV")
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
- def encode_image_base64(
- image: Image.Image,
- *,
- image_mode: str = "RGB",
- format: str = "JPEG",
- ) -> str:
- """
- Encode a pillow image to base64 format.
- By default, the image is converted into RGB format before being encoded.
- """
- buffered = BytesIO()
- image = image.convert(image_mode)
- image.save(buffered, format)
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
- def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
- """Load image from base64 format."""
- return _load_image_from_bytes(base64.b64decode(image))
- def rescale_image_size(image: Image.Image,
- size_factor: float,
- transpose: int = -1) -> Image.Image:
- """Rescale the dimensions of an image by a constant factor."""
- new_width = int(image.width * size_factor)
- new_height = int(image.height * size_factor)
- image = image.resize((new_width, new_height))
- if transpose >= 0:
- image = image.transpose(Image.Transpose(transpose))
- return image
|