paligemma.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  2. TypedDict, Union)
  3. import torch
  4. from loguru import logger
  5. from torch import nn
  6. from transformers import PaliGemmaConfig
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import CacheConfig, MultiModalConfig
  9. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  10. from aphrodite.common.utils import progress_bar
  11. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  12. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  13. from aphrodite.modeling.layers.sampler import Sampler
  14. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  15. from aphrodite.modeling.models.gemma import GemmaModel
  16. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  17. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  18. from aphrodite.multimodal.image import cached_get_tokenizer
  19. from aphrodite.quantization.base_config import QuantizationConfig
  20. from .interfaces import SupportsMultiModal
  21. from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
  22. dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
  23. from .utils import merge_multimodal_embeddings
  24. _KEYS_TO_MODIFY_MAPPING = {
  25. "language_model.model": "language_model",
  26. }
  27. class PaliGemmaImagePixelInputs(TypedDict):
  28. type: Literal["pixel_values"]
  29. data: torch.Tensor
  30. """Shape: (batch_size, num_channels, height, width)"""
  31. class PaliGemmaImageEmbeddingInputs(TypedDict):
  32. type: Literal["image_embeds"]
  33. data: torch.Tensor
  34. """Shape: `(batch_size, image_feature_size, hidden_size)`
  35. `hidden_size` must match the hidden size of language model backbone.
  36. """
  37. PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs,
  38. PaliGemmaImageEmbeddingInputs]
  39. def get_max_paligemma_image_tokens(ctx: InputContext):
  40. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  41. vision_config = hf_config.vision_config
  42. return get_max_siglip_image_tokens(vision_config)
  43. def dummy_data_for_paligemma(ctx: InputContext, seq_len: int,
  44. mm_counts: Mapping[str, int]):
  45. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  46. vision_config = hf_config.vision_config
  47. num_images = mm_counts["image"]
  48. seq_data = dummy_seq_data_for_siglip(
  49. vision_config,
  50. seq_len,
  51. num_images,
  52. image_token_id=hf_config.image_token_index,
  53. )
  54. mm_data = dummy_image_for_siglip(vision_config, num_images)
  55. return seq_data, mm_data
  56. def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
  57. """
  58. The correct prompt format needs to be:
  59. '<image>' * image_feature_size + '<bos>' + prompt + '\n'
  60. See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
  61. """ # noqa
  62. multi_modal_data = llm_inputs.get("multi_modal_data")
  63. if multi_modal_data is None or "image" not in multi_modal_data:
  64. return llm_inputs
  65. model_config = ctx.model_config
  66. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  67. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  68. image_feature_size = hf_config.text_config.num_image_tokens
  69. image_token_str = tokenizer.decode(hf_config.image_token_index)
  70. bos_token = tokenizer.decode(hf_config.bos_token_id)
  71. image_token_str_pad = image_token_str * image_feature_size
  72. image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
  73. orig_prompt = llm_inputs.get("prompt")
  74. orig_prompt_ids = llm_inputs.get("prompt_token_ids")
  75. if orig_prompt is not None and image_token_str in orig_prompt:
  76. logger.warning(
  77. f"The image token '{image_token_str}' was detected in the prompt "
  78. "and will be removed. Please follow the proper prompt format"
  79. " documented on HuggingFace.")
  80. orig_prompt = orig_prompt.replace(image_token_str, "")
  81. orig_prompt_ids.remove(hf_config.image_token_index)
  82. new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
  83. new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
  84. # NOTE: Create a defensive copy of the original inputs
  85. return LLMInputs(prompt_token_ids=new_token_ids,
  86. prompt=new_prompt,
  87. multi_modal_data=multi_modal_data)
  88. class PaliGemmaMultiModalProjector(nn.Module):
  89. def __init__(self, vision_hidden_size: int, projection_dim: int):
  90. super().__init__()
  91. self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True)
  92. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  93. hidden_states = self.linear(image_features)
  94. return hidden_states
  95. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  96. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
  97. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
  98. @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
  99. class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
  100. def __init__(self,
  101. config: PaliGemmaConfig,
  102. multimodal_config: MultiModalConfig,
  103. cache_config: Optional[CacheConfig] = None,
  104. quant_config: Optional[QuantizationConfig] = None) -> None:
  105. super().__init__()
  106. self.config = config
  107. self.multimodal_config = multimodal_config
  108. # TODO: Port over SiglipVisionModel & TP
  109. self.vision_tower = SiglipVisionModel(config.vision_config)
  110. self.multi_modal_projector = PaliGemmaMultiModalProjector(
  111. vision_hidden_size=config.vision_config.hidden_size,
  112. projection_dim=config.vision_config.projection_dim)
  113. self.quant_config = quant_config
  114. self.language_model = GemmaModel(config.text_config, cache_config,
  115. quant_config)
  116. self.unpadded_vocab_size = config.text_config.vocab_size
  117. logit_scale = getattr(config, "logit_scale", 1.0)
  118. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  119. config.vocab_size, logit_scale)
  120. self.sampler = Sampler()
  121. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  122. h = w = self.config.vision_config.image_size
  123. expected_dims = (3, h, w)
  124. actual_dims = tuple(data.shape[1:])
  125. if actual_dims != expected_dims:
  126. expected_expr = ("batch_size", *map(str, expected_dims))
  127. raise ValueError(
  128. f"The expected shape of pixel values is {expected_expr}. "
  129. f"You supplied {tuple(data.shape)}.")
  130. return data
  131. def _parse_and_validate_image_input(
  132. self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
  133. pixel_values = kwargs.pop("pixel_values", None)
  134. image_embeds = kwargs.pop("image_embeds", None)
  135. if pixel_values is None and image_embeds is None:
  136. return None
  137. if pixel_values is not None:
  138. if not isinstance(pixel_values, torch.Tensor):
  139. raise ValueError("Incorrect type of pixel values. "
  140. f"Got type: {type(pixel_values)}")
  141. return PaliGemmaImagePixelInputs(
  142. type="pixel_values",
  143. data=self._validate_pixel_values(pixel_values),
  144. )
  145. if image_embeds is not None:
  146. if not isinstance(image_embeds, torch.Tensor):
  147. raise ValueError("Incorrect type of image embeddings. "
  148. f"Got type: {type(image_embeds)}")
  149. return PaliGemmaImageEmbeddingInputs(
  150. type="image_embeds",
  151. data=image_embeds,
  152. )
  153. raise AssertionError("This line should be unreachable.")
  154. def _image_pixels_to_features(
  155. self,
  156. vision_tower: SiglipVisionModel,
  157. pixel_values: torch.Tensor,
  158. ) -> torch.Tensor:
  159. target_dtype = vision_tower.get_input_embeddings().weight.dtype
  160. image_features = vision_tower(pixel_values.to(dtype=target_dtype))
  161. return image_features
  162. def _process_image_input(
  163. self,
  164. image_input: PaliGemmaImageInputs,
  165. ) -> torch.Tensor:
  166. if image_input["type"] == "image_embeds":
  167. return image_input["data"]
  168. assert self.vision_tower is not None
  169. pixel_values = image_input["data"]
  170. image_features = self._image_pixels_to_features(
  171. self.vision_tower,
  172. pixel_values,
  173. )
  174. return self.multi_modal_projector(image_features)
  175. def forward(self,
  176. input_ids: torch.Tensor,
  177. positions: torch.Tensor,
  178. kv_caches: List[torch.Tensor],
  179. attn_metadata: AttentionMetadata,
  180. intermediate_tensors: Optional[IntermediateTensors] = None,
  181. **kwargs: object) -> SamplerOutput:
  182. parsed_image_input = self._parse_and_validate_image_input(**kwargs)
  183. if parsed_image_input is not None:
  184. vision_embeddings = self._process_image_input(parsed_image_input)
  185. # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
  186. vision_embeddings = vision_embeddings * (self.config.hidden_size**
  187. -0.5)
  188. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  189. inputs_embeds = merge_multimodal_embeddings(
  190. input_ids, inputs_embeds, vision_embeddings,
  191. self.config.image_token_index)
  192. input_ids = None
  193. else:
  194. inputs_embeds = None
  195. hidden_states = self.language_model(input_ids,
  196. positions,
  197. kv_caches,
  198. attn_metadata,
  199. None,
  200. inputs_embeds=inputs_embeds)
  201. return hidden_states
  202. # Copied from vllm/modeling/models/gemma.py
  203. def compute_logits(
  204. self,
  205. hidden_states: torch.Tensor,
  206. sampling_metadata: SamplingMetadata,
  207. ) -> Optional[torch.Tensor]:
  208. logits = self.logits_processor(self.language_model.embed_tokens,
  209. hidden_states, sampling_metadata)
  210. return logits
  211. # Copied from vllm/modeling/models/gemma.py
  212. def sample(
  213. self,
  214. logits: torch.Tensor,
  215. sampling_metadata: SamplingMetadata,
  216. ) -> Optional[SamplerOutput]:
  217. next_tokens = self.sampler(logits, sampling_metadata)
  218. return next_tokens
  219. # Adapted from vllm/modeling/models/gemma.py
  220. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  221. stacked_params_mapping = [
  222. # (param_name, shard_name, shard_id)
  223. ("qkv_proj", "q_proj", "q"),
  224. ("qkv_proj", "k_proj", "k"),
  225. ("qkv_proj", "v_proj", "v"),
  226. ("gate_up_proj", "gate_proj", 0),
  227. ("gate_up_proj", "up_proj", 1),
  228. ]
  229. params_dict = dict(self.named_parameters())
  230. loaded_params = set()
  231. weights_list = list(weights)
  232. for name, loaded_weight in progress_bar(weights_list,
  233. desc="Loading modules..."):
  234. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  235. if key_to_modify in name:
  236. name = name.replace(key_to_modify, new_key)
  237. use_default_weight_loading = False
  238. if "vision" in name:
  239. if self.vision_tower is not None:
  240. # We only do sharding for language model and
  241. # not vision model for now.
  242. use_default_weight_loading = True
  243. else:
  244. for (param_name, shard_name,
  245. shard_id) in stacked_params_mapping:
  246. if shard_name not in name:
  247. continue
  248. name = name.replace(shard_name, param_name)
  249. # Skip loading extra bias for GPTQ models.
  250. if name.endswith(".bias") and name not in params_dict:
  251. continue
  252. param = params_dict[name]
  253. weight_loader = param.weight_loader
  254. weight_loader(param, loaded_weight, shard_id)
  255. break
  256. else:
  257. # lm_head is not used in vllm as it is tied with
  258. # embed_token. To prevent errors, skip loading
  259. # lm_head.weight.
  260. if "lm_head.weight" in name:
  261. continue
  262. # Skip loading extra bias for GPTQ models.
  263. if name.endswith(".bias") and name not in params_dict:
  264. continue
  265. use_default_weight_loading = True
  266. if use_default_weight_loading:
  267. param = params_dict[name]
  268. weight_loader = getattr(param, "weight_loader",
  269. default_weight_loader)
  270. weight_loader(param, loaded_weight)
  271. loaded_params.add(name)
  272. unloaded_params = params_dict.keys() - loaded_params
  273. if unloaded_params:
  274. logger.warning(
  275. "Some weights are not initialized from checkpoints: "
  276. f"{unloaded_params}")