llava_next_video.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import itertools
  2. import math
  3. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  4. TypedDict, Union)
  5. import numpy as np
  6. import torch
  7. import torch.nn as nn
  8. from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
  9. SiglipVisionConfig)
  10. from aphrodite.attention import AttentionMetadata
  11. from aphrodite.common.config import CacheConfig, MultiModalConfig
  12. from aphrodite.common.sequence import IntermediateTensors
  13. from aphrodite.common.utils import is_list_of
  14. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  15. from aphrodite.modeling.layers.activation import get_act_fn
  16. from aphrodite.modeling.layers.sampler import SamplerOutput
  17. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  18. from aphrodite.modeling.models.clip import CLIPVisionModel
  19. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  20. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  21. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  22. repeat_and_pad_placeholder_tokens)
  23. from aphrodite.quantization.base_config import QuantizationConfig
  24. from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
  25. from .interfaces import SupportsMultiModal
  26. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  27. dummy_seq_data_for_siglip)
  28. from .utils import (filter_weights, init_aphrodite_registered_model,
  29. merge_multimodal_embeddings)
  30. # For profile run
  31. _MAX_FRAMES_PER_VIDEO = 32
  32. _MAX_NUM_VIDEOS = 1
  33. class LlavaNextVideoPixelInputs(TypedDict):
  34. type: Literal["pixel_values_videos"]
  35. data: Union[torch.Tensor, List[torch.Tensor]]
  36. """
  37. Shape: `(batch_size, num_frames, num_channels, height, width)`
  38. Note that `num_frames` may be different for each batch, in which case
  39. the data is passed as a list instead of a batched tensor.
  40. Note that it only supports one video input for one batch.
  41. """
  42. def get_llava_next_video_frame_feature_size(
  43. hf_config: LlavaNextVideoConfig
  44. ) -> int:
  45. # Support both CLIPVisionConfig and SiglipVisionConfig
  46. image_size = hf_config.vision_config.image_size
  47. patch_size = hf_config.vision_config.patch_size
  48. spatial_pool_stride = hf_config.spatial_pool_stride
  49. return int((image_size / patch_size / spatial_pool_stride) ** 2)
  50. def _get_max_llm_tokens(ctx: InputContext) -> int:
  51. """
  52. Calculated from the maximum video frames under the context length
  53. constraints of the language model.
  54. """
  55. hf_text_config = ctx.model_config.hf_text_config
  56. model_config = ctx.model_config
  57. max_tokens = model_config.max_model_len
  58. rope_scaling = model_config.rope_scaling
  59. if rope_scaling:
  60. rope_scaling_factor = hf_text_config.rope_scaling["factor"]
  61. else:
  62. rope_scaling_factor = 1
  63. max_tokens *= rope_scaling_factor
  64. return max_tokens
  65. def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
  66. # Currently set to 32 frames
  67. # TODO: max_tokens = _get_max_llm_tokens(ctx)
  68. hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
  69. tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
  70. return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
  71. def dummy_data_for_llava_next_video(
  72. ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
  73. ):
  74. hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
  75. vision_config = hf_config.vision_config
  76. # TODO: support multiple videos
  77. num_videos = mm_counts["video"]
  78. if num_videos != _MAX_NUM_VIDEOS:
  79. raise NotImplementedError(
  80. f"Only {_MAX_NUM_VIDEOS} videos are supported"
  81. )
  82. # TODO: support configuring the number of frames
  83. frames_per_video = _MAX_FRAMES_PER_VIDEO
  84. # num_images = num_videos * frames_per_video
  85. # fills the sequence with as longer video data as possible
  86. tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
  87. video_feature_size = frames_per_video * tokens_per_frame
  88. if isinstance(vision_config, CLIPVisionConfig):
  89. seq_data = dummy_seq_data_for_clip(
  90. vision_config,
  91. seq_len,
  92. num_videos,
  93. image_token_id=hf_config.video_token_index,
  94. image_feature_size_override=video_feature_size,
  95. )
  96. pil_frame = dummy_image_for_clip(vision_config, num_images=1)
  97. np_frame = np.array(pil_frame["image"])
  98. mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
  99. mm_data = {"video": mm_data_per_video}
  100. return seq_data, mm_data
  101. elif isinstance(vision_config, SiglipVisionConfig):
  102. seq_data = dummy_seq_data_for_siglip(
  103. vision_config,
  104. seq_len,
  105. num_videos,
  106. image_token_id=hf_config.video_token_index,
  107. image_feature_size_override=video_feature_size,
  108. )
  109. pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
  110. np_frame = np.array(pil_frame["image"])
  111. mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
  112. mm_data = {"video": mm_data_per_video}
  113. return seq_data, mm_data
  114. msg = f"Unsupported vision config: {type(vision_config)}"
  115. raise NotImplementedError(msg)
  116. def input_processor_for_llava_next_video(
  117. ctx: InputContext, llm_inputs: LLMInputs
  118. ):
  119. multi_modal_data = llm_inputs.get("multi_modal_data")
  120. if multi_modal_data is None or "video" not in multi_modal_data:
  121. return llm_inputs
  122. video_data = multi_modal_data["video"]
  123. model_config = ctx.model_config
  124. hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
  125. vision_config = hf_config.vision_config
  126. if isinstance(video_data, np.ndarray):
  127. # Supports both CLIP and Siglip
  128. num_frames = video_data.shape[0]
  129. frame_feature_size = get_llava_next_video_frame_feature_size(hf_config)
  130. video_feature_size = num_frames * frame_feature_size
  131. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  132. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  133. tokenizer,
  134. llm_inputs.get("prompt"),
  135. llm_inputs["prompt_token_ids"],
  136. placeholder_token_id=hf_config.video_token_index,
  137. repeat_count=video_feature_size,
  138. )
  139. return LLMInputs(
  140. prompt_token_ids=new_token_ids,
  141. prompt=new_prompt,
  142. multi_modal_data=multi_modal_data,
  143. )
  144. elif is_list_of(video_data, np.ndarray):
  145. raise NotImplementedError("Processing multiple videos is not supported")
  146. msg = f"Unsupported vision config: {type(vision_config)}"
  147. raise NotImplementedError(msg)
  148. def _init_vision_tower(hf_config: LlavaNextVideoConfig):
  149. vision_config = hf_config.vision_config
  150. # Initialize the vision tower only up to the required feature layer
  151. vision_feature_layer = hf_config.vision_feature_layer
  152. if vision_feature_layer < 0:
  153. num_hidden_layers = (
  154. hf_config.vision_config.num_hidden_layers + vision_feature_layer + 1
  155. )
  156. else:
  157. num_hidden_layers = vision_feature_layer + 1
  158. if isinstance(vision_config, CLIPVisionConfig):
  159. return CLIPVisionModel(
  160. vision_config,
  161. num_hidden_layers_override=num_hidden_layers,
  162. )
  163. elif isinstance(vision_config, SiglipVisionConfig):
  164. return SiglipVisionModel(
  165. vision_config,
  166. num_hidden_layers_override=num_hidden_layers,
  167. )
  168. msg = f"Unsupported vision config: {type(vision_config)}"
  169. raise NotImplementedError(msg)
  170. # adopted from transformers modeling_llava_next_video.py
  171. class LlavaNextVideoPooler(nn.Module):
  172. def __init__(self, config):
  173. super().__init__()
  174. mode = config.spatial_pool_mode
  175. stride = config.spatial_pool_stride
  176. image_size = config.vision_config.image_size
  177. patch_size = config.vision_config.patch_size
  178. self.image_size = image_size // patch_size**2
  179. if mode == "average":
  180. self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
  181. elif mode == "max":
  182. self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
  183. else:
  184. # TODO: Support Conv2d pooling layer, need to load weights
  185. raise ValueError(
  186. f"Unknown pooling mode: {mode}. Expected [`average`, `max`]"
  187. )
  188. def forward(self, image_features):
  189. ori_width = int(
  190. math.sqrt(
  191. image_features.shape[1] * self.image_size // self.image_size
  192. )
  193. )
  194. ori_height = int(ori_width * self.image_size // self.image_size)
  195. batch_size, _, dim = image_features.shape
  196. image_features_spatial = image_features.view(
  197. batch_size, ori_height, ori_height, dim
  198. ).permute(0, 3, 1, 2)
  199. image_features_spatial = self.pool(image_features_spatial)
  200. return image_features_spatial.flatten(2).transpose(1, 2).contiguous()
  201. class LlavaNextMultiModalProjector(nn.Module):
  202. def __init__(
  203. self,
  204. vision_hidden_size: int,
  205. text_hidden_size: int,
  206. projector_hidden_act: str,
  207. ):
  208. super().__init__()
  209. self.linear_1 = nn.Linear(
  210. vision_hidden_size, text_hidden_size, bias=True
  211. )
  212. self.act = get_act_fn(projector_hidden_act)
  213. self.linear_2 = nn.Linear(text_hidden_size, text_hidden_size, bias=True)
  214. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  215. hidden_states = self.linear_1(image_features)
  216. hidden_states = self.act(hidden_states)
  217. hidden_states = self.linear_2(hidden_states)
  218. return hidden_states
  219. @MULTIMODAL_REGISTRY.register_input_mapper("video")
  220. @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
  221. "video", get_max_llava_next_video_tokens
  222. )
  223. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
  224. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
  225. class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
  226. def __init__(
  227. self,
  228. config: LlavaNextVideoConfig,
  229. multimodal_config: MultiModalConfig,
  230. cache_config: Optional[CacheConfig] = None,
  231. quant_config: Optional[QuantizationConfig] = None,
  232. ) -> None:
  233. super().__init__()
  234. self.config = config
  235. self.multimodal_config = multimodal_config
  236. # Initialize the vision tower only up to the required feature layer
  237. self.vision_tower = _init_vision_tower(config)
  238. self.multi_modal_projector = LlavaNextMultiModalProjector(
  239. vision_hidden_size=config.vision_config.hidden_size,
  240. text_hidden_size=config.text_config.hidden_size,
  241. projector_hidden_act=config.projector_hidden_act,
  242. )
  243. self.language_model = init_aphrodite_registered_model(
  244. config.text_config, cache_config, quant_config
  245. )
  246. self.vision_resampler = LlavaNextVideoPooler(config)
  247. def _validate_video_pixel_values(
  248. self, data: Union[torch.Tensor, List[torch.Tensor]]
  249. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  250. h = w = self.config.vision_config.image_size
  251. expected_dims = (3, h, w)
  252. def _validate_shape(d: torch.Tensor):
  253. actual_dims = tuple(d.shape[2:])
  254. if actual_dims != expected_dims:
  255. expected_expr = ("num_frames", *map(str, expected_dims))
  256. raise ValueError(
  257. "The expected shape of pixel values in each video frame "
  258. f"is {expected_expr}. You supplied {tuple(d.shape)}."
  259. )
  260. for d in data:
  261. _validate_shape(d)
  262. return data
  263. def _parse_and_validate_video_input(
  264. self, **kwargs: object
  265. ) -> Optional[LlavaNextVideoPixelInputs]:
  266. """
  267. A legal video input should have the following dimensions:
  268. {
  269. "pixel_values_videos" :
  270. List[b, Tensor(nb_frames, nb_channels, height, width)]
  271. }
  272. """
  273. pixel_values = kwargs.pop("pixel_values_videos", None)
  274. if pixel_values is None:
  275. return None
  276. if not (
  277. is_list_of(pixel_values, (torch.Tensor)) # different shape videos
  278. or isinstance(pixel_values, torch.Tensor)
  279. ): # same shape videos
  280. raise ValueError(
  281. "Incorrect type of pixel values. "
  282. f"Got type: {type(pixel_values)}"
  283. )
  284. return LlavaNextVideoPixelInputs(
  285. type="pixel_values_videos",
  286. data=pixel_values,
  287. )
  288. def _select_image_features(
  289. self, image_features: torch.Tensor, *, strategy: str
  290. ) -> torch.Tensor:
  291. if strategy == "default":
  292. return image_features[:, 1:]
  293. elif strategy == "full":
  294. return image_features
  295. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  296. def _video_pixels_to_features(
  297. self,
  298. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  299. pixel_values: torch.Tensor,
  300. ) -> torch.Tensor:
  301. # NOTE: we skip the step to select the vision feature layer since
  302. # this is already done inside the vision tower
  303. image_features = vision_tower(pixel_values)
  304. image_features = self._select_image_features(
  305. image_features,
  306. strategy=self.config.vision_feature_select_strategy,
  307. )
  308. image_features = self.vision_resampler(image_features)
  309. image_features = self.multi_modal_projector(image_features)
  310. return image_features
  311. def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
  312. assert self.vision_tower is not None
  313. video_pixels = inputs["data"]
  314. if isinstance(video_pixels, torch.Tensor):
  315. # TODO: support multiple videos per input
  316. b, num_videos, num_frames, c, h, w = video_pixels.shape
  317. assert num_videos == 1
  318. stacked_pixels = video_pixels.view(
  319. b * num_videos * num_frames, c, h, w
  320. )
  321. stacked_embeddings = self._video_pixels_to_features(
  322. self.vision_tower, stacked_pixels
  323. )
  324. return stacked_embeddings.view(
  325. b, num_frames, *stacked_embeddings.shape[1:]
  326. )
  327. elif is_list_of(video_pixels, torch.Tensor):
  328. frames_per_videos = [v.shape[0] for v in video_pixels]
  329. stacked_pixels = torch.cat(video_pixels, dim=0)
  330. stacked_embeddings = self._video_pixels_to_features(
  331. self.vision_tower, stacked_pixels
  332. )
  333. return torch.split(stacked_embeddings, frames_per_videos, dim=0)
  334. else:
  335. raise ValueError(
  336. f"Unsupported type of video input {type(video_pixels)}"
  337. )
  338. def forward(
  339. self,
  340. input_ids: torch.Tensor,
  341. positions: torch.Tensor,
  342. kv_caches: List[torch.Tensor],
  343. attn_metadata: AttentionMetadata,
  344. intermediate_tensors: Optional[IntermediateTensors] = None,
  345. **kwargs: object,
  346. ) -> SamplerOutput:
  347. """Run forward pass for LlaVA-NeXT-Video.
  348. Args:
  349. input_ids: Flattened (concatenated) input_ids corresponding to a
  350. batch.
  351. pixel_values_videos: Pixels in each frames for each input videos.
  352. """
  353. video_input = self._parse_and_validate_video_input(**kwargs)
  354. # merge video embeddings into input embeddings
  355. if video_input is not None:
  356. video_embeddings = self._process_video_pixels(video_input)
  357. inputs_embeds = self.language_model.model.get_input_embeddings(
  358. input_ids
  359. )
  360. inputs_embeds = merge_multimodal_embeddings(
  361. input_ids,
  362. inputs_embeds,
  363. video_embeddings,
  364. self.config.video_token_index,
  365. )
  366. input_ids = None
  367. else:
  368. inputs_embeds = None
  369. hidden_states = self.language_model.model(
  370. input_ids,
  371. positions,
  372. kv_caches,
  373. attn_metadata,
  374. None,
  375. inputs_embeds=inputs_embeds,
  376. )
  377. return hidden_states
  378. def compute_logits(
  379. self,
  380. hidden_states: torch.Tensor,
  381. sampling_metadata: SamplingMetadata,
  382. ) -> Optional[torch.Tensor]:
  383. return self.language_model.compute_logits(
  384. hidden_states, sampling_metadata
  385. )
  386. def sample(
  387. self,
  388. logits: torch.Tensor,
  389. sampling_metadata: SamplingMetadata,
  390. ) -> Optional[SamplerOutput]:
  391. return self.language_model.sample(logits, sampling_metadata)
  392. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  393. # prepare weight iterators
  394. vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
  395. weights, 4
  396. )
  397. # load vision encoder
  398. vit_weights = filter_weights(vit_weights, "vision_tower")
  399. self.vision_tower.load_weights(vit_weights)
  400. # load mlp projector
  401. mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
  402. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  403. for name, loaded_weight in mlp_weights:
  404. param = mlp_params_dict[name]
  405. weight_loader = getattr(
  406. param, "weight_loader", default_weight_loader
  407. )
  408. weight_loader(param, loaded_weight)
  409. # load llm backbone
  410. llm_weights = filter_weights(llm_weights, "language_model")
  411. self.language_model.load_weights(llm_weights)