瀏覽代碼

feat: lora support for commandr models

AlpinDale 7 月之前
父節點
當前提交
da6765c084
共有 2 個文件被更改,包括 46 次插入6 次删除
  1. 42 6
      aphrodite/modeling/models/commandr.py
  2. 4 0
      kernels/punica/bgmv/bgmv_config.h

+ 42 - 6
aphrodite/modeling/models/commandr.py

@@ -29,7 +29,7 @@ from torch.nn.parameter import Parameter
 from transformers import CohereConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.common.config import CacheConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
@@ -264,10 +264,14 @@ class CohereModel(nn.Module):
         config: CohereConfig,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.vocab_size = config.vocab_size
+        lora_vocab = (lora_config.lora_extra_vocab_size *
+                      (lora_config.max_loras or 1)) if lora_config else 0
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size)
         self.layers = nn.ModuleList([
@@ -301,18 +305,44 @@ class CohereModel(nn.Module):
 
 class CohereForCausalLM(nn.Module):
 
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
+    ]
+    embedding_modules = {"embed_tokens": "input_embeddings"}
+    embedding_padding_modules = []
+
     def __init__(
         self,
         config: CohereConfig,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
+        self.unpadded_vocab_size = config.vocab_size
+        if lora_config:
+            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
         self.quant_config = quant_config
-        self.logits_processor = LogitsProcessor(config.vocab_size,
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.vocab_size,
                                                 scale=config.logit_scale)
-        self.model = CohereModel(config, cache_config, quant_config)
+        self.model = CohereModel(config,
+                                 cache_config,
+                                 quant_config,
+                                 lora_config=lora_config)
         self.sampler = Sampler()
 
     @torch.no_grad()
@@ -329,8 +359,14 @@ class CohereForCausalLM(nn.Module):
 
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
-        logits = self.logits_processor(self.model.embed_tokens.weight,
-                                       hidden_states, sampling_metadata)
+        is_not_lora = hasattr(self.model.embed_tokens, 'weight')
+        if is_not_lora:
+            embedding_weights = self.model.embed_tokens.weight
+        else:
+            embedding_weights = self.model.embed_tokens.base_layer.weight
+
+        logits = self.logits_processor(embedding_weights, hidden_states,
+                                       sampling_metadata)
         return logits
 
     def sample(

+ 4 - 0
kernels/punica/bgmv/bgmv_config.h

@@ -63,6 +63,8 @@ void bgmv_kernel(out_T* __restrict__ Y, const in_T* __restrict__ X,
     f(in_T, out_T, W_T, narrow, 36864) \
     f(in_T, out_T, W_T, narrow, 43264) \
     f(in_T, out_T, W_T, narrow, 49152) \
+    f(in_T, out_T, W_T, narrow, 60544) \
+    f(in_T, out_T, W_T, narrow, 60672) \
     f(in_T, out_T, W_T, narrow, 64000) \
     f(in_T, out_T, W_T, narrow, 64256) \
     f(in_T, out_T, W_T, narrow, 64512) \
@@ -131,6 +133,8 @@ void bgmv_kernel(out_T* __restrict__ Y, const in_T* __restrict__ X,
     f(in_T, out_T, W_T, 36864, narrow) \
     f(in_T, out_T, W_T, 43264, narrow) \
     f(in_T, out_T, W_T, 49152, narrow) \
+    f(in_T, out_T, W_T, 60544, narrow) \
+    f(in_T, out_T, W_T, 60672, narrow) \
     f(in_T, out_T, W_T, 64000, narrow) \
     f(in_T, out_T, W_T, 64256, narrow) \
     f(in_T, out_T, W_T, 64512, narrow) \