orion.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
  4. # Copyright (c) OrionStar Inc.
  5. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
  6. """Inference-only Orion-14B model compatible with HuggingFace weights."""
  7. from typing import Any, Dict, Iterable, List, Optional, Tuple
  8. import torch
  9. from torch import nn
  10. from transformers import PretrainedConfig
  11. from aphrodite.attention import Attention, AttentionMetadata
  12. from aphrodite.common.config import CacheConfig
  13. from aphrodite.common.sequence import IntermediateTensors
  14. from aphrodite.distributed import get_tensor_model_parallel_world_size
  15. from aphrodite.modeling.layers.activation import SiluAndMul
  16. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  17. QKVParallelLinear,
  18. RowParallelLinear)
  19. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  20. from aphrodite.modeling.layers.rotary_embedding import get_rope
  21. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  22. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  23. ParallelLMHead, VocabParallelEmbedding)
  24. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  25. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  26. from aphrodite.quantization.base_config import QuantizationConfig
  27. class OrionMLP(nn.Module):
  28. def __init__(
  29. self,
  30. hidden_size: int,
  31. intermediate_size: int,
  32. hidden_act: str,
  33. quant_config: Optional[QuantizationConfig] = None,
  34. ) -> None:
  35. super().__init__()
  36. self.gate_up_proj = MergedColumnParallelLinear(
  37. hidden_size, [intermediate_size] * 2,
  38. bias=False,
  39. quant_config=quant_config)
  40. self.down_proj = RowParallelLinear(intermediate_size,
  41. hidden_size,
  42. bias=False,
  43. quant_config=quant_config)
  44. if hidden_act != "silu":
  45. raise ValueError(f"Unsupported activation: {hidden_act}. "
  46. "Only silu is supported for now.")
  47. self.act_fn = SiluAndMul()
  48. def forward(self, x):
  49. gate_up, _ = self.gate_up_proj(x)
  50. x = self.act_fn(gate_up)
  51. x, _ = self.down_proj(x)
  52. return x
  53. class OrionAttention(nn.Module):
  54. def __init__(
  55. self,
  56. hidden_size: int,
  57. num_heads: int,
  58. num_kv_heads: int,
  59. rope_theta: float = 10000,
  60. rope_scaling: Optional[Dict[str, Any]] = None,
  61. max_position_embeddings: int = 8192,
  62. cache_config: Optional[CacheConfig] = None,
  63. quant_config: Optional[QuantizationConfig] = None,
  64. ) -> None:
  65. super().__init__()
  66. self.hidden_size = hidden_size
  67. tp_size = get_tensor_model_parallel_world_size()
  68. self.total_num_heads = num_heads
  69. assert self.total_num_heads % tp_size == 0
  70. self.num_heads = self.total_num_heads // tp_size
  71. self.total_num_kv_heads = num_kv_heads
  72. if self.total_num_kv_heads >= tp_size:
  73. # Number of KV heads is greater than TP size, so we partition
  74. # the KV heads across multiple tensor parallel GPUs.
  75. assert self.total_num_kv_heads % tp_size == 0
  76. else:
  77. # Number of KV heads is less than TP size, so we replicate
  78. # the KV heads across multiple tensor parallel GPUs.
  79. assert tp_size % self.total_num_kv_heads == 0
  80. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  81. self.head_dim = hidden_size // self.total_num_heads
  82. self.q_size = self.num_heads * self.head_dim
  83. self.kv_size = self.num_kv_heads * self.head_dim
  84. self.scaling = self.head_dim**-0.5
  85. self.rope_theta = rope_theta
  86. self.max_position_embeddings = max_position_embeddings
  87. self.qkv_proj = QKVParallelLinear(
  88. hidden_size,
  89. self.head_dim,
  90. self.total_num_heads,
  91. self.total_num_kv_heads,
  92. bias=False,
  93. quant_config=quant_config,
  94. )
  95. self.o_proj = RowParallelLinear(
  96. self.total_num_heads * self.head_dim,
  97. hidden_size,
  98. bias=False,
  99. quant_config=quant_config,
  100. )
  101. self.rotary_emb = get_rope(
  102. self.head_dim,
  103. rotary_dim=self.head_dim,
  104. max_position=max_position_embeddings,
  105. base=rope_theta,
  106. rope_scaling=rope_scaling,
  107. )
  108. self.attn = Attention(self.num_heads,
  109. self.head_dim,
  110. self.scaling,
  111. num_kv_heads=self.num_kv_heads,
  112. cache_config=cache_config,
  113. quant_config=quant_config)
  114. def forward(
  115. self,
  116. positions: torch.Tensor,
  117. hidden_states: torch.Tensor,
  118. kv_cache: torch.Tensor,
  119. attn_metadata: AttentionMetadata,
  120. ) -> torch.Tensor:
  121. qkv, _ = self.qkv_proj(hidden_states)
  122. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  123. q, k = self.rotary_emb(positions, q, k)
  124. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  125. output, _ = self.o_proj(attn_output)
  126. return output
  127. class OrionDecoderLayer(nn.Module):
  128. def __init__(
  129. self,
  130. config: PretrainedConfig,
  131. cache_config: Optional[CacheConfig] = None,
  132. quant_config: Optional[QuantizationConfig] = None,
  133. ) -> None:
  134. super().__init__()
  135. self.hidden_size = config.hidden_size
  136. rope_theta = getattr(config, "rope_theta", 10000)
  137. rope_scaling = getattr(config, "rope_scaling", None)
  138. max_position_embeddings = getattr(config, "max_position_embeddings",
  139. 8192)
  140. self.self_attn = OrionAttention(
  141. hidden_size=self.hidden_size,
  142. num_heads=config.num_attention_heads,
  143. num_kv_heads=config.num_key_value_heads,
  144. rope_theta=rope_theta,
  145. rope_scaling=rope_scaling,
  146. max_position_embeddings=max_position_embeddings,
  147. cache_config=cache_config,
  148. quant_config=quant_config,
  149. )
  150. self.mlp = OrionMLP(
  151. hidden_size=self.hidden_size,
  152. intermediate_size=config.intermediate_size,
  153. hidden_act=config.hidden_act,
  154. quant_config=quant_config,
  155. )
  156. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  157. eps=config.rms_norm_eps)
  158. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  159. eps=config.rms_norm_eps)
  160. def forward(
  161. self,
  162. positions: torch.Tensor,
  163. hidden_states: torch.Tensor,
  164. kv_cache: torch.Tensor,
  165. attn_metadata: AttentionMetadata,
  166. residual: Optional[torch.Tensor],
  167. ) -> Tuple[torch.Tensor, torch.Tensor]:
  168. # Self Attention
  169. residual = hidden_states
  170. hidden_states = self.input_layernorm(hidden_states)
  171. hidden_states = self.self_attn(
  172. positions=positions,
  173. hidden_states=hidden_states,
  174. kv_cache=kv_cache,
  175. attn_metadata=attn_metadata,
  176. )
  177. hidden_states = residual + hidden_states
  178. # Fully Connected
  179. residual = hidden_states
  180. hidden_states = self.post_attention_layernorm(hidden_states)
  181. hidden_states = self.mlp(hidden_states)
  182. hidden_states = residual + hidden_states
  183. return hidden_states, None
  184. class OrionModel(nn.Module):
  185. def __init__(
  186. self,
  187. config: PretrainedConfig,
  188. cache_config: Optional[CacheConfig] = None,
  189. quant_config: Optional[QuantizationConfig] = None,
  190. ) -> None:
  191. super().__init__()
  192. self.config = config
  193. self.padding_idx = config.pad_token_id
  194. self.vocab_size = config.vocab_size
  195. self.embed_tokens = VocabParallelEmbedding(
  196. config.vocab_size,
  197. config.hidden_size,
  198. )
  199. self.layers = nn.ModuleList([
  200. OrionDecoderLayer(config, cache_config, quant_config)
  201. for _ in range(config.num_hidden_layers)
  202. ])
  203. self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
  204. def forward(
  205. self,
  206. input_ids: torch.Tensor,
  207. positions: torch.Tensor,
  208. kv_caches: List[torch.Tensor],
  209. attn_metadata: AttentionMetadata,
  210. ) -> torch.Tensor:
  211. hidden_states = self.embed_tokens(input_ids)
  212. residual = None
  213. for i in range(len(self.layers)):
  214. layer = self.layers[i]
  215. hidden_states, residual = layer(
  216. positions,
  217. hidden_states,
  218. kv_caches[i],
  219. attn_metadata,
  220. residual,
  221. )
  222. hidden_states = self.norm(hidden_states)
  223. return hidden_states
  224. class OrionForCausalLM(nn.Module):
  225. def __init__(
  226. self,
  227. config: PretrainedConfig,
  228. cache_config: Optional[CacheConfig] = None,
  229. quant_config: Optional[QuantizationConfig] = None,
  230. ) -> None:
  231. super().__init__()
  232. self.config = config
  233. self.quant_config = quant_config
  234. self.model = OrionModel(config, cache_config, quant_config)
  235. self.lm_head = ParallelLMHead(config.vocab_size,
  236. config.hidden_size,
  237. quant_config=quant_config)
  238. self.logits_processor = LogitsProcessor(config.vocab_size)
  239. self.sampler = Sampler()
  240. def forward(
  241. self,
  242. input_ids: torch.Tensor,
  243. positions: torch.Tensor,
  244. kv_caches: List[torch.Tensor],
  245. attn_metadata: AttentionMetadata,
  246. intermediate_tensors: Optional[IntermediateTensors] = None,
  247. ) -> torch.Tensor:
  248. hidden_states = self.model(input_ids, positions, kv_caches,
  249. attn_metadata)
  250. return hidden_states
  251. def compute_logits(
  252. self,
  253. hidden_states: torch.Tensor,
  254. sampling_metadata: SamplingMetadata,
  255. ) -> Optional[torch.Tensor]:
  256. logits = self.logits_processor(self.lm_head, hidden_states,
  257. sampling_metadata)
  258. return logits
  259. def sample(
  260. self,
  261. logits: torch.Tensor,
  262. sampling_metadata: SamplingMetadata,
  263. ) -> Optional[SamplerOutput]:
  264. next_tokens = self.sampler(logits, sampling_metadata)
  265. return next_tokens
  266. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  267. stacked_params_mapping = [
  268. # (param_name, shard_name, shard_id)
  269. ("qkv_proj", "q_proj", "q"),
  270. ("qkv_proj", "k_proj", "k"),
  271. ("qkv_proj", "v_proj", "v"),
  272. ("gate_up_proj", "gate_proj", 0),
  273. ("gate_up_proj", "up_proj", 1),
  274. ]
  275. params_dict = dict(self.named_parameters())
  276. for name, loaded_weight in weights:
  277. if "rotary_emb.inv_freq" in name:
  278. continue
  279. if ("rotary_emb.cos_cached" in name
  280. or "rotary_emb.sin_cached" in name):
  281. # Models trained using ColossalAI may include these tensors in
  282. # the checkpoint. Skip them.
  283. continue
  284. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  285. if weight_name not in name:
  286. continue
  287. name = name.replace(weight_name, param_name)
  288. # Skip loading extra bias for GPTQ models.
  289. if name.endswith(".bias") and name not in params_dict:
  290. continue
  291. param = params_dict[name]
  292. weight_loader = param.weight_loader
  293. weight_loader(param, loaded_weight, shard_id)
  294. break
  295. else:
  296. # Skip loading extra bias for GPTQ models.
  297. if name.endswith(".bias") and name not in params_dict:
  298. continue
  299. param = params_dict[name]
  300. weight_loader = getattr(param, "weight_loader",
  301. default_weight_loader)
  302. weight_loader(param, loaded_weight)