123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720 |
- from array import array
- from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
- TypedDict, Union)
- import torch
- import torch.nn as nn
- from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig,
- apply_chunking_to_forward)
- from aphrodite.attention import AttentionMetadata
- from aphrodite.common.config import CacheConfig, MultiModalConfig
- from aphrodite.common.sequence import IntermediateTensors, SequenceData
- from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
- from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
- from aphrodite.modeling.layers.activation import get_act_fn
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.modeling.models.opt import OPTModel
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.multimodal import MULTIMODAL_REGISTRY
- from aphrodite.quantization import QuantizationConfig
- from .blip import (BlipVisionModel, dummy_image_for_blip,
- get_max_blip_image_tokens)
- from .interfaces import SupportsMultiModal
- from .utils import merge_multimodal_embeddings
- _KEYS_TO_MODIFY_MAPPING = {
- "language_model.lm_head": "lm_head",
- "language_model.model": "language_model",
- }
- # We use this internally as placeholders since there is no image token
- # defined on the HuggingFace repo
- BLIP2_IMAGE_TOKEN = "<image>"
- BLIP2_IMAGE_TOKEN_ID = 50265
- class Blip2ImagePixelInputs(TypedDict):
- type: Literal["pixel_values"]
- data: torch.Tensor
- """Shape: `(batch_size * num_images, num_channels, height, width)`"""
- class Blip2ImageEmbeddingInputs(TypedDict):
- type: Literal["image_embeds"]
- data: torch.Tensor
- """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
- `hidden_size` must match the hidden size of language model backbone.
- """
- Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
- class Blip2QFormerMultiHeadAttention(nn.Module):
- def __init__(
- self,
- config: Blip2QFormerConfig,
- *,
- quant_config: Optional[QuantizationConfig],
- cache_config: Optional[CacheConfig],
- is_cross_attention: bool = False,
- ) -> None:
- super().__init__()
- self.config = config
- if config.hidden_size % config.num_attention_heads != 0:
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of "
- f"the number of attention heads ({config.num_attention_heads})"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = (config.hidden_size //
- config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.scaling = self.attention_head_size**-0.5
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- if is_cross_attention:
- kv_hidden_size = config.encoder_hidden_size
- else:
- kv_hidden_size = config.hidden_size
- self.key = nn.Linear(kv_hidden_size, self.all_head_size)
- self.value = nn.Linear(kv_hidden_size, self.all_head_size)
- self.position_embedding_type = getattr(config,
- "position_embedding_type",
- "absolute")
- if self.position_embedding_type != "absolute":
- raise NotImplementedError("Unsupported position_embedding_type: "
- f"{self.position_embedding_type}")
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
- def transpose_for_scores(self, x):
- x = x.view(*x.size()[:-1], self.num_attention_heads,
- self.attention_head_size)
- return x.permute(0, 2, 1, 3)
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- ):
- is_cross_attention = encoder_hidden_states is not None
- if is_cross_attention:
- key_layer = self.transpose_for_scores(
- self.key(encoder_hidden_states))
- value_layer = self.transpose_for_scores(
- self.value(encoder_hidden_states))
- else:
- key_layer = self.transpose_for_scores(self.key(hidden_states))
- value_layer = self.transpose_for_scores(self.value(hidden_states))
- mixed_query_layer = self.query(hidden_states)
- query_layer = self.transpose_for_scores(mixed_query_layer)
- attention_scores = torch.matmul(query_layer,
- key_layer.transpose(-1, -2))
- attention_probs = torch.softmax(attention_scores * self.scaling,
- dim=-1)
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs_dropped = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs_dropped, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- context_layer = context_layer.view(*context_layer.size()[:-2],
- self.all_head_size)
- return context_layer
- class Blip2QFormerSelfOutput(nn.Module):
- def __init__(self, config: Blip2QFormerConfig) -> None:
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size,
- eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(
- self,
- hidden_states: torch.Tensor,
- input_tensor: torch.Tensor,
- ) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class Blip2QFormerAttention(nn.Module):
- def __init__(
- self,
- config: Blip2QFormerConfig,
- *,
- quant_config: Optional[QuantizationConfig],
- cache_config: Optional[CacheConfig],
- is_cross_attention: bool = False,
- ) -> None:
- super().__init__()
- self.attention = Blip2QFormerMultiHeadAttention(
- config,
- quant_config=quant_config,
- cache_config=cache_config,
- is_cross_attention=is_cross_attention,
- )
- self.output = Blip2QFormerSelfOutput(config)
- def forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
- ) -> Tuple[torch.Tensor]:
- self_output = self.attention(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- )
- attention_output = self.output(self_output, hidden_states)
- return attention_output
- class Blip2QFormerIntermediate(nn.Module):
- def __init__(self, config: Blip2QFormerConfig) -> None:
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- self.intermediate_act_fn = get_act_fn(config.hidden_act)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
- class Blip2QFormerOutput(nn.Module):
- def __init__(self, config: Blip2QFormerConfig) -> None:
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = nn.LayerNorm(config.hidden_size,
- eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(
- self,
- hidden_states: torch.Tensor,
- input_tensor: torch.Tensor,
- ) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
- class Blip2QFormerLayer(nn.Module):
- def __init__(
- self,
- config: Blip2QFormerConfig,
- *,
- quant_config: Optional[QuantizationConfig],
- cache_config: Optional[CacheConfig],
- layer_idx: int,
- ) -> None:
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = Blip2QFormerAttention(config,
- quant_config=quant_config,
- cache_config=cache_config)
- self.layer_idx = layer_idx
- if layer_idx % config.cross_attention_frequency == 0:
- self.crossattention = Blip2QFormerAttention(
- config,
- quant_config=quant_config,
- cache_config=cache_config,
- is_cross_attention=True)
- self.has_cross_attention = True
- else:
- self.has_cross_attention = False
- self.intermediate_query = Blip2QFormerIntermediate(config)
- self.output_query = Blip2QFormerOutput(config)
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor,
- query_length: int,
- ):
- attention_output = self.attention(hidden_states)
- if query_length > 0:
- query_attention_output = attention_output[:, :query_length, :]
- if self.has_cross_attention:
- query_attention_output = self.crossattention(
- query_attention_output,
- encoder_hidden_states=encoder_hidden_states,
- )
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk_query,
- self.chunk_size_feed_forward,
- self.seq_len_dim,
- query_attention_output,
- )
- if attention_output.shape[1] > query_length:
- layer_output_text = apply_chunking_to_forward(
- self.feed_forward_chunk,
- self.chunk_size_feed_forward,
- self.seq_len_dim,
- attention_output[:, query_length:, :],
- )
- layer_output = torch.cat([layer_output, layer_output_text],
- dim=1)
- else:
- layer_output = apply_chunking_to_forward(
- self.feed_forward_chunk,
- self.chunk_size_feed_forward,
- self.seq_len_dim,
- attention_output,
- )
- return layer_output
- def feed_forward_chunk(self,
- attention_output: torch.Tensor) -> torch.Tensor:
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- def feed_forward_chunk_query(
- self, attention_output: torch.Tensor) -> torch.Tensor:
- intermediate_output = self.intermediate_query(attention_output)
- layer_output = self.output_query(intermediate_output, attention_output)
- return layer_output
- class Blip2QFormerEncoder(nn.Module):
- def __init__(
- self,
- config: Blip2QFormerConfig,
- *,
- quant_config: Optional[QuantizationConfig],
- cache_config: Optional[CacheConfig],
- ) -> None:
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([
- Blip2QFormerLayer(config,
- quant_config=quant_config,
- cache_config=cache_config,
- layer_idx=layer_idx)
- for layer_idx in range(config.num_hidden_layers)
- ])
- def forward(
- self,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor,
- query_length: int,
- ) -> torch.Tensor:
- for i in range(self.config.num_hidden_layers):
- layer_module = self.layer[i]
- hidden_states = layer_module(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- query_length=query_length,
- )
- return hidden_states
- # Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025
- class Blip2QFormerModel(nn.Module):
- def __init__(
- self,
- config: Blip2QFormerConfig,
- *,
- quant_config: Optional[QuantizationConfig],
- cache_config: Optional[CacheConfig],
- ) -> None:
- super().__init__()
- self.config = config
- self.layernorm = nn.LayerNorm(config.hidden_size,
- eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.encoder = Blip2QFormerEncoder(config,
- quant_config=quant_config,
- cache_config=cache_config)
- def forward(
- self,
- query_embeds: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor,
- ) -> torch.Tensor:
- query_length = query_embeds.shape[1]
- embedding_output = self.layernorm(query_embeds)
- embedding_output = self.dropout(embedding_output)
- sequence_output = self.encoder(
- embedding_output,
- encoder_hidden_states=encoder_hidden_states,
- query_length=query_length,
- )
- return sequence_output
- def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
- return hf_config.num_query_tokens
- def get_max_blip2_image_tokens(ctx: InputContext):
- hf_config = ctx.get_hf_config(Blip2Config)
- vision_config = hf_config.vision_config
- if isinstance(vision_config, Blip2VisionConfig):
- return get_max_blip_image_tokens(vision_config)
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
- def dummy_seq_data_for_blip2(
- hf_config: Blip2Config,
- seq_len: int,
- num_images: int,
- *,
- image_token_id: int,
- image_feature_size_override: Optional[int] = None,
- ):
- if image_feature_size_override is None:
- image_feature_size = get_blip2_image_feature_size(hf_config)
- else:
- image_feature_size = image_feature_size_override
- token_ids = array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- [image_token_id]) * image_feature_size * num_images
- token_ids += array(APHRODITE_TOKEN_ID_ARRAY_TYPE,
- [0]) * (seq_len - image_feature_size * num_images)
- return SequenceData(token_ids)
- def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
- mm_counts: Mapping[str, int]):
- hf_config = ctx.get_hf_config(Blip2Config)
- vision_config = hf_config.vision_config
- num_images = mm_counts["image"]
- seq_data = dummy_seq_data_for_blip2(
- hf_config,
- seq_len,
- num_images,
- image_token_id=BLIP2_IMAGE_TOKEN_ID,
- )
- if isinstance(vision_config, Blip2VisionConfig):
- mm_data = dummy_image_for_blip(vision_config, num_images)
- return seq_data, mm_data
- msg = f"Unsupported vision config: {type(vision_config)}"
- raise NotImplementedError(msg)
- def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
- multi_modal_data = llm_inputs.get("multi_modal_data")
- if multi_modal_data is None or "image" not in multi_modal_data:
- return llm_inputs
- hf_config = ctx.get_hf_config(Blip2Config)
- image_feature_size = get_blip2_image_feature_size(hf_config)
- # The original model places image tokens at the front
- # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514
- new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
- new_token_ids += llm_inputs["prompt_token_ids"]
- new_prompt = llm_inputs.get("prompt")
- if new_prompt is not None:
- new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt
- return LLMInputs(prompt_token_ids=new_token_ids,
- prompt=new_prompt,
- multi_modal_data=multi_modal_data)
- @MULTIMODAL_REGISTRY.register_image_input_mapper()
- @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
- @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
- @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
- class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
- def __init__(self,
- config: Blip2Config,
- multimodal_config: MultiModalConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None) -> None:
- super().__init__()
- self.config = config
- self.multimodal_config = multimodal_config
- # TODO: Optionally initializes this for supporting embeddings.
- self.vision_model = BlipVisionModel(config.vision_config)
- self.query_tokens = nn.Parameter(
- torch.zeros(1, config.num_query_tokens,
- config.qformer_config.hidden_size))
- self.qformer = Blip2QFormerModel(config.qformer_config,
- cache_config=cache_config,
- quant_config=quant_config)
- self.language_projection = nn.Linear(
- config.qformer_config.hidden_size,
- config.text_config.hidden_size,
- bias=True,
- )
- self.quant_config = quant_config
- self.language_model = OPTModel(config.text_config, cache_config,
- quant_config)
- self.unpadded_vocab_size = config.text_config.vocab_size
- self.logits_processor = LogitsProcessor(self.unpadded_vocab_size)
- self.sampler = Sampler()
- def get_lm_head(self):
- return self.language_model.decoder.embed_tokens
- def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
- h = w = self.config.vision_config.image_size
- expected_dims = (3, h, w)
- actual_dims = tuple(data.shape[1:])
- if actual_dims != expected_dims:
- expected_expr = ("batch_size", *map(str, expected_dims))
- raise ValueError(
- f"The expected shape of pixel values is {expected_expr}. "
- f"You supplied {tuple(data.shape)}.")
- return data
- def _parse_and_validate_image_input(
- self, **kwargs: object) -> Optional[Blip2ImageInputs]:
- pixel_values = kwargs.pop("pixel_values", None)
- image_embeds = kwargs.pop("image_embeds", None)
- if pixel_values is None and image_embeds is None:
- return None
- if pixel_values is not None:
- if not isinstance(pixel_values, torch.Tensor):
- raise ValueError("Incorrect type of pixel values. "
- f"Got type: {type(pixel_values)}")
- # Remove the N dimension until multiple images are supported.
- pixel_values = pixel_values.squeeze(1)
- return Blip2ImagePixelInputs(
- type="pixel_values",
- data=self._validate_pixel_values(pixel_values),
- )
- if image_embeds is not None:
- if not isinstance(image_embeds, torch.Tensor):
- raise ValueError("Incorrect type of image embeddings. "
- f"Got type: {type(image_embeds)}")
- # Remove the N dimension until multiple images are supported.
- image_embeds = image_embeds.squeeze(1)
- return Blip2ImageEmbeddingInputs(
- type="image_embeds",
- data=image_embeds,
- )
- raise AssertionError("This line should be unreachable.")
- def _image_pixels_to_features(self, vision_model: BlipVisionModel,
- pixel_values: torch.Tensor) -> torch.Tensor:
- # NOTE: we skip the step to select the vision feature layer since
- # this is already done inside the vision tower
- image_features = vision_model(pixel_values)
- return image_features
- def _process_image_pixels(self,
- inputs: Blip2ImagePixelInputs) -> torch.Tensor:
- assert self.vision_model is not None
- pixel_values = inputs["data"]
- return self._image_pixels_to_features(self.vision_model, pixel_values)
- def _process_image_input(self,
- image_input: Blip2ImageInputs) -> torch.Tensor:
- if image_input["type"] == "image_embeds":
- return image_input["data"]
- assert self.vision_model is not None
- image_features = self._process_image_pixels(image_input)
- query_tokens = self.query_tokens.expand(image_features.shape[0], -1,
- -1)
- query_output = self.qformer(
- query_embeds=query_tokens,
- encoder_hidden_states=image_features,
- )
- return self.language_projection(query_output)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- **kwargs: object,
- ) -> SamplerOutput:
- """Run forward pass for BLIP-2.
- One key thing to understand is the `input_ids` already accounts for the
- positions of the to-be-inserted image embeddings.
- Concretely, consider a text prompt:
- `"Question: What's the content of the image? Answer:"`.
- Tokenizer outputs:
- `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`.
- To reserve space in KV cache, we have to insert placeholder tokens
- before they are inputted to the model, so the input processor prepends
- dummy tokens (denoted as `50265`), resulting in:
- `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`.
- We insert 32 tokens since it corresponds to the number of query
- embeddings outputted by the Q-Former and inputted to the language model.
- This way, the `positions` and `attn_metadata` are consistent
- with the `input_ids`.
- Args:
- input_ids: Flattened (concatenated) input_ids corresponding to a
- batch.
- pixel_values: The pixels in each input image.
-
- See also:
- :class:`Blip2ImageInputs`
- """
- image_input = self._parse_and_validate_image_input(**kwargs)
- if image_input is not None:
- vision_embeddings = self._process_image_input(image_input)
- inputs_embeds = self.language_model.get_input_embeddings(input_ids)
- inputs_embeds = merge_multimodal_embeddings(
- input_ids, inputs_embeds, vision_embeddings,
- BLIP2_IMAGE_TOKEN_ID)
- input_ids = None
- else:
- inputs_embeds = None
- hidden_states = self.language_model(input_ids,
- positions,
- kv_caches,
- attn_metadata,
- inputs_embeds=inputs_embeds)
- return hidden_states
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[torch.Tensor]:
- logits = self.logits_processor(self.get_lm_head(), hidden_states,
- sampling_metadata)
- return logits
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- # only doing this for language model part for now.
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
- params_dict = dict(self.named_parameters())
- for name, loaded_weight in weights:
- if "lm_head.weight" in name:
- continue
- if "rotary_emb.inv_freq" in name:
- continue
- for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
- if key_to_modify in name:
- name = name.replace(key_to_modify, new_key)
- use_default_weight_loading = False
- if "vision" in name:
- if self.vision_model is not None:
- # BlipVisionModel does not need sharding
- use_default_weight_loading = True
- else:
- for (param_name, weight_name,
- shard_id) in stacked_params_mapping:
- if weight_name not in name:
- continue
- param = params_dict[name.replace(weight_name, param_name)]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- use_default_weight_loading = True
- if use_default_weight_loading:
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
|