video.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from functools import lru_cache
  2. from typing import List, Union
  3. import numpy as np
  4. from loguru import logger
  5. from aphrodite.common.config import ModelConfig
  6. from aphrodite.common.utils import is_list_of
  7. from aphrodite.inputs.registry import InputContext
  8. from aphrodite.transformers_utils.image_processor import get_video_processor
  9. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  10. from .base import MultiModalData, MultiModalInputs
  11. from .image import ImagePlugin
  12. cached_get_video_processor = lru_cache(get_video_processor)
  13. cached_get_tokenizer = lru_cache(get_tokenizer)
  14. VideoInput = Union[
  15. "np.ndarray", # single video input
  16. List["np.ndarray"],
  17. # TODO: support more types
  18. # List[Image.Image], List[List[Image.Image]],
  19. # "torch.Tensor",
  20. # List["torch.Tensor"],
  21. # List[List["np.ndarrray"]],
  22. # List[List["torch.Tensor"]],
  23. ]
  24. class VideoPlugin(ImagePlugin):
  25. """Plugin for video data."""
  26. def get_data_key(self) -> str:
  27. return "video"
  28. def _get_hf_video_processor(self, model_config: ModelConfig):
  29. return cached_get_video_processor(
  30. model_config.model, trust_remote_code=model_config.trust_remote_code
  31. )
  32. def _default_input_mapper(
  33. self,
  34. ctx: InputContext,
  35. data: MultiModalData[object],
  36. ) -> MultiModalInputs:
  37. model_config = ctx.model_config
  38. # single video input as np.ndarray
  39. if isinstance(data, np.ndarray):
  40. video_processor = self._get_hf_video_processor(model_config)
  41. if video_processor is None:
  42. raise RuntimeError(
  43. "No HuggingFace processor is available "
  44. "to process the image object"
  45. )
  46. try:
  47. batch_data = video_processor(data, return_tensors="pt").data
  48. except Exception:
  49. logger.error(f"Failed to process image ({data})")
  50. raise
  51. return MultiModalInputs(batch_data)
  52. elif is_list_of(data, np.ndarray):
  53. raise NotImplementedError(
  54. "Multi video for a prompt is not supported yet"
  55. )
  56. raise TypeError(f"Invalid video type: {type(data)}")
  57. def _default_max_multimodal_tokens(self, ctx: InputContext) -> int:
  58. return 4096