paligemma.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  2. import torch
  3. from loguru import logger
  4. from PIL import Image
  5. from torch import nn
  6. from transformers import PaliGemmaConfig, SiglipVisionConfig, SiglipVisionModel
  7. from aphrodite.attention import AttentionMetadata
  8. from aphrodite.common.config import CacheConfig, MultiModalConfig
  9. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  10. SequenceData)
  11. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  12. from aphrodite.modeling.layers.linear import ColumnParallelLinear
  13. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  14. from aphrodite.modeling.layers.sampler import Sampler
  15. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  16. from aphrodite.modeling.models.gemma import GemmaModel
  17. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  18. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  19. from aphrodite.multimodal.image import cached_get_tokenizer
  20. from aphrodite.quantization.base_config import QuantizationConfig
  21. from .interfaces import SupportsVision
  22. from .utils import merge_vision_embeddings
  23. _KEYS_TO_MODIFY_MAPPING = {
  24. "language_model.model": "language_model",
  25. }
  26. def get_max_paligemma_image_tokens(ctx: InputContext):
  27. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  28. text_config = hf_config.text_config
  29. return text_config.num_image_tokens
  30. def dummy_seq_data_for_paligemma(
  31. hf_config: PaliGemmaConfig,
  32. seq_len: int,
  33. *,
  34. image_token_id: int,
  35. image_feature_size_override: Optional[int] = None,
  36. ):
  37. if image_feature_size_override is None:
  38. image_feature_size = hf_config.text_config.num_image_tokens
  39. else:
  40. image_feature_size = image_feature_size_override
  41. token_ids = [image_token_id] * image_feature_size
  42. token_ids += [0] * (seq_len - image_feature_size)
  43. return SequenceData(token_ids)
  44. def dummy_image_for_paligemma(
  45. hf_config: SiglipVisionConfig,
  46. *,
  47. image_width_override: Optional[int] = None,
  48. image_height_override: Optional[int] = None,
  49. ):
  50. width = height = hf_config.image_size
  51. if image_width_override is not None:
  52. width = image_width_override
  53. if image_height_override is not None:
  54. height = image_height_override
  55. image = Image.new("RGB", (width, height), color=0)
  56. return {"image": image}
  57. def dummy_data_for_paligemma(ctx: InputContext, seq_len: int):
  58. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  59. vision_config = hf_config.vision_config
  60. seq_data = dummy_seq_data_for_paligemma(
  61. hf_config,
  62. seq_len,
  63. image_token_id=hf_config.image_token_index,
  64. )
  65. mm_data = dummy_image_for_paligemma(vision_config)
  66. return seq_data, mm_data
  67. def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs):
  68. """
  69. The correct prompt format needs to be:
  70. '<image>' * image_feature_size + '<bos>' + prompt + '\n'
  71. See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55
  72. """ # noqa
  73. multi_modal_data = llm_inputs.get("multi_modal_data")
  74. if multi_modal_data is None or "image" not in multi_modal_data:
  75. return llm_inputs
  76. model_config = ctx.model_config
  77. hf_config = ctx.get_hf_config(PaliGemmaConfig)
  78. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  79. image_feature_size = hf_config.text_config.num_image_tokens
  80. image_token_str = tokenizer.decode(hf_config.image_token_index)
  81. bos_token = tokenizer.decode(hf_config.bos_token_id)
  82. image_token_str_pad = image_token_str * image_feature_size
  83. image_token_ids_pad = [hf_config.image_token_index] * image_feature_size
  84. orig_prompt = llm_inputs.get("prompt")
  85. orig_prompt_ids = llm_inputs.get("prompt_token_ids")
  86. if orig_prompt is not None and image_token_str in orig_prompt:
  87. logger.warning(
  88. f"The image token '{image_token_str}' was detected in the prompt "
  89. "and will be removed. Please follow the proper prompt format"
  90. " documented on HuggingFace.")
  91. orig_prompt = orig_prompt.replace(image_token_str, "")
  92. orig_prompt_ids.remove(hf_config.image_token_index)
  93. new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n"
  94. new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline
  95. # NOTE: Create a defensive copy of the original inputs
  96. return LLMInputs(prompt_token_ids=new_token_ids,
  97. prompt=new_prompt,
  98. multi_modal_data=multi_modal_data)
  99. class PaliGemmaMultiModalProjector(nn.Module):
  100. def __init__(self, vision_hidden_size: int, projection_dim: int):
  101. super().__init__()
  102. self.linear = ColumnParallelLinear(vision_hidden_size,
  103. projection_dim,
  104. bias=True)
  105. def forward(self, image_features: torch.Tensor) -> torch.Tensor:
  106. hidden_states, _ = self.linear(image_features)
  107. return hidden_states
  108. class PaliGemmaImagePixelInputs(TypedDict):
  109. type: Literal["pixel_values"]
  110. data: torch.Tensor
  111. """Shape: (batch_size, num_channels, height, width)"""
  112. PaliGemmaImageInputs = PaliGemmaImagePixelInputs
  113. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  114. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
  115. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
  116. @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
  117. class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
  118. def __init__(self,
  119. config: PaliGemmaConfig,
  120. multimodal_config: MultiModalConfig,
  121. cache_config: Optional[CacheConfig] = None,
  122. quant_config: Optional[QuantizationConfig] = None) -> None:
  123. super().__init__()
  124. self.config = config
  125. self.multimodal_config = multimodal_config
  126. # TODO: Port over SiglipVisionModel & TP
  127. self.vision_tower = SiglipVisionModel(config.vision_config)
  128. self.multi_modal_projector = PaliGemmaMultiModalProjector(
  129. vision_hidden_size=config.vision_config.hidden_size,
  130. projection_dim=config.vision_config.projection_dim)
  131. self.quant_config = quant_config
  132. self.language_model = GemmaModel(config.text_config, cache_config,
  133. quant_config)
  134. self.unpadded_vocab_size = config.text_config.vocab_size
  135. logit_scale = getattr(config, "logit_scale", 1.0)
  136. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  137. config.vocab_size, logit_scale)
  138. self.sampler = Sampler()
  139. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  140. h = w = self.config.vision_config.image_size
  141. expected_dims = (3, h, w)
  142. actual_dims = tuple(data.shape[1:])
  143. if actual_dims != expected_dims:
  144. expected_expr = ("batch_size", *map(str, expected_dims))
  145. raise ValueError(
  146. f"The expected shape of pixel values is {expected_expr}. "
  147. f"You supplied {tuple(data.shape)}.")
  148. return data
  149. def _parse_and_validate_image_input(
  150. self, **kwargs: object) -> Optional[PaliGemmaImageInputs]:
  151. pixel_values = kwargs.pop("pixel_values", None)
  152. if pixel_values is None:
  153. return None
  154. if not isinstance(pixel_values, torch.Tensor):
  155. raise ValueError("Incorrect type of pixel values. "
  156. f"Got type: {type(pixel_values)}")
  157. return PaliGemmaImagePixelInputs(
  158. type="pixel_values",
  159. data=self._validate_pixel_values(pixel_values),
  160. )
  161. def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
  162. pixel_values: torch.Tensor) -> torch.Tensor:
  163. target_dtype = vision_tower.get_input_embeddings().weight.dtype
  164. image_outputs = vision_tower(pixel_values.to(dtype=target_dtype),
  165. output_hidden_states=True)
  166. selected_image_features = image_outputs.last_hidden_state
  167. return selected_image_features
  168. def _process_image_pixels(
  169. self, inputs: PaliGemmaImagePixelInputs) -> torch.Tensor:
  170. assert self.vision_tower is not None
  171. pixel_values = inputs["data"]
  172. return self._image_pixels_to_features(self.vision_tower, pixel_values)
  173. def _process_image_input(
  174. self, image_input: PaliGemmaImageInputs) -> torch.Tensor:
  175. assert self.vision_tower is not None
  176. image_features = self._process_image_pixels(image_input)
  177. return self.multi_modal_projector(image_features)
  178. def forward(self,
  179. input_ids: torch.Tensor,
  180. positions: torch.Tensor,
  181. kv_caches: List[torch.Tensor],
  182. attn_metadata: AttentionMetadata,
  183. intermediate_tensors: Optional[IntermediateTensors] = None,
  184. **kwargs: object) -> SamplerOutput:
  185. parsed_image_input = self._parse_and_validate_image_input(**kwargs)
  186. if parsed_image_input is not None:
  187. vision_embeddings = self._process_image_input(parsed_image_input)
  188. # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
  189. vision_embeddings = vision_embeddings * (self.config.hidden_size**
  190. -0.5)
  191. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  192. inputs_embeds = merge_vision_embeddings(
  193. input_ids, inputs_embeds, vision_embeddings,
  194. self.config.image_token_index)
  195. input_ids = None
  196. else:
  197. inputs_embeds = None
  198. hidden_states = self.language_model(input_ids,
  199. positions,
  200. kv_caches,
  201. attn_metadata,
  202. None,
  203. inputs_embeds=inputs_embeds)
  204. return hidden_states
  205. # Copied from vllm/modeling/models/gemma.py
  206. def compute_logits(self, hidden_states: torch.Tensor,
  207. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  208. logits = self.logits_processor(self.language_model.embed_tokens,
  209. hidden_states, sampling_metadata)
  210. return logits
  211. # Copied from vllm/modeling/models/gemma.py
  212. def sample(
  213. self,
  214. logits: torch.Tensor,
  215. sampling_metadata: SamplingMetadata,
  216. ) -> Optional[SamplerOutput]:
  217. next_tokens = self.sampler(logits, sampling_metadata)
  218. return next_tokens
  219. # Adapted from vllm/modeling/models/gemma.py
  220. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  221. stacked_params_mapping = [
  222. # (param_name, shard_name, shard_id)
  223. ("qkv_proj", "q_proj", "q"),
  224. ("qkv_proj", "k_proj", "k"),
  225. ("qkv_proj", "v_proj", "v"),
  226. ("gate_up_proj", "gate_proj", 0),
  227. ("gate_up_proj", "up_proj", 1),
  228. ]
  229. params_dict = dict(self.named_parameters())
  230. loaded_params = set()
  231. for name, loaded_weight in weights:
  232. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  233. if key_to_modify in name:
  234. name = name.replace(key_to_modify, new_key)
  235. use_default_weight_loading = False
  236. if "vision" in name:
  237. if self.vision_tower is not None:
  238. # We only do sharding for language model and
  239. # not vision model for now.
  240. use_default_weight_loading = True
  241. else:
  242. for (param_name, shard_name,
  243. shard_id) in stacked_params_mapping:
  244. if shard_name not in name:
  245. continue
  246. name = name.replace(shard_name, param_name)
  247. # Skip loading extra bias for GPTQ models.
  248. if name.endswith(".bias") and name not in params_dict:
  249. continue
  250. param = params_dict[name]
  251. weight_loader = param.weight_loader
  252. weight_loader(param, loaded_weight, shard_id)
  253. break
  254. else:
  255. # lm_head is not used in vllm as it is tied with
  256. # embed_token. To prevent errors, skip loading
  257. # lm_head.weight.
  258. if "lm_head.weight" in name:
  259. continue
  260. # Skip loading extra bias for GPTQ models.
  261. if name.endswith(".bias") and name not in params_dict:
  262. continue
  263. use_default_weight_loading = True
  264. if use_default_weight_loading:
  265. param = params_dict[name]
  266. weight_loader = getattr(param, "weight_loader",
  267. default_weight_loader)
  268. weight_loader(param, loaded_weight)
  269. loaded_params.add(name)
  270. unloaded_params = params_dict.keys() - loaded_params
  271. if unloaded_params:
  272. raise RuntimeError(
  273. "Some weights are not initialized from checkpoints: "
  274. f"{unloaded_params}")