fuyu.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # coding=utf-8
  2. # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
  3. # Copyright 2023 The vLLM team.
  4. # Copyright 2023 HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """ PyTorch Fuyu model."""
  18. import math
  19. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  20. import torch
  21. import torch.nn as nn
  22. import torch.utils.checkpoint
  23. from PIL import Image
  24. from transformers import FuyuConfig, FuyuImageProcessor
  25. from aphrodite.attention import AttentionMetadata
  26. from aphrodite.common.config import CacheConfig, MultiModalConfig
  27. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  28. SequenceData)
  29. from aphrodite.common.utils import progress_bar
  30. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  31. from aphrodite.modeling.layers.linear import ColumnParallelLinear
  32. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  33. from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
  34. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  35. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  36. from aphrodite.multimodal.base import MultiModalInputs
  37. from aphrodite.multimodal.image import (cached_get_image_processor,
  38. cached_get_tokenizer)
  39. from aphrodite.quantization.base_config import QuantizationConfig
  40. from .interfaces import SupportsMultiModal
  41. from .utils import merge_multimodal_embeddings
  42. # Cannot find the following 2 numbers from hf config.
  43. _IMAGE_TOKEN_ID = 71011
  44. _NEWLINE_TOKEN_ID = 71019
  45. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
  46. MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
  47. class FuyuImagePixelInputs(TypedDict):
  48. type: Literal["pixel_values"]
  49. data: torch.Tensor
  50. """
  51. Shape:
  52. (batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
  53. """
  54. def _calculate_num_image_tokens(
  55. height: int,
  56. width: int,
  57. ) -> Tuple[int, int]:
  58. """
  59. calculate number of image tokens needed for a given image size
  60. The expected Fuyu image prompts is in format:
  61. (image_token * ncols + newline_token) * nrows
  62. args:
  63. image_size: Tuple[int, int] - (width, height) of the image
  64. returns:
  65. ncols: int - number of image tokens in x direction
  66. nrows: int - number of image tokens in y direction
  67. """
  68. ncol = math.ceil(width / 30)
  69. nrow = math.ceil(height / 30)
  70. return ncol, nrow
  71. def get_max_fuyu_image_feature_size():
  72. return _calculate_num_image_tokens(
  73. height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  74. width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  75. )
  76. def get_max_fuyu_image_tokens(ctx: InputContext):
  77. ncol, nrow = get_max_fuyu_image_feature_size()
  78. return (ncol + 1) * nrow
  79. def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int):
  80. ncol, nrow = get_max_fuyu_image_feature_size()
  81. image_feature_size = get_max_fuyu_image_tokens(ctx)
  82. token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
  83. token_ids += [0] * (seq_len - image_feature_size)
  84. return SequenceData(token_ids)
  85. def dummy_image_for_fuyu(
  86. image_width: int,
  87. image_height: int,
  88. ):
  89. image = Image.new("RGB", (image_width, image_height), color=0)
  90. return {"image": image}
  91. def dummy_data_for_fuyu(ctx: InputContext, seq_len: int):
  92. seq_data = dummy_seq_data_for_fuyu(ctx, seq_len)
  93. mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH,
  94. MAX_IMAGE_FEATURE_SIZE_HEIGHT)
  95. return seq_data, mm_data
  96. def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
  97. data: Image.Image):
  98. image_encoding = image_processor.preprocess(data, return_tensors="pt")
  99. batch_images = torch.stack([img[0] for img in image_encoding["images"]
  100. ]).unsqueeze(1)
  101. image_unpadded_heights = torch.tensor(
  102. image_encoding["image_unpadded_heights"])
  103. image_unpadded_widths = torch.tensor(
  104. image_encoding["image_unpadded_widths"])
  105. batch_size = len(image_encoding["images"])
  106. image_present = torch.ones(batch_size, 1, 1)
  107. model_image_input = image_processor.preprocess_with_tokenizer_info(
  108. image_input=batch_images,
  109. image_present=image_present,
  110. image_unpadded_h=image_unpadded_heights,
  111. image_unpadded_w=image_unpadded_widths,
  112. image_placeholder_id=_IMAGE_TOKEN_ID,
  113. image_newline_id=_NEWLINE_TOKEN_ID,
  114. variable_sized=True,
  115. )
  116. return model_image_input
  117. def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
  118. multi_modal_data = llm_inputs.get("multi_modal_data")
  119. if multi_modal_data is None or "image" not in multi_modal_data:
  120. return llm_inputs
  121. model_config = ctx.model_config
  122. image_data = multi_modal_data["image"]
  123. new_multi_modal_data = {}
  124. # process image data
  125. if isinstance(image_data, Image.Image):
  126. # Fuyu's image_processor can also finish token padding
  127. image_processor: FuyuImageProcessor = cached_get_image_processor(
  128. model_config.model)
  129. model_image_input = _fuyu_image_preprocess(image_processor, image_data)
  130. image_patches = torch.stack([
  131. image_patch[0]
  132. for image_patch in model_image_input["image_patches"]
  133. ])
  134. new_multi_modal_data["image"] = image_patches
  135. elif isinstance(image_data, torch.Tensor):
  136. raise NotImplementedError("Embeddings input is not supported yet")
  137. else:
  138. raise TypeError(f"Invalid image type: {type(image_data)}")
  139. # process prompts
  140. prompt = llm_inputs.get("prompt")
  141. prompt_token_ids = llm_inputs["prompt_token_ids"]
  142. tokenizer = cached_get_tokenizer(model_config.model)
  143. # dim0 is batch_size, dim1 is subseq_size which will always be 1
  144. image_input_ids: List[List[
  145. torch.Tensor]] = model_image_input["image_input_ids"]
  146. image_input_ids = image_input_ids[0][0].tolist()
  147. bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
  148. boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
  149. new_prompt = prompt + "\x04"
  150. new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
  151. 1:] + boa_token
  152. return LLMInputs(prompt=new_prompt,
  153. prompt_token_ids=new_prompt_token_ids,
  154. multi_modal_data=new_multi_modal_data)
  155. def input_mapper_for_fuyu(ctx: InputContext, data: object):
  156. model_config = ctx.model_config
  157. if isinstance(data, Image.Image):
  158. # Fuyu's image_processor can also finish token padding
  159. image_processor: FuyuImageProcessor = cached_get_image_processor(
  160. model_config.model)
  161. model_image_input = _fuyu_image_preprocess(image_processor, data)
  162. data = torch.stack([
  163. image_patch[0]
  164. for image_patch in model_image_input["image_patches"]
  165. ])
  166. # image has been processed with prompt in input processor
  167. return MultiModalInputs({"image_patches": data})
  168. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
  169. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
  170. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
  171. @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
  172. class FuyuForCausalLM(nn.Module, SupportsMultiModal):
  173. def __init__(self,
  174. config: FuyuConfig,
  175. multimodal_config: MultiModalConfig,
  176. cache_config: Optional[CacheConfig] = None,
  177. quant_config: Optional[QuantizationConfig] = None) -> None:
  178. super().__init__()
  179. self.config = config
  180. self.multimodal_config = multimodal_config
  181. self.padding_idx = config.pad_token_id
  182. self.vocab_size = config.vocab_size
  183. self.image_token_id = _IMAGE_TOKEN_ID
  184. self.image_feature_size = config.patch_size**2 * config.num_channels
  185. self.vision_embed_tokens = ColumnParallelLinear(
  186. self.image_feature_size,
  187. config.hidden_size,
  188. quant_config=quant_config,
  189. )
  190. self.language_model = PersimmonForCausalLM(config,
  191. cache_config=cache_config,
  192. quant_config=quant_config)
  193. def _parse_and_validate_image_input(
  194. self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
  195. image_patches = kwargs.pop("image_patches", None)
  196. if isinstance(image_patches, torch.Tensor):
  197. expected_feature_size = self.image_feature_size
  198. if image_patches.size(-1) != expected_feature_size:
  199. raise ValueError(
  200. f"Expected image patches to have the last dimension of "
  201. f"{expected_feature_size}, got {image_patches.size(-1)}")
  202. image_patches = image_patches.to(
  203. self.vision_embed_tokens.weight.dtype)
  204. return FuyuImagePixelInputs(type="pixel_values",
  205. data=image_patches)
  206. return None
  207. def _process_image_input(
  208. self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
  209. assert self.vision_embed_tokens is not None
  210. vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
  211. return vision_embeddings
  212. def forward(
  213. self,
  214. input_ids: torch.Tensor,
  215. positions: torch.Tensor,
  216. kv_caches: List[torch.Tensor],
  217. attn_metadata: AttentionMetadata,
  218. intermediate_tensors: Optional[IntermediateTensors] = None,
  219. **kwargs: object,
  220. ):
  221. image_input = self._parse_and_validate_image_input(**kwargs)
  222. if image_input is not None:
  223. vision_embeddings = self._process_image_input(image_input)
  224. inputs_embeds = self.language_model.model.embed_tokens(input_ids)
  225. inputs_embeds = merge_multimodal_embeddings(
  226. input_ids, inputs_embeds, vision_embeddings,
  227. self.image_token_id)
  228. else:
  229. inputs_embeds = None
  230. hidden_states = self.language_model(
  231. input_ids=input_ids,
  232. positions=positions,
  233. kv_caches=kv_caches,
  234. attn_metadata=attn_metadata,
  235. inputs_embeds=inputs_embeds,
  236. )
  237. return hidden_states
  238. def compute_logits(
  239. self,
  240. hidden_states: torch.Tensor,
  241. sampling_metadata: SamplingMetadata,
  242. ) -> Optional[torch.Tensor]:
  243. logits = self.language_model.logits_processor(
  244. self.language_model.lm_head, hidden_states, sampling_metadata)
  245. return logits
  246. def sample(
  247. self,
  248. logits: torch.Tensor,
  249. sampling_metadata: SamplingMetadata,
  250. ) -> Optional[SamplerOutput]:
  251. next_tokens = self.language_model.sampler(logits, sampling_metadata)
  252. return next_tokens
  253. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  254. params_dict = dict(self.named_parameters(remove_duplicate=False))
  255. weights_list = list(weights)
  256. for name, loaded_weight in progress_bar(weights_list,
  257. desc="Loading modules..."):
  258. if "rotary_emb.inv_freq" in name:
  259. continue
  260. if ("rotary_emb.cos_cached" in name
  261. or "rotary_emb.sin_cached" in name):
  262. # Models trained using ColossalAI may include these tensors in
  263. # the checkpoint. Skip them.
  264. continue
  265. param = params_dict[name]
  266. if "query_key_value" in name:
  267. # copy from vllm/model_executor/models/bloom.py
  268. # NOTE: Fuyu's fused QKV's output_dim has the shape of
  269. # (num_heads * 3 * head_size), while the
  270. # required shape is (3 * num_heads * head_size).
  271. # Thus, we need weight conversion.
  272. output_dim = getattr(param, "output_dim", None)
  273. num_heads = self.config.num_attention_heads
  274. if output_dim is not None:
  275. loaded_weight_shape = loaded_weight.shape
  276. loaded_weight = loaded_weight.view(
  277. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  278. loaded_weight_shape[output_dim + 1:])
  279. loaded_weight = loaded_weight.transpose(
  280. output_dim, output_dim + 1)
  281. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  282. weight_loader = getattr(param, "weight_loader",
  283. default_weight_loader)
  284. weight_loader(param, loaded_weight)