123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- import itertools
- import math
- from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
- TypedDict, Union)
- import numpy as np
- import torch
- import torch.nn as nn
- from transformers import (CLIPVisionConfig, LlavaNextVideoConfig,
- SiglipVisionConfig)
- from aphrodite.attention import AttentionMetadata
- from aphrodite.common.config import CacheConfig, MultiModalConfig
- from aphrodite.common.sequence import IntermediateTensors
- from aphrodite.common.utils import is_list_of
- from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
- from aphrodite.modeling.layers.activation import get_act_fn
- from aphrodite.modeling.layers.sampler import SamplerOutput
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.modeling.models.clip import CLIPVisionModel
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.multimodal import MULTIMODAL_REGISTRY
- from aphrodite.multimodal.utils import (cached_get_tokenizer,
- repeat_and_pad_placeholder_tokens)
- from aphrodite.quantization.base_config import QuantizationConfig
- from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
- from .interfaces import SupportsMultiModal
- from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
- dummy_seq_data_for_siglip)
- from .utils import (filter_weights, init_aphrodite_registered_model,
- merge_multimodal_embeddings)
- # For profile run
- _MAX_FRAMES_PER_VIDEO = 32
- _MAX_NUM_VIDEOS = 1
- class LlavaNextVideoPixelInputs(TypedDict):
- type: Literal["pixel_values_videos"]
- data: Union[torch.Tensor, List[torch.Tensor]]
- """
- Shape: `(batch_size, num_frames, num_channels, height, width)`
- Note that `num_frames` may be different for each batch, in which case
- the data is passed as a list instead of a batched tensor.
- Note that it only supports one video input for one batch.
- """
- def get_llava_next_video_frame_feature_size(
- hf_config: LlavaNextVideoConfig
- ) -> int:
- # Support both CLIPVisionConfig and SiglipVisionConfig
- image_size = hf_config.vision_config.image_size
- patch_size = hf_config.vision_config.patch_size
- spatial_pool_stride = hf_config.spatial_pool_stride
- return int((image_size / patch_size / spatial_pool_stride) ** 2)
- def _get_max_llm_tokens(ctx: InputContext) -> int:
- """
- Calculated from the maximum video frames under the context length
- constraints of the language model.
- """
- hf_text_config = ctx.model_config.hf_text_config
- model_config = ctx.model_config
- max_tokens = model_config.max_model_len
- rope_scaling = model_config.rope_scaling
- if rope_scaling:
- rope_scaling_factor = hf_text_config.rope_scaling["factor"]
- else:
- rope_scaling_factor = 1
- max_tokens *= rope_scaling_factor
- return max_tokens
- def get_max_llava_next_video_tokens(ctx: InputContext) -> int:
- # Currently set to 32 frames
- # TODO: max_tokens = _get_max_llm_tokens(ctx)
- hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
- tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
- return _MAX_FRAMES_PER_VIDEO * tokens_per_frame
- def dummy_data_for_llava_next_video(
- ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
- ):
- hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
- vision_config = hf_config.vision_config
- # TODO: support multiple videos
- num_videos = mm_counts["video"]
- if num_videos != _MAX_NUM_VIDEOS:
- raise NotImplementedError(
- f"Only {_MAX_NUM_VIDEOS} videos are supported"
- )
- # TODO: support configuring the number of frames
- frames_per_video = _MAX_FRAMES_PER_VIDEO
- # num_images = num_videos * frames_per_video
- # fills the sequence with as longer video data as possible
- tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config)
- video_feature_size = frames_per_video * tokens_per_frame
- if isinstance(vision_config, CLIPVisionConfig):
- seq_data = dummy_seq_data_for_clip(
- vision_config,
- seq_len,
- num_videos,
- image_token_id=hf_config.video_token_index,
- image_feature_size_override=video_feature_size,
- )
- pil_frame = dummy_image_for_clip(vision_config, num_images=1)
- np_frame = np.array(pil_frame["image"])
- mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
- mm_data = {"video": mm_data_per_video}
- return seq_data, mm_data
- elif isinstance(vision_config, SiglipVisionConfig):
- seq_data = dummy_seq_data_for_siglip(
- vision_config,
- seq_len,
- num_videos,
- image_token_id=hf_config.video_token_index,
- image_feature_size_override=video_feature_size,
- )
- pil_frame = dummy_image_for_siglip(vision_config, num_images=1)
- np_frame = np.array(pil_frame["image"])
- mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0)
- mm_data = {"video": mm_data_per_video}
- return seq_data, mm_data
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
- def input_processor_for_llava_next_video(
- ctx: InputContext, llm_inputs: LLMInputs
- ):
- multi_modal_data = llm_inputs.get("multi_modal_data")
- if multi_modal_data is None or "video" not in multi_modal_data:
- return llm_inputs
- video_data = multi_modal_data["video"]
- model_config = ctx.model_config
- hf_config = ctx.get_hf_config(LlavaNextVideoConfig)
- vision_config = hf_config.vision_config
- if isinstance(video_data, np.ndarray):
- # Supports both CLIP and Siglip
- num_frames = video_data.shape[0]
- frame_feature_size = get_llava_next_video_frame_feature_size(hf_config)
- video_feature_size = num_frames * frame_feature_size
- tokenizer = cached_get_tokenizer(model_config.tokenizer)
- new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
- tokenizer,
- llm_inputs.get("prompt"),
- llm_inputs["prompt_token_ids"],
- placeholder_token_id=hf_config.video_token_index,
- repeat_count=video_feature_size,
- )
- return LLMInputs(
- prompt_token_ids=new_token_ids,
- prompt=new_prompt,
- multi_modal_data=multi_modal_data,
- )
- elif is_list_of(video_data, np.ndarray):
- raise NotImplementedError("Processing multiple videos is not supported")
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
- def _init_vision_tower(hf_config: LlavaNextVideoConfig):
- vision_config = hf_config.vision_config
- # Initialize the vision tower only up to the required feature layer
- vision_feature_layer = hf_config.vision_feature_layer
- if vision_feature_layer < 0:
- num_hidden_layers = (
- hf_config.vision_config.num_hidden_layers + vision_feature_layer + 1
- )
- else:
- num_hidden_layers = vision_feature_layer + 1
- if isinstance(vision_config, CLIPVisionConfig):
- return CLIPVisionModel(
- vision_config,
- num_hidden_layers_override=num_hidden_layers,
- )
- elif isinstance(vision_config, SiglipVisionConfig):
- return SiglipVisionModel(
- vision_config,
- num_hidden_layers_override=num_hidden_layers,
- )
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
- # adopted from transformers modeling_llava_next_video.py
- class LlavaNextVideoPooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- mode = config.spatial_pool_mode
- stride = config.spatial_pool_stride
- image_size = config.vision_config.image_size
- patch_size = config.vision_config.patch_size
- self.image_size = image_size // patch_size**2
- if mode == "average":
- self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
- elif mode == "max":
- self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
- else:
- # TODO: Support Conv2d pooling layer, need to load weights
- raise ValueError(
- f"Unknown pooling mode: {mode}. Expected [`average`, `max`]"
- )
- def forward(self, image_features):
- ori_width = int(
- math.sqrt(
- image_features.shape[1] * self.image_size // self.image_size
- )
- )
- ori_height = int(ori_width * self.image_size // self.image_size)
- batch_size, _, dim = image_features.shape
- image_features_spatial = image_features.view(
- batch_size, ori_height, ori_height, dim
- ).permute(0, 3, 1, 2)
- image_features_spatial = self.pool(image_features_spatial)
- return image_features_spatial.flatten(2).transpose(1, 2).contiguous()
- class LlavaNextMultiModalProjector(nn.Module):
- def __init__(
- self,
- vision_hidden_size: int,
- text_hidden_size: int,
- projector_hidden_act: str,
- ):
- super().__init__()
- self.linear_1 = nn.Linear(
- vision_hidden_size, text_hidden_size, bias=True
- )
- self.act = get_act_fn(projector_hidden_act)
- self.linear_2 = nn.Linear(text_hidden_size, text_hidden_size, bias=True)
- def forward(self, image_features: torch.Tensor) -> torch.Tensor:
- hidden_states = self.linear_1(image_features)
- hidden_states = self.act(hidden_states)
- hidden_states = self.linear_2(hidden_states)
- return hidden_states
- @MULTIMODAL_REGISTRY.register_input_mapper("video")
- @MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
- "video", get_max_llava_next_video_tokens
- )
- @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
- @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
- class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
- def __init__(
- self,
- config: LlavaNextVideoConfig,
- multimodal_config: MultiModalConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.config = config
- self.multimodal_config = multimodal_config
- # Initialize the vision tower only up to the required feature layer
- self.vision_tower = _init_vision_tower(config)
- self.multi_modal_projector = LlavaNextMultiModalProjector(
- vision_hidden_size=config.vision_config.hidden_size,
- text_hidden_size=config.text_config.hidden_size,
- projector_hidden_act=config.projector_hidden_act,
- )
- self.language_model = init_aphrodite_registered_model(
- config.text_config, cache_config, quant_config
- )
- self.vision_resampler = LlavaNextVideoPooler(config)
- def _validate_video_pixel_values(
- self, data: Union[torch.Tensor, List[torch.Tensor]]
- ) -> Union[torch.Tensor, List[torch.Tensor]]:
- h = w = self.config.vision_config.image_size
- expected_dims = (3, h, w)
- def _validate_shape(d: torch.Tensor):
- actual_dims = tuple(d.shape[2:])
- if actual_dims != expected_dims:
- expected_expr = ("num_frames", *map(str, expected_dims))
- raise ValueError(
- "The expected shape of pixel values in each video frame "
- f"is {expected_expr}. You supplied {tuple(d.shape)}."
- )
- for d in data:
- _validate_shape(d)
- return data
- def _parse_and_validate_video_input(
- self, **kwargs: object
- ) -> Optional[LlavaNextVideoPixelInputs]:
- """
- A legal video input should have the following dimensions:
- {
- "pixel_values_videos" :
- List[b, Tensor(nb_frames, nb_channels, height, width)]
- }
- """
- pixel_values = kwargs.pop("pixel_values_videos", None)
- if pixel_values is None:
- return None
- if not (
- is_list_of(pixel_values, (torch.Tensor)) # different shape videos
- or isinstance(pixel_values, torch.Tensor)
- ): # same shape videos
- raise ValueError(
- "Incorrect type of pixel values. "
- f"Got type: {type(pixel_values)}"
- )
- return LlavaNextVideoPixelInputs(
- type="pixel_values_videos",
- data=pixel_values,
- )
- def _select_image_features(
- self, image_features: torch.Tensor, *, strategy: str
- ) -> torch.Tensor:
- if strategy == "default":
- return image_features[:, 1:]
- elif strategy == "full":
- return image_features
- raise ValueError(f"Unexpected select feature strategy: {strategy}")
- def _video_pixels_to_features(
- self,
- vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
- pixel_values: torch.Tensor,
- ) -> torch.Tensor:
- # NOTE: we skip the step to select the vision feature layer since
- # this is already done inside the vision tower
- image_features = vision_tower(pixel_values)
- image_features = self._select_image_features(
- image_features,
- strategy=self.config.vision_feature_select_strategy,
- )
- image_features = self.vision_resampler(image_features)
- image_features = self.multi_modal_projector(image_features)
- return image_features
- def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs):
- assert self.vision_tower is not None
- video_pixels = inputs["data"]
- if isinstance(video_pixels, torch.Tensor):
- # TODO: support multiple videos per input
- b, num_videos, num_frames, c, h, w = video_pixels.shape
- assert num_videos == 1
- stacked_pixels = video_pixels.view(
- b * num_videos * num_frames, c, h, w
- )
- stacked_embeddings = self._video_pixels_to_features(
- self.vision_tower, stacked_pixels
- )
- return stacked_embeddings.view(
- b, num_frames, *stacked_embeddings.shape[1:]
- )
- elif is_list_of(video_pixels, torch.Tensor):
- frames_per_videos = [v.shape[0] for v in video_pixels]
- stacked_pixels = torch.cat(video_pixels, dim=0)
- stacked_embeddings = self._video_pixels_to_features(
- self.vision_tower, stacked_pixels
- )
- return torch.split(stacked_embeddings, frames_per_videos, dim=0)
- else:
- raise ValueError(
- f"Unsupported type of video input {type(video_pixels)}"
- )
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- **kwargs: object,
- ) -> SamplerOutput:
- """Run forward pass for LlaVA-NeXT-Video.
- Args:
- input_ids: Flattened (concatenated) input_ids corresponding to a
- batch.
- pixel_values_videos: Pixels in each frames for each input videos.
- """
- video_input = self._parse_and_validate_video_input(**kwargs)
- # merge video embeddings into input embeddings
- if video_input is not None:
- video_embeddings = self._process_video_pixels(video_input)
- inputs_embeds = self.language_model.model.get_input_embeddings(
- input_ids
- )
- inputs_embeds = merge_multimodal_embeddings(
- input_ids,
- inputs_embeds,
- video_embeddings,
- self.config.video_token_index,
- )
- input_ids = None
- else:
- inputs_embeds = None
- hidden_states = self.language_model.model(
- input_ids,
- positions,
- kv_caches,
- attn_metadata,
- None,
- inputs_embeds=inputs_embeds,
- )
- return hidden_states
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[torch.Tensor]:
- return self.language_model.compute_logits(
- hidden_states, sampling_metadata
- )
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- return self.language_model.sample(logits, sampling_metadata)
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- # prepare weight iterators
- vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
- weights, 4
- )
- # load vision encoder
- vit_weights = filter_weights(vit_weights, "vision_tower")
- self.vision_tower.load_weights(vit_weights)
- # load mlp projector
- mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
- mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
- for name, loaded_weight in mlp_weights:
- param = mlp_params_dict[name]
- weight_loader = getattr(
- param, "weight_loader", default_weight_loader
- )
- weight_loader(param, loaded_weight)
- # load llm backbone
- llm_weights = filter_weights(llm_weights, "language_model")
- self.language_model.load_weights(llm_weights)
|