orion.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
  4. # Copyright (c) OrionStar Inc.
  5. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
  6. """Inference-only Orion-14B model compatible with HuggingFace weights."""
  7. from typing import Any, Dict, Iterable, List, Optional, Tuple
  8. import torch
  9. from torch import nn
  10. from transformers import PretrainedConfig
  11. from aphrodite.attention import Attention, AttentionMetadata
  12. from aphrodite.common.sequence import SamplerOutput
  13. from aphrodite.distributed import get_tensor_model_parallel_world_size
  14. from aphrodite.modeling.layers.activation import SiluAndMul
  15. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  16. QKVParallelLinear,
  17. RowParallelLinear)
  18. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  19. from aphrodite.modeling.layers.rotary_embedding import get_rope
  20. from aphrodite.modeling.layers.sampler import Sampler
  21. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  22. ParallelLMHead, VocabParallelEmbedding)
  23. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  24. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  25. from aphrodite.quantization.base_config import QuantizationConfig
  26. class OrionMLP(nn.Module):
  27. def __init__(
  28. self,
  29. hidden_size: int,
  30. intermediate_size: int,
  31. hidden_act: str,
  32. quant_config: Optional[QuantizationConfig] = None,
  33. ) -> None:
  34. super().__init__()
  35. self.gate_up_proj = MergedColumnParallelLinear(
  36. hidden_size, [intermediate_size] * 2,
  37. bias=False,
  38. quant_config=quant_config)
  39. self.down_proj = RowParallelLinear(intermediate_size,
  40. hidden_size,
  41. bias=False,
  42. quant_config=quant_config)
  43. if hidden_act != "silu":
  44. raise ValueError(f"Unsupported activation: {hidden_act}. "
  45. "Only silu is supported for now.")
  46. self.act_fn = SiluAndMul()
  47. def forward(self, x):
  48. gate_up, _ = self.gate_up_proj(x)
  49. x = self.act_fn(gate_up)
  50. x, _ = self.down_proj(x)
  51. return x
  52. class OrionAttention(nn.Module):
  53. def __init__(
  54. self,
  55. hidden_size: int,
  56. num_heads: int,
  57. num_kv_heads: int,
  58. rope_theta: float = 10000,
  59. rope_scaling: Optional[Dict[str, Any]] = None,
  60. max_position_embeddings: int = 8192,
  61. quant_config: Optional[QuantizationConfig] = None,
  62. ) -> None:
  63. super().__init__()
  64. self.hidden_size = hidden_size
  65. tp_size = get_tensor_model_parallel_world_size()
  66. self.total_num_heads = num_heads
  67. assert self.total_num_heads % tp_size == 0
  68. self.num_heads = self.total_num_heads // tp_size
  69. self.total_num_kv_heads = num_kv_heads
  70. if self.total_num_kv_heads >= tp_size:
  71. # Number of KV heads is greater than TP size, so we partition
  72. # the KV heads across multiple tensor parallel GPUs.
  73. assert self.total_num_kv_heads % tp_size == 0
  74. else:
  75. # Number of KV heads is less than TP size, so we replicate
  76. # the KV heads across multiple tensor parallel GPUs.
  77. assert tp_size % self.total_num_kv_heads == 0
  78. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  79. self.head_dim = hidden_size // self.total_num_heads
  80. self.q_size = self.num_heads * self.head_dim
  81. self.kv_size = self.num_kv_heads * self.head_dim
  82. self.scaling = self.head_dim**-0.5
  83. self.rope_theta = rope_theta
  84. self.max_position_embeddings = max_position_embeddings
  85. self.qkv_proj = QKVParallelLinear(
  86. hidden_size,
  87. self.head_dim,
  88. self.total_num_heads,
  89. self.total_num_kv_heads,
  90. bias=False,
  91. quant_config=quant_config,
  92. )
  93. self.o_proj = RowParallelLinear(
  94. self.total_num_heads * self.head_dim,
  95. hidden_size,
  96. bias=False,
  97. quant_config=quant_config,
  98. )
  99. self.rotary_emb = get_rope(
  100. self.head_dim,
  101. rotary_dim=self.head_dim,
  102. max_position=max_position_embeddings,
  103. base=rope_theta,
  104. rope_scaling=rope_scaling,
  105. )
  106. self.attn = Attention(self.num_heads,
  107. self.head_dim,
  108. self.scaling,
  109. num_kv_heads=self.num_kv_heads)
  110. def forward(
  111. self,
  112. positions: torch.Tensor,
  113. hidden_states: torch.Tensor,
  114. kv_cache: torch.Tensor,
  115. attn_metadata: AttentionMetadata,
  116. ) -> torch.Tensor:
  117. qkv, _ = self.qkv_proj(hidden_states)
  118. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  119. q, k = self.rotary_emb(positions, q, k)
  120. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  121. output, _ = self.o_proj(attn_output)
  122. return output
  123. class OrionDecoderLayer(nn.Module):
  124. def __init__(
  125. self,
  126. config: PretrainedConfig,
  127. quant_config: Optional[QuantizationConfig] = None,
  128. ) -> None:
  129. super().__init__()
  130. self.hidden_size = config.hidden_size
  131. rope_theta = getattr(config, "rope_theta", 10000)
  132. rope_scaling = getattr(config, "rope_scaling", None)
  133. max_position_embeddings = getattr(config, "max_position_embeddings",
  134. 8192)
  135. self.self_attn = OrionAttention(
  136. hidden_size=self.hidden_size,
  137. num_heads=config.num_attention_heads,
  138. num_kv_heads=config.num_key_value_heads,
  139. rope_theta=rope_theta,
  140. rope_scaling=rope_scaling,
  141. max_position_embeddings=max_position_embeddings,
  142. quant_config=quant_config,
  143. )
  144. self.mlp = OrionMLP(
  145. hidden_size=self.hidden_size,
  146. intermediate_size=config.intermediate_size,
  147. hidden_act=config.hidden_act,
  148. quant_config=quant_config,
  149. )
  150. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  151. eps=config.rms_norm_eps)
  152. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  153. eps=config.rms_norm_eps)
  154. def forward(
  155. self,
  156. positions: torch.Tensor,
  157. hidden_states: torch.Tensor,
  158. kv_cache: torch.Tensor,
  159. attn_metadata: AttentionMetadata,
  160. residual: Optional[torch.Tensor],
  161. ) -> Tuple[torch.Tensor, torch.Tensor]:
  162. # Self Attention
  163. residual = hidden_states
  164. hidden_states = self.input_layernorm(hidden_states)
  165. hidden_states = self.self_attn(
  166. positions=positions,
  167. hidden_states=hidden_states,
  168. kv_cache=kv_cache,
  169. attn_metadata=attn_metadata,
  170. )
  171. hidden_states = residual + hidden_states
  172. # Fully Connected
  173. residual = hidden_states
  174. hidden_states = self.post_attention_layernorm(hidden_states)
  175. hidden_states = self.mlp(hidden_states)
  176. hidden_states = residual + hidden_states
  177. return hidden_states, None
  178. class OrionModel(nn.Module):
  179. def __init__(
  180. self,
  181. config: PretrainedConfig,
  182. quant_config: Optional[QuantizationConfig] = None,
  183. ) -> None:
  184. super().__init__()
  185. self.config = config
  186. self.padding_idx = config.pad_token_id
  187. self.vocab_size = config.vocab_size
  188. self.embed_tokens = VocabParallelEmbedding(
  189. config.vocab_size,
  190. config.hidden_size,
  191. )
  192. self.layers = nn.ModuleList([
  193. OrionDecoderLayer(config, quant_config)
  194. for _ in range(config.num_hidden_layers)
  195. ])
  196. self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
  197. def forward(
  198. self,
  199. input_ids: torch.Tensor,
  200. positions: torch.Tensor,
  201. kv_caches: List[torch.Tensor],
  202. attn_metadata: AttentionMetadata,
  203. ) -> torch.Tensor:
  204. hidden_states = self.embed_tokens(input_ids)
  205. residual = None
  206. for i in range(len(self.layers)):
  207. layer = self.layers[i]
  208. hidden_states, residual = layer(
  209. positions,
  210. hidden_states,
  211. kv_caches[i],
  212. attn_metadata,
  213. residual,
  214. )
  215. hidden_states = self.norm(hidden_states)
  216. return hidden_states
  217. class OrionForCausalLM(nn.Module):
  218. def __init__(
  219. self,
  220. config: PretrainedConfig,
  221. quant_config: Optional[QuantizationConfig] = None,
  222. ) -> None:
  223. super().__init__()
  224. self.config = config
  225. self.quant_config = quant_config
  226. self.model = OrionModel(config, quant_config)
  227. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  228. self.logits_processor = LogitsProcessor(config.vocab_size)
  229. self.sampler = Sampler()
  230. def forward(
  231. self,
  232. input_ids: torch.Tensor,
  233. positions: torch.Tensor,
  234. kv_caches: List[torch.Tensor],
  235. attn_metadata: AttentionMetadata,
  236. ) -> torch.Tensor:
  237. hidden_states = self.model(input_ids, positions, kv_caches,
  238. attn_metadata)
  239. return hidden_states
  240. def compute_logits(self, hidden_states: torch.Tensor,
  241. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  242. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  243. sampling_metadata)
  244. return logits
  245. def sample(
  246. self,
  247. logits: torch.Tensor,
  248. sampling_metadata: SamplingMetadata,
  249. ) -> Optional[SamplerOutput]:
  250. next_tokens = self.sampler(logits, sampling_metadata)
  251. return next_tokens
  252. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  253. stacked_params_mapping = [
  254. # (param_name, shard_name, shard_id)
  255. ("qkv_proj", "q_proj", "q"),
  256. ("qkv_proj", "k_proj", "k"),
  257. ("qkv_proj", "v_proj", "v"),
  258. ("gate_up_proj", "gate_proj", 0),
  259. ("gate_up_proj", "up_proj", 1),
  260. ]
  261. params_dict = dict(self.named_parameters())
  262. for name, loaded_weight in weights:
  263. if "rotary_emb.inv_freq" in name:
  264. continue
  265. if ("rotary_emb.cos_cached" in name
  266. or "rotary_emb.sin_cached" in name):
  267. # Models trained using ColossalAI may include these tensors in
  268. # the checkpoint. Skip them.
  269. continue
  270. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  271. if weight_name not in name:
  272. continue
  273. name = name.replace(weight_name, param_name)
  274. # Skip loading extra bias for GPTQ models.
  275. if name.endswith(".bias") and name not in params_dict:
  276. continue
  277. param = params_dict[name]
  278. weight_loader = param.weight_loader
  279. weight_loader(param, loaded_weight, shard_id)
  280. break
  281. else:
  282. # Skip loading extra bias for GPTQ models.
  283. if name.endswith(".bias") and name not in params_dict:
  284. continue
  285. param = params_dict[name]
  286. weight_loader = getattr(param, "weight_loader",
  287. default_weight_loader)
  288. weight_loader(param, loaded_weight)