Преглед на файлове

llama support for safetensors

AlpinDale преди 1 година
родител
ревизия
8c2353e803
променени са 1 файла, в които са добавени 25 реда и са изтрити 12 реда
  1. 25 12
      aphrodite/modeling/models/llama.py

+ 25 - 12
aphrodite/modeling/models/llama.py

@@ -34,7 +34,7 @@ from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.hf_downloader import hf_model_weights_iterator, load_tensor_parallel_weights
+from aphrodite.modeling.hf_downloader import load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, hf_model_weights_iterator
 from aphrodite.modeling.megatron.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
 from aphrodite.modeling.megatron.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear
 from aphrodite.common.sequence import SequenceOutputs
@@ -246,7 +246,8 @@ class LlamaForCausalLM(nn.Module):
     def load_weights(self,
                      model_name_or_path: str,
                      cache_dir: Optional[str] = None,
-                     use_np_cache: bool = False):
+                     use_np_cache: bool = False,
+                     allow_patterns: str = "*.safetensors"):
         tp_size = get_tensor_model_parallel_world_size()
         tensor_model_parallel_rank = get_tensor_model_parallel_rank()
         q_proj_shard_size = (self.config.hidden_size // tp_size)
@@ -262,19 +263,19 @@ class LlamaForCausalLM(nn.Module):
         state_dict = self.state_dict()
 
         for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, use_np_cache):
+                model_name_or_path, cache_dir, use_np_cache, allow_patterns):
             if "rotary_emb.inv_freq" in name:
                 continue
 
-            if "embed_tokens" in name or "lm_head" in name:
-                param = state_dict[name]
-                padded_vocab_size = (param.shape[0] *
-                                     tp_size)
-                num_extra_rows = padded_vocab_size - self.config.vocab_size
-                extra_rows = torch.empty(num_extra_rows,
-                                         loaded_weight.shape[1])
-                extra_rows = extra_rows.to(loaded_weight)
-                loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
+            # if "embed_tokens" in name or "lm_head" in name:
+            #     param = state_dict[name]
+            #     padded_vocab_size = (param.shape[0] *
+            #                          tp_size)
+            #     num_extra_rows = padded_vocab_size - self.config.vocab_size
+            #     extra_rows = torch.empty(num_extra_rows,
+            #                              loaded_weight.shape[1])
+            #     extra_rows = extra_rows.to(loaded_weight)
+            #     loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
 
             is_attention_weight = False
             for weight_name, shard_size, offset in attention_weight_specs:
@@ -313,6 +314,18 @@ class LlamaForCausalLM(nn.Module):
                 continue
 
             param = state_dict[name]
+
+            if "embed_tokens" in name or "lm_head" in name:
+                param = state_dict[name]
+                padded_vocab_size = param.shape[0] * tp_size
+                if padded_vocab_size > self.config.vocab_size:
+                    load_padded_tensor_parallel_vocab(param, loaded_weight, name,
+                                                      self._column_parallel_weights,
+                                                      self._row_parallel_weights,
+                                                      tensor_model_parallel_rank)
+                    continue
+                
+
             load_tensor_parallel_weights(param, loaded_weight, name,
                                          self._column_parallel_weights,
                                          self._row_parallel_weights,