paligemma.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  2. TypedDict, Union)
  3. import torch
  4. from loguru import logger
  5. from torch import nn
  6. from transformers import PaliGemmaConfig
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import CacheConfig, MultiModalConfig
  9. from aphrodite.common.sequence import IntermediateTensors
  10. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  11. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  12. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  13. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  14. from aphrodite.modeling.models.gemma import GemmaForCausalLM
  15. from aphrodite.modeling.models.gemma2 import Gemma2ForCausalLM
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  18. from aphrodite.multimodal.utils import cached_get_tokenizer
  19. from aphrodite.quantization.base_config import QuantizationConfig
  20. from .interfaces import SupportsMultiModal
  21. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  22. dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
  23. from .utils import group_weights_with_prefix, merge_multimodal_embeddings
  24. class PaliGemmaImagePixelInputs(TypedDict):
  25. type: Literal["pixel_values"]
  26. data: torch.Tensor
  27. """Shape: `(batch_size * num_images, num_channels, height, width)`"""
  28. class PaliGemmaImageEmbeddingInputs(TypedDict):
  29. type: Literal["image_embeds"]
  30. data: torch.Tensor
  31. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  32. `hidden_size` must match the hidden size of language model backbone.
  33. """
  34. PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
  35. PaliGemmaImageEmbeddingInputs]
  36. def get_max_paligemma_image_tokens(ctx: InputContext):
  37. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  38. vision_config = hf_config.vision_config
  39. return get_max_siglip_image_tokens(vision_config)
  40. def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
  41. mm_counts: Mapping[str, int]):
  42. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  43. vision_config = hf_config.vision_config
  44. num_images = mm_counts["image"]
  45. seq_data = dummy_seq_data_for_siglip(
  46. vision_config,
  47. seq_len,
  48. num_images,
  49. image_token_id=hf_config.image_token_index,
  50. )
  51. mm_data = dummy_image_for_siglip(vision_config, num_images)
  52. return seq_data, mm_data
  53. def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
  54. """
  55. The correct prompt format needs to be:
  56. '<image>' * image_feature_size + '<bos>' + prompt + '\n'
  57. See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
  58. """ # noqa
  59. multi_modal_data = llm_inputs.get("multi_modal_data")
  60. if multi_modal_data is None or "image" not in multi_modal_data:
  61. return llm_inputs
  62. model_config = ctx.model_config
  63. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  64. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  65. image_feature_size = hf_config.text_config.num_image_tokens
  66. image_token_str = tokenizer.decode(hf_config.image_token_index)
  67. bos_token = tokenizer.decode(hf_config.bos_token_id)
  68. image_token_str_pad = image_token_str * image_feature_size
  69. image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
  70. orig_prompt = llm_inputs.get("prompt")
  71. orig_prompt_ids = llm_inputs.get("prompt_token_ids")
  72. if orig_prompt is not None and image_token_str in orig_prompt:
  73. logger.warning(
  74. f"The image token '{image_token_str}' was detected in the prompt "
  75. "and will be removed. Please follow the proper prompt format"
  76. " documented on HuggingFace.")
  77. orig_prompt = orig_prompt.replace(image_token_str, "")
  78. orig_prompt_ids.remove(hf_config.image_token_index)
  79. new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
  80. # The PaliGemma 2 tokenizer does not include a starting BOS token
  81. if orig_prompt_ids[0] != hf_config.bos_token_id:
  82. orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
  83. new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
  84. # NOTE: Create a defensive copy of the original inputs
  85. return LLMInputs(prompt_token_ids=new_token_ids,
  86. prompt=new_prompt,
  87. multi_modal_data=multi_modal_data)
  88. class PaliGemmaMultiModalProjector(nn.Module):
  89. def __init__(self, vision_hidden_size: int, projection_dim: int):
  90. super().__init__()
  91. self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
  92. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  93. hidden_states = self.linear(image_features)
  94. return hidden_states
  95. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  96. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
  97. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
  98. @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
  99. class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
  100. def __init__(self,
  101. config: PaliGemmaConfig,
  102. multimodal_config: MultiModalConfig,
  103. cache_config: Optional[CacheConfig] = None,
  104. quant_config: Optional[QuantizationConfig] = None) -> None:
  105. super().__init__()
  106. self.config = config
  107. self.multimodal_config = multimodal_config
  108. self.vision_tower = SiglipVisionModel(config.vision_config)
  109. self.multi_modal_projector = PaliGemmaMultiModalProjector(
  110. vision_hidden_size=config.vision_config.hidden_size,
  111. projection_dim=config.vision_config.projection_dim)
  112. self.quant_config = quant_config
  113. if config.text_config.model_type == "gemma":
  114. self.language_model = GemmaForCausalLM(config.text_config,
  115. cache_config, quant_config)
  116. else:
  117. self.language_model = Gemma2ForCausalLM(config.text_config,
  118. cache_config,
  119. quant_config)
  120. self.unpadded_vocab_size = config.text_config.vocab_size
  121. logit_scale = getattr(config, "logit_scale", 1.0)
  122. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  123. config.text_config.vocab_size,
  124. logit_scale)
  125. self.sampler = Sampler()
  126. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  127. h = w = self.config.vision_config.image_size
  128. expected_dims = (3, h, w)
  129. actual_dims = tuple(data.shape[1:])
  130. if actual_dims != expected_dims:
  131. expected_expr = ("batch_size", *map(str, expected_dims))
  132. raise ValueError(
  133. f"The expected shape of pixel values is {expected_expr}. "
  134. f"You supplied {tuple(data.shape)}.")
  135. return data
  136. def _parse_and_validate_image_input(
  137. self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
  138. pixel_values = kwargs.pop("pixel_values", None)
  139. image_embeds = kwargs.pop("image_embeds", None)
  140. if pixel_values is None and image_embeds is None:
  141. return None
  142. if pixel_values is not None:
  143. if not isinstance(pixel_values, torch.Tensor):
  144. raise ValueError("Incorrect type of pixel values. "
  145. f"Got type: {type(pixel_values)}")
  146. # Remove the N dimension until multiple images are supported.
  147. pixel_values = pixel_values.squeeze(1)
  148. return PaliGemmaImagePixelInputs(
  149. type="pixel_values",
  150. data=self._validate_pixel_values(pixel_values),
  151. )
  152. if image_embeds is not None:
  153. if not isinstance(image_embeds, torch.Tensor):
  154. raise ValueError("Incorrect type of image embeddings. "
  155. f"Got type: {type(image_embeds)}")
  156. # Remove the N dimension until multiple images are supported.
  157. image_embeds = image_embeds.squeeze(1)
  158. return PaliGemmaImageEmbeddingInputs(
  159. type="image_embeds",
  160. data=image_embeds,
  161. )
  162. raise AssertionError("This line should be unreachable.")
  163. def _image_pixels_to_features(
  164. self,
  165. vision_tower: SiglipVisionModel,
  166. pixel_values: torch.Tensor,
  167. ) -> torch.Tensor:
  168. target_dtype = vision_tower.get_input_embeddings().weight.dtype
  169. image_features = vision_tower(pixel_values.to(dtype=target_dtype))
  170. return image_features
  171. def _process_image_input(
  172. self,
  173. image_input: PaliGemmaImageInputs,
  174. ) -> torch.Tensor:
  175. if image_input["type"] == "image_embeds":
  176. return image_input["data"]
  177. assert self.vision_tower is not None
  178. pixel_values = image_input["data"]
  179. image_features = self._image_pixels_to_features(
  180. self.vision_tower,
  181. pixel_values,
  182. )
  183. return self.multi_modal_projector(image_features)
  184. def forward(self,
  185. input_ids: torch.Tensor,
  186. positions: torch.Tensor,
  187. kv_caches: List[torch.Tensor],
  188. attn_metadata: AttentionMetadata,
  189. intermediate_tensors: Optional[IntermediateTensors] = None,
  190. **kwargs: object) -> SamplerOutput:
  191. parsed_image_input = self._parse_and_validate_image_input(**kwargs)
  192. if parsed_image_input is not None:
  193. vision_embeddings = self._process_image_input(parsed_image_input)
  194. # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
  195. vision_embeddings = vision_embeddings * (self.config.hidden_size**
  196. -0.5)
  197. inputs_embeds = self.language_model.model.get_input_embeddings(
  198. input_ids)
  199. inputs_embeds = merge_multimodal_embeddings(
  200. input_ids, inputs_embeds, vision_embeddings,
  201. self.config.image_token_index)
  202. input_ids = None
  203. else:
  204. inputs_embeds = None
  205. hidden_states = self.language_model.model(input_ids,
  206. positions,
  207. kv_caches,
  208. attn_metadata,
  209. None,
  210. inputs_embeds=inputs_embeds)
  211. return hidden_states
  212. def compute_logits(
  213. self,
  214. hidden_states: torch.Tensor,
  215. sampling_metadata: SamplingMetadata,
  216. ) -> Optional[torch.Tensor]:
  217. return self.language_model.compute_logits(hidden_states,
  218. sampling_metadata)
  219. def sample(
  220. self,
  221. logits: torch.Tensor,
  222. sampling_metadata: SamplingMetadata,
  223. ) -> Optional[SamplerOutput]:
  224. return self.language_model.sample(logits, sampling_metadata)
  225. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  226. # prepare weight iterators for components
  227. weights_group = group_weights_with_prefix(weights)
  228. # load vision tower
  229. self.vision_tower.load_weights(weights_group["vision_tower"])
  230. # load mlp projector
  231. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  232. for name, loaded_weight in weights_group["multi_modal_projector"]:
  233. param = mlp_params_dict[name]
  234. weight_loader = getattr(param, "weight_loader",
  235. default_weight_loader)
  236. weight_loader(param, loaded_weight)
  237. # load llm backbone
  238. self.language_model.load_weights(weights_group["language_model"])