1
0

paligemma.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. import itertools
  2. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  3. TypedDict, Union)
  4. import torch
  5. from loguru import logger
  6. from torch import nn
  7. from transformers import PaliGemmaConfig
  8. from aphrodite.attention import AttentionMetadata
  9. from aphrodite.common.config import CacheConfig, MultiModalConfig
  10. from aphrodite.common.sequence import IntermediateTensors
  11. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  12. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  13. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  14. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  15. from aphrodite.modeling.models.gemma import GemmaForCausalLM
  16. from aphrodite.modeling.models.gemma2 import Gemma2ForCausalLM
  17. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  18. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  19. from aphrodite.multimodal.utils import cached_get_tokenizer
  20. from aphrodite.quantization.base_config import QuantizationConfig
  21. from .interfaces import SupportsMultiModal
  22. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  23. dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
  24. from .utils import filter_weights, merge_multimodal_embeddings
  25. class PaliGemmaImagePixelInputs(TypedDict):
  26. type: Literal["pixel_values"]
  27. data: torch.Tensor
  28. """Shape: `(batch_size * num_images, num_channels, height, width)`"""
  29. class PaliGemmaImageEmbeddingInputs(TypedDict):
  30. type: Literal["image_embeds"]
  31. data: torch.Tensor
  32. """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
  33. `hidden_size` must match the hidden size of language model backbone.
  34. """
  35. PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
  36. PaliGemmaImageEmbeddingInputs]
  37. def get_max_paligemma_image_tokens(ctx: InputContext):
  38. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  39. vision_config = hf_config.vision_config
  40. return get_max_siglip_image_tokens(vision_config)
  41. def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
  42. mm_counts: Mapping[str, int]):
  43. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  44. vision_config = hf_config.vision_config
  45. num_images = mm_counts["image"]
  46. seq_data = dummy_seq_data_for_siglip(
  47. vision_config,
  48. seq_len,
  49. num_images,
  50. image_token_id=hf_config.image_token_index,
  51. )
  52. mm_data = dummy_image_for_siglip(vision_config, num_images)
  53. return seq_data, mm_data
  54. def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
  55. """
  56. The correct prompt format needs to be:
  57. '<image>' * image_feature_size + '<bos>' + prompt + '\n'
  58. See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
  59. """ # noqa
  60. multi_modal_data = llm_inputs.get("multi_modal_data")
  61. if multi_modal_data is None or "image" not in multi_modal_data:
  62. return llm_inputs
  63. model_config = ctx.model_config
  64. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  65. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  66. image_feature_size = hf_config.text_config.num_image_tokens
  67. image_token_str = tokenizer.decode(hf_config.image_token_index)
  68. bos_token = tokenizer.decode(hf_config.bos_token_id)
  69. image_token_str_pad = image_token_str * image_feature_size
  70. image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
  71. orig_prompt = llm_inputs.get("prompt")
  72. orig_prompt_ids = llm_inputs.get("prompt_token_ids")
  73. if orig_prompt is not None and image_token_str in orig_prompt:
  74. logger.warning(
  75. f"The image token '{image_token_str}' was detected in the prompt "
  76. "and will be removed. Please follow the proper prompt format"
  77. " documented on HuggingFace.")
  78. orig_prompt = orig_prompt.replace(image_token_str, "")
  79. orig_prompt_ids.remove(hf_config.image_token_index)
  80. new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
  81. # The PaliGemma 2 tokenizer does not include a starting BOS token
  82. if orig_prompt_ids[0] != hf_config.bos_token_id:
  83. orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
  84. new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
  85. # NOTE: Create a defensive copy of the original inputs
  86. return LLMInputs(prompt_token_ids=new_token_ids,
  87. prompt=new_prompt,
  88. multi_modal_data=multi_modal_data)
  89. class PaliGemmaMultiModalProjector(nn.Module):
  90. def __init__(self, vision_hidden_size: int, projection_dim: int):
  91. super().__init__()
  92. self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
  93. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  94. hidden_states = self.linear(image_features)
  95. return hidden_states
  96. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  97. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
  98. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
  99. @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
  100. class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
  101. def __init__(self,
  102. config: PaliGemmaConfig,
  103. multimodal_config: MultiModalConfig,
  104. cache_config: Optional[CacheConfig] = None,
  105. quant_config: Optional[QuantizationConfig] = None) -> None:
  106. super().__init__()
  107. self.config = config
  108. self.multimodal_config = multimodal_config
  109. self.vision_tower = SiglipVisionModel(config.vision_config)
  110. self.multi_modal_projector = PaliGemmaMultiModalProjector(
  111. vision_hidden_size=config.vision_config.hidden_size,
  112. projection_dim=config.vision_config.projection_dim)
  113. self.quant_config = quant_config
  114. if config.text_config.model_type == "gemma":
  115. self.language_model = GemmaForCausalLM(config.text_config,
  116. cache_config, quant_config)
  117. else:
  118. self.language_model = Gemma2ForCausalLM(config.text_config,
  119. cache_config,
  120. quant_config)
  121. self.unpadded_vocab_size = config.text_config.vocab_size
  122. logit_scale = getattr(config, "logit_scale", 1.0)
  123. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  124. config.text_config.vocab_size,
  125. logit_scale)
  126. self.sampler = Sampler()
  127. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  128. h = w = self.config.vision_config.image_size
  129. expected_dims = (3, h, w)
  130. actual_dims = tuple(data.shape[1:])
  131. if actual_dims != expected_dims:
  132. expected_expr = ("batch_size", *map(str, expected_dims))
  133. raise ValueError(
  134. f"The expected shape of pixel values is {expected_expr}. "
  135. f"You supplied {tuple(data.shape)}.")
  136. return data
  137. def _parse_and_validate_image_input(
  138. self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
  139. pixel_values = kwargs.pop("pixel_values", None)
  140. image_embeds = kwargs.pop("image_embeds", None)
  141. if pixel_values is None and image_embeds is None:
  142. return None
  143. if pixel_values is not None:
  144. if not isinstance(pixel_values, torch.Tensor):
  145. raise ValueError("Incorrect type of pixel values. "
  146. f"Got type: {type(pixel_values)}")
  147. # Remove the N dimension until multiple images are supported.
  148. pixel_values = pixel_values.squeeze(1)
  149. return PaliGemmaImagePixelInputs(
  150. type="pixel_values",
  151. data=self._validate_pixel_values(pixel_values),
  152. )
  153. if image_embeds is not None:
  154. if not isinstance(image_embeds, torch.Tensor):
  155. raise ValueError("Incorrect type of image embeddings. "
  156. f"Got type: {type(image_embeds)}")
  157. # Remove the N dimension until multiple images are supported.
  158. image_embeds = image_embeds.squeeze(1)
  159. return PaliGemmaImageEmbeddingInputs(
  160. type="image_embeds",
  161. data=image_embeds,
  162. )
  163. raise AssertionError("This line should be unreachable.")
  164. def _image_pixels_to_features(
  165. self,
  166. vision_tower: SiglipVisionModel,
  167. pixel_values: torch.Tensor,
  168. ) -> torch.Tensor:
  169. target_dtype = vision_tower.get_input_embeddings().weight.dtype
  170. image_features = vision_tower(pixel_values.to(dtype=target_dtype))
  171. return image_features
  172. def _process_image_input(
  173. self,
  174. image_input: PaliGemmaImageInputs,
  175. ) -> torch.Tensor:
  176. if image_input["type"] == "image_embeds":
  177. return image_input["data"]
  178. assert self.vision_tower is not None
  179. pixel_values = image_input["data"]
  180. image_features = self._image_pixels_to_features(
  181. self.vision_tower,
  182. pixel_values,
  183. )
  184. return self.multi_modal_projector(image_features)
  185. def forward(self,
  186. input_ids: torch.Tensor,
  187. positions: torch.Tensor,
  188. kv_caches: List[torch.Tensor],
  189. attn_metadata: AttentionMetadata,
  190. intermediate_tensors: Optional[IntermediateTensors] = None,
  191. **kwargs: object) -> SamplerOutput:
  192. parsed_image_input = self._parse_and_validate_image_input(**kwargs)
  193. if parsed_image_input is not None:
  194. vision_embeddings = self._process_image_input(parsed_image_input)
  195. # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
  196. vision_embeddings = vision_embeddings * (self.config.hidden_size**
  197. -0.5)
  198. inputs_embeds = self.language_model.model.get_input_embeddings(
  199. input_ids)
  200. inputs_embeds = merge_multimodal_embeddings(
  201. input_ids, inputs_embeds, vision_embeddings,
  202. self.config.image_token_index)
  203. input_ids = None
  204. else:
  205. inputs_embeds = None
  206. hidden_states = self.language_model.model(input_ids,
  207. positions,
  208. kv_caches,
  209. attn_metadata,
  210. None,
  211. inputs_embeds=inputs_embeds)
  212. return hidden_states
  213. def compute_logits(
  214. self,
  215. hidden_states: torch.Tensor,
  216. sampling_metadata: SamplingMetadata,
  217. ) -> Optional[torch.Tensor]:
  218. return self.language_model.compute_logits(hidden_states,
  219. sampling_metadata)
  220. def sample(
  221. self,
  222. logits: torch.Tensor,
  223. sampling_metadata: SamplingMetadata,
  224. ) -> Optional[SamplerOutput]:
  225. return self.language_model.sample(logits, sampling_metadata)
  226. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  227. # prepare weight iterators for components
  228. vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
  229. # load vision tower
  230. vit_weights = filter_weights(vit_weights, "vision_tower")
  231. self.vision_tower.load_weights(vit_weights)
  232. # load mlp projector
  233. mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
  234. mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
  235. for name, loaded_weight in mlp_weights:
  236. param = mlp_params_dict[name]
  237. weight_loader = getattr(param, "weight_loader",
  238. default_weight_loader)
  239. weight_loader(param, loaded_weight)
  240. # load llm backbone
  241. llm_weights = filter_weights(llm_weights, "language_model")
  242. self.language_model.load_weights(llm_weights)