Browse Source

pass tensor progress bar as weight iterator

50h100a 4 tháng trước cách đây
mục cha
commit
1a5e612eaf

+ 5 - 4
aphrodite/modeling/model_loader/loader.py

@@ -24,7 +24,7 @@ from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
                                      DeviceConfig, LoadConfig, LoadFormat,
                                      LoRAConfig, ModelConfig, MultiModalConfig,
                                      ParallelConfig, SchedulerConfig)
-from aphrodite.common.utils import is_pin_memory_available
+from aphrodite.common.utils import is_pin_memory_available, tensor_progress_bar
 from aphrodite.modeling.model_loader.tensorizer import (
     TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
     serialize_aphrodite_model, tensorizer_weights_iterator)
@@ -345,14 +345,15 @@ class DefaultModelLoader(BaseModelLoader):
                 model = _initialize_model(model_config, self.load_config,
                                           lora_config, cache_config,
                                           scheduler_config)
-            model.load_weights(
-                *self._get_weights_iterator(model_config.model,
+                
+            weights, wgt_bytes = self._get_weights_iterator(model_config.model,
                                            model_config.revision,
                                            fall_back_to_pt=getattr(
                                                model,
                                                "fall_back_to_pt_during_load",
                                                True))
-            )
+            model.load_weights(tensor_progress_bar(weights, wgt_bytes,
+                                                   "Loading modules..."))
 
             for _, module in model.named_modules():
                 quant_method = getattr(module, "quant_method", None)

+ 3 - 4
aphrodite/modeling/models/llama.py

@@ -30,7 +30,7 @@ from transformers import LlamaConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.common.utils import is_hip, tensor_progress_bar
+from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
                                    get_pp_group,
                                    get_tensor_model_parallel_rank,
@@ -477,7 +477,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
                         device=device),
         })
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], total_bytes:int):
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -487,8 +487,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in tensor_progress_bar(weights, total_bytes,
-                                                       "Loading modules..."):
+        for name, loaded_weight in weights:
             name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
             if "rotary_emb.inv_freq" in name:
                 continue

+ 1 - 4
aphrodite/modeling/models/qwen2.py

@@ -31,7 +31,6 @@ from transformers import Qwen2Config
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
                                    get_pp_group,
                                    get_tensor_model_parallel_rank,
@@ -398,9 +397,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        weights_list = list(weights)
-        for name, loaded_weight in progress_bar(weights_list,
-                                                desc="Loading modules..."):
+        for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
             if self.config.tie_word_embeddings and "lm_head.weight" in name: