Преглед на файлове

fix: s-lora vocab embeddings

AlpinDale преди 1 година
родител
ревизия
8635901c76
променени са 2 файла, в които са добавени 25 реда и са изтрити 23 реда
  1. 17 15
      aphrodite/modeling/layers/vocab_parallel_embedding.py
  2. 8 8
      examples/slora_inference.py

+ 17 - 15
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -1,8 +1,11 @@
 from typing import Optional, Sequence
 
 import torch
+import torch.nn.functional as F
 from torch.nn.parameter import Parameter
 
+from aphrodite.modeling.megatron.communication_op import (
+    tensor_model_parallel_gather)
 from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank,
@@ -54,7 +57,7 @@ class VocabParallelEmbedding(torch.nn.Module):
                  num_embeddings: int,
                  embedding_dim: int,
                  params_dtype: Optional[torch.dtype] = None,
-                 linear_method=None,
+                 linear_method = None,
                  org_num_embeddings: Optional[int] = None,
                  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
         super().__init__()
@@ -75,28 +78,26 @@ class VocabParallelEmbedding(torch.nn.Module):
                 self.tp_size))
         self.num_embeddings_per_partition = (self.vocab_end_index -
                                              self.vocab_start_index)
-        if linear_method is None or not linear_method.quant_config.quant_vocab(
-        ):
+        if linear_method is None or not linear_method.quant_config.quant_vocab():
             linear_method = UnquantizedLinearMethod()
         self.linear_method = linear_method
         self.linear_weights = self.linear_method.create_weights(
-            self.embedding_dim, self.num_embeddings_per_partition,
-            self.embedding_dim, self.num_embeddings_padded, params_dtype)
+            self.embedding_dim, self.num_embeddings_per_partition, self.embedding_dim,
+            self.num_embeddings_padded, params_dtype)
         for name, weight in self.linear_weights.items():
             if isinstance(weight, torch.nn.parameter.Parameter):
                 self.register_parameter(name, weight)
                 set_weight_attrs(weight, {"weight_loader": self.weight_loader})
 
+
     def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
         output_dim = getattr(param, "output_dim", None)
         if output_dim is not None:
-            assert loaded_weight.shape[output_dim] == self.num_embeddings
+            assert loaded_weight.shape[output_dim] == self.org_vocab_size
             loaded_weight = loaded_weight[self.vocab_start_index:self.
                                           vocab_end_index]
         if isinstance(param, torch.nn.parameter.UninitializedParameter):
-            param.materialize(
-                (self.num_embeddings_per_partition, loaded_weight.shape[1]),
-                dtype=loaded_weight.dtype)
+            param.materialize((self.num_embeddings_per_partition, loaded_weight.shape[1]), dtype=loaded_weight.dtype)
         if output_dim is not None:
             param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
         else:
@@ -113,8 +114,8 @@ class VocabParallelEmbedding(torch.nn.Module):
         else:
             masked_input = input_
             # Get the embeddings.
-        output_parallel = self.linear_method.apply_embedding(
-            self.linear_weights, masked_input)
+        output_parallel = self.linear_method.apply_embedding(self.linear_weights, masked_input)
+        # output_parallel = F.embedding(masked_input, self.weight)
         # Mask the output embedding.
         if self.tp_size > 1:
             output_parallel[input_mask, :] = 0.0
@@ -144,11 +145,11 @@ class ParallelLMHead(VocabParallelEmbedding):
                  embedding_dim: int,
                  bias: bool = False,
                  params_dtype: Optional[torch.dtype] = None,
-                 linear_method=None,
+                 linear_method = None,
                  org_num_embeddings: Optional[int] = None,
                  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
-        super().__init__(num_embeddings, embedding_dim, params_dtype,
-                         linear_method, org_num_embeddings, padding_size)
+        super().__init__(num_embeddings, embedding_dim, params_dtype, linear_method,
+                         org_num_embeddings, padding_size)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.num_embeddings_per_partition,
@@ -161,7 +162,8 @@ class ParallelLMHead(VocabParallelEmbedding):
             self.register_parameter("bias", None)
 
     def forward(self, input_):
-        logits = self.linear_method.apply_weights(self.linear_weights, input_)
+        logits = self.linear_method.apply_weights(
+            self.linear_weights, input_)
         if self.bias is not None:
             logits += self.bias
         return logits

+ 8 - 8
examples/slora_inference.py

@@ -23,7 +23,7 @@ def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
     return [
         ("A robot may not injure a human being",
          SamplingParams(temperature=0.0,
-                        logprobs=1,
+                        # logprobs=1,
                         prompt_logprobs=1,
                         max_tokens=128), None),
         ("To be or not to be,",
@@ -33,33 +33,33 @@ def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]:
                         max_tokens=128), None),
         ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
          SamplingParams(temperature=0.0,
-                        logprobs=1,
+                        # logprobs=1,
                         prompt_logprobs=1,
                         max_tokens=128,
                         stop_token_ids=[32003]),
-         LoRARequest("sql-lora", 1, lora_path)),
+         LoRARequest("l2-lora-test", 1, lora_path)),
         ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
          SamplingParams(n=3,
                         best_of=3,
                         temperature=0.8,
                         max_tokens=128,
                         stop_token_ids=[32003]),
-         LoRARequest("sql-lora", 1, lora_path)),
+         LoRARequest("l2-lora-test", 1, lora_path)),
         ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
          SamplingParams(temperature=0.0,
-                        logprobs=1,
+                        # logprobs=1,
                         prompt_logprobs=1,
                         max_tokens=128,
                         stop_token_ids=[32003]),
-         LoRARequest("sql-lora2", 2, lora_path)),
+         LoRARequest("l2-lora-test2", 2, lora_path)),
         ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
          SamplingParams(n=3,
                         best_of=3,
                         temperature=0.9,
                         max_tokens=128,
                         stop_token_ids=[32003]),
-         LoRARequest("sql-lora", 1, lora_path)),
-    ]
+         LoRARequest("l2-lora-test", 1, lora_path)),
+    ] # type: ignore
 
 
 def process_requests(engine: AphroditeEngine,