1
0

blip2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694
  1. from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
  2. import torch
  3. import torch.nn as nn
  4. from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
  5. apply_chunking_to_forward)
  6. from aphrodite.attention import AttentionMetadata
  7. from aphrodite.common.config import CacheConfig, MultiModalConfig
  8. from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
  9. SequenceData)
  10. from aphrodite.common.utils import progress_bar
  11. from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
  12. from aphrodite.modeling.layers.activation import get_act_fn
  13. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  14. from aphrodite.modeling.layers.sampler import Sampler
  15. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  16. from aphrodite.modeling.models.opt import OPTModel
  17. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  18. from aphrodite.multimodal import MULTIMODAL_REGISTRY
  19. from aphrodite.quantization import QuantizationConfig
  20. from .blip import (BlipVisionModel, dummy_image_for_blip,
  21. get_max_blip_image_tokens)
  22. from .interfaces import SupportsMultiModal
  23. from .utils import merge_multimodal_embeddings
  24. _KEYS_TO_MODIFY_MAPPING = {
  25. "language_model.lm_head": "lm_head",
  26. "language_model.model": "language_model",
  27. }
  28. # We use this internally as placeholders since there is no image token
  29. # defined on the HuggingFace repo
  30. BLIP2_IMAGE_TOKEN = "<image>"
  31. BLIP2_IMAGE_TOKEN_ID = 50265
  32. class Blip2ImagePixelInputs(TypedDict):
  33. type: Literal["pixel_values"]
  34. data: torch.Tensor
  35. """Shape: (batch_size, num_channels, height, width)"""
  36. class Blip2ImageEmbeddingInputs(TypedDict):
  37. type: Literal["image_embeds"]
  38. data: torch.Tensor
  39. """Shape: `(batch_size, image_feature_size, hidden_size)`
  40. `hidden_size` must match the hidden size of language model backbone.
  41. """
  42. Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
  43. class Blip2QFormerMultiHeadAttention(nn.Module):
  44. def __init__(
  45. self,
  46. config: Blip2QFormerConfig,
  47. *,
  48. quant_config: Optional[QuantizationConfig],
  49. cache_config: Optional[CacheConfig],
  50. is_cross_attention: bool = False,
  51. ) -> None:
  52. super().__init__()
  53. self.config = config
  54. if config.hidden_size % config.num_attention_heads != 0:
  55. raise ValueError(
  56. f"The hidden size ({config.hidden_size}) is not a multiple of "
  57. f"the number of attention heads ({config.num_attention_heads})"
  58. )
  59. self.num_attention_heads = config.num_attention_heads
  60. self.attention_head_size = (config.hidden_size //
  61. config.num_attention_heads)
  62. self.all_head_size = self.num_attention_heads * self.attention_head_size
  63. self.scaling = self.attention_head_size**-0.5
  64. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  65. if is_cross_attention:
  66. kv_hidden_size = config.encoder_hidden_size
  67. else:
  68. kv_hidden_size = config.hidden_size
  69. self.key = nn.Linear(kv_hidden_size, self.all_head_size)
  70. self.value = nn.Linear(kv_hidden_size, self.all_head_size)
  71. self.position_embedding_type = getattr(config,
  72. "position_embedding_type",
  73. "absolute")
  74. if self.position_embedding_type != "absolute":
  75. raise NotImplementedError("Unsupported position_embedding_type: "
  76. f"{self.position_embedding_type}")
  77. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  78. def transpose_for_scores(self, x):
  79. x = x.view(*x.size()[:-1], self.num_attention_heads,
  80. self.attention_head_size)
  81. return x.permute(0, 2, 1, 3)
  82. def forward(
  83. self,
  84. hidden_states: torch.Tensor,
  85. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  86. ):
  87. is_cross_attention = encoder_hidden_states is not None
  88. if is_cross_attention:
  89. key_layer = self.transpose_for_scores(
  90. self.key(encoder_hidden_states))
  91. value_layer = self.transpose_for_scores(
  92. self.value(encoder_hidden_states))
  93. else:
  94. key_layer = self.transpose_for_scores(self.key(hidden_states))
  95. value_layer = self.transpose_for_scores(self.value(hidden_states))
  96. mixed_query_layer = self.query(hidden_states)
  97. query_layer = self.transpose_for_scores(mixed_query_layer)
  98. attention_scores = torch.matmul(query_layer,
  99. key_layer.transpose(-1, -2))
  100. attention_probs = torch.softmax(attention_scores * self.scaling,
  101. dim=-1)
  102. # This is actually dropping out entire tokens to attend to, which might
  103. # seem a bit unusual, but is taken from the original Transformer paper.
  104. attention_probs_dropped = self.dropout(attention_probs)
  105. context_layer = torch.matmul(attention_probs_dropped, value_layer)
  106. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  107. context_layer = context_layer.view(*context_layer.size()[:-2],
  108. self.all_head_size)
  109. return context_layer
  110. class Blip2QFormerSelfOutput(nn.Module):
  111. def __init__(self, config: Blip2QFormerConfig) -> None:
  112. super().__init__()
  113. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  114. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  115. eps=config.layer_norm_eps)
  116. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  117. def forward(
  118. self,
  119. hidden_states: torch.Tensor,
  120. input_tensor: torch.Tensor,
  121. ) -> torch.Tensor:
  122. hidden_states = self.dense(hidden_states)
  123. hidden_states = self.dropout(hidden_states)
  124. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  125. return hidden_states
  126. class Blip2QFormerAttention(nn.Module):
  127. def __init__(
  128. self,
  129. config: Blip2QFormerConfig,
  130. *,
  131. quant_config: Optional[QuantizationConfig],
  132. cache_config: Optional[CacheConfig],
  133. is_cross_attention: bool = False,
  134. ) -> None:
  135. super().__init__()
  136. self.attention = Blip2QFormerMultiHeadAttention(
  137. config,
  138. quant_config=quant_config,
  139. cache_config=cache_config,
  140. is_cross_attention=is_cross_attention,
  141. )
  142. self.output = Blip2QFormerSelfOutput(config)
  143. def forward(
  144. self,
  145. hidden_states: torch.Tensor,
  146. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  147. ) -> Tuple[torch.Tensor]:
  148. self_output = self.attention(
  149. hidden_states,
  150. encoder_hidden_states=encoder_hidden_states,
  151. )
  152. attention_output = self.output(self_output, hidden_states)
  153. return attention_output
  154. class Blip2QFormerIntermediate(nn.Module):
  155. def __init__(self, config: Blip2QFormerConfig) -> None:
  156. super().__init__()
  157. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  158. self.intermediate_act_fn = get_act_fn(config.hidden_act)
  159. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  160. hidden_states = self.dense(hidden_states)
  161. hidden_states = self.intermediate_act_fn(hidden_states)
  162. return hidden_states
  163. class Blip2QFormerOutput(nn.Module):
  164. def __init__(self, config: Blip2QFormerConfig) -> None:
  165. super().__init__()
  166. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  167. self.LayerNorm = nn.LayerNorm(config.hidden_size,
  168. eps=config.layer_norm_eps)
  169. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  170. def forward(
  171. self,
  172. hidden_states: torch.Tensor,
  173. input_tensor: torch.Tensor,
  174. ) -> torch.Tensor:
  175. hidden_states = self.dense(hidden_states)
  176. hidden_states = self.dropout(hidden_states)
  177. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  178. return hidden_states
  179. class Blip2QFormerLayer(nn.Module):
  180. def __init__(
  181. self,
  182. config: Blip2QFormerConfig,
  183. *,
  184. quant_config: Optional[QuantizationConfig],
  185. cache_config: Optional[CacheConfig],
  186. layer_idx: int,
  187. ) -> None:
  188. super().__init__()
  189. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  190. self.seq_len_dim = 1
  191. self.attention = Blip2QFormerAttention(config,
  192. quant_config=quant_config,
  193. cache_config=cache_config)
  194. self.layer_idx = layer_idx
  195. if layer_idx % config.cross_attention_frequency == 0:
  196. self.crossattention = Blip2QFormerAttention(
  197. config,
  198. quant_config=quant_config,
  199. cache_config=cache_config,
  200. is_cross_attention=True)
  201. self.has_cross_attention = True
  202. else:
  203. self.has_cross_attention = False
  204. self.intermediate_query = Blip2QFormerIntermediate(config)
  205. self.output_query = Blip2QFormerOutput(config)
  206. def forward(
  207. self,
  208. hidden_states: torch.FloatTensor,
  209. encoder_hidden_states: torch.FloatTensor,
  210. query_length: int,
  211. ):
  212. attention_output = self.attention(hidden_states)
  213. if query_length > 0:
  214. query_attention_output = attention_output[:, :query_length, :]
  215. if self.has_cross_attention:
  216. query_attention_output = self.crossattention(
  217. query_attention_output,
  218. encoder_hidden_states=encoder_hidden_states,
  219. )
  220. layer_output = apply_chunking_to_forward(
  221. self.feed_forward_chunk_query,
  222. self.chunk_size_feed_forward,
  223. self.seq_len_dim,
  224. query_attention_output,
  225. )
  226. if attention_output.shape[1] > query_length:
  227. layer_output_text = apply_chunking_to_forward(
  228. self.feed_forward_chunk,
  229. self.chunk_size_feed_forward,
  230. self.seq_len_dim,
  231. attention_output[:, query_length:, :],
  232. )
  233. layer_output = torch.cat([layer_output, layer_output_text],
  234. dim=1)
  235. else:
  236. layer_output = apply_chunking_to_forward(
  237. self.feed_forward_chunk,
  238. self.chunk_size_feed_forward,
  239. self.seq_len_dim,
  240. attention_output,
  241. )
  242. return layer_output
  243. def feed_forward_chunk(self,
  244. attention_output: torch.Tensor) -> torch.Tensor:
  245. intermediate_output = self.intermediate(attention_output)
  246. layer_output = self.output(intermediate_output, attention_output)
  247. return layer_output
  248. def feed_forward_chunk_query(
  249. self, attention_output: torch.Tensor) -> torch.Tensor:
  250. intermediate_output = self.intermediate_query(attention_output)
  251. layer_output = self.output_query(intermediate_output, attention_output)
  252. return layer_output
  253. class Blip2QFormerEncoder(nn.Module):
  254. def __init__(
  255. self,
  256. config: Blip2QFormerConfig,
  257. *,
  258. quant_config: Optional[QuantizationConfig],
  259. cache_config: Optional[CacheConfig],
  260. ) -> None:
  261. super().__init__()
  262. self.config = config
  263. self.layer = nn.ModuleList([
  264. Blip2QFormerLayer(config,
  265. quant_config=quant_config,
  266. cache_config=cache_config,
  267. layer_idx=layer_idx)
  268. for layer_idx in range(config.num_hidden_layers)
  269. ])
  270. def forward(
  271. self,
  272. hidden_states: torch.FloatTensor,
  273. encoder_hidden_states: torch.FloatTensor,
  274. query_length: int,
  275. ) -> torch.Tensor:
  276. for i in range(self.config.num_hidden_layers):
  277. layer_module = self.layer[i]
  278. hidden_states = layer_module(
  279. hidden_states,
  280. encoder_hidden_states=encoder_hidden_states,
  281. query_length=query_length,
  282. )
  283. return hidden_states
  284. # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
  285. class Blip2QFormerModel(nn.Module):
  286. def __init__(
  287. self,
  288. config: Blip2QFormerConfig,
  289. *,
  290. quant_config: Optional[QuantizationConfig],
  291. cache_config: Optional[CacheConfig],
  292. ) -> None:
  293. super().__init__()
  294. self.config = config
  295. self.layernorm = nn.LayerNorm(config.hidden_size,
  296. eps=config.layer_norm_eps)
  297. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  298. self.encoder = Blip2QFormerEncoder(config,
  299. quant_config=quant_config,
  300. cache_config=cache_config)
  301. def forward(
  302. self,
  303. query_embeds: torch.FloatTensor,
  304. encoder_hidden_states: torch.FloatTensor,
  305. ) -> torch.Tensor:
  306. query_length = query_embeds.shape[1]
  307. embedding_output = self.layernorm(query_embeds)
  308. embedding_output = self.dropout(embedding_output)
  309. sequence_output = self.encoder(
  310. embedding_output,
  311. encoder_hidden_states=encoder_hidden_states,
  312. query_length=query_length,
  313. )
  314. return sequence_output
  315. def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
  316. return hf_config.num_query_tokens
  317. def get_max_blip2_image_tokens(ctx: InputContext):
  318. hf_config = ctx.get_hf_config(Blip2Config)
  319. vision_config = hf_config.vision_config
  320. if isinstance(vision_config, Blip2VisionConfig):
  321. return get_max_blip_image_tokens(vision_config)
  322. msg = f"Unsupported vision config: {type(vision_config)}"
  323. raise NotImplementedError(msg)
  324. def dummy_data_for_blip2(ctx: InputContext, seq_len: int):
  325. hf_config = ctx.get_hf_config(Blip2Config)
  326. vision_config = hf_config.vision_config
  327. image_feature_size = get_blip2_image_feature_size(hf_config)
  328. token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
  329. token_ids += [0] * (seq_len - image_feature_size)
  330. seq_data = SequenceData(token_ids)
  331. if isinstance(vision_config, Blip2VisionConfig):
  332. mm_data = dummy_image_for_blip(vision_config)
  333. return seq_data, mm_data
  334. msg = f"Unsupported vision config: {type(vision_config)}"
  335. raise NotImplementedError(msg)
  336. def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
  337. multi_modal_data = llm_inputs.get("multi_modal_data")
  338. if multi_modal_data is None or "image" not in multi_modal_data:
  339. return llm_inputs
  340. hf_config = ctx.get_hf_config(Blip2Config)
  341. image_feature_size = get_blip2_image_feature_size(hf_config)
  342. # The original model places image tokens at the front
  343. # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
  344. new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
  345. new_token_ids += llm_inputs["prompt_token_ids"]
  346. new_prompt = llm_inputs.get("prompt")
  347. if new_prompt is not None:
  348. new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
  349. return LLMInputs(prompt_token_ids=new_token_ids,
  350. prompt=new_prompt,
  351. multi_modal_data=multi_modal_data)
  352. @MULTIMODAL_REGISTRY.register_image_input_mapper()
  353. @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
  354. @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
  355. @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
  356. class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
  357. def __init__(self,
  358. config: Blip2Config,
  359. multimodal_config: MultiModalConfig,
  360. cache_config: Optional[CacheConfig] = None,
  361. quant_config: Optional[QuantizationConfig] = None) -> None:
  362. super().__init__()
  363. self.config = config
  364. self.multimodal_config = multimodal_config
  365. # TODO: Optionally initializes this for supporting embeddings.
  366. self.vision_model = BlipVisionModel(config.vision_config)
  367. self.query_tokens = nn.Parameter(
  368. torch.zeros(1, config.num_query_tokens,
  369. config.qformer_config.hidden_size))
  370. self.qformer = Blip2QFormerModel(config.qformer_config,
  371. cache_config=cache_config,
  372. quant_config=quant_config)
  373. self.language_projection = nn.Linear(
  374. config.qformer_config.hidden_size,
  375. config.text_config.hidden_size,
  376. bias=True,
  377. )
  378. self.quant_config = quant_config
  379. self.language_model = OPTModel(config.text_config, cache_config,
  380. quant_config)
  381. self.unpadded_vocab_size = config.text_config.vocab_size
  382. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
  383. self.sampler = Sampler()
  384. def get_lm_head(self):
  385. return self.language_model.decoder.embed_tokens
  386. def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
  387. h = w = self.config.vision_config.image_size
  388. expected_dims = (3, h, w)
  389. actual_dims = tuple(data.shape[1:])
  390. if actual_dims != expected_dims:
  391. expected_expr = ("batch_size", *map(str, expected_dims))
  392. raise ValueError(
  393. f"The expected shape of pixel values is {expected_expr}. "
  394. f"You supplied {tuple(data.shape)}.")
  395. return data
  396. def _parse_and_validate_image_input(
  397. self, **kwargs: object) -> Optional[Blip2ImageInputs]:
  398. pixel_values = kwargs.pop("pixel_values", None)
  399. image_embeds = kwargs.pop("image_embeds", None)
  400. if pixel_values is None and image_embeds is None:
  401. return None
  402. if pixel_values is not None:
  403. if not isinstance(pixel_values, torch.Tensor):
  404. raise ValueError("Incorrect type of pixel values. "
  405. f"Got type: {type(pixel_values)}")
  406. return Blip2ImagePixelInputs(
  407. type="pixel_values",
  408. data=self._validate_pixel_values(pixel_values),
  409. )
  410. if image_embeds is not None:
  411. if not isinstance(image_embeds, torch.Tensor):
  412. raise ValueError("Incorrect type of image embeddings. "
  413. f"Got type: {type(image_embeds)}")
  414. return Blip2ImageEmbeddingInputs(
  415. type="image_embeds",
  416. data=image_embeds,
  417. )
  418. raise AssertionError("This line should be unreachable.")
  419. def _image_pixels_to_features(self, vision_model: BlipVisionModel,
  420. pixel_values: torch.Tensor) -> torch.Tensor:
  421. # NOTE: we skip the step to select the vision feature layer since
  422. # this is already done inside the vision tower
  423. image_features = vision_model(pixel_values)
  424. return image_features
  425. def _process_image_pixels(self,
  426. inputs: Blip2ImagePixelInputs) -> torch.Tensor:
  427. assert self.vision_model is not None
  428. pixel_values = inputs["data"]
  429. return self._image_pixels_to_features(self.vision_model, pixel_values)
  430. def _process_image_input(self,
  431. image_input: Blip2ImageInputs) -> torch.Tensor:
  432. if image_input["type"] == "image_embeds":
  433. return image_input["data"]
  434. assert self.vision_model is not None
  435. image_features = self._process_image_pixels(image_input)
  436. query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
  437. -1)
  438. query_output = self.qformer(
  439. query_embeds=query_tokens,
  440. encoder_hidden_states=image_features,
  441. )
  442. return self.language_projection(query_output)
  443. def forward(
  444. self,
  445. input_ids: torch.Tensor,
  446. positions: torch.Tensor,
  447. kv_caches: List[torch.Tensor],
  448. attn_metadata: AttentionMetadata,
  449. intermediate_tensors: Optional[IntermediateTensors] = None,
  450. **kwargs: object,
  451. ) -> SamplerOutput:
  452. """Run forward pass for BLIP-2.
  453. One key thing to understand is the `input_ids` already accounts for the
  454. positions of the to-be-inserted image embeddings.
  455. Concretely, consider a text prompt:
  456. `"Question: What's the content of the image? Answer:"`.
  457. Tokenizer outputs:
  458. `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
  459. To reserve space in KV cache, we have to insert placeholder tokens
  460. before they are inputted to the model, so the input processor prepends
  461. dummy tokens (denoted as `50265`), resulting in:
  462. `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
  463. We insert 32 tokens since it corresponds to the number of query
  464. embeddings outputted by the Q-Former and inputted to the language model.
  465. This way, the `positions` and `attn_metadata` are consistent
  466. with the `input_ids`.
  467. Args:
  468. input_ids: Flattened (concatenated) input_ids corresponding to a
  469. batch.
  470. pixel_values: The pixels in each input image.
  471. See also:
  472. :class:`Blip2ImageInputs`
  473. """
  474. image_input = self._parse_and_validate_image_input(**kwargs)
  475. if image_input is not None:
  476. vision_embeddings = self._process_image_input(image_input)
  477. inputs_embeds = self.language_model.get_input_embeddings(input_ids)
  478. inputs_embeds = merge_multimodal_embeddings(
  479. input_ids, inputs_embeds, vision_embeddings,
  480. BLIP2_IMAGE_TOKEN_ID)
  481. input_ids = None
  482. else:
  483. inputs_embeds = None
  484. hidden_states = self.language_model(input_ids,
  485. positions,
  486. kv_caches,
  487. attn_metadata,
  488. inputs_embeds=inputs_embeds)
  489. return hidden_states
  490. def compute_logits(
  491. self,
  492. hidden_states: torch.Tensor,
  493. sampling_metadata: SamplingMetadata,
  494. ) -> Optional[torch.Tensor]:
  495. logits = self.logits_processor(self.get_lm_head(), hidden_states,
  496. sampling_metadata)
  497. return logits
  498. def sample(
  499. self,
  500. logits: torch.Tensor,
  501. sampling_metadata: SamplingMetadata,
  502. ) -> Optional[SamplerOutput]:
  503. next_tokens = self.sampler(logits, sampling_metadata)
  504. return next_tokens
  505. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  506. # only doing this for language model part for now.
  507. stacked_params_mapping = [
  508. # (param_name, shard_name, shard_id)
  509. ("qkv_proj", "q_proj", "q"),
  510. ("qkv_proj", "k_proj", "k"),
  511. ("qkv_proj", "v_proj", "v"),
  512. ("gate_up_proj", "gate_proj", 0),
  513. ("gate_up_proj", "up_proj", 1),
  514. ]
  515. params_dict = dict(self.named_parameters())
  516. weights_list = list(weights)
  517. for name, loaded_weight in progress_bar(weights_list,
  518. desc="Loading modules..."):
  519. if "lm_head.weight" in name:
  520. continue
  521. if "rotary_emb.inv_freq" in name:
  522. continue
  523. for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
  524. if key_to_modify in name:
  525. name = name.replace(key_to_modify, new_key)
  526. use_default_weight_loading = False
  527. if "vision" in name:
  528. if self.vision_model is not None:
  529. # We only do sharding for language model and
  530. # not vision model for now.
  531. use_default_weight_loading = True
  532. else:
  533. for (param_name, weight_name,
  534. shard_id) in stacked_params_mapping:
  535. if weight_name not in name:
  536. continue
  537. param = params_dict[name.replace(weight_name, param_name)]
  538. weight_loader = param.weight_loader
  539. weight_loader(param, loaded_weight, shard_id)
  540. break
  541. else:
  542. use_default_weight_loading = True
  543. if use_default_weight_loading:
  544. param = params_dict[name]
  545. weight_loader = getattr(param, "weight_loader",
  546. default_weight_loader)
  547. weight_loader(param, loaded_weight)