1
0

phi3v.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. # coding=utf-8
  2. # Copyright 2024 The PygmalionAI team.
  3. # Copyright 2024 The vLLM team.
  4. # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import re
  18. from functools import lru_cache
  19. from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
  20. TypedDict, Union)
  21. import numpy as np
  22. import torch
  23. import torch.nn as nn
  24. from loguru import logger
  25. from PIL import Image
  26. from transformers import CLIPVisionConfig, PretrainedConfig
  27. from aphrodite.attention import AttentionMetadata
  28. from aphrodite.common.config import CacheConfig, ModelConfig, MultiModalConfig
  29. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  30. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  31. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  32. from aphrodite.modeling.layers.sampler import Sampler
  33. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  34. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  35. from aphrodite.modeling.models.clip import CLIPVisionModel
  36. from aphrodite.modeling.models.llama import LlamaModel
  37. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  38. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  39. from aphrodite.multimodal.image import cached_get_tokenizer
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
  42. input_processor_for_clip)
  43. from .interfaces import SupportsMultiModal
  44. from .utils import merge_multimodal_embeddings
  45. _KEYS_TO_MODIFY_MAPPING = {
  46. "model.vision_embed_tokens": "vision_embed_tokens",
  47. }
  48. # Cannot find the following 2 numbers from hf config.
  49. _IMAGE_TOKEN_ID = 32044
  50. # Result in the max possible feature size (h:w = 16:1)
  51. MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
  52. MAX_IMAGE_FEATURE_SIZE_WIDTH = 50
  53. CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
  54. hidden_act="quick_gelu",
  55. hidden_size=1024,
  56. image_size=336,
  57. intermediate_size=4096,
  58. num_attention_heads=16,
  59. num_channels=3,
  60. num_hidden_layers=24,
  61. patch_size=14,
  62. projection_dim=768)
  63. class Phi3VImagePixelInputs(TypedDict):
  64. type: Literal["pixel_values"]
  65. data: Union[torch.Tensor, List[torch.Tensor]]
  66. """
  67. Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
  68. Note that `num_patches` may be different for each batch, in which case
  69. the data is passed as a list instead of a batched tensor.
  70. """
  71. image_sizes: torch.Tensor
  72. """
  73. Shape: `(batch_size, 2)`
  74. This should be in `(height, width)` format.
  75. """
  76. class Phi3VImageEmbeddingInputs(TypedDict):
  77. type: Literal["image_embeds"]
  78. data: Union[torch.Tensor, List[torch.Tensor]]
  79. """Shape: `(batch_size, image_feature_size, hidden_size)`
  80. `hidden_size` must match the hidden size of language model backbone.
  81. """
  82. Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
  83. class Phi3ImageEmbeddingBase(nn.Module):
  84. def __init__(self) -> None:
  85. super().__init__()
  86. self.layer_idx: int
  87. self.type_feature: str
  88. self.img_processor: CLIPVisionModel
  89. def get_img_features(self,
  90. img_embeds: torch.FloatTensor) -> torch.FloatTensor:
  91. TYPE_FEATURE = self.type_feature
  92. # NOTE: we skip the step to select the vision feature layer since
  93. # this is already done inside the img_processor
  94. img_feature = self.img_processor(img_embeds)
  95. if TYPE_FEATURE == "patch":
  96. patch_feature = img_feature[:, 1:]
  97. return patch_feature
  98. if TYPE_FEATURE == "cls_patch":
  99. return img_feature
  100. raise NotImplementedError
  101. # adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
  102. class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
  103. """Phi3 Image embedding with HD transform."""
  104. def __init__(self, config: PretrainedConfig) -> None:
  105. super().__init__()
  106. # n_embed or hidden_size
  107. hidden_size = config.n_embd if hasattr(
  108. config, 'n_embd') else config.hidden_size
  109. clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
  110. self.layer_idx = config.img_processor.get('layer_idx', -2)
  111. # Initialize the CLIP only up to the required feature layer
  112. if self.layer_idx < 0:
  113. num_hidden_layers = clip_config.num_hidden_layers + \
  114. self.layer_idx + 1
  115. else:
  116. num_hidden_layers = self.layer_idx + 1
  117. self.img_processor = CLIPVisionModel(
  118. clip_config, num_hidden_layers_override=num_hidden_layers)
  119. image_dim_out = config.img_processor['image_dim_out']
  120. self.num_img_tokens = config.img_processor['num_img_tokens']
  121. self.image_dim_out = image_dim_out
  122. # global_gn and sub_gn for hd transform, serves as line separator
  123. self.use_hd_transform = config.embd_layer.get('use_hd_transform',
  124. False)
  125. self.with_learnable_separator = config.embd_layer.get(
  126. 'with_learnable_separator', False)
  127. self.hd_transform_order = config.embd_layer.get(
  128. 'hd_transform_order', 'glb_sub')
  129. # with_hd_transform and with_learnable_separator should have same value
  130. assert self.use_hd_transform and self.with_learnable_separator
  131. # 1024 * 4, merge spatial to channel dimension
  132. self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
  133. self.sub_GN = nn.Parameter(
  134. torch.empty([1, 1, 1, self.image_dim_out * 4]))
  135. dim_projection = hidden_size
  136. depth = 2
  137. layers = [nn.Linear(image_dim_out * 4, dim_projection)]
  138. for _ in range(1, depth):
  139. layers.extend(
  140. [nn.GELU(),
  141. nn.Linear(dim_projection, dim_projection)])
  142. self.img_projection = nn.Sequential(*layers)
  143. self.type_feature = config.img_processor.get('type_feature', 'patch')
  144. def forward(self, pixel_values: torch.FloatTensor,
  145. image_sizes: torch.Tensor) -> torch.FloatTensor:
  146. """
  147. process image and return vision embeddings.
  148. pixel_values: (num_images, num_crops, c, h, w)
  149. output: (num_images, num_img_tokens, hidden_size)
  150. """
  151. num_images, num_crops, c, h, w = pixel_values.shape
  152. pixel_values = pixel_values.flatten(0, 1)
  153. img_features = self.get_img_features(pixel_values)
  154. img_features = img_features.reshape(num_images, num_crops, -1,
  155. self.image_dim_out)
  156. image_features_proj = self.hd_feature_transform(
  157. img_features, image_sizes)
  158. return image_features_proj
  159. def hd_feature_transform(self, image_features, image_sizes):
  160. """
  161. image_features: (num_images, num_crops+1, 24*24, 1024)
  162. """
  163. assert (
  164. self.hd_transform_order == 'sub_glb'
  165. ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
  166. if isinstance(self.img_projection, nn.Sequential):
  167. target_device = self.img_projection[0].bias.device
  168. target_dtype = self.img_projection[0].bias.dtype
  169. else: # It's a single nn.Linear layer
  170. target_device = self.img_projection.bias.device
  171. target_dtype = self.img_projection.bias.dtype
  172. global_image_features = image_features[:,
  173. 0] # (num_images, 24*24, 1024)
  174. # global feature can be viewed as a special HD case with num_crops 1x1
  175. global_image_features_hd = self.reshape_hd_patches_2x2merge(
  176. global_image_features, 1, 1)
  177. global_image_features_hd_newline = self.add_image_newline(
  178. global_image_features_hd)
  179. batch_image_features_proj = []
  180. # need a for loop to process each image because of different image sizes
  181. # (patch arrangement is different for each image)
  182. for i, img_size in enumerate(image_sizes):
  183. h, w = img_size
  184. h_crop = h // 336
  185. w_crop = w // 336
  186. num_crops = h_crop * w_crop
  187. # NOTE: real num_crops is padded
  188. # (num_crops, 24*24, 1024)
  189. sub_image_features = image_features[i, 1:1 + num_crops]
  190. sub_image_features_hd = self.reshape_hd_patches_2x2merge(
  191. sub_image_features, h_crop, w_crop)
  192. sub_image_features_hd_newline = self.add_image_newline(
  193. sub_image_features_hd)
  194. # [sub features, separator, global features]
  195. image_embeddings = torch.cat([
  196. sub_image_features_hd_newline.squeeze(
  197. 0), # (h_crop*12*(w_crop*12+1), 4096)
  198. self.glb_GN.squeeze(0),
  199. global_image_features_hd_newline[i],
  200. ])
  201. img_proj = self.img_projection(
  202. image_embeddings.to(target_device, target_dtype))
  203. batch_image_features_proj.append(img_proj)
  204. return batch_image_features_proj
  205. def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
  206. """
  207. image_features: (num_images*num_crops, 24*24, 1024)
  208. output: (num_images, h_crop*12, w_crop*12, 4096)
  209. where h_crop*w_crop == num_crops
  210. """
  211. N, L, C = image_features.shape
  212. assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
  213. num_images = N // (h_crop * w_crop)
  214. H = int(L**0.5)
  215. image_features_hd = (
  216. image_features.reshape(N, H, H, C) # N, 24, 24, 1024
  217. .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
  218. .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
  219. .reshape(N, -1, 4 * C) # N, 144, 4096
  220. .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
  221. -1) # n_img, h_crop, w_crop, 12, 12, 4096
  222. .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
  223. .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
  224. 4 * C) # n_img, h_crop*12, w_crop*12, 4096
  225. )
  226. return image_features_hd
  227. def add_image_newline(self, image_features_hd):
  228. """
  229. image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
  230. output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
  231. """
  232. num_images, h, w, hid_dim = image_features_hd.shape
  233. # add the newline token to the HD image feature patches
  234. newline_embeddings = self.sub_GN.expand(num_images, h, -1,
  235. -1) # (n_img, h, 1, hid_dim)
  236. image_features_hd_newline = torch.cat(
  237. [image_features_hd, newline_embeddings],
  238. dim=2).reshape(num_images, -1, hid_dim)
  239. return image_features_hd_newline
  240. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
  241. def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
  242. target_height = int(np.ceil(height / padding_unit) * padding_unit)
  243. top_padding = int((target_height - height) / 2)
  244. bottom_padding = target_height - height - top_padding
  245. padded_width = width
  246. padded_height = height + top_padding + bottom_padding
  247. return padded_width, padded_height
  248. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
  249. def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
  250. transposed = False
  251. if width < height:
  252. width, height = height, width
  253. transposed = True
  254. ratio = width / height
  255. scale = 1
  256. while scale * np.ceil(scale / ratio) <= hd_num:
  257. scale += 1
  258. scale -= 1
  259. new_width = int(scale * 336)
  260. new_height = int(new_width / ratio)
  261. padded_width, padded_height = _calc_padded_size(width=new_width,
  262. height=new_height)
  263. if transposed:
  264. padded_width, padded_height = padded_height, padded_width
  265. return padded_width, padded_height
  266. # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
  267. def get_phi3v_image_feature_size(
  268. hf_config: PretrainedConfig,
  269. *,
  270. input_height: int,
  271. input_width: int,
  272. ) -> int:
  273. num_crops = getattr(hf_config, "num_crops", 16)
  274. new_width, new_height = _calc_hd_transform_size(width=input_width,
  275. height=input_height,
  276. hd_num=num_crops)
  277. return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
  278. + (new_height // 336 + 1) * 12
  279. def get_max_phi3v_image_tokens(ctx: InputContext):
  280. return get_phi3v_image_feature_size(
  281. ctx.get_hf_config(PretrainedConfig),
  282. input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  283. input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  284. )
  285. def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
  286. mm_counts: Mapping[str, int]):
  287. num_images = mm_counts["image"]
  288. image_feature_size = get_max_phi3v_image_tokens(ctx)
  289. seq_data = dummy_seq_data_for_clip(
  290. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  291. seq_len,
  292. num_images,
  293. image_token_id=_IMAGE_TOKEN_ID,
  294. image_feature_size_override=image_feature_size,
  295. )
  296. mm_data = dummy_image_for_clip(
  297. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  298. num_images,
  299. image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
  300. image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
  301. )
  302. return seq_data, mm_data
  303. # Reserve this function to also handle placeholders for additional images
  304. # [ref: PR #5820]
  305. @lru_cache
  306. def _get_image_placeholder_token_ids(model_config: ModelConfig,
  307. idx: int) -> List[int]:
  308. assert idx > 0
  309. tokenizer = cached_get_tokenizer(model_config.tokenizer)
  310. # We need to get the token for "<", not "▁<"
  311. # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
  312. a_token_id, = tokenizer.encode("a", add_special_tokens=False)
  313. a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
  314. f"a<|image_{idx}|>", add_special_tokens=False)
  315. assert a_token_id == a_token_id_
  316. return image_placeholder_token_ids
  317. def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
  318. multi_modal_data = llm_inputs.get("multi_modal_data")
  319. if multi_modal_data is None or "image" not in multi_modal_data:
  320. return llm_inputs
  321. model_config = ctx.model_config
  322. hf_config = ctx.get_hf_config(PretrainedConfig)
  323. image_data = multi_modal_data["image"]
  324. if isinstance(image_data, Image.Image):
  325. w, h = image_data.size
  326. w, h = _calc_hd_transform_size(width=w, height=h)
  327. image_feature_size = get_phi3v_image_feature_size(hf_config,
  328. input_width=w,
  329. input_height=h)
  330. elif isinstance(image_data, torch.Tensor):
  331. image_feature_size = image_data.shape[0]
  332. else:
  333. raise TypeError(f"Invalid image type: {type(image_data)}")
  334. prompt = llm_inputs.get("prompt")
  335. if prompt is None:
  336. new_prompt = None
  337. else:
  338. if prompt.count("<|image|>") > 0:
  339. logger.warning("Please follow the prompt format that is "
  340. "documented on HuggingFace which does not involve "
  341. "repeating <|image|> tokens.")
  342. elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
  343. logger.warning("Multiple image input is not supported yet, "
  344. "so any extra image tokens will be treated "
  345. "as plain text.")
  346. new_prompt = prompt
  347. prompt_token_ids = llm_inputs["prompt_token_ids"]
  348. image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
  349. new_token_ids: List[int] = []
  350. for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
  351. if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
  352. new_token_ids.append(_IMAGE_TOKEN_ID)
  353. # No need to further scan the list since we only replace once
  354. new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
  355. break
  356. else:
  357. new_token_ids.append(prompt_token_ids[i])
  358. # NOTE: Create a defensive copy of the original inputs
  359. llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
  360. prompt=new_prompt,
  361. multi_modal_data=multi_modal_data)
  362. return input_processor_for_clip(
  363. model_config,
  364. CLIP_VIT_LARGE_PATCH14_336_CONFIG,
  365. llm_inputs,
  366. image_token_id=_IMAGE_TOKEN_ID,
  367. image_feature_size_override=image_feature_size,
  368. )
  369. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  370. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
  371. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
  372. @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
  373. class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
  374. def __init__(self,
  375. config: PretrainedConfig,
  376. multimodal_config: MultiModalConfig,
  377. cache_config: Optional[CacheConfig] = None,
  378. quant_config: Optional[QuantizationConfig] = None) -> None:
  379. super().__init__()
  380. self.config = config
  381. self.multimodal_config = multimodal_config
  382. self.image_token_id = _IMAGE_TOKEN_ID
  383. self.model = LlamaModel(config, cache_config, quant_config)
  384. # TODO: Optionally initializes this for supporting embeddings.
  385. self.vision_embed_tokens = Phi3HDImageEmbedding(config)
  386. self.lm_head = ParallelLMHead(config.vocab_size,
  387. config.hidden_size,
  388. quant_config=quant_config)
  389. self.logits_processor = LogitsProcessor(config.vocab_size)
  390. self.sampler = Sampler()
  391. def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
  392. if list(data.shape[1:]) != [2]:
  393. raise ValueError(
  394. f"The expected shape of image sizes is batch dimension plus "
  395. f"{[2]}. You supplied {tuple(data.shape)}.")
  396. return data
  397. def _validate_pixel_values(
  398. self, data: Union[torch.Tensor, List[torch.Tensor]]
  399. ) -> Union[torch.Tensor, List[torch.Tensor]]:
  400. h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
  401. expected_dims = (3, h, w)
  402. def _validate_shape(d: torch.Tensor):
  403. actual_dims = tuple(d.shape[1:])
  404. if actual_dims != expected_dims:
  405. expected_expr = ("num_patches", *map(str, expected_dims))
  406. raise ValueError(
  407. "The expected shape of pixel values in each batch element "
  408. f"is {expected_expr}. You supplied {tuple(d.shape)}.")
  409. for d in data:
  410. _validate_shape(d)
  411. return data
  412. def _parse_and_validate_image_input(
  413. self, **kwargs: object) -> Optional[Phi3VImageInputs]:
  414. pixel_values = kwargs.pop("pixel_values", None)
  415. image_sizes = kwargs.pop("image_sizes", None)
  416. image_embeds = kwargs.pop("image_embeds", None)
  417. if pixel_values is None:
  418. return None
  419. if pixel_values is None and image_embeds is None:
  420. return None
  421. if pixel_values is not None:
  422. if not isinstance(pixel_values, (torch.Tensor, list)):
  423. raise ValueError("Incorrect type of pixel values. "
  424. f"Got type: {type(pixel_values)}")
  425. if not isinstance(image_sizes, torch.Tensor):
  426. raise ValueError("Incorrect type of image sizes. "
  427. f"Got type: {type(image_sizes)}")
  428. return Phi3VImagePixelInputs(
  429. type="pixel_values",
  430. data=self._validate_pixel_values(pixel_values),
  431. image_sizes=self._validate_image_sizes(image_sizes))
  432. if image_embeds is not None:
  433. if not isinstance(image_embeds, torch.Tensor):
  434. raise ValueError("Incorrect type of image embeddings. "
  435. f"Got type: {type(image_embeds)}")
  436. return Phi3VImageEmbeddingInputs(
  437. type="image_embeds",
  438. data=image_embeds,
  439. )
  440. raise AssertionError("This line should be unreachable.")
  441. def _process_image_input(
  442. self,
  443. image_input: Phi3VImageInputs,
  444. ) -> torch.Tensor:
  445. if image_input["type"] == "image_embeds":
  446. return image_input["data"]
  447. assert self.vision_embed_tokens is not None
  448. image_embeds = self.vision_embed_tokens(image_input["data"],
  449. image_input["image_sizes"])
  450. return image_embeds
  451. def forward(self,
  452. input_ids: torch.Tensor,
  453. positions: torch.Tensor,
  454. kv_caches: List[torch.Tensor],
  455. attn_metadata: AttentionMetadata,
  456. intermediate_tensors: Optional[IntermediateTensors] = None,
  457. **kwargs: object):
  458. image_input = self._parse_and_validate_image_input(**kwargs)
  459. if image_input is not None:
  460. vision_embeddings = self._process_image_input(image_input)
  461. inputs_embeds = self.model.get_input_embeddings(input_ids)
  462. inputs_embeds = merge_multimodal_embeddings(
  463. input_ids, inputs_embeds, vision_embeddings,
  464. self.image_token_id)
  465. input_ids = None
  466. else:
  467. inputs_embeds = None
  468. hidden_states = self.model(input_ids,
  469. positions,
  470. kv_caches,
  471. attn_metadata,
  472. intermediate_tensors,
  473. inputs_embeds=inputs_embeds)
  474. return hidden_states
  475. def compute_logits(
  476. self,
  477. hidden_states: torch.Tensor,
  478. sampling_metadata: SamplingMetadata,
  479. ) -> Optional[torch.Tensor]:
  480. logits = self.logits_processor(self.lm_head, hidden_states,
  481. sampling_metadata)
  482. return logits
  483. def sample(
  484. self,
  485. logits: torch.Tensor,
  486. sampling_metadata: SamplingMetadata,
  487. ) -> Optional[SamplerOutput]:
  488. next_tokens = self.sampler(logits, sampling_metadata)
  489. return next_tokens
  490. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  491. stacked_params_mapping = [
  492. # (param_name, shard_name, shard_id)
  493. (".qkv_proj", ".q_proj", "q"),
  494. (".qkv_proj", ".k_proj", "k"),
  495. (".qkv_proj", ".v_proj", "v"),
  496. (".gate_up_proj", ".gate_proj", 0),
  497. (".gate_up_proj", ".up_proj", 1),
  498. ]
  499. params_dict = dict(self.named_parameters())
  500. for name, loaded_weight in weights:
  501. if "rotary_emb.inv_freq" in name:
  502. continue
  503. # post_layernorm is not needed in CLIPVisionModel
  504. if "vision_model.post_layernorm" in name:
  505. continue
  506. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  507. if key_to_modify in name:
  508. name = name.replace(key_to_modify, new_key)
  509. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  510. # We only do sharding for language model
  511. # and not vision model for now.
  512. if "vision_embed_tokens" in name and self.vision_embed_tokens:
  513. continue
  514. if weight_name not in name:
  515. continue
  516. param = params_dict[name.replace(weight_name, param_name)]
  517. weight_loader = param.weight_loader
  518. weight_loader(param, loaded_weight, shard_id)
  519. break
  520. else:
  521. # Skip loading extra bias for GPTQ models.
  522. if name.endswith(".bias") and name not in params_dict:
  523. continue
  524. if name in params_dict:
  525. param = params_dict[name]
  526. weight_loader = getattr(param, "weight_loader",
  527. default_weight_loader)
  528. weight_loader(param, loaded_weight)