fuyu.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # coding=utf-8
  2. # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
  3. # Copyright 2023 The vLLM team.
  4. # Copyright 2023 HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """ PyTorch Fuyu model."""
  18. import math
  19. from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
  20. import torch
  21. import torch.nn as nn
  22. import torch.utils.checkpoint
  23. from PIL import Image
  24. from transformers import FuyuConfig, FuyuImageProcessor
  25. from aphrodite.attention import AttentionMetadata
  26. from aphrodite.common.config import CacheConfig, MultiModalConfig
  27. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  28. SequenceData)
  29. from aphrodite.common.utils import progress_bar
  30. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  31. from aphrodite.modeling.layers.linear import ColumnParallelLinear
  32. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  33. from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
  34. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  35. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  36. from aphrodite.multimodal.base import MultiModalInputs
  37. from aphrodite.multimodal.image import (cached_get_image_processor,
  38. cached_get_tokenizer)
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. from .interfaces import SupportsMultiModal
  41. from .utils import merge_multimodal_embeddings
  42. # Cannot find the following 2 numbers from hf config.
  43. _IMAGE_TOKEN_ID = 71011
  44. _NEWLINE_TOKEN_ID = 71019
  45. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
  46. MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
  47. class FuyuImagePixelInputs(TypedDict):
  48. type: Literal["pixel_values"]
  49. data: torch.Tensor
  50. """
  51. Shape:
  52. (batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
  53. """
  54. def _calculate_num_image_tokens(
  55. height: int,
  56. width: int,
  57. ) -> Tuple[int, int]:
  58. """
  59. calculate number of image tokens needed for a given image size
  60. The expected Fuyu image prompts is in format:
  61. (image_token * ncols + newline_token) * nrows
  62. args:
  63. image_size: Tuple[int, int] - (width, height) of the image
  64. returns:
  65. ncols: int - number of image tokens in x direction
  66. nrows: int - number of image tokens in y direction
  67. """
  68. ncol = math.ceil(width / 30)
  69. nrow = math.ceil(height / 30)
  70. return ncol, nrow
  71. def get_max_fuyu_image_feature_size():
  72. return _calculate_num_image_tokens(
  73. height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  74. width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  75. )
  76. def get_max_fuyu_image_tokens(ctx: InputContext):
  77. ncol, nrow = get_max_fuyu_image_feature_size()
  78. return (ncol + 1) * nrow
  79. def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
  80. ncol, nrow = get_max_fuyu_image_feature_size()
  81. image_feature_size = get_max_fuyu_image_tokens(ctx)
  82. image_token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
  83. token_ids = image_token_ids * num_images
  84. token_ids += [0] * (seq_len - image_feature_size * num_images)
  85. return SequenceData(token_ids)
  86. def dummy_image_for_fuyu(
  87. num_images: int,
  88. *,
  89. image_width: int,
  90. image_height: int,
  91. ):
  92. image = Image.new("RGB", (image_width, image_height), color=0)
  93. return {"image": image if num_images == 1 else [image] * num_images}
  94. def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
  95. mm_counts: Mapping[str, int]):
  96. num_images = mm_counts["image"]
  97. seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
  98. mm_data = dummy_image_for_fuyu(num_images,
  99. image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  100. image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
  101. return seq_data, mm_data
  102. def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
  103. data: Image.Image):
  104. image_encoding = image_processor.preprocess(data, return_tensors="pt")
  105. batch_images = torch.stack([img[0] for img in image_encoding["images"]
  106. ]).unsqueeze(1)
  107. image_unpadded_heights = torch.tensor(
  108. image_encoding["image_unpadded_heights"])
  109. image_unpadded_widths = torch.tensor(
  110. image_encoding["image_unpadded_widths"])
  111. batch_size = len(image_encoding["images"])
  112. image_present = torch.ones(batch_size, 1, 1)
  113. model_image_input = image_processor.preprocess_with_tokenizer_info(
  114. image_input=batch_images,
  115. image_present=image_present,
  116. image_unpadded_h=image_unpadded_heights,
  117. image_unpadded_w=image_unpadded_widths,
  118. image_placeholder_id=_IMAGE_TOKEN_ID,
  119. image_newline_id=_NEWLINE_TOKEN_ID,
  120. variable_sized=True,
  121. )
  122. return model_image_input
  123. def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
  124. multi_modal_data = llm_inputs.get("multi_modal_data")
  125. if multi_modal_data is None or "image" not in multi_modal_data:
  126. return llm_inputs
  127. model_config = ctx.model_config
  128. image_data = multi_modal_data["image"]
  129. new_multi_modal_data = {}
  130. # process image data
  131. if isinstance(image_data, Image.Image):
  132. # Fuyu's image_processor can also finish token padding
  133. image_processor: FuyuImageProcessor = cached_get_image_processor(
  134. model_config.model)
  135. model_image_input = _fuyu_image_preprocess(image_processor, image_data)
  136. image_patches = torch.stack([
  137. image_patch[0]
  138. for image_patch in model_image_input["image_patches"]
  139. ])
  140. new_multi_modal_data["image"] = image_patches
  141. elif isinstance(image_data, torch.Tensor):
  142. raise NotImplementedError("Embeddings input is not supported yet")
  143. else:
  144. raise TypeError(f"Invalid image type: {type(image_data)}")
  145. # process prompts
  146. prompt = llm_inputs.get("prompt")
  147. prompt_token_ids = llm_inputs["prompt_token_ids"]
  148. tokenizer = cached_get_tokenizer(model_config.model)
  149. # dim0 is batch_size, dim1 is subseq_size which will always be 1
  150. image_input_ids: List[List[
  151. torch.Tensor]] = model_image_input["image_input_ids"]
  152. image_input_ids = image_input_ids[0][0].tolist()
  153. bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
  154. boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
  155. new_prompt = prompt + "\x04"
  156. new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
  157. 1:] + boa_token
  158. return LLMInputs(prompt=new_prompt,
  159. prompt_token_ids=new_prompt_token_ids,
  160. multi_modal_data=new_multi_modal_data)
  161. def input_mapper_for_fuyu(ctx: InputContext, data: object):
  162. model_config = ctx.model_config
  163. if isinstance(data, Image.Image):
  164. # Fuyu's image_processor can also finish token padding
  165. image_processor: FuyuImageProcessor = cached_get_image_processor(
  166. model_config.model)
  167. model_image_input = _fuyu_image_preprocess(image_processor, data)
  168. data = torch.stack([
  169. image_patch[0]
  170. for image_patch in model_image_input["image_patches"]
  171. ])
  172. # image has been processed with prompt in input processor
  173. return MultiModalInputs({"image_patches": data})
  174. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
  175. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
  176. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
  177. @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
  178. class FuyuForCausalLM(nn.Module, SupportsMultiModal):
  179. def __init__(self,
  180. config: FuyuConfig,
  181. multimodal_config: MultiModalConfig,
  182. cache_config: Optional[CacheConfig] = None,
  183. quant_config: Optional[QuantizationConfig] = None) -> None:
  184. super().__init__()
  185. self.config = config
  186. self.multimodal_config = multimodal_config
  187. self.padding_idx = config.pad_token_id
  188. self.vocab_size = config.vocab_size
  189. self.image_token_id = _IMAGE_TOKEN_ID
  190. self.image_feature_size = config.patch_size**2 * config.num_channels
  191. self.vision_embed_tokens = ColumnParallelLinear(
  192. self.image_feature_size,
  193. config.hidden_size,
  194. quant_config=quant_config,
  195. )
  196. self.language_model = PersimmonForCausalLM(config,
  197. cache_config=cache_config,
  198. quant_config=quant_config)
  199. def _parse_and_validate_image_input(
  200. self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
  201. image_patches = kwargs.pop("image_patches", None)
  202. if isinstance(image_patches, torch.Tensor):
  203. expected_feature_size = self.image_feature_size
  204. if image_patches.size(-1) != expected_feature_size:
  205. raise ValueError(
  206. f"Expected image patches to have the last dimension of "
  207. f"{expected_feature_size}, got {image_patches.size(-1)}")
  208. image_patches = image_patches.to(
  209. self.vision_embed_tokens.weight.dtype)
  210. return FuyuImagePixelInputs(type="pixel_values",
  211. data=image_patches)
  212. return None
  213. def _process_image_input(
  214. self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
  215. assert self.vision_embed_tokens is not None
  216. vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
  217. return vision_embeddings
  218. def forward(
  219. self,
  220. input_ids: torch.Tensor,
  221. positions: torch.Tensor,
  222. kv_caches: List[torch.Tensor],
  223. attn_metadata: AttentionMetadata,
  224. intermediate_tensors: Optional[IntermediateTensors] = None,
  225. **kwargs: object,
  226. ):
  227. image_input = self._parse_and_validate_image_input(**kwargs)
  228. if image_input is not None:
  229. vision_embeddings = self._process_image_input(image_input)
  230. inputs_embeds = self.language_model.model.embed_tokens(input_ids)
  231. inputs_embeds = merge_multimodal_embeddings(
  232. input_ids, inputs_embeds, vision_embeddings,
  233. self.image_token_id)
  234. else:
  235. inputs_embeds = None
  236. hidden_states = self.language_model(
  237. input_ids=input_ids,
  238. positions=positions,
  239. kv_caches=kv_caches,
  240. attn_metadata=attn_metadata,
  241. inputs_embeds=inputs_embeds,
  242. )
  243. return hidden_states
  244. def compute_logits(
  245. self,
  246. hidden_states: torch.Tensor,
  247. sampling_metadata: SamplingMetadata,
  248. ) -> Optional[torch.Tensor]:
  249. logits = self.language_model.logits_processor(
  250. self.language_model.lm_head, hidden_states, sampling_metadata)
  251. return logits
  252. def sample(
  253. self,
  254. logits: torch.Tensor,
  255. sampling_metadata: SamplingMetadata,
  256. ) -> Optional[SamplerOutput]:
  257. next_tokens = self.language_model.sampler(logits, sampling_metadata)
  258. return next_tokens
  259. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  260. params_dict = dict(self.named_parameters(remove_duplicate=False))
  261. weights_list = list(weights)
  262. for name, loaded_weight in progress_bar(weights_list,
  263. desc="Loading modules..."):
  264. if "rotary_emb.inv_freq" in name:
  265. continue
  266. if ("rotary_emb.cos_cached" in name
  267. or "rotary_emb.sin_cached" in name):
  268. # Models trained using ColossalAI may include these tensors in
  269. # the checkpoint. Skip them.
  270. continue
  271. param = params_dict[name]
  272. if "query_key_value" in name:
  273. # copy from vllm/model_executor/models/bloom.py
  274. # NOTE: Fuyu's fused QKV's output_dim has the shape of
  275. # (num_heads * 3 * head_size), while the
  276. # required shape is (3 * num_heads * head_size).
  277. # Thus, we need weight conversion.
  278. output_dim = getattr(param, "output_dim", None)
  279. num_heads = self.config.num_attention_heads
  280. if output_dim is not None:
  281. loaded_weight_shape = loaded_weight.shape
  282. loaded_weight = loaded_weight.view(
  283. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  284. loaded_weight_shape[output_dim + 1:])
  285. loaded_weight = loaded_weight.transpose(
  286. output_dim, output_dim + 1)
  287. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  288. weight_loader = getattr(param, "weight_loader",
  289. default_weight_loader)
  290. weight_loader(param, loaded_weight)