|
@@ -31,6 +31,7 @@ from transformers import CohereConfig
|
|
|
from aphrodite.attention import Attention, AttentionMetadata
|
|
|
from aphrodite.common.config import CacheConfig, LoRAConfig
|
|
|
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
|
|
|
+from aphrodite.common.utils import progress_bar
|
|
|
from aphrodite.distributed import (get_tensor_model_parallel_rank,
|
|
|
get_tensor_model_parallel_world_size)
|
|
|
from aphrodite.modeling.layers.activation import SiluAndMul
|
|
@@ -40,8 +41,8 @@ from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
|
|
|
from aphrodite.modeling.layers.logits_processor import LogitsProcessor
|
|
|
from aphrodite.modeling.layers.rotary_embedding import get_rope
|
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
|
-from aphrodite.modeling.layers.vocab_parallel_embedding import \
|
|
|
- VocabParallelEmbedding
|
|
|
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
|
|
|
+ VocabParallelEmbedding)
|
|
|
from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
@@ -389,7 +390,9 @@ class CohereForCausalLM(nn.Module):
|
|
|
]
|
|
|
params_dict = dict(self.named_parameters())
|
|
|
loaded_params = set()
|
|
|
- for name, loaded_weight in weights:
|
|
|
+ weights_list = list(weights)
|
|
|
+ for name, loaded_weight in progress_bar(weights_list,
|
|
|
+ desc="Loading modules..."):
|
|
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
|
|
if shard_name not in name:
|
|
|
continue
|