paligemma.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  2. TypedDict, Union)
  3. import torch
  4. from loguru import logger
  5. from torch import nn
  6. from transformers import PaliGemmaConfig
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import CacheConfig, MultiModalConfig
  9. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  10. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  11. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  12. from aphrodite.modeling.layers.sampler import Sampler
  13. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  14. from aphrodite.modeling.models.gemma import GemmaModel
  15. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  16. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  17. from aphrodite.multimodal.image import cached_get_tokenizer
  18. from aphrodite.quantization.base_config import QuantizationConfig
  19. from .interfaces import SupportsMultiModal
  20. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  21. dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
  22. from .utils import merge_multimodal_embeddings
  23. _KEYS_TO_MODIFY_MAPPING = {
  24. "language_model.model": "language_model",
  25. }
  26. class PaliGemmaImagePixelInputs(TypedDict):
  27. type: Literal["pixel_values"]
  28. data: torch.Tensor
  29. """Shape: (batch_size, num_channels, height, width)"""
  30. class PaliGemmaImageEmbeddingInputs(TypedDict):
  31. type: Literal["image_embeds"]
  32. data: torch.Tensor
  33. """Shape: `(batch_size, image_feature_size, hidden_size)`
  34. `hidden_size` must match the hidden size of language model backbone.
  35. """
  36. PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
  37. PaliGemmaImageEmbeddingInputs]
  38. def get_max_paligemma_image_tokens(ctx: InputContext):
  39. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  40. vision_config = hf_config.vision_config
  41. return get_max_siglip_image_tokens(vision_config)
  42. def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
  43. mm_counts: Mapping[str, int]):
  44. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  45. vision_config = hf_config.vision_config
  46. num_images = mm_counts["image"]
  47. seq_data = dummy_seq_data_for_siglip(
  48. vision_config,
  49. seq_len,
  50. num_images,
  51. image_token_id=hf_config.image_token_index,
  52. )
  53. mm_data = dummy_image_for_siglip(vision_config, num_images)
  54. return seq_data, mm_data
  55. def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
  56. """
  57. The correct prompt format needs to be:
  58. '<image>' * image_feature_size + '<bos>' + prompt + '\n'
  59. See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
  60. """ # noqa
  61. multi_modal_data = llm_inputs.get("multi_modal_data")
  62. if multi_modal_data is None or "image" not in multi_modal_data:
  63. return llm_inputs
  64. model_config = ctx.model_config
  65. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  66. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  67. image_feature_size = hf_config.text_config.num_image_tokens
  68. image_token_str = tokenizer.decode(hf_config.image_token_index)
  69. bos_token = tokenizer.decode(hf_config.bos_token_id)
  70. image_token_str_pad = image_token_str * image_feature_size
  71. image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
  72. orig_prompt = llm_inputs.get("prompt")
  73. orig_prompt_ids = llm_inputs.get("prompt_token_ids")
  74. if orig_prompt is not None and image_token_str in orig_prompt:
  75. logger.warning(
  76. f"The image token '{image_token_str}' was detected in the prompt "
  77. "and will be removed. Please follow the proper prompt format"
  78. " documented on HuggingFace.")
  79. orig_prompt = orig_prompt.replace(image_token_str, "")
  80. orig_prompt_ids.remove(hf_config.image_token_index)
  81. new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
  82. new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
  83. # NOTE: Create a defensive copy of the original inputs
  84. return LLMInputs(prompt_token_ids=new_token_ids,
  85. prompt=new_prompt,
  86. multi_modal_data=multi_modal_data)
  87. class PaliGemmaMultiModalProjector(nn.Module):
  88. def __init__(self, vision_hidden_size: int, projection_dim: int):
  89. super().__init__()
  90. self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
  91. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  92. hidden_states = self.linear(image_features)
  93. return hidden_states
  94. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  95. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
  96. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
  97. @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
  98. class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
  99. def __init__(self,
  100. config: PaliGemmaConfig,
  101. multimodal_config: MultiModalConfig,
  102. cache_config: Optional[CacheConfig] = None,
  103. quant_config: Optional[QuantizationConfig] = None) -> None:
  104. super().__init__()
  105. self.config = config
  106. self.multimodal_config = multimodal_config
  107. # TODO: Port over SiglipVisionModel & TP
  108. self.vision_tower = SiglipVisionModel(config.vision_config)
  109. self.multi_modal_projector = PaliGemmaMultiModalProjector(
  110. vision_hidden_size=config.vision_config.hidden_size,
  111. projection_dim=config.vision_config.projection_dim)
  112. self.quant_config = quant_config
  113. self.language_model = GemmaModel(config.text_config, cache_config,
  114. quant_config)
  115. self.unpadded_vocab_size = config.text_config.vocab_size
  116. logit_scale = getattr(config, "logit_scale", 1.0)
  117. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  118. config.vocab_size, logit_scale)
  119. self.sampler = Sampler()
  120. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  121. h = w = self.config.vision_config.image_size
  122. expected_dims = (3, h, w)
  123. actual_dims = tuple(data.shape[1:])
  124. if actual_dims != expected_dims:
  125. expected_expr = ("batch_size", *map(str, expected_dims))
  126. raise ValueError(
  127. f"The expected shape of pixel values is {expected_expr}. "
  128. f"You supplied {tuple(data.shape)}.")
  129. return data
  130. def _parse_and_validate_image_input(
  131. self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
  132. pixel_values = kwargs.pop("pixel_values", None)
  133. image_embeds = kwargs.pop("image_embeds", None)
  134. if pixel_values is None and image_embeds is None:
  135. return None
  136. if pixel_values is not None:
  137. if not isinstance(pixel_values, torch.Tensor):
  138. raise ValueError("Incorrect type of pixel values. "
  139. f"Got type: {type(pixel_values)}")
  140. return PaliGemmaImagePixelInputs(
  141. type="pixel_values",
  142. data=self._validate_pixel_values(pixel_values),
  143. )
  144. if image_embeds is not None:
  145. if not isinstance(image_embeds, torch.Tensor):
  146. raise ValueError("Incorrect type of image embeddings. "
  147. f"Got type: {type(image_embeds)}")
  148. return PaliGemmaImageEmbeddingInputs(
  149. type="image_embeds",
  150. data=image_embeds,
  151. )
  152. raise AssertionError("This line should be unreachable.")
  153. def _image_pixels_to_features(
  154. self,
  155. vision_tower: SiglipVisionModel,
  156. pixel_values: torch.Tensor,
  157. ) -> torch.Tensor:
  158. target_dtype = vision_tower.get_input_embeddings().weight.dtype
  159. image_features = vision_tower(pixel_values.to(dtype=target_dtype))
  160. return image_features
  161. def _process_image_input(
  162. self,
  163. image_input: PaliGemmaImageInputs,
  164. ) -> torch.Tensor:
  165. if image_input["type"] == "image_embeds":
  166. return image_input["data"]
  167. assert self.vision_tower is not None
  168. pixel_values = image_input["data"]
  169. image_features = self._image_pixels_to_features(
  170. self.vision_tower,
  171. pixel_values,
  172. )
  173. return self.multi_modal_projector(image_features)
  174. def forward(self,
  175. input_ids: torch.Tensor,
  176. positions: torch.Tensor,
  177. kv_caches: List[torch.Tensor],
  178. attn_metadata: AttentionMetadata,
  179. intermediate_tensors: Optional[IntermediateTensors] = None,
  180. **kwargs: object) -> SamplerOutput:
  181. parsed_image_input = self._parse_and_validate_image_input(**kwargs)
  182. if parsed_image_input is not None:
  183. vision_embeddings = self._process_image_input(parsed_image_input)
  184. # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
  185. vision_embeddings = vision_embeddings * (self.config.hidden_size**
  186. -0.5)
  187. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  188. inputs_embeds = merge_multimodal_embeddings(
  189. input_ids, inputs_embeds, vision_embeddings,
  190. self.config.image_token_index)
  191. input_ids = None
  192. else:
  193. inputs_embeds = None
  194. hidden_states = self.language_model(input_ids,
  195. positions,
  196. kv_caches,
  197. attn_metadata,
  198. None,
  199. inputs_embeds=inputs_embeds)
  200. return hidden_states
  201. # Copied from vllm/modeling/models/gemma.py
  202. def compute_logits(
  203. self,
  204. hidden_states: torch.Tensor,
  205. sampling_metadata: SamplingMetadata,
  206. ) -> Optional[torch.Tensor]:
  207. logits = self.logits_processor(self.language_model.embed_tokens,
  208. hidden_states, sampling_metadata)
  209. return logits
  210. # Copied from vllm/modeling/models/gemma.py
  211. def sample(
  212. self,
  213. logits: torch.Tensor,
  214. sampling_metadata: SamplingMetadata,
  215. ) -> Optional[SamplerOutput]:
  216. next_tokens = self.sampler(logits, sampling_metadata)
  217. return next_tokens
  218. # Adapted from vllm/modeling/models/gemma.py
  219. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  220. stacked_params_mapping = [
  221. # (param_name, shard_name, shard_id)
  222. ("qkv_proj", "q_proj", "q"),
  223. ("qkv_proj", "k_proj", "k"),
  224. ("qkv_proj", "v_proj", "v"),
  225. ("gate_up_proj", "gate_proj", 0),
  226. ("gate_up_proj", "up_proj", 1),
  227. ]
  228. params_dict = dict(self.named_parameters())
  229. loaded_params = set()
  230. for name, loaded_weight in weights:
  231. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  232. if key_to_modify in name:
  233. name = name.replace(key_to_modify, new_key)
  234. use_default_weight_loading = False
  235. if "vision" in name:
  236. if self.vision_tower is not None:
  237. # We only do sharding for language model and
  238. # not vision model for now.
  239. use_default_weight_loading = True
  240. else:
  241. for (param_name, shard_name,
  242. shard_id) in stacked_params_mapping:
  243. if shard_name not in name:
  244. continue
  245. name = name.replace(shard_name, param_name)
  246. # Skip loading extra bias for GPTQ models.
  247. if name.endswith(".bias") and name not in params_dict:
  248. continue
  249. param = params_dict[name]
  250. weight_loader = param.weight_loader
  251. weight_loader(param, loaded_weight, shard_id)
  252. break
  253. else:
  254. # lm_head is not used in vllm as it is tied with
  255. # embed_token. To prevent errors, skip loading
  256. # lm_head.weight.
  257. if "lm_head.weight" in name:
  258. continue
  259. # Skip loading extra bias for GPTQ models.
  260. if name.endswith(".bias") and name not in params_dict:
  261. continue
  262. use_default_weight_loading = True
  263. if use_default_weight_loading:
  264. param = params_dict[name]
  265. weight_loader = getattr(param, "weight_loader",
  266. default_weight_loader)
  267. weight_loader(param, loaded_weight)
  268. loaded_params.add(name)
  269. unloaded_params = params_dict.keys() - loaded_params
  270. if unloaded_params:
  271. logger.warning(
  272. "Some weights are not initialized from checkpoints: "
  273. f"{unloaded_params}")