paligemma.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  2. TypedDict, Union)
  3. import torch
  4. from loguru import logger
  5. from torch import nn
  6. from transformers import PaliGemmaConfig
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import CacheConfig, MultiModalConfig
  9. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  10. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  11. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  12. from aphrodite.modeling.layers.sampler import Sampler
  13. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  14. from aphrodite.modeling.models.gemma import GemmaModel
  15. from aphrodite.modeling.models.gemma2 import Gemma2Model
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  18. from aphrodite.multimodal.utils import cached_get_tokenizer
  19. from aphrodite.quantization.base_config import QuantizationConfig
  20. from .interfaces import SupportsMultiModal
  21. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  22. dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
  23. from .utils import merge_multimodal_embeddings
  24. _KEYS_TO_MODIFY_MAPPING = {
  25. "language_model.model": "language_model",
  26. }
  27. class PaliGemmaImagePixelInputs(TypedDict):
  28. type: Literal["pixel_values"]
  29. data: torch.Tensor
  30. """Shape: (batch_size, num_channels, height, width)"""
  31. class PaliGemmaImageEmbeddingInputs(TypedDict):
  32. type: Literal["image_embeds"]
  33. data: torch.Tensor
  34. """Shape: `(batch_size, image_feature_size, hidden_size)`
  35. `hidden_size` must match the hidden size of language model backbone.
  36. """
  37. PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
  38. PaliGemmaImageEmbeddingInputs]
  39. def get_max_paligemma_image_tokens(ctx: InputContext):
  40. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  41. vision_config = hf_config.vision_config
  42. return get_max_siglip_image_tokens(vision_config)
  43. def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
  44. mm_counts: Mapping[str, int]):
  45. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  46. vision_config = hf_config.vision_config
  47. num_images = mm_counts["image"]
  48. seq_data = dummy_seq_data_for_siglip(
  49. vision_config,
  50. seq_len,
  51. num_images,
  52. image_token_id=hf_config.image_token_index,
  53. )
  54. mm_data = dummy_image_for_siglip(vision_config, num_images)
  55. return seq_data, mm_data
  56. def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
  57. """
  58. The correct prompt format needs to be:
  59. '<image>' * image_feature_size + '<bos>' + prompt + '\n'
  60. See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
  61. """ # noqa
  62. multi_modal_data = llm_inputs.get("multi_modal_data")
  63. if multi_modal_data is None or "image" not in multi_modal_data:
  64. return llm_inputs
  65. model_config = ctx.model_config
  66. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  67. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  68. image_feature_size = hf_config.text_config.num_image_tokens
  69. image_token_str = tokenizer.decode(hf_config.image_token_index)
  70. bos_token = tokenizer.decode(hf_config.bos_token_id)
  71. image_token_str_pad = image_token_str * image_feature_size
  72. image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
  73. orig_prompt = llm_inputs.get("prompt")
  74. orig_prompt_ids = llm_inputs.get("prompt_token_ids")
  75. if orig_prompt is not None and image_token_str in orig_prompt:
  76. logger.warning(
  77. f"The image token '{image_token_str}' was detected in the prompt "
  78. "and will be removed. Please follow the proper prompt format"
  79. " documented on HuggingFace.")
  80. orig_prompt = orig_prompt.replace(image_token_str, "")
  81. orig_prompt_ids.remove(hf_config.image_token_index)
  82. new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
  83. # The PaliGemma 2 tokenizer does not include a starting BOS token
  84. if orig_prompt_ids[0] != hf_config.bos_token_id:
  85. orig_prompt_ids = [hf_config.bos_token_id] + orig_prompt_ids
  86. new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
  87. # NOTE: Create a defensive copy of the original inputs
  88. return LLMInputs(prompt_token_ids=new_token_ids,
  89. prompt=new_prompt,
  90. multi_modal_data=multi_modal_data)
  91. class PaliGemmaMultiModalProjector(nn.Module):
  92. def __init__(self, vision_hidden_size: int, projection_dim: int):
  93. super().__init__()
  94. self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
  95. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  96. hidden_states = self.linear(image_features)
  97. return hidden_states
  98. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  99. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
  100. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
  101. @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
  102. class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
  103. def __init__(self,
  104. config: PaliGemmaConfig,
  105. multimodal_config: MultiModalConfig,
  106. cache_config: Optional[CacheConfig] = None,
  107. quant_config: Optional[QuantizationConfig] = None) -> None:
  108. super().__init__()
  109. self.config = config
  110. self.multimodal_config = multimodal_config
  111. # TODO: Port over SiglipVisionModel & TP
  112. self.vision_tower = SiglipVisionModel(config.vision_config)
  113. self.multi_modal_projector = PaliGemmaMultiModalProjector(
  114. vision_hidden_size=config.vision_config.hidden_size,
  115. projection_dim=config.vision_config.projection_dim)
  116. self.quant_config = quant_config
  117. if config.text_config.model_type == "gemma":
  118. self.language_model = GemmaModel(config.text_config, cache_config,
  119. quant_config)
  120. else:
  121. self.language_model = Gemma2Model(config.text_config, cache_config,
  122. quant_config)
  123. self.unpadded_vocab_size = config.text_config.vocab_size
  124. logit_scale = getattr(config, "logit_scale", 1.0)
  125. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  126. config.text_config.vocab_size,
  127. logit_scale)
  128. self.sampler = Sampler()
  129. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  130. h = w = self.config.vision_config.image_size
  131. expected_dims = (3, h, w)
  132. actual_dims = tuple(data.shape[1:])
  133. if actual_dims != expected_dims:
  134. expected_expr = ("batch_size", *map(str, expected_dims))
  135. raise ValueError(
  136. f"The expected shape of pixel values is {expected_expr}. "
  137. f"You supplied {tuple(data.shape)}.")
  138. return data
  139. def _parse_and_validate_image_input(
  140. self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
  141. pixel_values = kwargs.pop("pixel_values", None)
  142. image_embeds = kwargs.pop("image_embeds", None)
  143. if pixel_values is None and image_embeds is None:
  144. return None
  145. if pixel_values is not None:
  146. if not isinstance(pixel_values, torch.Tensor):
  147. raise ValueError("Incorrect type of pixel values. "
  148. f"Got type: {type(pixel_values)}")
  149. return PaliGemmaImagePixelInputs(
  150. type="pixel_values",
  151. data=self._validate_pixel_values(pixel_values),
  152. )
  153. if image_embeds is not None:
  154. if not isinstance(image_embeds, torch.Tensor):
  155. raise ValueError("Incorrect type of image embeddings. "
  156. f"Got type: {type(image_embeds)}")
  157. return PaliGemmaImageEmbeddingInputs(
  158. type="image_embeds",
  159. data=image_embeds,
  160. )
  161. raise AssertionError("This line should be unreachable.")
  162. def _image_pixels_to_features(
  163. self,
  164. vision_tower: SiglipVisionModel,
  165. pixel_values: torch.Tensor,
  166. ) -> torch.Tensor:
  167. target_dtype = vision_tower.get_input_embeddings().weight.dtype
  168. image_features = vision_tower(pixel_values.to(dtype=target_dtype))
  169. return image_features
  170. def _process_image_input(
  171. self,
  172. image_input: PaliGemmaImageInputs,
  173. ) -> torch.Tensor:
  174. if image_input["type"] == "image_embeds":
  175. return image_input["data"]
  176. assert self.vision_tower is not None
  177. pixel_values = image_input["data"]
  178. image_features = self._image_pixels_to_features(
  179. self.vision_tower,
  180. pixel_values,
  181. )
  182. return self.multi_modal_projector(image_features)
  183. def forward(self,
  184. input_ids: torch.Tensor,
  185. positions: torch.Tensor,
  186. kv_caches: List[torch.Tensor],
  187. attn_metadata: AttentionMetadata,
  188. intermediate_tensors: Optional[IntermediateTensors] = None,
  189. **kwargs: object) -> SamplerOutput:
  190. parsed_image_input = self._parse_and_validate_image_input(**kwargs)
  191. if parsed_image_input is not None:
  192. vision_embeddings = self._process_image_input(parsed_image_input)
  193. # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
  194. vision_embeddings = vision_embeddings * (self.config.hidden_size**
  195. -0.5)
  196. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  197. inputs_embeds = merge_multimodal_embeddings(
  198. input_ids, inputs_embeds, vision_embeddings,
  199. self.config.image_token_index)
  200. input_ids = None
  201. else:
  202. inputs_embeds = None
  203. hidden_states = self.language_model(input_ids,
  204. positions,
  205. kv_caches,
  206. attn_metadata,
  207. None,
  208. inputs_embeds=inputs_embeds)
  209. return hidden_states
  210. # Copied from vllm/modeling/models/gemma.py
  211. def compute_logits(
  212. self,
  213. hidden_states: torch.Tensor,
  214. sampling_metadata: SamplingMetadata,
  215. ) -> Optional[torch.Tensor]:
  216. logits = self.logits_processor(self.language_model.embed_tokens,
  217. hidden_states, sampling_metadata)
  218. return logits
  219. # Copied from vllm/modeling/models/gemma.py
  220. def sample(
  221. self,
  222. logits: torch.Tensor,
  223. sampling_metadata: SamplingMetadata,
  224. ) -> Optional[SamplerOutput]:
  225. next_tokens = self.sampler(logits, sampling_metadata)
  226. return next_tokens
  227. # Adapted from vllm/modeling/models/gemma.py
  228. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  229. stacked_params_mapping = [
  230. # (param_name, shard_name, shard_id)
  231. ("qkv_proj", "q_proj", "q"),
  232. ("qkv_proj", "k_proj", "k"),
  233. ("qkv_proj", "v_proj", "v"),
  234. ("gate_up_proj", "gate_proj", 0),
  235. ("gate_up_proj", "up_proj", 1),
  236. ]
  237. params_dict = dict(self.named_parameters())
  238. loaded_params = set()
  239. for name, loaded_weight in weights:
  240. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  241. if key_to_modify in name:
  242. name = name.replace(key_to_modify, new_key)
  243. use_default_weight_loading = False
  244. if "vision" in name:
  245. if self.vision_tower is not None:
  246. # We only do sharding for language model and
  247. # not vision model for now.
  248. use_default_weight_loading = True
  249. else:
  250. for (param_name, shard_name,
  251. shard_id) in stacked_params_mapping:
  252. if shard_name not in name:
  253. continue
  254. name = name.replace(shard_name, param_name)
  255. # Skip loading extra bias for GPTQ models.
  256. if name.endswith(".bias") and name not in params_dict:
  257. continue
  258. param = params_dict[name]
  259. weight_loader = param.weight_loader
  260. weight_loader(param, loaded_weight, shard_id)
  261. break
  262. else:
  263. # lm_head is not used in vllm as it is tied with
  264. # embed_token. To prevent errors, skip loading
  265. # lm_head.weight.
  266. if "lm_head.weight" in name:
  267. continue
  268. # Skip loading extra bias for GPTQ models.
  269. if name.endswith(".bias") and name not in params_dict:
  270. continue
  271. use_default_weight_loading = True
  272. if use_default_weight_loading:
  273. param = params_dict[name]
  274. weight_loader = getattr(param, "weight_loader",
  275. default_weight_loader)
  276. weight_loader(param, loaded_weight)
  277. loaded_params.add(name)
  278. unloaded_params = params_dict.keys() - loaded_params
  279. if unloaded_params:
  280. logger.warning(
  281. "Some weights are not initialized from checkpoints: "
  282. f"{unloaded_params}")