123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332 |
- # coding=utf-8
- # Adapted from
- # https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
- # Copyright (c) OrionStar Inc.
- # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
- """Inference-only Orion-14B model compatible with HuggingFace weights."""
- from typing import Any, Dict, Iterable, List, Optional, Tuple
- import torch
- from torch import nn
- from transformers import PretrainedConfig
- from aphrodite.attention import Attention, AttentionMetadata
- from aphrodite.common.config import CacheConfig
- from aphrodite.common.sequence import IntermediateTensors
- from aphrodite.distributed import get_tensor_model_parallel_world_size
- from aphrodite.modeling.layers.activation import SiluAndMul
- from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
- QKVParallelLinear,
- RowParallelLinear)
- from aphrodite.modeling.layers.logits_processor import LogitsProcessor
- from aphrodite.modeling.layers.rotary_embedding import get_rope
- from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
- from aphrodite.modeling.layers.vocab_parallel_embedding import (
- ParallelLMHead, VocabParallelEmbedding)
- from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- from aphrodite.quantization.base_config import QuantizationConfig
- class OrionMLP(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- intermediate_size: int,
- hidden_act: str,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.gate_up_proj = MergedColumnParallelLinear(
- hidden_size, [intermediate_size] * 2,
- bias=False,
- quant_config=quant_config)
- self.down_proj = RowParallelLinear(intermediate_size,
- hidden_size,
- bias=False,
- quant_config=quant_config)
- if hidden_act != "silu":
- raise ValueError(f"Unsupported activation: {hidden_act}. "
- "Only silu is supported for now.")
- self.act_fn = SiluAndMul()
- def forward(self, x):
- gate_up, _ = self.gate_up_proj(x)
- x = self.act_fn(gate_up)
- x, _ = self.down_proj(x)
- return x
- class OrionAttention(nn.Module):
- def __init__(
- self,
- hidden_size: int,
- num_heads: int,
- num_kv_heads: int,
- rope_theta: float = 10000,
- rope_scaling: Optional[Dict[str, Any]] = None,
- max_position_embeddings: int = 8192,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.hidden_size = hidden_size
- tp_size = get_tensor_model_parallel_world_size()
- self.total_num_heads = num_heads
- assert self.total_num_heads % tp_size == 0
- self.num_heads = self.total_num_heads // tp_size
- self.total_num_kv_heads = num_kv_heads
- if self.total_num_kv_heads >= tp_size:
- # Number of KV heads is greater than TP size, so we partition
- # the KV heads across multiple tensor parallel GPUs.
- assert self.total_num_kv_heads % tp_size == 0
- else:
- # Number of KV heads is less than TP size, so we replicate
- # the KV heads across multiple tensor parallel GPUs.
- assert tp_size % self.total_num_kv_heads == 0
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
- self.head_dim = hidden_size // self.total_num_heads
- self.q_size = self.num_heads * self.head_dim
- self.kv_size = self.num_kv_heads * self.head_dim
- self.scaling = self.head_dim**-0.5
- self.rope_theta = rope_theta
- self.max_position_embeddings = max_position_embeddings
- self.qkv_proj = QKVParallelLinear(
- hidden_size,
- self.head_dim,
- self.total_num_heads,
- self.total_num_kv_heads,
- bias=False,
- quant_config=quant_config,
- )
- self.o_proj = RowParallelLinear(
- self.total_num_heads * self.head_dim,
- hidden_size,
- bias=False,
- quant_config=quant_config,
- )
- self.rotary_emb = get_rope(
- self.head_dim,
- rotary_dim=self.head_dim,
- max_position=max_position_embeddings,
- base=rope_theta,
- rope_scaling=rope_scaling,
- )
- self.attn = Attention(self.num_heads,
- self.head_dim,
- self.scaling,
- num_kv_heads=self.num_kv_heads,
- cache_config=cache_config,
- quant_config=quant_config)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- qkv, _ = self.qkv_proj(hidden_states)
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- q, k = self.rotary_emb(positions, q, k)
- attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
- output, _ = self.o_proj(attn_output)
- return output
- class OrionDecoderLayer(nn.Module):
- def __init__(
- self,
- config: PretrainedConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.hidden_size = config.hidden_size
- rope_theta = getattr(config, "rope_theta", 10000)
- rope_scaling = getattr(config, "rope_scaling", None)
- max_position_embeddings = getattr(config, "max_position_embeddings",
- 8192)
- self.self_attn = OrionAttention(
- hidden_size=self.hidden_size,
- num_heads=config.num_attention_heads,
- num_kv_heads=config.num_key_value_heads,
- rope_theta=rope_theta,
- rope_scaling=rope_scaling,
- max_position_embeddings=max_position_embeddings,
- cache_config=cache_config,
- quant_config=quant_config,
- )
- self.mlp = OrionMLP(
- hidden_size=self.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_act,
- quant_config=quant_config,
- )
- self.input_layernorm = nn.LayerNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
- eps=config.rms_norm_eps)
- def forward(
- self,
- positions: torch.Tensor,
- hidden_states: torch.Tensor,
- kv_cache: torch.Tensor,
- attn_metadata: AttentionMetadata,
- residual: Optional[torch.Tensor],
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # Self Attention
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- hidden_states = self.self_attn(
- positions=positions,
- hidden_states=hidden_states,
- kv_cache=kv_cache,
- attn_metadata=attn_metadata,
- )
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- return hidden_states, None
- class OrionModel(nn.Module):
- def __init__(
- self,
- config: PretrainedConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.config = config
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = VocabParallelEmbedding(
- config.vocab_size,
- config.hidden_size,
- )
- self.layers = nn.ModuleList([
- OrionDecoderLayer(config, cache_config, quant_config)
- for _ in range(config.num_hidden_layers)
- ])
- self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- ) -> torch.Tensor:
- hidden_states = self.embed_tokens(input_ids)
- residual = None
- for i in range(len(self.layers)):
- layer = self.layers[i]
- hidden_states, residual = layer(
- positions,
- hidden_states,
- kv_caches[i],
- attn_metadata,
- residual,
- )
- hidden_states = self.norm(hidden_states)
- return hidden_states
- class OrionForCausalLM(nn.Module):
- def __init__(
- self,
- config: PretrainedConfig,
- cache_config: Optional[CacheConfig] = None,
- quant_config: Optional[QuantizationConfig] = None,
- ) -> None:
- super().__init__()
- self.config = config
- self.quant_config = quant_config
- self.model = OrionModel(config, cache_config, quant_config)
- self.lm_head = ParallelLMHead(config.vocab_size,
- config.hidden_size,
- quant_config=quant_config)
- self.logits_processor = LogitsProcessor(config.vocab_size)
- self.sampler = Sampler()
- def forward(
- self,
- input_ids: torch.Tensor,
- positions: torch.Tensor,
- kv_caches: List[torch.Tensor],
- attn_metadata: AttentionMetadata,
- intermediate_tensors: Optional[IntermediateTensors] = None,
- ) -> torch.Tensor:
- hidden_states = self.model(input_ids, positions, kv_caches,
- attn_metadata)
- return hidden_states
- def compute_logits(
- self,
- hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[torch.Tensor]:
- logits = self.logits_processor(self.lm_head, hidden_states,
- sampling_metadata)
- return logits
- def sample(
- self,
- logits: torch.Tensor,
- sampling_metadata: SamplingMetadata,
- ) -> Optional[SamplerOutput]:
- next_tokens = self.sampler(logits, sampling_metadata)
- return next_tokens
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
- stacked_params_mapping = [
- # (param_name, shard_name, shard_id)
- ("qkv_proj", "q_proj", "q"),
- ("qkv_proj", "k_proj", "k"),
- ("qkv_proj", "v_proj", "v"),
- ("gate_up_proj", "gate_proj", 0),
- ("gate_up_proj", "up_proj", 1),
- ]
- params_dict = dict(self.named_parameters())
- for name, loaded_weight in weights:
- if "rotary_emb.inv_freq" in name:
- continue
- if ("rotary_emb.cos_cached" in name
- or "rotary_emb.sin_cached" in name):
- # Models trained using ColossalAI may include these tensors in
- # the checkpoint. Skip them.
- continue
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
- if weight_name not in name:
- continue
- name = name.replace(weight_name, param_name)
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = param.weight_loader
- weight_loader(param, loaded_weight, shard_id)
- break
- else:
- # Skip loading extra bias for GPTQ models.
- if name.endswith(".bias") and name not in params_dict:
- continue
- param = params_dict[name]
- weight_loader = getattr(param, "weight_loader",
- default_weight_loader)
- weight_loader(param, loaded_weight)
|