1
0

llava_next_video.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. import math
  2. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  3. TypedDict, Union)
  4. import numpy as np
  5. import torch
  6. import torch.nn as nn
  7. from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
  8. SiglipVisionConfig)
  9. from aphrodite.attention import AttentionMetadata
  10. from aphrodite.common.config import CacheConfig, MultiModalConfig
  11. from aphrodite.common.sequence import IntermediateTensors
  12. from aphrodite.common.utils import is_list_of
  13. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  14. from aphrodite.modeling.layers.activation import get_act_fn
  15. from aphrodite.modeling.layers.sampler import SamplerOutput
  16. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  17. from aphrodite.modeling.models.clip import CLIPVisionModel
  18. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  19. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  20. from aphrodite.multimodal.utils import (cached_get_tokenizer,
  21. repeat_and_pad_placeholder_tokens)
  22. from aphrodite.quantization.base_config import QuantizationConfig
  23. from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
  24. from .interfaces import SupportsMultiModal
  25. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  26. dummy_seq_data_for_siglip)
  27. from .utils import (group_weights_with_prefix, init_aphrodite_registered_model,
  28. merge_multimodal_embeddings)
  29. # For profile run
  30. _MAX_FRAMES_PER_VIDEO = 32
  31. _MAX_NUM_VIDEOS = 1
  32. class LlavaNextVideoPixelInputs(TypedDict):
  33. type: Literal["pixel_values_videos"]
  34. data: Union[torch.Tensor, List[torch.Tensor]]
  35. """
  36. Shape: `(batch_size, num_frames, num_channels, height, width)`
  37. Note that `num_frames` may be different for each batch, in which case
  38. the data is passed as a list instead of a batched tensor.
  39. Note that it only supports one video input for one batch.
  40. """
  41. def get_llava_next_video_frame_feature_size(
  42. hf_config: LlavaNextVideoConfig
  43. ) -> int:
  44. # Support both CLIPVisionConfig and SiglipVisionConfig
  45. image_size = hf_config.vision_config.image_size
  46. patch_size = hf_config.vision_config.patch_size
  47. spatial_pool_stride = hf_config.spatial_pool_stride
  48. return int((image_size / patch_size / spatial_pool_stride) ** 2)
  49. def _get_max_llm_tokens(ctx: InputContext) -> int:
  50. """
  51. Calculated from the maximum video frames under the context length
  52. constraints of the language model.
  53. """
  54. hf_text_config = ctx.model_config.hf_text_config
  55. model_config = ctx.model_config
  56. max_tokens = model_config.max_model_len
  57. rope_scaling = model_config.rope_scaling
  58. if rope_scaling:
  59. rope_scaling_factor = hf_text_config.rope_scaling["factor"]
  60. else:
  61. rope_scaling_factor = 1
  62. max_tokens *= rope_scaling_factor
  63. return max_tokens
  64. def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
  65. # Currently set to 32 frames
  66. # TODO: max_tokens = _get_max_llm_tokens(ctx)
  67. hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
  68. tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
  69. return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
  70. def dummy_data_for_llava_next_video(
  71. ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
  72. ):
  73. hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
  74. vision_config = hf_config.vision_config
  75. # TODO: support multiple videos
  76. num_videos = mm_counts["video"]
  77. if num_videos != _MAX_NUM_VIDEOS:
  78. raise NotImplementedError(
  79. f"Only {_MAX_NUM_VIDEOS} videos are supported"
  80. )
  81. # TODO: support configuring the number of frames
  82. frames_per_video = _MAX_FRAMES_PER_VIDEO
  83. # num_images = num_videos * frames_per_video
  84. # fills the sequence with as longer video data as possible
  85. tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
  86. video_feature_size = frames_per_video * tokens_per_frame
  87. if isinstance(vision_config, CLIPVisionConfig):
  88. seq_data = dummy_seq_data_for_clip(
  89. vision_config,
  90. seq_len,
  91. num_videos,
  92. image_token_id=hf_config.video_token_index,
  93. image_feature_size_override=video_feature_size,
  94. )
  95. pil_frame = dummy_image_for_clip(vision_config, num_images=1)
  96. np_frame = np.array(pil_frame["image"])
  97. mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
  98. mm_data = {"video": mm_data_per_video}
  99. return seq_data, mm_data
  100. elif isinstance(vision_config, SiglipVisionConfig):
  101. seq_data = dummy_seq_data_for_siglip(
  102. vision_config,
  103. seq_len,
  104. num_videos,
  105. image_token_id=hf_config.video_token_index,
  106. image_feature_size_override=video_feature_size,
  107. )
  108. pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
  109. np_frame = np.array(pil_frame["image"])
  110. mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
  111. mm_data = {"video": mm_data_per_video}
  112. return seq_data, mm_data
  113. msg = f"Unsupported vision config: {type(vision_config)}"
  114. raise NotImplementedError(msg)
  115. def input_processor_for_llava_next_video(
  116. ctx: InputContext, llm_inputs: LLMInputs
  117. ):
  118. multi_modal_data = llm_inputs.get("multi_modal_data")
  119. if multi_modal_data is None or "video" not in multi_modal_data:
  120. return llm_inputs
  121. video_data = multi_modal_data["video"]
  122. model_config = ctx.model_config
  123. hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
  124. vision_config = hf_config.vision_config
  125. if isinstance(video_data, np.ndarray):
  126. # Supports both CLIP and Siglip
  127. num_frames = video_data.shape[0]
  128. frame_feature_size = get_llava_next_video_frame_feature_size(hf_config)
  129. video_feature_size = num_frames * frame_feature_size
  130. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  131. new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
  132. tokenizer,
  133. llm_inputs.get("prompt"),
  134. llm_inputs["prompt_token_ids"],
  135. placeholder_token_id=hf_config.video_token_index,
  136. repeat_count=video_feature_size,
  137. )
  138. return LLMInputs(
  139. prompt_token_ids=new_token_ids,
  140. prompt=new_prompt,
  141. multi_modal_data=multi_modal_data,
  142. )
  143. elif is_list_of(video_data, np.ndarray):
  144. raise NotImplementedError("Processing multiple videos is not supported")
  145. msg = f"Unsupported vision config: {type(vision_config)}"
  146. raise NotImplementedError(msg)
  147. def _init_vision_tower(hf_config: LlavaNextVideoConfig):
  148. vision_config = hf_config.vision_config
  149. # Initialize the vision tower only up to the required feature layer
  150. vision_feature_layer = hf_config.vision_feature_layer
  151. if vision_feature_layer < 0:
  152. num_hidden_layers = (
  153. hf_config.vision_config.num_hidden_layers + vision_feature_layer + 1
  154. )
  155. else:
  156. num_hidden_layers = vision_feature_layer + 1
  157. if isinstance(vision_config, CLIPVisionConfig):
  158. return CLIPVisionModel(
  159. vision_config,
  160. num_hidden_layers_override=num_hidden_layers,
  161. )
  162. elif isinstance(vision_config, SiglipVisionConfig):
  163. return SiglipVisionModel(
  164. vision_config,
  165. num_hidden_layers_override=num_hidden_layers,
  166. )
  167. msg = f"Unsupported vision config: {type(vision_config)}"
  168. raise NotImplementedError(msg)
  169. # adopted from transformers modeling_llava_next_video.py
  170. class LlavaNextVideoPooler(nn.Module):
  171. def __init__(self, config):
  172. super().__init__()
  173. mode = config.spatial_pool_mode
  174. stride = config.spatial_pool_stride
  175. image_size = config.vision_config.image_size
  176. patch_size = config.vision_config.patch_size
  177. self.image_size = image_size // patch_size**2
  178. if mode == "average":
  179. self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
  180. elif mode == "max":
  181. self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
  182. else:
  183. # TODO: Support Conv2d pooling layer, need to load weights
  184. raise ValueError(
  185. f"Unknown pooling mode: {mode}. Expected [`average`, `max`]"
  186. )
  187. def forward(self, image_features):
  188. ori_width = int(
  189. math.sqrt(
  190. image_features.shape[1] * self.image_size // self.image_size
  191. )
  192. )
  193. ori_height = int(ori_width * self.image_size // self.image_size)
  194. batch_size, _, dim = image_features.shape
  195. image_features_spatial = image_features.view(
  196. batch_size, ori_height, ori_height, dim
  197. ).permute(0, 3, 1, 2)
  198. image_features_spatial = self.pool(image_features_spatial)
  199. return image_features_spatial.flatten(2).transpose(1, 2).contiguous()
  200. class LlavaNextMultiModalProjector(nn.Module):
  201. def __init__(
  202. self,
  203. vision_hidden_size: int,
  204. text_hidden_size: int,
  205. projector_hidden_act: str,
  206. ):
  207. super().__init__()
  208. self.linear_1 = nn.Linear(
  209. vision_hidden_size, text_hidden_size, bias=True
  210. )
  211. self.act = get_act_fn(projector_hidden_act)
  212. self.linear_2 = nn.Linear(text_hidden_size, text_hidden_size, bias=True)
  213. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  214. hidden_states = self.linear_1(image_features)
  215. hidden_states = self.act(hidden_states)
  216. hidden_states = self.linear_2(hidden_states)
  217. return hidden_states
  218. @MULTIMODAL_REGISTRY.register_input_mapper("video")
  219. @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
  220. "video", get_max_llava_next_video_tokens
  221. )
  222. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
  223. @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
  224. class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
  225. def __init__(
  226. self,
  227. config: LlavaNextVideoConfig,
  228. multimodal_config: MultiModalConfig,
  229. cache_config: Optional[CacheConfig] = None,
  230. quant_config: Optional[QuantizationConfig] = None,
  231. ) -> None:
  232. super().__init__()
  233. self.config = config
  234. self.multimodal_config = multimodal_config
  235. # Initialize the vision tower only up to the required feature layer
  236. self.vision_tower = _init_vision_tower(config)
  237. self.multi_modal_projector = LlavaNextMultiModalProjector(
  238. vision_hidden_size=config.vision_config.hidden_size,
  239. text_hidden_size=config.text_config.hidden_size,
  240. projector_hidden_act=config.projector_hidden_act,
  241. )
  242. self.language_model = init_aphrodite_registered_model(
  243. config.text_config, cache_config, quant_config
  244. )
  245. self.vision_resampler = LlavaNextVideoPooler(config)
  246. def _validate_video_pixel_values(
  247. self, data: Union[torch.Tensor, List[torch.Tensor]]
  248. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  249. h = w = self.config.vision_config.image_size
  250. expected_dims = (3, h, w)
  251. def _validate_shape(d: torch.Tensor):
  252. actual_dims = tuple(d.shape[2:])
  253. if actual_dims != expected_dims:
  254. expected_expr = ("num_frames", *map(str, expected_dims))
  255. raise ValueError(
  256. "The expected shape of pixel values in each video frame "
  257. f"is {expected_expr}. You supplied {tuple(d.shape)}."
  258. )
  259. for d in data:
  260. _validate_shape(d)
  261. return data
  262. def _parse_and_validate_video_input(
  263. self, **kwargs: object
  264. ) -> Optional[LlavaNextVideoPixelInputs]:
  265. """
  266. A legal video input should have the following dimensions:
  267. {
  268. "pixel_values_videos" :
  269. List[b, Tensor(nb_frames, nb_channels, height, width)]
  270. }
  271. """
  272. pixel_values = kwargs.pop("pixel_values_videos", None)
  273. if pixel_values is None:
  274. return None
  275. if not (
  276. is_list_of(pixel_values, (torch.Tensor)) # different shape videos
  277. or isinstance(pixel_values, torch.Tensor)
  278. ): # same shape videos
  279. raise ValueError(
  280. "Incorrect type of pixel values. "
  281. f"Got type: {type(pixel_values)}"
  282. )
  283. return LlavaNextVideoPixelInputs(
  284. type="pixel_values_videos",
  285. data=pixel_values,
  286. )
  287. def _select_image_features(
  288. self, image_features: torch.Tensor, *, strategy: str
  289. ) -> torch.Tensor:
  290. if strategy == "default":
  291. return image_features[:, 1:]
  292. elif strategy == "full":
  293. return image_features
  294. raise ValueError(f"Unexpected select feature strategy: {strategy}")
  295. def _video_pixels_to_features(
  296. self,
  297. vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
  298. pixel_values: torch.Tensor,
  299. ) -> torch.Tensor:
  300. # NOTE: we skip the step to select the vision feature layer since
  301. # this is already done inside the vision tower
  302. image_features = vision_tower(pixel_values)
  303. image_features = self._select_image_features(
  304. image_features,
  305. strategy=self.config.vision_feature_select_strategy,
  306. )
  307. image_features = self.vision_resampler(image_features)
  308. image_features = self.multi_modal_projector(image_features)
  309. return image_features
  310. def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
  311. assert self.vision_tower is not None
  312. video_pixels = inputs["data"]
  313. if isinstance(video_pixels, torch.Tensor):
  314. # TODO: support multiple videos per input
  315. b, num_videos, num_frames, c, h, w = video_pixels.shape
  316. assert num_videos == 1
  317. stacked_pixels = video_pixels.view(
  318. b * num_videos * num_frames, c, h, w
  319. )
  320. stacked_embeddings = self._video_pixels_to_features(
  321. self.vision_tower, stacked_pixels
  322. )
  323. return stacked_embeddings.view(
  324. b, num_frames, *stacked_embeddings.shape[1:]
  325. )
  326. elif is_list_of(video_pixels, torch.Tensor):
  327. frames_per_videos = [v.shape[0] for v in video_pixels]
  328. stacked_pixels = torch.cat(video_pixels, dim=0)
  329. stacked_embeddings = self._video_pixels_to_features(
  330. self.vision_tower, stacked_pixels
  331. )
  332. return torch.split(stacked_embeddings, frames_per_videos, dim=0)
  333. else:
  334. raise ValueError(
  335. f"Unsupported type of video input {type(video_pixels)}"
  336. )
  337. def forward(
  338. self,
  339. input_ids: torch.Tensor,
  340. positions: torch.Tensor,
  341. kv_caches: List[torch.Tensor],
  342. attn_metadata: AttentionMetadata,
  343. intermediate_tensors: Optional[IntermediateTensors] = None,
  344. **kwargs: object,
  345. ) -> SamplerOutput:
  346. """Run forward pass for LlaVA-NeXT-Video.
  347. Args:
  348. input_ids: Flattened (concatenated) input_ids corresponding to a
  349. batch.
  350. pixel_values_videos: Pixels in each frames for each input videos.
  351. """
  352. video_input = self._parse_and_validate_video_input(**kwargs)
  353. # merge video embeddings into input embeddings
  354. if video_input is not None:
  355. video_embeddings = self._process_video_pixels(video_input)
  356. inputs_embeds = self.language_model.model.get_input_embeddings(
  357. input_ids
  358. )
  359. inputs_embeds = merge_multimodal_embeddings(
  360. input_ids,
  361. inputs_embeds,
  362. video_embeddings,
  363. self.config.video_token_index,
  364. )
  365. input_ids = None
  366. else:
  367. inputs_embeds = None
  368. hidden_states = self.language_model.model(
  369. input_ids,
  370. positions,
  371. kv_caches,
  372. attn_metadata,
  373. None,
  374. inputs_embeds=inputs_embeds,
  375. )
  376. return hidden_states
  377. def compute_logits(
  378. self,
  379. hidden_states: torch.Tensor,
  380. sampling_metadata: SamplingMetadata,
  381. ) -> Optional[torch.Tensor]:
  382. return self.language_model.compute_logits(
  383. hidden_states, sampling_metadata
  384. )
  385. def sample(
  386. self,
  387. logits: torch.Tensor,
  388. sampling_metadata: SamplingMetadata,
  389. ) -> Optional[SamplerOutput]:
  390. return self.language_model.sample(logits, sampling_metadata)
  391. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  392. # prepare weight iterators for components
  393. weights_group = group_weights_with_prefix(weights)
  394. # load vision encoder
  395. self.vision_tower.load_weights(weights_group["vision_tower"])
  396. # load mlp projector
  397. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  398. for name, loaded_weight in weights_group["multi_modal_projector"]:
  399. param = mlp_params_dict[name]
  400. weight_loader = getattr(param, "weight_loader",
  401. default_weight_loader)
  402. weight_loader(param, loaded_weight)
  403. # load llm backbone
  404. self.language_model.load_weights(weights_group["language_model"])