Преглед изворни кода

VLM: refactor composite weight loading logic (#1097)

AlpinDale пре 1 месец
родитељ
комит
91d03c04d2

+ 5 - 9
aphrodite/modeling/models/internvl.py

@@ -4,7 +4,6 @@
 # Copyright (c) 2023 OpenGVLab
 # Licensed under The MIT License [see LICENSE for details]
 # --------------------------------------------------------
-import itertools
 import re
 from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                     TypedDict, Union)
@@ -33,7 +32,7 @@ from aphrodite.quantization import QuantizationConfig
 from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
                    get_clip_num_patches)
 from .interfaces import SupportsMultiModal
-from .utils import (filter_weights, flatten_bn,
+from .utils import (flatten_bn, group_weights_with_prefix,
                     init_aphrodite_registered_model,
                     merge_multimodal_embeddings)
 
@@ -515,21 +514,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         # prepare weight iterators for components
-        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
+        weights_group = group_weights_with_prefix(weights)
 
         # load vision encoder
-        vit_weights = filter_weights(vit_weights, "vision_model")
-        self.vision_model.load_weights(vit_weights)
+        self.vision_model.load_weights(weights_group["vision_model"])
 
         # load mlp projector
-        mlp_weights = filter_weights(mlp_weights, "mlp1")
         mlp_params_dict = dict(self.mlp1.named_parameters())
-        for name, loaded_weight in mlp_weights:
+        for name, loaded_weight in weights_group["mlp1"]:
             param = mlp_params_dict[name]
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
 
         # load llm backbone
-        llm_weights = filter_weights(llm_weights, "language_model")
-        self.language_model.load_weights(llm_weights)
+        self.language_model.load_weights(weights_group["language_model"])

+ 5 - 9
aphrodite/modeling/models/llava.py

@@ -1,4 +1,3 @@
-import itertools
 from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                     TypedDict, Union)
 
@@ -26,7 +25,7 @@ from .interfaces import SupportsMultiModal
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                      input_processor_for_siglip)
-from .utils import (filter_weights, flatten_bn,
+from .utils import (flatten_bn, group_weights_with_prefix,
                     init_aphrodite_registered_model,
                     merge_multimodal_embeddings)
 
@@ -392,21 +391,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         # prepare weight iterators for components
-        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
+        weights_group = group_weights_with_prefix(weights)
 
         # load vision encoder
-        vit_weights = filter_weights(vit_weights, "vision_tower")
-        self.vision_tower.load_weights(vit_weights)
+        self.vision_tower.load_weights(weights_group["vision_tower"])
 
         # load mlp projector
-        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
         mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
-        for name, loaded_weight in mlp_weights:
+        for name, loaded_weight in weights_group["multi_modal_projector"]:
             param = mlp_params_dict[name]
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
 
         # load llm backbone
-        llm_weights = filter_weights(llm_weights, "language_model")
-        self.language_model.load_weights(llm_weights)
+        self.language_model.load_weights(weights_group["language_model"])

+ 6 - 12
aphrodite/modeling/models/llava_next.py

@@ -1,4 +1,3 @@
-import itertools
 from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                     TypedDict, Union)
 
@@ -29,7 +28,7 @@ from .llava import LlavaMultiModalProjector
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip, get_siglip_image_feature_size,
                      get_siglip_patch_grid_length, input_processor_for_siglip)
-from .utils import (filter_weights, flatten_bn,
+from .utils import (flatten_bn, group_weights_with_prefix,
                     init_aphrodite_registered_model,
                     merge_multimodal_embeddings)
 
@@ -628,25 +627,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         # prepare weight iterators for components
-        vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
-            weights, 4)
+        weights_group = group_weights_with_prefix(weights)
 
         # load vision encoder
-        vit_weights = filter_weights(vit_weights, "vision_tower")
-        self.vision_tower.load_weights(vit_weights)
+        self.vision_tower.load_weights(weights_group["vision_tower"])
 
         # load mlp projector
-        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
         mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
-        for name, loaded_weight in mlp_weights:
+        for name, loaded_weight in weights_group["multi_modal_projector"]:
             param = mlp_params_dict[name]
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
 
         # load newline
-        newline_weights = filter_weights(newline_weights, "image_newline")
-        for name, loaded_weight in newline_weights:
+        for name, loaded_weight in weights_group["image_newline"]:
             assert name == ""
             param = self.image_newline
             weight_loader = getattr(param, "weight_loader",
@@ -654,6 +649,5 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
             weight_loader(param, loaded_weight)
 
         # load llm backbone
-        llm_weights = filter_weights(llm_weights, "language_model")
-        self.language_model.load_weights(llm_weights)
+        self.language_model.load_weights(weights_group["language_model"])
     

+ 11 - 15
aphrodite/modeling/models/llava_next_video.py

@@ -1,4 +1,3 @@
-import itertools
 import math
 from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                     TypedDict, Union)
@@ -28,7 +27,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
 from .interfaces import SupportsMultiModal
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip)
-from .utils import (filter_weights, init_aphrodite_registered_model,
+from .utils import (group_weights_with_prefix, init_aphrodite_registered_model,
                     merge_multimodal_embeddings)
 
 # For profile run
@@ -427,22 +426,19 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
         return self.language_model.sample(logits, sampling_metadata)
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
-        # prepare weight iterators
-        vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
-            weights, 4
-        )
+        # prepare weight iterators for components
+        weights_group = group_weights_with_prefix(weights)
+
         # load vision encoder
-        vit_weights = filter_weights(vit_weights, "vision_tower")
-        self.vision_tower.load_weights(vit_weights)
+        self.vision_tower.load_weights(weights_group["vision_tower"])
+
         # load mlp projector
-        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
         mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
-        for name, loaded_weight in mlp_weights:
+        for name, loaded_weight in weights_group["multi_modal_projector"]:
             param = mlp_params_dict[name]
-            weight_loader = getattr(
-                param, "weight_loader", default_weight_loader
-            )
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
             weight_loader(param, loaded_weight)
+
         # load llm backbone
-        llm_weights = filter_weights(llm_weights, "language_model")
-        self.language_model.load_weights(llm_weights)
+        self.language_model.load_weights(weights_group["language_model"])

+ 5 - 9
aphrodite/modeling/models/paligemma.py

@@ -1,4 +1,3 @@
-import itertools
 from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
                     TypedDict, Union)
 
@@ -24,7 +23,7 @@ from aphrodite.quantization.base_config import QuantizationConfig
 from .interfaces import SupportsMultiModal
 from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                      dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
-from .utils import filter_weights, merge_multimodal_embeddings
+from .utils import group_weights_with_prefix, merge_multimodal_embeddings
 
 
 class PaliGemmaImagePixelInputs(TypedDict):
@@ -288,21 +287,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         # prepare weight iterators for components
-        vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
+        weights_group = group_weights_with_prefix(weights)
 
         # load vision tower
-        vit_weights = filter_weights(vit_weights, "vision_tower")
-        self.vision_tower.load_weights(vit_weights)
+        self.vision_tower.load_weights(weights_group["vision_tower"])
 
         # load mlp projector
-        mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
         mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
-        for name, loaded_weight in mlp_weights:
+        for name, loaded_weight in weights_group["multi_modal_projector"]:
             param = mlp_params_dict[name]
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
 
         # load llm backbone
-        llm_weights = filter_weights(llm_weights, "language_model")
-        self.language_model.load_weights(llm_weights)
+        self.language_model.load_weights(weights_group["language_model"])

+ 5 - 7
aphrodite/modeling/models/ultravox.py

@@ -1,7 +1,6 @@
 # Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
 """PyTorch Ultravox model."""
 
-import itertools
 import math
 from array import array
 from functools import lru_cache
@@ -29,7 +28,8 @@ from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.interfaces import SupportsMultiModal
-from aphrodite.modeling.models.utils import (filter_weights, flatten_bn,
+from aphrodite.modeling.models.utils import (flatten_bn,
+                                             group_weights_with_prefix,
                                              init_aphrodite_registered_model,
                                              merge_multimodal_embeddings)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
@@ -449,11 +449,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         # prepare weight iterators for components
-        projector_weights, llm_weights = itertools.tee(weights, 2)
+        weights_group = group_weights_with_prefix(weights)
 
         # load projector weights
-        projector_weights = filter_weights(projector_weights,
-                                           "multi_modal_projector")
+        projector_weights = weights_group["multi_modal_projector"]
         projector_params_dict = dict(
             self.multi_modal_projector.named_parameters())
         for name, loaded_weight in projector_weights:
@@ -463,5 +462,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
             weight_loader(param, loaded_weight)
 
         # load llm backbone
-        llm_weights = filter_weights(llm_weights, "language_model")
-        self.language_model.load_weights(llm_weights)
+        self.language_model.load_weights(weights_group["language_model"])

+ 33 - 1
aphrodite/modeling/models/utils.py

@@ -1,3 +1,5 @@
+import itertools
+from collections import UserDict
 from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
                     Union, overload)
 
@@ -16,7 +18,22 @@ from aphrodite.multimodal.base import NestedTensors
 from aphrodite.quantization import QuantizationConfig
 
 
-def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
+class WeightsGroup(UserDict):
+    """
+    Wraps grouped weights dictionary for a more informative error message
+    when attempting to access a weight component that does not exist.
+    """
+    def __getitem__(self, key: str) -> int:
+        try:
+            return super().__getitem__(key)
+        except KeyError as exc:
+            msg = (f"There is no weights named with the prefix: {key}. "
+                   f"Available prefix: {set(self.keys())}")
+            raise KeyError(msg) from exc
+
+
+def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
+                   prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
     """
     Helper function to load weights for inner aphrodite models.
 
@@ -32,6 +49,21 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
             yield name, loaded_weight
 
 
+def group_weights_with_prefix(
+    weights: Iterable[Tuple[str, torch.Tensor]]
+) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
+    """
+    Helper function to group weights with prefix
+    """
+    init_weights, repeated_weights = itertools.tee(weights, 2)
+    weights_prefix = {name.split(".")[0] for name, _ in init_weights}
+    repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
+    return WeightsGroup({
+        prefix: filter_weights(component, prefix)
+        for component, prefix in zip(repeated_weights, weights_prefix)
+    })
+
+
 def init_aphrodite_registered_model(
     hf_config: PretrainedConfig,
     cache_config: Optional[CacheConfig],