Selaa lähdekoodia

feat: add progress bar for loading individual weight modules (#640)

* wip

* chore:add progress_bar to all models (#649)

* chore:add progress_bar to all models

* Remove whitespace

---------

Co-authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>

* ruff

---------

Co-authored-by: ewof <marroq@marroq.xyz>
AlpinDale 6 kuukautta sitten
vanhempi
commit
3f712cd287
51 muutettua tiedostoa jossa 226 lisäystä ja 61 poistoa
  1. 22 0
      aphrodite/common/utils.py
  2. 6 6
      aphrodite/modeling/model_loader/weight_utils.py
  3. 4 1
      aphrodite/modeling/models/arctic.py
  4. 4 1
      aphrodite/modeling/models/baichuan.py
  5. 4 1
      aphrodite/modeling/models/blip2.py
  6. 4 1
      aphrodite/modeling/models/bloom.py
  7. 4 2
      aphrodite/modeling/models/chameleon.py
  8. 4 1
      aphrodite/modeling/models/chatglm.py
  9. 6 3
      aphrodite/modeling/models/commandr.py
  10. 4 1
      aphrodite/modeling/models/dbrx.py
  11. 4 1
      aphrodite/modeling/models/decilm.py
  12. 4 1
      aphrodite/modeling/models/deepseek.py
  13. 4 1
      aphrodite/modeling/models/deepseek_v2.py
  14. 4 1
      aphrodite/modeling/models/falcon.py
  15. 4 1
      aphrodite/modeling/models/fuyu.py
  16. 4 1
      aphrodite/modeling/models/gemma.py
  17. 4 1
      aphrodite/modeling/models/gemma2.py
  18. 4 1
      aphrodite/modeling/models/gpt2.py
  19. 4 1
      aphrodite/modeling/models/gpt_bigcode.py
  20. 4 1
      aphrodite/modeling/models/gpt_j.py
  21. 4 1
      aphrodite/modeling/models/gpt_neox.py
  22. 4 1
      aphrodite/modeling/models/intern_vit.py
  23. 4 1
      aphrodite/modeling/models/internlm2.py
  24. 4 1
      aphrodite/modeling/models/jais.py
  25. 4 1
      aphrodite/modeling/models/jamba.py
  26. 4 2
      aphrodite/modeling/models/llama.py
  27. 4 1
      aphrodite/modeling/models/llama_embedding.py
  28. 4 1
      aphrodite/modeling/models/medusa.py
  29. 4 1
      aphrodite/modeling/models/minicpm.py
  30. 4 1
      aphrodite/modeling/models/minicpmv.py
  31. 4 1
      aphrodite/modeling/models/mixtral.py
  32. 4 1
      aphrodite/modeling/models/mixtral_quant.py
  33. 4 1
      aphrodite/modeling/models/mlp_speculator.py
  34. 4 1
      aphrodite/modeling/models/mpt.py
  35. 4 1
      aphrodite/modeling/models/nemotron.py
  36. 4 1
      aphrodite/modeling/models/olmo.py
  37. 4 1
      aphrodite/modeling/models/opt.py
  38. 4 1
      aphrodite/modeling/models/orion.py
  39. 4 1
      aphrodite/modeling/models/paligemma.py
  40. 4 1
      aphrodite/modeling/models/persimmon.py
  41. 4 1
      aphrodite/modeling/models/phi.py
  42. 4 1
      aphrodite/modeling/models/phi3_small.py
  43. 4 1
      aphrodite/modeling/models/phi3v.py
  44. 4 1
      aphrodite/modeling/models/qwen.py
  45. 4 1
      aphrodite/modeling/models/qwen2.py
  46. 4 2
      aphrodite/modeling/models/qwen2_moe.py
  47. 4 1
      aphrodite/modeling/models/siglip.py
  48. 4 1
      aphrodite/modeling/models/stablelm.py
  49. 4 1
      aphrodite/modeling/models/starcoder2.py
  50. 4 2
      aphrodite/modeling/models/utils.py
  51. 4 1
      aphrodite/modeling/models/xverse.py

+ 22 - 0
aphrodite/common/utils.py

@@ -27,10 +27,13 @@ import psutil
 import torch
 import torch.types
 from loguru import logger
+from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
+                           SpinnerColumn, TextColumn, TimeElapsedColumn)
 from typing_extensions import ParamSpec, TypeIs, assert_never
 
 from aphrodite import _custom_ops as ops
 from aphrodite.common.logger import enable_trace_function_call
+from aphrodite.distributed import get_tensor_model_parallel_rank
 
 # Exception strings for non-implemented encoder/decoder scenarios
 
@@ -1129,3 +1132,22 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
     """Utility function to run async task in a lock"""
     async with lock:
         return await task(*args, **kwargs)
+
+
+def progress_bar(iterable, desc="Processing"):
+    show_progress = get_tensor_model_parallel_rank() == 0
+    if show_progress:
+        with Progress(
+            SpinnerColumn(),
+            TextColumn("[progress.description]{task.description}"),
+            BarColumn(),
+            MofNCompleteColumn(),
+            TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
+            TimeElapsedColumn(),
+        ) as progress:
+            task = progress.add_task(f"[cyan]{desc}", total=len(iterable))
+            for item in iterable:
+                yield item
+                progress.update(task, advance=1)
+    else:
+        yield from iterable

+ 6 - 6
aphrodite/modeling/model_loader/weight_utils.py

@@ -330,8 +330,8 @@ def np_cache_weights_iterator(
 
     Will dump the model weights to numpy files if they are not already dumped.
     """
-    enable_tqdm = not torch.distributed.is_initialized(
-    ) or torch.distributed.get_rank() == 0
+    enable_tqdm = False #not torch.distributed.is_initialized(
+    #) or torch.distributed.get_rank() == 0
     # Convert the model weights from torch tensors to numpy arrays for
     # faster loading.
     np_folder = os.path.join(hf_folder, "np")
@@ -370,8 +370,8 @@ def safetensors_weights_iterator(
     hf_weights_files: List[str]
 ) -> Generator[Tuple[str, torch.Tensor], None, None]:
     """Iterate over the weights in the model safetensor files."""
-    enable_tqdm = not torch.distributed.is_initialized(
-    ) or torch.distributed.get_rank() == 0
+    enable_tqdm = False #not torch.distributed.is_initialized(
+    #) or torch.distributed.get_rank() == 0
     for st_file in tqdm(
             hf_weights_files,
             desc="Loading safetensors checkpoint shards",
@@ -387,8 +387,8 @@ def pt_weights_iterator(
     hf_weights_files: List[str]
 ) -> Generator[Tuple[str, torch.Tensor], None, None]:
     """Iterate over the weights in the model bin/pt files."""
-    enable_tqdm = not torch.distributed.is_initialized(
-    ) or torch.distributed.get_rank() == 0
+    enable_tqdm = False #not torch.distributed.is_initialized(
+    #) or torch.distributed.get_rank() == 0
     for bin_file in tqdm(
             hf_weights_files,
             desc="Loading pt checkpoint shards",

+ 4 - 1
aphrodite/modeling/models/arctic.py

@@ -8,6 +8,7 @@ from torch import nn
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -487,7 +488,9 @@ class ArcticForCausalLM(nn.Module):
             "It will take ~10 minutes loading from the 16-bit weights. "
             "Alternatively, use the prequantized 8-bit weights of arctic "
             "and set load-format to `sharded_state` will accelerate loading.")
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue

+ 4 - 1
aphrodite/modeling/models/baichuan.py

@@ -28,6 +28,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -364,7 +365,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if name == "lm_head.weight":

+ 4 - 1
aphrodite/modeling/models/blip2.py

@@ -9,6 +9,7 @@ from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
                                        SequenceData)
+from aphrodite.common.utils import progress_bar
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -656,7 +657,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
         ]
         params_dict = dict(self.named_parameters())
 
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "lm_head.weight" in name:
                 continue
             if "rotary_emb.inv_freq" in name:

+ 4 - 1
aphrodite/modeling/models/bloom.py

@@ -26,6 +26,7 @@ from transformers import BloomConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -307,7 +308,9 @@ class BloomForCausalLM(nn.Module):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if name == "lm_head.weight":
                 continue
             if not name.startswith("transformer."):

+ 4 - 2
aphrodite/modeling/models/chameleon.py

@@ -12,7 +12,7 @@ from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
                                        SequenceData)
-from aphrodite.common.utils import print_warning_once
+from aphrodite.common.utils import print_warning_once, progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -986,7 +986,9 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
 

+ 4 - 1
aphrodite/modeling/models/chatglm.py

@@ -11,6 +11,7 @@ from torch.nn import LayerNorm
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -385,7 +386,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_pos_emb.inv_freq" in name:
                 continue
             if "word_embeddings" in name:

+ 6 - 3
aphrodite/modeling/models/commandr.py

@@ -31,6 +31,7 @@ from transformers import CohereConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -40,8 +41,8 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import \
-    VocabParallelEmbedding
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
@@ -389,7 +390,9 @@ class CohereForCausalLM(nn.Module):
         ]
         params_dict = dict(self.named_parameters())
         loaded_params = set()
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             for param_name, shard_name, shard_id in stacked_params_mapping:
                 if shard_name not in name:
                     continue

+ 4 - 1
aphrodite/modeling/models/dbrx.py

@@ -7,6 +7,7 @@ import torch.nn as nn
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -407,7 +408,9 @@ class DbrxForCausalLM(nn.Module):
             f"experts.mlp.{weight_name}",
         ) for weight_name in ["w1", "v1", "w2"]]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             for param_name, weight_name in expert_params_mapping:
                 if weight_name not in name:
                     continue

+ 4 - 1
aphrodite/modeling/models/decilm.py

@@ -29,6 +29,7 @@ import torch
 from transformers import LlamaConfig
 
 from aphrodite.common.config import CacheConfig, LoRAConfig
+from aphrodite.common.utils import progress_bar
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.llama import LlamaForCausalLM
 from aphrodite.quantization.base_config import QuantizationConfig
@@ -76,7 +77,9 @@ class DeciLMForCausalLM(LlamaForCausalLM):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
 

+ 4 - 1
aphrodite/modeling/models/deepseek.py

@@ -30,6 +30,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -419,7 +420,9 @@ class DeepseekForCausalLM(nn.Module):
         ]
 
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 4 - 1
aphrodite/modeling/models/deepseek_v2.py

@@ -31,6 +31,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -485,7 +486,9 @@ class DeepseekV2ForCausalLM(nn.Module):
             num_experts=self.config.n_routed_experts)
 
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 4 - 1
aphrodite/modeling/models/falcon.py

@@ -29,6 +29,7 @@ from transformers import FalconConfig as HF_FalconConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -418,7 +419,9 @@ class FalconForCausalLM(nn.Module):
             total_num_kv_heads = total_num_heads
         num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if name == "lm_head.weight":
                 # Falcon uses tied embeddings.
                 continue

+ 4 - 1
aphrodite/modeling/models/fuyu.py

@@ -28,6 +28,7 @@ from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
                                        SequenceData)
+from aphrodite.common.utils import progress_bar
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.linear import ColumnParallelLinear
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
@@ -300,7 +301,9 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name

+ 4 - 1
aphrodite/modeling/models/gemma.py

@@ -25,6 +25,7 @@ from transformers import GemmaConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
@@ -374,7 +375,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
         ]
         params_dict = dict(self.named_parameters())
         loaded_params: Set[str] = set()
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             for (param_name, shard_name, shard_id) in stacked_params_mapping:
                 if shard_name not in name:
                     continue

+ 4 - 1
aphrodite/modeling/models/gemma2.py

@@ -25,6 +25,7 @@ from transformers import Gemma2Config
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.layernorm import GemmaRMSNorm
@@ -366,7 +367,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
         ]
         params_dict = dict(self.named_parameters())
         loaded_params: Set[str] = set()
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             for (param_name, shard_name, shard_id) in stacked_params_mapping:
                 if shard_name not in name:
                     continue

+ 4 - 1
aphrodite/modeling/models/gpt2.py

@@ -26,6 +26,7 @@ from transformers import GPT2Config
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -249,7 +250,9 @@ class GPT2LMHeadModel(nn.Module):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "lm_head.weight" in name:
                 # GPT-2 ties the weights of the embedding layer and the final
                 # linear layer.

+ 4 - 1
aphrodite/modeling/models/gpt_bigcode.py

@@ -27,6 +27,7 @@ from transformers import GPTBigCodeConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -268,7 +269,9 @@ class GPTBigCodeForCausalLM(nn.Module):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "lm_head.weight" in name:
                 continue
             if ".attn.bias" in name:

+ 4 - 1
aphrodite/modeling/models/gpt_j.py

@@ -25,6 +25,7 @@ from transformers import GPTJConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -269,7 +270,9 @@ class GPTJForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "attn.bias" in name or "attn.masked_bias" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 4 - 1
aphrodite/modeling/models/gpt_neox.py

@@ -25,6 +25,7 @@ from transformers import GPTNeoXConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -273,7 +274,9 @@ class GPTNeoXForCausalLM(nn.Module):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if ("attention.bias" in name or "attention.masked_bias" in name
                     or "rotary_emb.inv_freq" in name):
                 continue

+ 4 - 1
aphrodite/modeling/models/intern_vit.py

@@ -11,6 +11,7 @@ import torch.nn as nn
 import torch.nn.functional as F
 from transformers import PretrainedConfig
 
+from aphrodite.common.utils import progress_bar
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -272,7 +273,9 @@ class InternVisionModel(nn.Module):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             param = params_dict[name]
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)

+ 4 - 1
aphrodite/modeling/models/internlm2.py

@@ -8,6 +8,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -299,7 +300,9 @@ class InternLM2ForCausalLM(nn.Module):
             ("gate_up_proj", "w3", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 4 - 1
aphrodite/modeling/models/jais.py

@@ -28,6 +28,7 @@ from torch import nn
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -310,7 +311,9 @@ class JAISLMHeadModel(nn.Module):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "lm_head.weight" in name:
                 # GPT-2 ties the weights of the embedding layer and the final
                 # linear layer.

+ 4 - 1
aphrodite/modeling/models/jamba.py

@@ -12,6 +12,7 @@ from aphrodite.attention.backends.abstract import AttentionMetadata
 from aphrodite.attention.layer import Attention
 from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 # yapf: disable
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
@@ -716,7 +717,9 @@ class JambaForCausalLM(nn.Module, HasInnerState):
             num_experts=self.config.num_experts)
 
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
 

+ 4 - 2
aphrodite/modeling/models/llama.py

@@ -30,7 +30,7 @@ from transformers import LlamaConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.common.utils import is_hip
+from aphrodite.common.utils import is_hip, progress_bar
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
                                    get_pp_group,
                                    get_tensor_model_parallel_rank,
@@ -463,7 +463,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name

+ 4 - 1
aphrodite/modeling/models/llama_embedding.py

@@ -5,6 +5,7 @@ from torch import nn
 
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.sequence import PoolerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.modeling.layers.pooler import Pooler, PoolingType
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.llama import LlamaModel
@@ -58,7 +59,9 @@ class LlamaEmbeddingModel(nn.Module):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.model.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name

+ 4 - 1
aphrodite/modeling/models/medusa.py

@@ -4,6 +4,7 @@ import torch
 import torch.nn as nn
 
 from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
@@ -131,7 +132,9 @@ class Medusa(nn.Module):
 
         weights_map = {}
 
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             name = name.replace("medusa_heads.", "")
 
             if name == "token_map":

+ 4 - 1
aphrodite/modeling/models/minicpm.py

@@ -32,6 +32,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -503,7 +504,9 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
             for weight_name in ["w1", "w2", "w3"]
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name

+ 4 - 1
aphrodite/modeling/models/minicpmv.py

@@ -40,6 +40,7 @@ from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
                                        SequenceData)
+from aphrodite.common.utils import progress_bar
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.linear import ReplicatedLinear
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -654,7 +655,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                 if key_to_modify in name:
                     name = name.replace(key_to_modify, new_key)

+ 4 - 1
aphrodite/modeling/models/mixtral.py

@@ -30,6 +30,7 @@ from transformers import MixtralConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.fused_moe import FusedMoE
@@ -419,7 +420,9 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
             num_experts=self.config.num_local_experts)
 
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
 

+ 4 - 1
aphrodite/modeling/models/mixtral_quant.py

@@ -32,6 +32,7 @@ from transformers import MixtralConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -384,7 +385,9 @@ class MixtralForCausalLM(nn.Module):
         ]
 
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 4 - 1
aphrodite/modeling/models/mlp_speculator.py

@@ -5,6 +5,7 @@ import torch
 import torch.nn as nn
 
 from aphrodite.common.sequence import SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.modeling import SamplingMetadata
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
@@ -181,7 +182,9 @@ class MLPSpeculator(nn.Module):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             param = params_dict.get(name.replace("speculator.", ""))
             if param is not None:
                 weight_loader = getattr(param, "weight_loader",

+ 4 - 1
aphrodite/modeling/models/mpt.py

@@ -9,6 +9,7 @@ import torch.nn as nn
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -294,7 +295,9 @@ class MPTForCausalLM(nn.Module):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             # Skip loading extra bias for GPTQ models.
             if name.endswith(".bias") and name not in params_dict:
                 continue

+ 4 - 1
aphrodite/modeling/models/nemotron.py

@@ -31,6 +31,7 @@ from transformers import NemotronConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -490,7 +491,9 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
             (".qkv_proj", ".v_proj", "v"),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name

+ 4 - 1
aphrodite/modeling/models/olmo.py

@@ -30,6 +30,7 @@ from transformers import OlmoConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -334,7 +335,9 @@ class OlmoForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name

+ 4 - 1
aphrodite/modeling/models/opt.py

@@ -26,6 +26,7 @@ from transformers import OPTConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -344,7 +345,9 @@ class OPTForCausalLM(nn.Module):
             ("qkv_proj", "v_proj", "v"),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "lm_head.weight" in name:
                 continue
             if name.startswith("decoder."):

+ 4 - 1
aphrodite/modeling/models/orion.py

@@ -13,6 +13,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -300,7 +301,9 @@ class OrionForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name

+ 4 - 1
aphrodite/modeling/models/paligemma.py

@@ -8,6 +8,7 @@ from transformers import PaliGemmaConfig
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, MultiModalConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
@@ -284,7 +285,9 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
         ]
         params_dict = dict(self.named_parameters())
         loaded_params = set()
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                 if key_to_modify in name:
                     name = name.replace(key_to_modify, new_key)

+ 4 - 1
aphrodite/modeling/models/persimmon.py

@@ -31,6 +31,7 @@ from transformers.activations import ReLUSquaredActivation
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
@@ -301,7 +302,9 @@ class PersimmonForCausalLM(nn.Module):
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name

+ 4 - 1
aphrodite/modeling/models/phi.py

@@ -44,6 +44,7 @@ from transformers import PhiConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -305,7 +306,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
         ]
         params_dict = dict(self.named_parameters())
 
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
 

+ 4 - 1
aphrodite/modeling/models/phi3_small.py

@@ -8,6 +8,7 @@ from transformers.configuration_utils import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -436,7 +437,9 @@ class Phi3SmallForCausalLM(nn.Module):
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
 
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if name.endswith(".bias") and name not in params_dict:

+ 4 - 1
aphrodite/modeling/models/phi3v.py

@@ -28,6 +28,7 @@ from transformers import CLIPVisionConfig, PretrainedConfig
 from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import CacheConfig, ModelConfig, MultiModalConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.inputs import INPUT_REGISTRY, InputContext, LLMInputs
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
@@ -603,7 +604,9 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             # post_layernorm is not needed in CLIPVisionModel

+ 4 - 1
aphrodite/modeling/models/qwen.py

@@ -13,6 +13,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -272,7 +273,9 @@ class QWenLMHeadModel(nn.Module):
             ("gate_up_proj", "w1", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 4 - 1
aphrodite/modeling/models/qwen2.py

@@ -31,6 +31,7 @@ from transformers import Qwen2Config
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import (get_current_tp_rank_partition_size,
                                    get_pp_group,
                                    get_tensor_model_parallel_rank,
@@ -394,7 +395,9 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if self.config.tie_word_embeddings and "lm_head.weight" in name:

+ 4 - 2
aphrodite/modeling/models/qwen2_moe.py

@@ -32,7 +32,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
-from aphrodite.common.utils import print_warning_once
+from aphrodite.common.utils import print_warning_once, progress_bar
 from aphrodite.distributed import (get_pp_group,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
@@ -446,7 +446,9 @@ class Qwen2MoeForCausalLM(nn.Module):
             num_experts=self.config.num_experts)
 
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:

+ 4 - 1
aphrodite/modeling/models/siglip.py

@@ -14,6 +14,7 @@ from xformers.ops import memory_efficient_attention
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.sequence import SequenceData
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.inputs import LLMInputs
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -641,7 +642,9 @@ class SiglipVisionModel(nn.Module):
         params_dict = dict(self.named_parameters())
         layer_count = len(self.vision_model.encoder.layers)
 
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             # omit layers when num_hidden_layers_override is set
             if "vision_model.encoder.layers." in name:
                 layer_idx = int(name.split(".")[3])

+ 4 - 1
aphrodite/modeling/models/stablelm.py

@@ -28,6 +28,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
@@ -281,7 +282,9 @@ class StablelmForCausalLM(nn.Module):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
             if ("rotary_emb.cos_cached" in name

+ 4 - 1
aphrodite/modeling/models/starcoder2.py

@@ -27,6 +27,7 @@ from transformers import Starcoder2Config
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
@@ -290,7 +291,9 @@ class Starcoder2ForCausalLM(nn.Module):
         ]
 
         params_dict = dict(self.named_parameters(remove_duplicate=False))
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if "rotary_emb.inv_freq" in name:
                 continue
 

+ 4 - 2
aphrodite/modeling/models/utils.py

@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.common.config import (CacheConfig, LoRAConfig, MultiModalConfig,
                                      SchedulerConfig)
-from aphrodite.common.utils import is_pin_memory_available
+from aphrodite.common.utils import is_pin_memory_available, progress_bar
 from aphrodite.modeling.model_loader.loader import build_model
 from aphrodite.modeling.models import ModelRegistry
 from aphrodite.multimodal import BatchedTensors
@@ -21,7 +21,9 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
     See also:
         :ref:`init_aphrodite_registered_model`
     """
-    for name, loaded_weight in weights:
+    weights_list = list(weights)
+    for name, loaded_weight in progress_bar(weights_list,
+                                desc="Loading modules..."):
         name = name.split(".")
         if prefix == name.pop(0):
             name = ".".join(name)

+ 4 - 1
aphrodite/modeling/models/xverse.py

@@ -29,6 +29,7 @@ from transformers import PretrainedConfig
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
@@ -347,7 +348,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        for name, loaded_weight in weights:
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(weights_list,
+                                                desc="Loading modules..."):
             if ("rotary_emb.inv_freq" in name
                     or "rotary_emb.cos_cached" in name
                     or "rotary_emb.sin_cached" in name):