paligemma.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  2. import torch
  3. from loguru import logger
  4. from torch import nn
  5. from transformers import PaliGemmaConfig
  6. from aphrodite.attention import AttentionMetadata
  7. from aphrodite.common.config import CacheConfig, MultiModalConfig
  8. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  9. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  10. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  11. from aphrodite.modeling.layers.sampler import Sampler
  12. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  13. from aphrodite.modeling.models.gemma import GemmaModel
  14. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  15. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  16. from aphrodite.multimodal.image import cached_get_tokenizer
  17. from aphrodite.quantization.base_config import QuantizationConfig
  18. from .interfaces import SupportsVision
  19. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  20. dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
  21. from .utils import merge_vision_embeddings
  22. _KEYS_TO_MODIFY_MAPPING = {
  23. "language_model.model": "language_model",
  24. }
  25. def get_max_paligemma_image_tokens(ctx: InputContext):
  26. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  27. vision_config = hf_config.vision_config
  28. return get_max_siglip_image_tokens(vision_config)
  29. def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
  30. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  31. vision_config = hf_config.vision_config
  32. seq_data = dummy_seq_data_for_siglip(
  33. vision_config,
  34. seq_len,
  35. image_token_id=hf_config.image_token_index,
  36. )
  37. mm_data = dummy_image_for_siglip(vision_config)
  38. return seq_data, mm_data
  39. def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
  40. """
  41. The correct prompt format needs to be:
  42. '<image>' * image_feature_size + '<bos>' + prompt + '\n'
  43. See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
  44. """ # noqa
  45. multi_modal_data = llm_inputs.get("multi_modal_data")
  46. if multi_modal_data is None or "image" not in multi_modal_data:
  47. return llm_inputs
  48. model_config = ctx.model_config
  49. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  50. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  51. image_feature_size = hf_config.text_config.num_image_tokens
  52. image_token_str = tokenizer.decode(hf_config.image_token_index)
  53. bos_token = tokenizer.decode(hf_config.bos_token_id)
  54. image_token_str_pad = image_token_str * image_feature_size
  55. image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
  56. orig_prompt = llm_inputs.get("prompt")
  57. orig_prompt_ids = llm_inputs.get("prompt_token_ids")
  58. if orig_prompt is not None and image_token_str in orig_prompt:
  59. logger.warning(
  60. f"The image token '{image_token_str}' was detected in the prompt "
  61. "and will be removed. Please follow the proper prompt format"
  62. " documented on HuggingFace.")
  63. orig_prompt = orig_prompt.replace(image_token_str, "")
  64. orig_prompt_ids.remove(hf_config.image_token_index)
  65. new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
  66. new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
  67. # NOTE: Create a defensive copy of the original inputs
  68. return LLMInputs(prompt_token_ids=new_token_ids,
  69. prompt=new_prompt,
  70. multi_modal_data=multi_modal_data)
  71. class PaliGemmaMultiModalProjector(nn.Module):
  72. def __init__(self, vision_hidden_size: int, projection_dim: int):
  73. super().__init__()
  74. self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
  75. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  76. hidden_states = self.linear(image_features)
  77. return hidden_states
  78. class PaliGemmaImagePixelInputs(TypedDict):
  79. type: Literal["pixel_values"]
  80. data: torch.Tensor
  81. """Shape: (batch_size, num_channels, height, width)"""
  82. PaliGemmaImageInputs = PaliGemmaImagePixelInputs
  83. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  84. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
  85. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
  86. @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
  87. class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
  88. def __init__(self,
  89. config: PaliGemmaConfig,
  90. multimodal_config: MultiModalConfig,
  91. cache_config: Optional[CacheConfig] = None,
  92. quant_config: Optional[QuantizationConfig] = None) -> None:
  93. super().__init__()
  94. self.config = config
  95. self.multimodal_config = multimodal_config
  96. # TODO: Port over SiglipVisionModel & TP
  97. self.vision_tower = SiglipVisionModel(config.vision_config)
  98. self.multi_modal_projector = PaliGemmaMultiModalProjector(
  99. vision_hidden_size=config.vision_config.hidden_size,
  100. projection_dim=config.vision_config.projection_dim)
  101. self.quant_config = quant_config
  102. self.language_model = GemmaModel(config.text_config, cache_config,
  103. quant_config)
  104. self.unpadded_vocab_size = config.text_config.vocab_size
  105. logit_scale = getattr(config, "logit_scale", 1.0)
  106. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  107. config.vocab_size, logit_scale)
  108. self.sampler = Sampler()
  109. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  110. h = w = self.config.vision_config.image_size
  111. expected_dims = (3, h, w)
  112. actual_dims = tuple(data.shape[1:])
  113. if actual_dims != expected_dims:
  114. expected_expr = ("batch_size", *map(str, expected_dims))
  115. raise ValueError(
  116. f"The expected shape of pixel values is {expected_expr}. "
  117. f"You supplied {tuple(data.shape)}.")
  118. return data
  119. def _parse_and_validate_image_input(
  120. self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
  121. pixel_values = kwargs.pop("pixel_values", None)
  122. if pixel_values is None:
  123. return None
  124. if not isinstance(pixel_values, torch.Tensor):
  125. raise ValueError("Incorrect type of pixel values. "
  126. f"Got type: {type(pixel_values)}")
  127. return PaliGemmaImagePixelInputs(
  128. type="pixel_values",
  129. data=self._validate_pixel_values(pixel_values),
  130. )
  131. def _image_pixels_to_features(
  132. self,
  133. vision_tower: SiglipVisionModel,
  134. pixel_values: torch.Tensor,
  135. ) -> torch.Tensor:
  136. target_dtype = vision_tower.get_input_embeddings().weight.dtype
  137. image_features = vision_tower(pixel_values.to(dtype=target_dtype))
  138. return image_features
  139. def _process_image_pixels(
  140. self,
  141. inputs: PaliGemmaImagePixelInputs,
  142. ) -> torch.Tensor:
  143. assert self.vision_tower is not None
  144. pixel_values = inputs["data"]
  145. return self._image_pixels_to_features(
  146. self.vision_tower,
  147. pixel_values,
  148. )
  149. def _process_image_input(
  150. self,
  151. image_input: PaliGemmaImageInputs,
  152. ) -> torch.Tensor:
  153. assert self.vision_tower is not None
  154. image_features = self._process_image_pixels(image_input, )
  155. return self.multi_modal_projector(image_features)
  156. def forward(self,
  157. input_ids: torch.Tensor,
  158. positions: torch.Tensor,
  159. kv_caches: List[torch.Tensor],
  160. attn_metadata: AttentionMetadata,
  161. intermediate_tensors: Optional[IntermediateTensors] = None,
  162. **kwargs: object) -> SamplerOutput:
  163. parsed_image_input = self._parse_and_validate_image_input(**kwargs)
  164. if parsed_image_input is not None:
  165. vision_embeddings = self._process_image_input(parsed_image_input)
  166. # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
  167. vision_embeddings = vision_embeddings * (self.config.hidden_size**
  168. -0.5)
  169. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  170. inputs_embeds = merge_vision_embeddings(
  171. input_ids, inputs_embeds, vision_embeddings,
  172. self.config.image_token_index)
  173. input_ids = None
  174. else:
  175. inputs_embeds = None
  176. hidden_states = self.language_model(input_ids,
  177. positions,
  178. kv_caches,
  179. attn_metadata,
  180. None,
  181. inputs_embeds=inputs_embeds)
  182. return hidden_states
  183. # Copied from vllm/modeling/models/gemma.py
  184. def compute_logits(self, hidden_states: torch.Tensor,
  185. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  186. logits = self.logits_processor(self.language_model.embed_tokens,
  187. hidden_states, sampling_metadata)
  188. return logits
  189. # Copied from vllm/modeling/models/gemma.py
  190. def sample(
  191. self,
  192. logits: torch.Tensor,
  193. sampling_metadata: SamplingMetadata,
  194. ) -> Optional[SamplerOutput]:
  195. next_tokens = self.sampler(logits, sampling_metadata)
  196. return next_tokens
  197. # Adapted from vllm/modeling/models/gemma.py
  198. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  199. stacked_params_mapping = [
  200. # (param_name, shard_name, shard_id)
  201. ("qkv_proj", "q_proj", "q"),
  202. ("qkv_proj", "k_proj", "k"),
  203. ("qkv_proj", "v_proj", "v"),
  204. ("gate_up_proj", "gate_proj", 0),
  205. ("gate_up_proj", "up_proj", 1),
  206. ]
  207. params_dict = dict(self.named_parameters())
  208. loaded_params = set()
  209. for name, loaded_weight in weights:
  210. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  211. if key_to_modify in name:
  212. name = name.replace(key_to_modify, new_key)
  213. use_default_weight_loading = False
  214. if "vision" in name:
  215. if self.vision_tower is not None:
  216. # We only do sharding for language model and
  217. # not vision model for now.
  218. use_default_weight_loading = True
  219. else:
  220. for (param_name, shard_name,
  221. shard_id) in stacked_params_mapping:
  222. if shard_name not in name:
  223. continue
  224. name = name.replace(shard_name, param_name)
  225. # Skip loading extra bias for GPTQ models.
  226. if name.endswith(".bias") and name not in params_dict:
  227. continue
  228. param = params_dict[name]
  229. weight_loader = param.weight_loader
  230. weight_loader(param, loaded_weight, shard_id)
  231. break
  232. else:
  233. # lm_head is not used in vllm as it is tied with
  234. # embed_token. To prevent errors, skip loading
  235. # lm_head.weight.
  236. if "lm_head.weight" in name:
  237. continue
  238. # Skip loading extra bias for GPTQ models.
  239. if name.endswith(".bias") and name not in params_dict:
  240. continue
  241. use_default_weight_loading = True
  242. if use_default_weight_loading:
  243. param = params_dict[name]
  244. weight_loader = getattr(param, "weight_loader",
  245. default_weight_loader)
  246. weight_loader(param, loaded_weight)
  247. loaded_params.add(name)
  248. unloaded_params = params_dict.keys() - loaded_params
  249. if unloaded_params:
  250. logger.warning(
  251. "Some weights are not initialized from checkpoints: "
  252. f"{unloaded_params}")