phi3v.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. # coding=utf-8
  2. # Copyright 2024 The PygmalionAI team.
  3. # Copyright 2024 The vLLM team.
  4. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import re
  18. from functools import lru_cache
  19. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  20. TypedDict, Union)
  21. import numpy as np
  22. import torch
  23. import torch.nn as nn
  24. from loguru import logger
  25. from PIL import Image
  26. from transformers import CLIPVisionConfig, PretrainedConfig
  27. from aphrodite.attention import AttentionMetadata
  28. from aphrodite.common.config import CacheConfig, ModelConfig, MultiModalConfig
  29. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  30. from aphrodite.common.utils import progress_bar
  31. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  32. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  33. from aphrodite.modeling.layers.sampler import Sampler
  34. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  35. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  36. from aphrodite.modeling.models.clip import CLIPVisionModel
  37. from aphrodite.modeling.models.llama import LlamaModel
  38. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  39. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  40. from aphrodite.multimodal.image import cached_get_tokenizer
  41. from aphrodite.quantization.base_config import QuantizationConfig
  42. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  43. input_processor_for_clip)
  44. from .interfaces import SupportsMultiModal
  45. from .utils import merge_multimodal_embeddings
  46. _KEYS_TO_MODIFY_MAPPING = {
  47. "model.vision_embed_tokens": "vision_embed_tokens",
  48. }
  49. # Cannot find the following 2 numbers from hf config.
  50. _IMAGE_TOKEN_ID = 32044
  51. # Result in the max possible feature size (h:w = 16:1)
  52. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
  53. MAX_IMAGE_FEATURE_SIZE_WIDTH = 50
  54. CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
  55. hidden_act="quick_gelu",
  56. hidden_size=1024,
  57. image_size=336,
  58. intermediate_size=4096,
  59. num_attention_heads=16,
  60. num_channels=3,
  61. num_hidden_layers=24,
  62. patch_size=14,
  63. projection_dim=768)
  64. class Phi3VImagePixelInputs(TypedDict):
  65. type: Literal["pixel_values"]
  66. data: Union[torch.Tensor, List[torch.Tensor]]
  67. """
  68. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  69. Note that `num_patches` may be different for each batch, in which case
  70. the data is passed as a list instead of a batched tensor.
  71. """
  72. image_sizes: torch.Tensor
  73. """
  74. Shape: `(batch_size, 2)`
  75. This should be in `(height, width)` format.
  76. """
  77. class Phi3VImageEmbeddingInputs(TypedDict):
  78. type: Literal["image_embeds"]
  79. data: Union[torch.Tensor, List[torch.Tensor]]
  80. """Shape: `(batch_size, image_feature_size, hidden_size)`
  81. `hidden_size` must match the hidden size of language model backbone.
  82. """
  83. Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
  84. class Phi3ImageEmbeddingBase(nn.Module):
  85. def __init__(self) -> None:
  86. super().__init__()
  87. self.layer_idx: int
  88. self.type_feature: str
  89. self.img_processor: CLIPVisionModel
  90. def get_img_features(self,
  91. img_embeds: torch.FloatTensor) -> torch.FloatTensor:
  92. TYPE_FEATURE = self.type_feature
  93. # NOTE: we skip the step to select the vision feature layer since
  94. # this is already done inside the img_processor
  95. img_feature = self.img_processor(img_embeds)
  96. if TYPE_FEATURE == "patch":
  97. patch_feature = img_feature[:, 1:]
  98. return patch_feature
  99. if TYPE_FEATURE == "cls_patch":
  100. return img_feature
  101. raise NotImplementedError
  102. # adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
  103. class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
  104. """Phi3 Image embedding with HD transform."""
  105. def __init__(self, config: PretrainedConfig) -> None:
  106. super().__init__()
  107. # n_embed or hidden_size
  108. hidden_size = config.n_embd if hasattr(
  109. config, 'n_embd') else config.hidden_size
  110. clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
  111. self.layer_idx = config.img_processor.get('layer_idx', -2)
  112. # Initialize the CLIP only up to the required feature layer
  113. if self.layer_idx < 0:
  114. num_hidden_layers = clip_config.num_hidden_layers + \
  115. self.layer_idx + 1
  116. else:
  117. num_hidden_layers = self.layer_idx + 1
  118. self.img_processor = CLIPVisionModel(
  119. clip_config, num_hidden_layers_override=num_hidden_layers)
  120. image_dim_out = config.img_processor['image_dim_out']
  121. self.num_img_tokens = config.img_processor['num_img_tokens']
  122. self.image_dim_out = image_dim_out
  123. # global_gn and sub_gn for hd transform, serves as line separator
  124. self.use_hd_transform = config.embd_layer.get('use_hd_transform',
  125. False)
  126. self.with_learnable_separator = config.embd_layer.get(
  127. 'with_learnable_separator', False)
  128. self.hd_transform_order = config.embd_layer.get(
  129. 'hd_transform_order', 'glb_sub')
  130. # with_hd_transform and with_learnable_separator should have same value
  131. assert self.use_hd_transform and self.with_learnable_separator
  132. # 1024 * 4, merge spatial to channel dimension
  133. self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
  134. self.sub_GN = nn.Parameter(
  135. torch.empty([1, 1, 1, self.image_dim_out * 4]))
  136. dim_projection = hidden_size
  137. depth = 2
  138. layers = [nn.Linear(image_dim_out * 4, dim_projection)]
  139. for _ in range(1, depth):
  140. layers.extend(
  141. [nn.GELU(),
  142. nn.Linear(dim_projection, dim_projection)])
  143. self.img_projection = nn.Sequential(*layers)
  144. self.type_feature = config.img_processor.get('type_feature', 'patch')
  145. def forward(self, pixel_values: torch.FloatTensor,
  146. image_sizes: torch.Tensor) -> torch.FloatTensor:
  147. """
  148. process image and return vision embeddings.
  149. pixel_values: (num_images, num_crops, c, h, w)
  150. output: (num_images, num_img_tokens, hidden_size)
  151. """
  152. num_images, num_crops, c, h, w = pixel_values.shape
  153. pixel_values = pixel_values.flatten(0, 1)
  154. img_features = self.get_img_features(pixel_values)
  155. img_features = img_features.reshape(num_images, num_crops, -1,
  156. self.image_dim_out)
  157. image_features_proj = self.hd_feature_transform(
  158. img_features, image_sizes)
  159. return image_features_proj
  160. def hd_feature_transform(self, image_features, image_sizes):
  161. """
  162. image_features: (num_images, num_crops+1, 24*24, 1024)
  163. """
  164. assert (
  165. self.hd_transform_order == 'sub_glb'
  166. ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
  167. if isinstance(self.img_projection, nn.Sequential):
  168. target_device = self.img_projection[0].bias.device
  169. target_dtype = self.img_projection[0].bias.dtype
  170. else: # It's a single nn.Linear layer
  171. target_device = self.img_projection.bias.device
  172. target_dtype = self.img_projection.bias.dtype
  173. global_image_features = image_features[:,
  174. 0] # (num_images, 24*24, 1024)
  175. # global feature can be viewed as a special HD case with num_crops 1x1
  176. global_image_features_hd = self.reshape_hd_patches_2x2merge(
  177. global_image_features, 1, 1)
  178. global_image_features_hd_newline = self.add_image_newline(
  179. global_image_features_hd)
  180. batch_image_features_proj = []
  181. # need a for loop to process each image because of different image sizes
  182. # (patch arrangement is different for each image)
  183. for i, img_size in enumerate(image_sizes):
  184. h, w = img_size
  185. h_crop = h // 336
  186. w_crop = w // 336
  187. num_crops = h_crop * w_crop
  188. # NOTE: real num_crops is padded
  189. # (num_crops, 24*24, 1024)
  190. sub_image_features = image_features[i, 1:1 + num_crops]
  191. sub_image_features_hd = self.reshape_hd_patches_2x2merge(
  192. sub_image_features, h_crop, w_crop)
  193. sub_image_features_hd_newline = self.add_image_newline(
  194. sub_image_features_hd)
  195. # [sub features, separator, global features]
  196. image_embeddings = torch.cat([
  197. sub_image_features_hd_newline.squeeze(
  198. 0), # (h_crop*12*(w_crop*12+1), 4096)
  199. self.glb_GN.squeeze(0),
  200. global_image_features_hd_newline[i],
  201. ])
  202. img_proj = self.img_projection(
  203. image_embeddings.to(target_device, target_dtype))
  204. batch_image_features_proj.append(img_proj)
  205. return batch_image_features_proj
  206. def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
  207. """
  208. image_features: (num_images*num_crops, 24*24, 1024)
  209. output: (num_images, h_crop*12, w_crop*12, 4096)
  210. where h_crop*w_crop == num_crops
  211. """
  212. N, L, C = image_features.shape
  213. assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
  214. num_images = N // (h_crop * w_crop)
  215. H = int(L**0.5)
  216. image_features_hd = (
  217. image_features.reshape(N, H, H, C) # N, 24, 24, 1024
  218. .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
  219. .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
  220. .reshape(N, -1, 4 * C) # N, 144, 4096
  221. .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
  222. -1) # n_img, h_crop, w_crop, 12, 12, 4096
  223. .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
  224. .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
  225. 4 * C) # n_img, h_crop*12, w_crop*12, 4096
  226. )
  227. return image_features_hd
  228. def add_image_newline(self, image_features_hd):
  229. """
  230. image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
  231. output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
  232. """
  233. num_images, h, w, hid_dim = image_features_hd.shape
  234. # add the newline token to the HD image feature patches
  235. newline_embeddings = self.sub_GN.expand(num_images, h, -1,
  236. -1) # (n_img, h, 1, hid_dim)
  237. image_features_hd_newline = torch.cat(
  238. [image_features_hd, newline_embeddings],
  239. dim=2).reshape(num_images, -1, hid_dim)
  240. return image_features_hd_newline
  241. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
  242. def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
  243. target_height = int(np.ceil(height / padding_unit) * padding_unit)
  244. top_padding = int((target_height - height) / 2)
  245. bottom_padding = target_height - height - top_padding
  246. padded_width = width
  247. padded_height = height + top_padding + bottom_padding
  248. return padded_width, padded_height
  249. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
  250. def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
  251. transposed = False
  252. if width < height:
  253. width, height = height, width
  254. transposed = True
  255. ratio = width / height
  256. scale = 1
  257. while scale * np.ceil(scale / ratio) <= hd_num:
  258. scale += 1
  259. scale -= 1
  260. new_width = int(scale * 336)
  261. new_height = int(new_width / ratio)
  262. padded_width, padded_height = _calc_padded_size(width=new_width,
  263. height=new_height)
  264. if transposed:
  265. padded_width, padded_height = padded_height, padded_width
  266. return padded_width, padded_height
  267. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
  268. def get_phi3v_image_feature_size(
  269. hf_config: PretrainedConfig,
  270. *,
  271. input_height: int,
  272. input_width: int,
  273. ) -> int:
  274. num_crops = getattr(hf_config, "num_crops", 16)
  275. new_width, new_height = _calc_hd_transform_size(width=input_width,
  276. height=input_height,
  277. hd_num=num_crops)
  278. return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
  279. + (new_height // 336 + 1) * 12
  280. def get_max_phi3v_image_tokens(ctx: InputContext):
  281. return get_phi3v_image_feature_size(
  282. ctx.get_hf_config(PretrainedConfig),
  283. input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  284. input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  285. )
  286. def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
  287. mm_counts: Mapping[str, int]):
  288. num_images = mm_counts["image"]
  289. image_feature_size = get_max_phi3v_image_tokens(ctx)
  290. seq_data = dummy_seq_data_for_clip(
  291. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  292. seq_len,
  293. num_images,
  294. image_token_id=_IMAGE_TOKEN_ID,
  295. image_feature_size_override=image_feature_size,
  296. )
  297. mm_data = dummy_image_for_clip(
  298. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  299. num_images,
  300. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  301. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  302. )
  303. return seq_data, mm_data
  304. # Reserve this function to also handle placeholders for additional images
  305. # [ref: PR #5820]
  306. @lru_cache
  307. def _get_image_placeholder_token_ids(model_config: ModelConfig,
  308. idx: int) -> List[int]:
  309. assert idx > 0
  310. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  311. # We need to get the token for "<", not "▁<"
  312. # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
  313. a_token_id, = tokenizer.encode("a", add_special_tokens=False)
  314. a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
  315. f"a<|image_{idx}|>", add_special_tokens=False)
  316. assert a_token_id == a_token_id_
  317. return image_placeholder_token_ids
  318. def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
  319. multi_modal_data = llm_inputs.get("multi_modal_data")
  320. if multi_modal_data is None or "image" not in multi_modal_data:
  321. return llm_inputs
  322. model_config = ctx.model_config
  323. hf_config = ctx.get_hf_config(PretrainedConfig)
  324. image_data = multi_modal_data["image"]
  325. if isinstance(image_data, Image.Image):
  326. w, h = image_data.size
  327. w, h = _calc_hd_transform_size(width=w, height=h)
  328. image_feature_size = get_phi3v_image_feature_size(hf_config,
  329. input_width=w,
  330. input_height=h)
  331. elif isinstance(image_data, torch.Tensor):
  332. image_feature_size = image_data.shape[0]
  333. else:
  334. raise TypeError(f"Invalid image type: {type(image_data)}")
  335. prompt = llm_inputs.get("prompt")
  336. if prompt is None:
  337. new_prompt = None
  338. else:
  339. if prompt.count("<|image|>") > 0:
  340. logger.warning("Please follow the prompt format that is "
  341. "documented on HuggingFace which does not involve "
  342. "repeating <|image|> tokens.")
  343. elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
  344. logger.warning("Multiple image input is not supported yet, "
  345. "so any extra image tokens will be treated "
  346. "as plain text.")
  347. new_prompt = prompt
  348. prompt_token_ids = llm_inputs["prompt_token_ids"]
  349. image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
  350. new_token_ids: List[int] = []
  351. for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
  352. if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
  353. new_token_ids.append(_IMAGE_TOKEN_ID)
  354. # No need to further scan the list since we only replace once
  355. new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
  356. break
  357. else:
  358. new_token_ids.append(prompt_token_ids[i])
  359. # NOTE: Create a defensive copy of the original inputs
  360. llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
  361. prompt=new_prompt,
  362. multi_modal_data=multi_modal_data)
  363. return input_processor_for_clip(
  364. model_config,
  365. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  366. llm_inputs,
  367. image_token_id=_IMAGE_TOKEN_ID,
  368. image_feature_size_override=image_feature_size,
  369. )
  370. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  371. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
  372. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
  373. @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
  374. class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
  375. def __init__(self,
  376. config: PretrainedConfig,
  377. multimodal_config: MultiModalConfig,
  378. cache_config: Optional[CacheConfig] = None,
  379. quant_config: Optional[QuantizationConfig] = None) -> None:
  380. super().__init__()
  381. self.config = config
  382. self.multimodal_config = multimodal_config
  383. self.image_token_id = _IMAGE_TOKEN_ID
  384. self.model = LlamaModel(config, cache_config, quant_config)
  385. # TODO: Optionally initializes this for supporting embeddings.
  386. self.vision_embed_tokens = Phi3HDImageEmbedding(config)
  387. self.lm_head = ParallelLMHead(config.vocab_size,
  388. config.hidden_size,
  389. quant_config=quant_config)
  390. self.logits_processor = LogitsProcessor(config.vocab_size)
  391. self.sampler = Sampler()
  392. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  393. if list(data.shape[1:]) != [2]:
  394. raise ValueError(
  395. f"The expected shape of image sizes is batch dimension plus "
  396. f"{[2]}. You supplied {tuple(data.shape)}.")
  397. return data
  398. def _validate_pixel_values(
  399. self, data: Union[torch.Tensor, List[torch.Tensor]]
  400. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  401. h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
  402. expected_dims = (3, h, w)
  403. def _validate_shape(d: torch.Tensor):
  404. actual_dims = tuple(d.shape[1:])
  405. if actual_dims != expected_dims:
  406. expected_expr = ("num_patches", *map(str, expected_dims))
  407. raise ValueError(
  408. "The expected shape of pixel values in each batch element "
  409. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  410. for d in data:
  411. _validate_shape(d)
  412. return data
  413. def _parse_and_validate_image_input(
  414. self, **kwargs: object) -> Optional[Phi3VImageInputs]:
  415. pixel_values = kwargs.pop("pixel_values", None)
  416. image_sizes = kwargs.pop("image_sizes", None)
  417. image_embeds = kwargs.pop("image_embeds", None)
  418. if pixel_values is None:
  419. return None
  420. if pixel_values is None and image_embeds is None:
  421. return None
  422. if pixel_values is not None:
  423. if not isinstance(pixel_values, (torch.Tensor, list)):
  424. raise ValueError("Incorrect type of pixel values. "
  425. f"Got type: {type(pixel_values)}")
  426. if not isinstance(image_sizes, torch.Tensor):
  427. raise ValueError("Incorrect type of image sizes. "
  428. f"Got type: {type(image_sizes)}")
  429. return Phi3VImagePixelInputs(
  430. type="pixel_values",
  431. data=self._validate_pixel_values(pixel_values),
  432. image_sizes=self._validate_image_sizes(image_sizes))
  433. if image_embeds is not None:
  434. if not isinstance(image_embeds, torch.Tensor):
  435. raise ValueError("Incorrect type of image embeddings. "
  436. f"Got type: {type(image_embeds)}")
  437. return Phi3VImageEmbeddingInputs(
  438. type="image_embeds",
  439. data=image_embeds,
  440. )
  441. raise AssertionError("This line should be unreachable.")
  442. def _process_image_input(
  443. self,
  444. image_input: Phi3VImageInputs,
  445. ) -> torch.Tensor:
  446. if image_input["type"] == "image_embeds":
  447. return image_input["data"]
  448. assert self.vision_embed_tokens is not None
  449. image_embeds = self.vision_embed_tokens(image_input["data"],
  450. image_input["image_sizes"])
  451. return image_embeds
  452. def forward(self,
  453. input_ids: torch.Tensor,
  454. positions: torch.Tensor,
  455. kv_caches: List[torch.Tensor],
  456. attn_metadata: AttentionMetadata,
  457. intermediate_tensors: Optional[IntermediateTensors] = None,
  458. **kwargs: object):
  459. image_input = self._parse_and_validate_image_input(**kwargs)
  460. if image_input is not None:
  461. vision_embeddings = self._process_image_input(image_input)
  462. inputs_embeds = self.model.get_input_embeddings(input_ids)
  463. inputs_embeds = merge_multimodal_embeddings(
  464. input_ids, inputs_embeds, vision_embeddings,
  465. self.image_token_id)
  466. input_ids = None
  467. else:
  468. inputs_embeds = None
  469. hidden_states = self.model(input_ids,
  470. positions,
  471. kv_caches,
  472. attn_metadata,
  473. intermediate_tensors,
  474. inputs_embeds=inputs_embeds)
  475. return hidden_states
  476. def compute_logits(
  477. self,
  478. hidden_states: torch.Tensor,
  479. sampling_metadata: SamplingMetadata,
  480. ) -> Optional[torch.Tensor]:
  481. logits = self.logits_processor(self.lm_head, hidden_states,
  482. sampling_metadata)
  483. return logits
  484. def sample(
  485. self,
  486. logits: torch.Tensor,
  487. sampling_metadata: SamplingMetadata,
  488. ) -> Optional[SamplerOutput]:
  489. next_tokens = self.sampler(logits, sampling_metadata)
  490. return next_tokens
  491. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  492. stacked_params_mapping = [
  493. # (param_name, shard_name, shard_id)
  494. (".qkv_proj", ".q_proj", "q"),
  495. (".qkv_proj", ".k_proj", "k"),
  496. (".qkv_proj", ".v_proj", "v"),
  497. (".gate_up_proj", ".gate_proj", 0),
  498. (".gate_up_proj", ".up_proj", 1),
  499. ]
  500. params_dict = dict(self.named_parameters())
  501. weights_list = list(weights)
  502. for name, loaded_weight in progress_bar(weights_list,
  503. desc="Loading modules..."):
  504. if "rotary_emb.inv_freq" in name:
  505. continue
  506. # post_layernorm is not needed in CLIPVisionModel
  507. if "vision_model.post_layernorm" in name:
  508. continue
  509. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  510. if key_to_modify in name:
  511. name = name.replace(key_to_modify, new_key)
  512. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  513. # We only do sharding for language model
  514. # and not vision model for now.
  515. if "vision_embed_tokens" in name and self.vision_embed_tokens:
  516. continue
  517. if weight_name not in name:
  518. continue
  519. param = params_dict[name.replace(weight_name, param_name)]
  520. weight_loader = param.weight_loader
  521. weight_loader(param, loaded_weight, shard_id)
  522. break
  523. else:
  524. # Skip loading extra bias for GPTQ models.
  525. if name.endswith(".bias") and name not in params_dict:
  526. continue
  527. if name in params_dict:
  528. param = params_dict[name]
  529. weight_loader = getattr(param, "weight_loader",
  530. default_weight_loader)
  531. weight_loader(param, loaded_weight)