|
@@ -29,7 +29,7 @@ from torch.nn.parameter import Parameter
|
|
|
from transformers import CohereConfig
|
|
|
|
|
|
from aphrodite.attention import Attention, AttentionMetadata
|
|
|
-from aphrodite.common.config import CacheConfig
|
|
|
+from aphrodite.common.config import CacheConfig, LoRAConfig
|
|
|
from aphrodite.common.sequence import SamplerOutput
|
|
|
from aphrodite.distributed import (get_tensor_model_parallel_rank,
|
|
|
get_tensor_model_parallel_world_size)
|
|
@@ -264,10 +264,14 @@ class CohereModel(nn.Module):
|
|
|
config: CohereConfig,
|
|
|
cache_config: Optional[CacheConfig] = None,
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ lora_config: Optional[LoRAConfig] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
- self.vocab_size = config.vocab_size
|
|
|
+ lora_vocab = (lora_config.lora_extra_vocab_size *
|
|
|
+ (lora_config.max_loras or 1)) if lora_config else 0
|
|
|
+ self.vocab_size = config.vocab_size + lora_vocab
|
|
|
+ self.org_vocab_size = config.vocab_size
|
|
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
|
|
config.hidden_size)
|
|
|
self.layers = nn.ModuleList([
|
|
@@ -301,18 +305,44 @@ class CohereModel(nn.Module):
|
|
|
|
|
|
class CohereForCausalLM(nn.Module):
|
|
|
|
|
|
+ packed_modules_mapping = {
|
|
|
+ "qkv_proj": [
|
|
|
+ "q_proj",
|
|
|
+ "k_proj",
|
|
|
+ "v_proj",
|
|
|
+ ],
|
|
|
+ "gate_up_proj": [
|
|
|
+ "gate_proj",
|
|
|
+ "up_proj",
|
|
|
+ ],
|
|
|
+ }
|
|
|
+ # LoRA specific attributes
|
|
|
+ supported_lora_modules = [
|
|
|
+ "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens"
|
|
|
+ ]
|
|
|
+ embedding_modules = {"embed_tokens": "input_embeddings"}
|
|
|
+ embedding_padding_modules = []
|
|
|
+
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: CohereConfig,
|
|
|
cache_config: Optional[CacheConfig] = None,
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
|
+ lora_config: Optional[LoRAConfig] = None,
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
self.config = config
|
|
|
+ self.unpadded_vocab_size = config.vocab_size
|
|
|
+ if lora_config:
|
|
|
+ self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
|
|
self.quant_config = quant_config
|
|
|
- self.logits_processor = LogitsProcessor(config.vocab_size,
|
|
|
+ self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
|
+ config.vocab_size,
|
|
|
scale=config.logit_scale)
|
|
|
- self.model = CohereModel(config, cache_config, quant_config)
|
|
|
+ self.model = CohereModel(config,
|
|
|
+ cache_config,
|
|
|
+ quant_config,
|
|
|
+ lora_config=lora_config)
|
|
|
self.sampler = Sampler()
|
|
|
|
|
|
@torch.no_grad()
|
|
@@ -329,8 +359,14 @@ class CohereForCausalLM(nn.Module):
|
|
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor,
|
|
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
|
- logits = self.logits_processor(self.model.embed_tokens.weight,
|
|
|
- hidden_states, sampling_metadata)
|
|
|
+ is_not_lora = hasattr(self.model.embed_tokens, 'weight')
|
|
|
+ if is_not_lora:
|
|
|
+ embedding_weights = self.model.embed_tokens.weight
|
|
|
+ else:
|
|
|
+ embedding_weights = self.model.embed_tokens.base_layer.weight
|
|
|
+
|
|
|
+ logits = self.logits_processor(embedding_weights, hidden_states,
|
|
|
+ sampling_metadata)
|
|
|
return logits
|
|
|
|
|
|
def sample(
|