fuyu.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # coding=utf-8
  2. # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
  3. # Copyright 2023 The vLLM team.
  4. # Copyright 2023 HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """ PyTorch Fuyu model."""
  18. import math
  19. from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
  20. import torch
  21. import torch.nn as nn
  22. import torch.utils.checkpoint
  23. from PIL import Image
  24. from transformers import FuyuConfig, FuyuImageProcessor
  25. from aphrodite.attention import AttentionMetadata
  26. from aphrodite.common.config import CacheConfig, MultiModalConfig
  27. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  28. SequenceData)
  29. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  30. from aphrodite.modeling.layers.linear import ColumnParallelLinear
  31. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  32. from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
  33. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  34. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  35. from aphrodite.multimodal.base import MultiModalInputs
  36. from aphrodite.multimodal.image import (cached_get_image_processor,
  37. cached_get_tokenizer)
  38. from aphrodite.quantization.base_config import QuantizationConfig
  39. from .interfaces import SupportsMultiModal
  40. from .utils import merge_multimodal_embeddings
  41. # Cannot find the following 2 numbers from hf config.
  42. _IMAGE_TOKEN_ID = 71011
  43. _NEWLINE_TOKEN_ID = 71019
  44. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
  45. MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
  46. class FuyuImagePixelInputs(TypedDict):
  47. type: Literal["pixel_values"]
  48. data: torch.Tensor
  49. """
  50. Shape:
  51. (batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
  52. """
  53. def _calculate_num_image_tokens(
  54. height: int,
  55. width: int,
  56. ) -> Tuple[int, int]:
  57. """
  58. calculate number of image tokens needed for a given image size
  59. The expected Fuyu image prompts is in format:
  60. (image_token * ncols + newline_token) * nrows
  61. args:
  62. image_size: Tuple[int, int] - (width, height) of the image
  63. returns:
  64. ncols: int - number of image tokens in x direction
  65. nrows: int - number of image tokens in y direction
  66. """
  67. ncol = math.ceil(width / 30)
  68. nrow = math.ceil(height / 30)
  69. return ncol, nrow
  70. def get_max_fuyu_image_feature_size():
  71. return _calculate_num_image_tokens(
  72. height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  73. width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  74. )
  75. def get_max_fuyu_image_tokens(ctx: InputContext):
  76. ncol, nrow = get_max_fuyu_image_feature_size()
  77. return (ncol + 1) * nrow
  78. def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
  79. ncol, nrow = get_max_fuyu_image_feature_size()
  80. image_feature_size = get_max_fuyu_image_tokens(ctx)
  81. image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
  82. token_ids = image_token_ids * num_images
  83. token_ids += [0] * (seq_len - image_feature_size * num_images)
  84. return SequenceData(token_ids)
  85. def dummy_image_for_fuyu(
  86. num_images: int,
  87. *,
  88. image_width: int,
  89. image_height: int,
  90. ):
  91. image = Image.new("RGB", (image_width, image_height), color=0)
  92. return {"image": image if num_images == 1 else [image] * num_images}
  93. def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
  94. mm_counts: Mapping[str, int]):
  95. num_images = mm_counts["image"]
  96. seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
  97. mm_data = dummy_image_for_fuyu(num_images,
  98. image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  99. image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
  100. return seq_data, mm_data
  101. def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
  102. data: Image.Image):
  103. image_encoding = image_processor.preprocess(data, return_tensors="pt")
  104. batch_images = torch.stack([img[0] for img in image_encoding["images"]
  105. ]).unsqueeze(1)
  106. image_unpadded_heights = torch.tensor(
  107. image_encoding["image_unpadded_heights"])
  108. image_unpadded_widths = torch.tensor(
  109. image_encoding["image_unpadded_widths"])
  110. batch_size = len(image_encoding["images"])
  111. image_present = torch.ones(batch_size, 1, 1)
  112. model_image_input = image_processor.preprocess_with_tokenizer_info(
  113. image_input=batch_images,
  114. image_present=image_present,
  115. image_unpadded_h=image_unpadded_heights,
  116. image_unpadded_w=image_unpadded_widths,
  117. image_placeholder_id=_IMAGE_TOKEN_ID,
  118. image_newline_id=_NEWLINE_TOKEN_ID,
  119. variable_sized=True,
  120. )
  121. return model_image_input
  122. def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
  123. multi_modal_data = llm_inputs.get("multi_modal_data")
  124. if multi_modal_data is None or "image" not in multi_modal_data:
  125. return llm_inputs
  126. model_config = ctx.model_config
  127. image_data = multi_modal_data["image"]
  128. new_multi_modal_data = {}
  129. # process image data
  130. if isinstance(image_data, Image.Image):
  131. # Fuyu's image_processor can also finish token padding
  132. image_processor: FuyuImageProcessor = cached_get_image_processor(
  133. model_config.model)
  134. model_image_input = _fuyu_image_preprocess(image_processor, image_data)
  135. image_patches = torch.stack([
  136. image_patch[0]
  137. for image_patch in model_image_input["image_patches"]
  138. ])
  139. new_multi_modal_data["image"] = image_patches
  140. elif isinstance(image_data, torch.Tensor):
  141. raise NotImplementedError("Embeddings input is not supported yet")
  142. else:
  143. raise TypeError(f"Invalid image type: {type(image_data)}")
  144. # process prompts
  145. prompt = llm_inputs.get("prompt")
  146. prompt_token_ids = llm_inputs["prompt_token_ids"]
  147. tokenizer = cached_get_tokenizer(model_config.model)
  148. # dim0 is batch_size, dim1 is subseq_size which will always be 1
  149. image_input_ids: List[List[
  150. torch.Tensor]] = model_image_input["image_input_ids"]
  151. image_input_ids = image_input_ids[0][0].tolist()
  152. bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
  153. boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
  154. new_prompt = prompt + "\x04"
  155. new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
  156. 1:] + boa_token
  157. return LLMInputs(prompt=new_prompt,
  158. prompt_token_ids=new_prompt_token_ids,
  159. multi_modal_data=new_multi_modal_data)
  160. def input_mapper_for_fuyu(ctx: InputContext, data: object):
  161. model_config = ctx.model_config
  162. if isinstance(data, Image.Image):
  163. # Fuyu's image_processor can also finish token padding
  164. image_processor: FuyuImageProcessor = cached_get_image_processor(
  165. model_config.model)
  166. model_image_input = _fuyu_image_preprocess(image_processor, data)
  167. data = torch.stack([
  168. image_patch[0]
  169. for image_patch in model_image_input["image_patches"]
  170. ])
  171. # image has been processed with prompt in input processor
  172. return MultiModalInputs({"image_patches": data})
  173. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
  174. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
  175. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
  176. @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
  177. class FuyuForCausalLM(nn.Module, SupportsMultiModal):
  178. def __init__(self,
  179. config: FuyuConfig,
  180. multimodal_config: MultiModalConfig,
  181. cache_config: Optional[CacheConfig] = None,
  182. quant_config: Optional[QuantizationConfig] = None) -> None:
  183. super().__init__()
  184. self.config = config
  185. self.multimodal_config = multimodal_config
  186. self.padding_idx = config.pad_token_id
  187. self.vocab_size = config.vocab_size
  188. self.image_token_id = _IMAGE_TOKEN_ID
  189. self.image_feature_size = config.patch_size**2 * config.num_channels
  190. self.vision_embed_tokens = ColumnParallelLinear(
  191. self.image_feature_size,
  192. config.hidden_size,
  193. quant_config=quant_config,
  194. )
  195. self.language_model = PersimmonForCausalLM(config,
  196. cache_config=cache_config,
  197. quant_config=quant_config)
  198. def _parse_and_validate_image_input(
  199. self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
  200. image_patches = kwargs.pop("image_patches", None)
  201. if isinstance(image_patches, torch.Tensor):
  202. expected_feature_size = self.image_feature_size
  203. if image_patches.size(-1) != expected_feature_size:
  204. raise ValueError(
  205. f"Expected image patches to have the last dimension of "
  206. f"{expected_feature_size}, got {image_patches.size(-1)}")
  207. image_patches = image_patches.to(
  208. self.vision_embed_tokens.weight.dtype)
  209. return FuyuImagePixelInputs(type="pixel_values",
  210. data=image_patches)
  211. return None
  212. def _process_image_input(
  213. self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
  214. assert self.vision_embed_tokens is not None
  215. vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
  216. return vision_embeddings
  217. def forward(
  218. self,
  219. input_ids: torch.Tensor,
  220. positions: torch.Tensor,
  221. kv_caches: List[torch.Tensor],
  222. attn_metadata: AttentionMetadata,
  223. intermediate_tensors: Optional[IntermediateTensors] = None,
  224. **kwargs: object,
  225. ):
  226. image_input = self._parse_and_validate_image_input(**kwargs)
  227. if image_input is not None:
  228. vision_embeddings = self._process_image_input(image_input)
  229. inputs_embeds = self.language_model.model.embed_tokens(input_ids)
  230. inputs_embeds = merge_multimodal_embeddings(
  231. input_ids, inputs_embeds, vision_embeddings,
  232. self.image_token_id)
  233. else:
  234. inputs_embeds = None
  235. hidden_states = self.language_model(
  236. input_ids=input_ids,
  237. positions=positions,
  238. kv_caches=kv_caches,
  239. attn_metadata=attn_metadata,
  240. inputs_embeds=inputs_embeds,
  241. )
  242. return hidden_states
  243. def compute_logits(
  244. self,
  245. hidden_states: torch.Tensor,
  246. sampling_metadata: SamplingMetadata,
  247. ) -> Optional[torch.Tensor]:
  248. logits = self.language_model.logits_processor(
  249. self.language_model.lm_head, hidden_states, sampling_metadata)
  250. return logits
  251. def sample(
  252. self,
  253. logits: torch.Tensor,
  254. sampling_metadata: SamplingMetadata,
  255. ) -> Optional[SamplerOutput]:
  256. next_tokens = self.language_model.sampler(logits, sampling_metadata)
  257. return next_tokens
  258. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  259. params_dict = dict(self.named_parameters(remove_duplicate=False))
  260. for name, loaded_weight in weights:
  261. if "rotary_emb.inv_freq" in name:
  262. continue
  263. if ("rotary_emb.cos_cached" in name
  264. or "rotary_emb.sin_cached" in name):
  265. # Models trained using ColossalAI may include these tensors in
  266. # the checkpoint. Skip them.
  267. continue
  268. param = params_dict[name]
  269. if "query_key_value" in name:
  270. # copy from vllm/model_executor/models/bloom.py
  271. # NOTE: Fuyu's fused QKV's output_dim has the shape of
  272. # (num_heads * 3 * head_size), while the
  273. # required shape is (3 * num_heads * head_size).
  274. # Thus, we need weight conversion.
  275. output_dim = getattr(param, "output_dim", None)
  276. num_heads = self.config.num_attention_heads
  277. if output_dim is not None:
  278. loaded_weight_shape = loaded_weight.shape
  279. loaded_weight = loaded_weight.view(
  280. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  281. loaded_weight_shape[output_dim + 1:])
  282. loaded_weight = loaded_weight.transpose(
  283. output_dim, output_dim + 1)
  284. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  285. weight_loader = getattr(param, "weight_loader",
  286. default_weight_loader)
  287. weight_loader(param, loaded_weight)