video.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from dataclasses import dataclass
  2. from functools import lru_cache
  3. from typing import List, Optional
  4. import numpy as np
  5. import numpy.typing as npt
  6. from huggingface_hub import hf_hub_download
  7. from PIL import Image
  8. from aphrodite.multimodal.utils import (sample_frames_from_video,
  9. try_import_video_packages)
  10. from .base import get_cache_dir
  11. @lru_cache
  12. def download_video_asset(filename: str) -> str:
  13. """
  14. Download and open an image from huggingface
  15. repo: raushan-testing-hf/videos-test
  16. """
  17. video_directory = get_cache_dir() / "video-eample-data"
  18. video_directory.mkdir(parents=True, exist_ok=True)
  19. video_path = video_directory / filename
  20. video_path_str = str(video_path)
  21. if not video_path.exists():
  22. video_path_str = hf_hub_download(
  23. repo_id="raushan-testing-hf/videos-test",
  24. filename=filename,
  25. repo_type="dataset",
  26. cache_dir=video_directory,
  27. )
  28. return video_path_str
  29. def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
  30. cv2 = try_import_video_packages()
  31. cap = cv2.VideoCapture(path)
  32. if not cap.isOpened():
  33. raise ValueError(f"Could not open video file {path}")
  34. total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  35. frames = []
  36. for i in range(total_frames):
  37. ret, frame = cap.read()
  38. if ret:
  39. frames.append(frame)
  40. cap.release()
  41. frames = np.stack(frames)
  42. frames = sample_frames_from_video(frames, num_frames)
  43. if len(frames) < num_frames:
  44. raise ValueError(
  45. f"Could not read enough frames from video file {path}"
  46. f" (expected {num_frames} frames, got {len(frames)})"
  47. )
  48. return frames
  49. def video_to_pil_images_list(
  50. path: str, num_frames: int = -1
  51. ) -> List[Image.Image]:
  52. cv2 = try_import_video_packages()
  53. frames = video_to_ndarrays(path, num_frames)
  54. return [
  55. Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
  56. for frame in frames
  57. ]
  58. @dataclass(frozen=True)
  59. class VideoAsset:
  60. name: str = "sample_demo_1.mp4"
  61. num_frames: int = -1
  62. local_path: Optional[str] = None
  63. @property
  64. def pil_images(self) -> List[Image.Image]:
  65. video_path = (self.local_path if self.local_path else
  66. download_video_asset(self.name))
  67. ret = video_to_pil_images_list(video_path, self.num_frames)
  68. return ret
  69. @property
  70. def np_ndarrays(self) -> List[npt.NDArray]:
  71. video_path = (self.local_path if self.local_path else
  72. download_video_asset(self.name))
  73. ret = video_to_ndarrays(video_path, self.num_frames)
  74. return ret