1
0

fuyu.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # coding=utf-8
  2. # adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py
  3. # Copyright 2023 The vLLM team.
  4. # Copyright 2023 HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """ PyTorch Fuyu model."""
  18. import math
  19. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
  20. import torch
  21. import torch.nn as nn
  22. import torch.utils.checkpoint
  23. from PIL import Image
  24. from transformers import FuyuConfig, FuyuImageProcessor
  25. from aphrodite.attention import AttentionMetadata
  26. from aphrodite.common.config import CacheConfig, MultiModalConfig
  27. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  28. SequenceData)
  29. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  30. from aphrodite.modeling.layers.linear import ColumnParallelLinear
  31. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  32. from aphrodite.modeling.models.persimmon import PersimmonForCausalLM
  33. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  34. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  35. from aphrodite.multimodal.base import MultiModalInputs
  36. from aphrodite.multimodal.image import (cached_get_image_processor,
  37. cached_get_tokenizer)
  38. from aphrodite.quantization.base_config import QuantizationConfig
  39. from .interfaces import SupportsVision
  40. from .utils import merge_vision_embeddings
  41. # Cannot find the following 2 numbers from hf config.
  42. _IMAGE_TOKEN_ID = 71011
  43. _NEWLINE_TOKEN_ID = 71019
  44. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
  45. MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
  46. class FuyuImagePixelInputs(TypedDict):
  47. type: Literal["pixel_values"]
  48. data: torch.Tensor
  49. """
  50. Shape:
  51. (batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
  52. """
  53. def _calculate_num_image_tokens(
  54. height: int,
  55. width: int,
  56. ) -> Tuple[int, int]:
  57. """
  58. calculate number of image tokens needed for a given image size
  59. The expected Fuyu image prompts is in format:
  60. (image_token * ncols + newline_token) * nrows
  61. args:
  62. image_size: Tuple[int, int] - (width, height) of the image
  63. returns:
  64. ncols: int - number of image tokens in x direction
  65. nrows: int - number of image tokens in y direction
  66. """
  67. ncol = math.ceil(width / 30)
  68. nrow = math.ceil(height / 30)
  69. return ncol, nrow
  70. def get_max_fuyu_image_feature_size():
  71. return _calculate_num_image_tokens(
  72. height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  73. width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  74. )
  75. def get_max_fuyu_image_tokens(ctx: InputContext):
  76. ncol, nrow = get_max_fuyu_image_feature_size()
  77. return (ncol + 1) * nrow
  78. def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int):
  79. ncol, nrow = get_max_fuyu_image_feature_size()
  80. image_feature_size = get_max_fuyu_image_tokens(ctx)
  81. token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
  82. token_ids += [0] * (seq_len - image_feature_size)
  83. return SequenceData(token_ids)
  84. def dummy_image_for_fuyu(
  85. image_width: int,
  86. image_height: int,
  87. ):
  88. image = Image.new("RGB", (image_width, image_height), color=0)
  89. return {"image": image}
  90. def dummy_data_for_fuyu(ctx: InputContext, seq_len: int):
  91. seq_data = dummy_seq_data_for_fuyu(ctx, seq_len)
  92. mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH,
  93. MAX_IMAGE_FEATURE_SIZE_HEIGHT)
  94. return seq_data, mm_data
  95. def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
  96. data: Image.Image):
  97. image_encoding = image_processor.preprocess(data, return_tensors="pt")
  98. batch_images = torch.stack([img[0] for img in image_encoding["images"]
  99. ]).unsqueeze(1)
  100. image_unpadded_heights = torch.tensor(
  101. image_encoding["image_unpadded_heights"])
  102. image_unpadded_widths = torch.tensor(
  103. image_encoding["image_unpadded_widths"])
  104. batch_size = len(image_encoding["images"])
  105. image_present = torch.ones(batch_size, 1, 1)
  106. model_image_input = image_processor.preprocess_with_tokenizer_info(
  107. image_input=batch_images,
  108. image_present=image_present,
  109. image_unpadded_h=image_unpadded_heights,
  110. image_unpadded_w=image_unpadded_widths,
  111. image_placeholder_id=_IMAGE_TOKEN_ID,
  112. image_newline_id=_NEWLINE_TOKEN_ID,
  113. variable_sized=True,
  114. )
  115. return model_image_input
  116. def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
  117. multi_modal_data = llm_inputs.get("multi_modal_data")
  118. if multi_modal_data is None or "image" not in multi_modal_data:
  119. return llm_inputs
  120. model_config = ctx.model_config
  121. image_data = multi_modal_data["image"]
  122. new_multi_modal_data = {}
  123. # process image data
  124. if isinstance(image_data, Image.Image):
  125. # Fuyu's image_processor can also finish token padding
  126. image_processor: FuyuImageProcessor = cached_get_image_processor(
  127. model_config.model)
  128. model_image_input = _fuyu_image_preprocess(image_processor, image_data)
  129. image_patches = torch.stack([
  130. image_patch[0]
  131. for image_patch in model_image_input["image_patches"]
  132. ])
  133. new_multi_modal_data["image"] = image_patches
  134. elif isinstance(image_data, torch.Tensor):
  135. raise NotImplementedError("Embeddings input is not supported yet")
  136. else:
  137. raise TypeError(f"Invalid image type: {type(image_data)}")
  138. # process prompts
  139. prompt = llm_inputs.get("prompt")
  140. prompt_token_ids = llm_inputs["prompt_token_ids"]
  141. tokenizer = cached_get_tokenizer(model_config.model)
  142. # dim0 is batch_size, dim1 is subseq_size which will always be 1
  143. image_input_ids: List[List[
  144. torch.Tensor]] = model_image_input["image_input_ids"]
  145. image_input_ids = image_input_ids[0][0].tolist()
  146. bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
  147. boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
  148. new_prompt = prompt + "\x04"
  149. new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
  150. 1:] + boa_token
  151. return LLMInputs(prompt=new_prompt,
  152. prompt_token_ids=new_prompt_token_ids,
  153. multi_modal_data=new_multi_modal_data)
  154. def input_mapper_for_fuyu(ctx: InputContext, data: object):
  155. model_config = ctx.model_config
  156. if isinstance(data, Image.Image):
  157. # Fuyu's image_processor can also finish token padding
  158. image_processor: FuyuImageProcessor = cached_get_image_processor(
  159. model_config.model)
  160. model_image_input = _fuyu_image_preprocess(image_processor, data)
  161. data = torch.stack([
  162. image_patch[0]
  163. for image_patch in model_image_input["image_patches"]
  164. ])
  165. # image has been processed with prompt in input processor
  166. return MultiModalInputs({"image_patches": data})
  167. @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
  168. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
  169. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
  170. @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
  171. class FuyuForCausalLM(nn.Module, SupportsVision):
  172. def __init__(self,
  173. config: FuyuConfig,
  174. multimodal_config: MultiModalConfig,
  175. cache_config: Optional[CacheConfig] = None,
  176. quant_config: Optional[QuantizationConfig] = None) -> None:
  177. super().__init__()
  178. self.config = config
  179. self.multimodal_config = multimodal_config
  180. self.padding_idx = config.pad_token_id
  181. self.vocab_size = config.vocab_size
  182. self.image_token_id = _IMAGE_TOKEN_ID
  183. self.image_feature_size = config.patch_size**2 * config.num_channels
  184. self.vision_embed_tokens = ColumnParallelLinear(
  185. self.image_feature_size,
  186. config.hidden_size,
  187. quant_config=quant_config,
  188. )
  189. self.language_model = PersimmonForCausalLM(config,
  190. cache_config=cache_config,
  191. quant_config=quant_config)
  192. def _parse_and_validate_image_input(self, **kwargs: object):
  193. image_patches = kwargs.pop("image_patches", None)
  194. if isinstance(image_patches, torch.Tensor):
  195. expected_feature_size = self.image_feature_size
  196. if image_patches.size(-1) != expected_feature_size:
  197. raise ValueError(
  198. f"Expected image patches to have the last dimension of "
  199. f"{expected_feature_size}, got {image_patches.size(-1)}")
  200. image_patches = image_patches.to(
  201. self.vision_embed_tokens.weight.dtype)
  202. return FuyuImagePixelInputs(type="pixel_values",
  203. data=image_patches)
  204. return None
  205. def forward(
  206. self,
  207. input_ids: torch.Tensor,
  208. positions: torch.Tensor,
  209. kv_caches: List[torch.Tensor],
  210. attn_metadata: AttentionMetadata,
  211. intermediate_tensors: Optional[IntermediateTensors] = None,
  212. **kwargs: object,
  213. ):
  214. image_input = self._parse_and_validate_image_input(**kwargs)
  215. if image_input is not None:
  216. vision_embeddings, _ = self.vision_embed_tokens(
  217. image_input["data"])
  218. inputs_embeds = self.language_model.model.embed_tokens(input_ids)
  219. inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
  220. vision_embeddings,
  221. self.image_token_id)
  222. else:
  223. inputs_embeds = None
  224. hidden_states = self.language_model(
  225. input_ids=input_ids,
  226. positions=positions,
  227. kv_caches=kv_caches,
  228. attn_metadata=attn_metadata,
  229. inputs_embeds=inputs_embeds,
  230. )
  231. return hidden_states
  232. def compute_logits(self, hidden_states: torch.Tensor,
  233. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  234. logits = self.language_model.logits_processor(
  235. self.language_model.lm_head, hidden_states, sampling_metadata)
  236. return logits
  237. def sample(
  238. self,
  239. logits: torch.Tensor,
  240. sampling_metadata: SamplingMetadata,
  241. ) -> Optional[SamplerOutput]:
  242. next_tokens = self.language_model.sampler(logits, sampling_metadata)
  243. return next_tokens
  244. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  245. params_dict = dict(self.named_parameters(remove_duplicate=False))
  246. for name, loaded_weight in weights:
  247. if "rotary_emb.inv_freq" in name:
  248. continue
  249. if ("rotary_emb.cos_cached" in name
  250. or "rotary_emb.sin_cached" in name):
  251. # Models trained using ColossalAI may include these tensors in
  252. # the checkpoint. Skip them.
  253. continue
  254. param = params_dict[name]
  255. if "query_key_value" in name:
  256. # copy from vllm/model_executor/models/bloom.py
  257. # NOTE: Fuyu's fused QKV's output_dim has the shape of
  258. # (num_heads * 3 * head_size), while the
  259. # required shape is (3 * num_heads * head_size).
  260. # Thus, we need weight conversion.
  261. output_dim = getattr(param, "output_dim", None)
  262. num_heads = self.config.num_attention_heads
  263. if output_dim is not None:
  264. loaded_weight_shape = loaded_weight.shape
  265. loaded_weight = loaded_weight.view(
  266. loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
  267. loaded_weight_shape[output_dim + 1:])
  268. loaded_weight = loaded_weight.transpose(
  269. output_dim, output_dim + 1)
  270. loaded_weight = loaded_weight.reshape(loaded_weight_shape)
  271. weight_loader = getattr(param, "weight_loader",
  272. default_weight_loader)
  273. weight_loader(param, loaded_weight)