|
@@ -25,15 +25,13 @@ from typing import Iterable, List, Optional, Tuple
|
|
import torch
|
|
import torch
|
|
import torch.utils.checkpoint
|
|
import torch.utils.checkpoint
|
|
from torch import nn
|
|
from torch import nn
|
|
-from torch.nn.parameter import Parameter
|
|
|
|
from transformers import CohereConfig
|
|
from transformers import CohereConfig
|
|
|
|
|
|
from aphrodite.attention import Attention, AttentionMetadata
|
|
from aphrodite.attention import Attention, AttentionMetadata
|
|
from aphrodite.common.config import CacheConfig, LoRAConfig
|
|
from aphrodite.common.config import CacheConfig, LoRAConfig
|
|
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
|
|
from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
|
|
from aphrodite.common.utils import progress_bar
|
|
from aphrodite.common.utils import progress_bar
|
|
-from aphrodite.distributed import (get_tensor_model_parallel_rank,
|
|
|
|
- get_tensor_model_parallel_world_size)
|
|
|
|
|
|
+from aphrodite.distributed import get_tensor_model_parallel_world_size
|
|
from aphrodite.modeling.layers.activation import SiluAndMul
|
|
from aphrodite.modeling.layers.activation import SiluAndMul
|
|
from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
|
|
from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
QKVParallelLinear,
|
|
@@ -43,7 +41,8 @@ from aphrodite.modeling.layers.rotary_embedding import get_rope
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
from aphrodite.modeling.layers.sampler import Sampler
|
|
from aphrodite.modeling.layers.vocab_parallel_embedding import (
|
|
from aphrodite.modeling.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding)
|
|
VocabParallelEmbedding)
|
|
-from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
|
|
|
|
|
|
+from aphrodite.modeling.model_loader.weight_utils import (
|
|
|
|
+ default_weight_loader, row_parallel_weight_loader)
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
from aphrodite.modeling.sampling_metadata import SamplingMetadata
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
from aphrodite.quantization.base_config import QuantizationConfig
|
|
from aphrodite.quantization.base_config import QuantizationConfig
|
|
@@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
|
|
super().__init__()
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(param_shape))
|
|
self.weight = nn.Parameter(torch.ones(param_shape))
|
|
self.variance_epsilon = eps
|
|
self.variance_epsilon = eps
|
|
- set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
|
|
|
|
|
|
+ set_weight_attrs(self.weight,
|
|
|
|
+ {"weight_loader": row_parallel_weight_loader})
|
|
|
|
|
|
def forward(self, hidden_states, residuals=None):
|
|
def forward(self, hidden_states, residuals=None):
|
|
hidden_states = layer_norm_func(hidden_states, self.weight,
|
|
hidden_states = layer_norm_func(hidden_states, self.weight,
|
|
self.variance_epsilon)
|
|
self.variance_epsilon)
|
|
return hidden_states, residuals
|
|
return hidden_states, residuals
|
|
|
|
|
|
- def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
|
|
- tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
- shard_dim = 0 if param.dim() != 1 else None
|
|
|
|
- param_data = param.data
|
|
|
|
- if shard_dim is not None:
|
|
|
|
- shard_size = param_data.shape[shard_dim]
|
|
|
|
- start_idx = tp_rank * shard_size
|
|
|
|
- loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
|
|
|
|
- shard_size)
|
|
|
|
- assert param_data.shape == loaded_weight.shape
|
|
|
|
- param_data.copy_(loaded_weight)
|
|
|
|
-
|
|
|
|
|
|
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
|
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
|
|
class CohereMLP(nn.Module):
|
|
class CohereMLP(nn.Module):
|
|
@@ -359,8 +347,11 @@ class CohereForCausalLM(nn.Module):
|
|
attn_metadata)
|
|
attn_metadata)
|
|
return hidden_states
|
|
return hidden_states
|
|
|
|
|
|
- def compute_logits(self, hidden_states: torch.Tensor,
|
|
|
|
- sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
|
|
|
|
|
+ def compute_logits(
|
|
|
|
+ self,
|
|
|
|
+ hidden_states: torch.Tensor,
|
|
|
|
+ sampling_metadata: SamplingMetadata,
|
|
|
|
+ ) -> Optional[torch.Tensor]:
|
|
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
|
|
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
|
|
if is_not_lora:
|
|
if is_not_lora:
|
|
logits = self.logits_processor(self.model.embed_tokens,
|
|
logits = self.logits_processor(self.model.embed_tokens,
|