|
@@ -34,7 +34,7 @@ from aphrodite.modeling.layers.activation import SiluAndMul
|
|
|
from aphrodite.modeling.layers.layernorm import RMSNorm
|
|
|
from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
|
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
|
-from aphrodite.modeling.hf_downloader import hf_model_weights_iterator, load_tensor_parallel_weights
|
|
|
+from aphrodite.modeling.hf_downloader import load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, hf_model_weights_iterator
|
|
|
from aphrodite.modeling.megatron.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
|
|
|
from aphrodite.modeling.megatron.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear
|
|
|
from aphrodite.common.sequence import SequenceOutputs
|
|
@@ -246,7 +246,8 @@ class LlamaForCausalLM(nn.Module):
|
|
|
def load_weights(self,
|
|
|
model_name_or_path: str,
|
|
|
cache_dir: Optional[str] = None,
|
|
|
- use_np_cache: bool = False):
|
|
|
+ use_np_cache: bool = False,
|
|
|
+ allow_patterns: str = "*.safetensors"):
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
|
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
|
@@ -262,19 +263,19 @@ class LlamaForCausalLM(nn.Module):
|
|
|
state_dict = self.state_dict()
|
|
|
|
|
|
for name, loaded_weight in hf_model_weights_iterator(
|
|
|
- model_name_or_path, cache_dir, use_np_cache):
|
|
|
+ model_name_or_path, cache_dir, use_np_cache, allow_patterns):
|
|
|
if "rotary_emb.inv_freq" in name:
|
|
|
continue
|
|
|
|
|
|
- if "embed_tokens" in name or "lm_head" in name:
|
|
|
- param = state_dict[name]
|
|
|
- padded_vocab_size = (param.shape[0] *
|
|
|
- tp_size)
|
|
|
- num_extra_rows = padded_vocab_size - self.config.vocab_size
|
|
|
- extra_rows = torch.empty(num_extra_rows,
|
|
|
- loaded_weight.shape[1])
|
|
|
- extra_rows = extra_rows.to(loaded_weight)
|
|
|
- loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
|
|
+ # if "embed_tokens" in name or "lm_head" in name:
|
|
|
+ # param = state_dict[name]
|
|
|
+ # padded_vocab_size = (param.shape[0] *
|
|
|
+ # tp_size)
|
|
|
+ # num_extra_rows = padded_vocab_size - self.config.vocab_size
|
|
|
+ # extra_rows = torch.empty(num_extra_rows,
|
|
|
+ # loaded_weight.shape[1])
|
|
|
+ # extra_rows = extra_rows.to(loaded_weight)
|
|
|
+ # loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
|
|
|
|
|
|
is_attention_weight = False
|
|
|
for weight_name, shard_size, offset in attention_weight_specs:
|
|
@@ -313,6 +314,18 @@ class LlamaForCausalLM(nn.Module):
|
|
|
continue
|
|
|
|
|
|
param = state_dict[name]
|
|
|
+
|
|
|
+ if "embed_tokens" in name or "lm_head" in name:
|
|
|
+ param = state_dict[name]
|
|
|
+ padded_vocab_size = param.shape[0] * tp_size
|
|
|
+ if padded_vocab_size > self.config.vocab_size:
|
|
|
+ load_padded_tensor_parallel_vocab(param, loaded_weight, name,
|
|
|
+ self._column_parallel_weights,
|
|
|
+ self._row_parallel_weights,
|
|
|
+ tensor_model_parallel_rank)
|
|
|
+ continue
|
|
|
+
|
|
|
+
|
|
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
|
|
self._column_parallel_weights,
|
|
|
self._row_parallel_weights,
|