|
@@ -30,7 +30,7 @@ from transformers import LlamaConfig
|
|
|
from aphrodite.attention import Attention, AttentionMetadata
|
|
|
from aphrodite.common.config import CacheConfig, LoRAConfig
|
|
|
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
|
|
|
-from aphrodite.common.utils import is_hip, tensor_progress_bar
|
|
|
+from aphrodite.common.utils import is_hip
|
|
|
from aphrodite.distributed import (get_current_tp_rank_partition_size,
|
|
|
get_pp_group,
|
|
|
get_tensor_model_parallel_rank,
|
|
@@ -477,7 +477,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
|
|
device=device),
|
|
|
})
|
|
|
|
|
|
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], total_bytes:int):
|
|
|
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
|
stacked_params_mapping = [
|
|
|
# (param_name, shard_name, shard_id)
|
|
|
(".qkv_proj", ".q_proj", "q"),
|
|
@@ -487,8 +487,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
|
|
(".gate_up_proj", ".up_proj", 1),
|
|
|
]
|
|
|
params_dict = dict(self.named_parameters())
|
|
|
- for name, loaded_weight in tensor_progress_bar(weights, total_bytes,
|
|
|
- "Loading modules..."):
|
|
|
+ for name, loaded_weight in weights:
|
|
|
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
|
|
|
if "rotary_emb.inv_freq" in name:
|
|
|
continue
|