|
@@ -1,4 +1,3 @@
|
|
-import itertools
|
|
|
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
|
TypedDict, Union)
|
|
TypedDict, Union)
|
|
|
|
|
|
@@ -29,7 +28,7 @@ from .llava import LlavaMultiModalProjector
|
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
|
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
|
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
|
get_siglip_patch_grid_length, input_processor_for_siglip)
|
|
get_siglip_patch_grid_length, input_processor_for_siglip)
|
|
-from .utils import (filter_weights, flatten_bn,
|
|
|
|
|
|
+from .utils import (flatten_bn, group_weights_with_prefix,
|
|
init_aphrodite_registered_model,
|
|
init_aphrodite_registered_model,
|
|
merge_multimodal_embeddings)
|
|
merge_multimodal_embeddings)
|
|
|
|
|
|
@@ -628,25 +627,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|
|
|
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
# prepare weight iterators for components
|
|
# prepare weight iterators for components
|
|
- vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
|
|
|
|
- weights, 4)
|
|
|
|
|
|
+ weights_group = group_weights_with_prefix(weights)
|
|
|
|
|
|
# load vision encoder
|
|
# load vision encoder
|
|
- vit_weights = filter_weights(vit_weights, "vision_tower")
|
|
|
|
- self.vision_tower.load_weights(vit_weights)
|
|
|
|
|
|
+ self.vision_tower.load_weights(weights_group["vision_tower"])
|
|
|
|
|
|
# load mlp projector
|
|
# load mlp projector
|
|
- mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
|
|
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
|
- for name, loaded_weight in mlp_weights:
|
|
|
|
|
|
+ for name, loaded_weight in weights_group["multi_modal_projector"]:
|
|
param = mlp_params_dict[name]
|
|
param = mlp_params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
|
|
# load newline
|
|
# load newline
|
|
- newline_weights = filter_weights(newline_weights, "image_newline")
|
|
|
|
- for name, loaded_weight in newline_weights:
|
|
|
|
|
|
+ for name, loaded_weight in weights_group["image_newline"]:
|
|
assert name == ""
|
|
assert name == ""
|
|
param = self.image_newline
|
|
param = self.image_newline
|
|
weight_loader = getattr(param, "weight_loader",
|
|
weight_loader = getattr(param, "weight_loader",
|
|
@@ -654,6 +649,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|
weight_loader(param, loaded_weight)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
|
|
# load llm backbone
|
|
# load llm backbone
|
|
- llm_weights = filter_weights(llm_weights, "language_model")
|
|
|
|
- self.language_model.load_weights(llm_weights)
|
|
|
|
|
|
+ self.language_model.load_weights(weights_group["language_model"])
|
|
|
|
|