fuyu.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. # coding=utf-8
  2. # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
  3. # Copyright 2023 The vLLM team.
  4. # Copyright 2023 HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """ PyTorch Fuyu model."""
  18. import math
  19. from array import array
  20. from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
  21. import torch
  22. import torch.nn as nn
  23. import torch.utils.checkpoint
  24. from PIL import Image
  25. from transformers import FuyuConfig, FuyuImageProcessor
  26. from aphrodite.attention import AttentionMetadata
  27. from aphrodite.common.config import CacheConfig, MultiModalConfig
  28. from aphrodite.common.sequence import IntermediateTensors, SequenceData
  29. from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
  30. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  31. from aphrodite.modeling.layers.linear import ColumnParallelLinear
  32. from aphrodite.modeling.layers.sampler import SamplerOutput
  33. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  34. from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
  35. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  36. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  37. from aphrodite.multimodal.base import MultiModalInputs
  38. from aphrodite.multimodal.image import cached_get_image_processor
  39. from aphrodite.multimodal.utils import cached_get_tokenizer
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. from .interfaces import SupportsMultiModal
  42. from .utils import merge_multimodal_embeddings
  43. # Cannot find the following 2 numbers from hf config.
  44. _IMAGE_TOKEN_ID = 71011
  45. _NEWLINE_TOKEN_ID = 71019
  46. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
  47. MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
  48. class FuyuImagePixelInputs(TypedDict):
  49. type: Literal["pixel_values"]
  50. data: torch.Tensor
  51. """
  52. Shape:
  53. (batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
  54. """
  55. def _calculate_num_image_tokens(
  56. height: int,
  57. width: int,
  58. ) -> Tuple[int, int]:
  59. """
  60. calculate number of image tokens needed for a given image size
  61. The expected Fuyu image prompts is in format:
  62. (image_token * ncols + newline_token) * nrows
  63. args:
  64. image_size: Tuple[int, int] - (width, height) of the image
  65. returns:
  66. ncols: int - number of image tokens in x direction
  67. nrows: int - number of image tokens in y direction
  68. """
  69. ncol = math.ceil(width / 30)
  70. nrow = math.ceil(height / 30)
  71. return ncol, nrow
  72. def get_max_fuyu_image_feature_size():
  73. return _calculate_num_image_tokens(
  74. height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  75. width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  76. )
  77. def get_max_fuyu_image_tokens(ctx: InputContext):
  78. ncol, nrow = get_max_fuyu_image_feature_size()
  79. return (ncol + 1) * nrow
  80. def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
  81. ncol, nrow = get_max_fuyu_image_feature_size()
  82. image_feature_size = get_max_fuyu_image_tokens(ctx)
  83. image_token_ids = (
  84. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
  85. array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
  86. token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  87. image_token_ids) * num_images
  88. token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
  89. [0]) * (seq_len - image_feature_size * num_images)
  90. return SequenceData(token_ids)
  91. def dummy_image_for_fuyu(
  92. num_images: int,
  93. *,
  94. image_width: int,
  95. image_height: int,
  96. ):
  97. image = Image.new("RGB", (image_width, image_height), color=0)
  98. return {"image": image if num_images == 1 else [image] * num_images}
  99. def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
  100. mm_counts: Mapping[str, int]):
  101. num_images = mm_counts["image"]
  102. seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
  103. mm_data = dummy_image_for_fuyu(num_images,
  104. image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  105. image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
  106. return seq_data, mm_data
  107. def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
  108. data: Image.Image):
  109. image_encoding = image_processor.preprocess(data, return_tensors="pt")
  110. batch_images = torch.stack([img[0] for img in image_encoding["images"]
  111. ]).unsqueeze(1)
  112. image_unpadded_heights = torch.tensor(
  113. image_encoding["image_unpadded_heights"])
  114. image_unpadded_widths = torch.tensor(
  115. image_encoding["image_unpadded_widths"])
  116. batch_size = len(image_encoding["images"])
  117. image_present = torch.ones(batch_size, 1, 1)
  118. model_image_input = image_processor.preprocess_with_tokenizer_info(
  119. image_input=batch_images,
  120. image_present=image_present,
  121. image_unpadded_h=image_unpadded_heights,
  122. image_unpadded_w=image_unpadded_widths,
  123. image_placeholder_id=_IMAGE_TOKEN_ID,
  124. image_newline_id=_NEWLINE_TOKEN_ID,
  125. variable_sized=True,
  126. )
  127. return model_image_input
  128. def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
  129. multi_modal_data = llm_inputs.get("multi_modal_data")
  130. if multi_modal_data is None or "image" not in multi_modal_data:
  131. return llm_inputs
  132. model_config = ctx.model_config
  133. image_data = multi_modal_data["image"]
  134. new_multi_modal_data = {}
  135. # process image data
  136. if isinstance(image_data, Image.Image):
  137. # Fuyu's image_processor can also finish token padding
  138. image_processor: FuyuImageProcessor = cached_get_image_processor(
  139. model_config.model)
  140. model_image_input = _fuyu_image_preprocess(image_processor, image_data)
  141. image_patches = torch.stack([
  142. image_patch[0]
  143. for image_patch in model_image_input["image_patches"]
  144. ])
  145. new_multi_modal_data["image"] = image_patches
  146. elif isinstance(image_data, torch.Tensor):
  147. raise NotImplementedError("Embeddings input is not supported yet")
  148. else:
  149. raise TypeError(f"Invalid image type: {type(image_data)}")
  150. # process prompts
  151. prompt = llm_inputs.get("prompt")
  152. prompt_token_ids = llm_inputs["prompt_token_ids"]
  153. tokenizer = cached_get_tokenizer(model_config.model)
  154. # dim0 is batch_size, dim1 is subseq_size which will always be 1
  155. image_input_ids: List[List[
  156. torch.Tensor]] = model_image_input["image_input_ids"]
  157. image_input_ids = image_input_ids[0][0].tolist()
  158. bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
  159. boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
  160. new_prompt = prompt + "\x04"
  161. new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
  162. 1:] + boa_token
  163. return LLMInputs(prompt=new_prompt,
  164. prompt_token_ids=new_prompt_token_ids,
  165. multi_modal_data=new_multi_modal_data)
  166. def input_mapper_for_fuyu(ctx: InputContext, data: object):
  167. model_config = ctx.model_config
  168. if isinstance(data, Image.Image):
  169. # Fuyu's image_processor can also finish token padding
  170. image_processor: FuyuImageProcessor = cached_get_image_processor(
  171. model_config.model)
  172. model_image_input = _fuyu_image_preprocess(image_processor, data)
  173. data = torch.stack([
  174. image_patch[0]
  175. for image_patch in model_image_input["image_patches"]
  176. ])
  177. # image has been processed with prompt in input processor
  178. return MultiModalInputs({"image_patches": data})
  179. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
  180. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
  181. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
  182. @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
  183. class FuyuForCausalLM(nn.Module, SupportsMultiModal):
  184. def __init__(self,
  185. config: FuyuConfig,
  186. multimodal_config: MultiModalConfig,
  187. cache_config: Optional[CacheConfig] = None,
  188. quant_config: Optional[QuantizationConfig] = None) -> None:
  189. super().__init__()
  190. self.config = config
  191. self.multimodal_config = multimodal_config
  192. self.padding_idx = config.pad_token_id
  193. self.vocab_size = config.text_config.vocab_size
  194. self.image_token_id = _IMAGE_TOKEN_ID
  195. self.image_feature_size = config.patch_size**2 * config.num_channels
  196. self.vision_embed_tokens = ColumnParallelLinear(
  197. self.image_feature_size,
  198. config.hidden_size,
  199. quant_config=quant_config,
  200. )
  201. self.language_model = PersimmonForCausalLM(config,
  202. cache_config=cache_config,
  203. quant_config=quant_config)
  204. def _parse_and_validate_image_input(
  205. self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
  206. image_patches = kwargs.pop("image_patches", None)
  207. if isinstance(image_patches, torch.Tensor):
  208. # Remove the N dimension until multiple images are supported.
  209. image_patches = image_patches.squeeze(1)
  210. expected_feature_size = self.image_feature_size
  211. if image_patches.size(-1) != expected_feature_size:
  212. raise ValueError(
  213. f"Expected image patches to have the last dimension of "
  214. f"{expected_feature_size}, got {image_patches.size(-1)}")
  215. image_patches = image_patches.to(
  216. self.vision_embed_tokens.weight.dtype)
  217. return FuyuImagePixelInputs(type="pixel_values",
  218. data=image_patches)
  219. return None
  220. def _process_image_input(
  221. self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
  222. assert self.vision_embed_tokens is not None
  223. vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
  224. return vision_embeddings
  225. def forward(
  226. self,
  227. input_ids: torch.Tensor,
  228. positions: torch.Tensor,
  229. kv_caches: List[torch.Tensor],
  230. attn_metadata: AttentionMetadata,
  231. intermediate_tensors: Optional[IntermediateTensors] = None,
  232. **kwargs: object,
  233. ):
  234. image_input = self._parse_and_validate_image_input(**kwargs)
  235. if image_input is not None:
  236. vision_embeddings = self._process_image_input(image_input)
  237. inputs_embeds = self.language_model.model.embed_tokens(input_ids)
  238. inputs_embeds = merge_multimodal_embeddings(
  239. input_ids, inputs_embeds, vision_embeddings,
  240. self.image_token_id)
  241. else:
  242. inputs_embeds = None
  243. hidden_states = self.language_model(
  244. input_ids=input_ids,
  245. positions=positions,
  246. kv_caches=kv_caches,
  247. attn_metadata=attn_metadata,
  248. inputs_embeds=inputs_embeds,
  249. )
  250. return hidden_states
  251. def compute_logits(
  252. self,
  253. hidden_states: torch.Tensor,
  254. sampling_metadata: SamplingMetadata,
  255. ) -> Optional[torch.Tensor]:
  256. logits = self.language_model.logits_processor(
  257. self.language_model.lm_head, hidden_states, sampling_metadata)
  258. return logits
  259. def sample(
  260. self,
  261. logits: torch.Tensor,
  262. sampling_metadata: SamplingMetadata,
  263. ) -> Optional[SamplerOutput]:
  264. next_tokens = self.language_model.sampler(logits, sampling_metadata)
  265. return next_tokens
  266. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  267. params_dict = dict(self.named_parameters(remove_duplicate=False))
  268. for name, loaded_weight in weights:
  269. if "rotary_emb.inv_freq" in name:
  270. continue
  271. if ("rotary_emb.cos_cached" in name
  272. or "rotary_emb.sin_cached" in name):
  273. # Models trained using ColossalAI may include these tensors in
  274. # the checkpoint. Skip them.
  275. continue
  276. param = params_dict[name]
  277. if "query_key_value" in name:
  278. # copy from vllm/model_executor/models/bloom.py
  279. # NOTE: Fuyu's fused QKV's output_dim has the shape of
  280. # (num_heads * 3 * head_size), while the
  281. # required shape is (3 * num_heads * head_size).
  282. # Thus, we need weight conversion.
  283. output_dim = getattr(param, "output_dim", None)
  284. num_heads = self.config.num_attention_heads
  285. if output_dim is not None:
  286. loaded_weight_shape = loaded_weight.shape
  287. loaded_weight = loaded_weight.view(
  288. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  289. loaded_weight_shape[output_dim + 1:])
  290. loaded_weight = loaded_weight.transpose(
  291. output_dim, output_dim + 1)
  292. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  293. weight_loader = getattr(param, "weight_loader",
  294. default_weight_loader)
  295. weight_loader(param, loaded_weight)