image.py 861 B

123456789101112131415161718192021222324252627282930
  1. from dataclasses import dataclass
  2. from typing import Literal
  3. import torch
  4. from PIL import Image
  5. from aphrodite.assets.base import get_vllm_public_assets
  6. VLM_IMAGES_DIR = "vision_model_images"
  7. @dataclass(frozen=True)
  8. class ImageAsset:
  9. name: Literal["stop_sign", "cherry_blossom"]
  10. @property
  11. def pil_image(self) -> Image.Image:
  12. image_path = get_vllm_public_assets(filename=f"{self.name}.jpg",
  13. s3_prefix=VLM_IMAGES_DIR)
  14. return Image.open(image_path)
  15. @property
  16. def image_embeds(self) -> torch.Tensor:
  17. """
  18. Image embeddings, only used for testing purposes with llava 1.5.
  19. """
  20. image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
  21. s3_prefix=VLM_IMAGES_DIR)
  22. return torch.load(image_path)