orion.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
  4. # Copyright (c) OrionStar Inc.
  5. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
  6. """Inference-only Orion-14B model compatible with HuggingFace weights."""
  7. from typing import Any, Dict, Iterable, List, Optional, Tuple
  8. import torch
  9. from torch import nn
  10. from transformers import PretrainedConfig
  11. from aphrodite.attention import Attention, AttentionMetadata
  12. from aphrodite.common.config import CacheConfig
  13. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  14. from aphrodite.common.utils import progress_bar
  15. from aphrodite.distributed import get_tensor_model_parallel_world_size
  16. from aphrodite.modeling.layers.activation import SiluAndMul
  17. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  18. QKVParallelLinear,
  19. RowParallelLinear)
  20. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  21. from aphrodite.modeling.layers.rotary_embedding import get_rope
  22. from aphrodite.modeling.layers.sampler import Sampler
  23. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  24. ParallelLMHead, VocabParallelEmbedding)
  25. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  26. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  27. from aphrodite.quantization.base_config import QuantizationConfig
  28. class OrionMLP(nn.Module):
  29. def __init__(
  30. self,
  31. hidden_size: int,
  32. intermediate_size: int,
  33. hidden_act: str,
  34. quant_config: Optional[QuantizationConfig] = None,
  35. ) -> None:
  36. super().__init__()
  37. self.gate_up_proj = MergedColumnParallelLinear(
  38. hidden_size, [intermediate_size] * 2,
  39. bias=False,
  40. quant_config=quant_config)
  41. self.down_proj = RowParallelLinear(intermediate_size,
  42. hidden_size,
  43. bias=False,
  44. quant_config=quant_config)
  45. if hidden_act != "silu":
  46. raise ValueError(f"Unsupported activation: {hidden_act}. "
  47. "Only silu is supported for now.")
  48. self.act_fn = SiluAndMul()
  49. def forward(self, x):
  50. gate_up, _ = self.gate_up_proj(x)
  51. x = self.act_fn(gate_up)
  52. x, _ = self.down_proj(x)
  53. return x
  54. class OrionAttention(nn.Module):
  55. def __init__(
  56. self,
  57. hidden_size: int,
  58. num_heads: int,
  59. num_kv_heads: int,
  60. rope_theta: float = 10000,
  61. rope_scaling: Optional[Dict[str, Any]] = None,
  62. max_position_embeddings: int = 8192,
  63. cache_config: Optional[CacheConfig] = None,
  64. quant_config: Optional[QuantizationConfig] = None,
  65. ) -> None:
  66. super().__init__()
  67. self.hidden_size = hidden_size
  68. tp_size = get_tensor_model_parallel_world_size()
  69. self.total_num_heads = num_heads
  70. assert self.total_num_heads % tp_size == 0
  71. self.num_heads = self.total_num_heads // tp_size
  72. self.total_num_kv_heads = num_kv_heads
  73. if self.total_num_kv_heads >= tp_size:
  74. # Number of KV heads is greater than TP size, so we partition
  75. # the KV heads across multiple tensor parallel GPUs.
  76. assert self.total_num_kv_heads % tp_size == 0
  77. else:
  78. # Number of KV heads is less than TP size, so we replicate
  79. # the KV heads across multiple tensor parallel GPUs.
  80. assert tp_size % self.total_num_kv_heads == 0
  81. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  82. self.head_dim = hidden_size // self.total_num_heads
  83. self.q_size = self.num_heads * self.head_dim
  84. self.kv_size = self.num_kv_heads * self.head_dim
  85. self.scaling = self.head_dim**-0.5
  86. self.rope_theta = rope_theta
  87. self.max_position_embeddings = max_position_embeddings
  88. self.qkv_proj = QKVParallelLinear(
  89. hidden_size,
  90. self.head_dim,
  91. self.total_num_heads,
  92. self.total_num_kv_heads,
  93. bias=False,
  94. quant_config=quant_config,
  95. )
  96. self.o_proj = RowParallelLinear(
  97. self.total_num_heads * self.head_dim,
  98. hidden_size,
  99. bias=False,
  100. quant_config=quant_config,
  101. )
  102. self.rotary_emb = get_rope(
  103. self.head_dim,
  104. rotary_dim=self.head_dim,
  105. max_position=max_position_embeddings,
  106. base=rope_theta,
  107. rope_scaling=rope_scaling,
  108. )
  109. self.attn = Attention(self.num_heads,
  110. self.head_dim,
  111. self.scaling,
  112. num_kv_heads=self.num_kv_heads,
  113. cache_config=cache_config,
  114. quant_config=quant_config)
  115. def forward(
  116. self,
  117. positions: torch.Tensor,
  118. hidden_states: torch.Tensor,
  119. kv_cache: torch.Tensor,
  120. attn_metadata: AttentionMetadata,
  121. ) -> torch.Tensor:
  122. qkv, _ = self.qkv_proj(hidden_states)
  123. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  124. q, k = self.rotary_emb(positions, q, k)
  125. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  126. output, _ = self.o_proj(attn_output)
  127. return output
  128. class OrionDecoderLayer(nn.Module):
  129. def __init__(
  130. self,
  131. config: PretrainedConfig,
  132. cache_config: Optional[CacheConfig] = None,
  133. quant_config: Optional[QuantizationConfig] = None,
  134. ) -> None:
  135. super().__init__()
  136. self.hidden_size = config.hidden_size
  137. rope_theta = getattr(config, "rope_theta", 10000)
  138. rope_scaling = getattr(config, "rope_scaling", None)
  139. max_position_embeddings = getattr(config, "max_position_embeddings",
  140. 8192)
  141. self.self_attn = OrionAttention(
  142. hidden_size=self.hidden_size,
  143. num_heads=config.num_attention_heads,
  144. num_kv_heads=config.num_key_value_heads,
  145. rope_theta=rope_theta,
  146. rope_scaling=rope_scaling,
  147. max_position_embeddings=max_position_embeddings,
  148. cache_config=cache_config,
  149. quant_config=quant_config,
  150. )
  151. self.mlp = OrionMLP(
  152. hidden_size=self.hidden_size,
  153. intermediate_size=config.intermediate_size,
  154. hidden_act=config.hidden_act,
  155. quant_config=quant_config,
  156. )
  157. self.input_layernorm = nn.LayerNorm(config.hidden_size,
  158. eps=config.rms_norm_eps)
  159. self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
  160. eps=config.rms_norm_eps)
  161. def forward(
  162. self,
  163. positions: torch.Tensor,
  164. hidden_states: torch.Tensor,
  165. kv_cache: torch.Tensor,
  166. attn_metadata: AttentionMetadata,
  167. residual: Optional[torch.Tensor],
  168. ) -> Tuple[torch.Tensor, torch.Tensor]:
  169. # Self Attention
  170. residual = hidden_states
  171. hidden_states = self.input_layernorm(hidden_states)
  172. hidden_states = self.self_attn(
  173. positions=positions,
  174. hidden_states=hidden_states,
  175. kv_cache=kv_cache,
  176. attn_metadata=attn_metadata,
  177. )
  178. hidden_states = residual + hidden_states
  179. # Fully Connected
  180. residual = hidden_states
  181. hidden_states = self.post_attention_layernorm(hidden_states)
  182. hidden_states = self.mlp(hidden_states)
  183. hidden_states = residual + hidden_states
  184. return hidden_states, None
  185. class OrionModel(nn.Module):
  186. def __init__(
  187. self,
  188. config: PretrainedConfig,
  189. cache_config: Optional[CacheConfig] = None,
  190. quant_config: Optional[QuantizationConfig] = None,
  191. ) -> None:
  192. super().__init__()
  193. self.config = config
  194. self.padding_idx = config.pad_token_id
  195. self.vocab_size = config.vocab_size
  196. self.embed_tokens = VocabParallelEmbedding(
  197. config.vocab_size,
  198. config.hidden_size,
  199. )
  200. self.layers = nn.ModuleList([
  201. OrionDecoderLayer(config, cache_config, quant_config)
  202. for _ in range(config.num_hidden_layers)
  203. ])
  204. self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
  205. def forward(
  206. self,
  207. input_ids: torch.Tensor,
  208. positions: torch.Tensor,
  209. kv_caches: List[torch.Tensor],
  210. attn_metadata: AttentionMetadata,
  211. ) -> torch.Tensor:
  212. hidden_states = self.embed_tokens(input_ids)
  213. residual = None
  214. for i in range(len(self.layers)):
  215. layer = self.layers[i]
  216. hidden_states, residual = layer(
  217. positions,
  218. hidden_states,
  219. kv_caches[i],
  220. attn_metadata,
  221. residual,
  222. )
  223. hidden_states = self.norm(hidden_states)
  224. return hidden_states
  225. class OrionForCausalLM(nn.Module):
  226. def __init__(
  227. self,
  228. config: PretrainedConfig,
  229. cache_config: Optional[CacheConfig] = None,
  230. quant_config: Optional[QuantizationConfig] = None,
  231. ) -> None:
  232. super().__init__()
  233. self.config = config
  234. self.quant_config = quant_config
  235. self.model = OrionModel(config, cache_config, quant_config)
  236. self.lm_head = ParallelLMHead(config.vocab_size,
  237. config.hidden_size,
  238. quant_config=quant_config)
  239. self.logits_processor = LogitsProcessor(config.vocab_size)
  240. self.sampler = Sampler()
  241. def forward(
  242. self,
  243. input_ids: torch.Tensor,
  244. positions: torch.Tensor,
  245. kv_caches: List[torch.Tensor],
  246. attn_metadata: AttentionMetadata,
  247. intermediate_tensors: Optional[IntermediateTensors] = None,
  248. ) -> torch.Tensor:
  249. hidden_states = self.model(input_ids, positions, kv_caches,
  250. attn_metadata)
  251. return hidden_states
  252. def compute_logits(
  253. self,
  254. hidden_states: torch.Tensor,
  255. sampling_metadata: SamplingMetadata,
  256. ) -> Optional[torch.Tensor]:
  257. logits = self.logits_processor(self.lm_head, hidden_states,
  258. sampling_metadata)
  259. return logits
  260. def sample(
  261. self,
  262. logits: torch.Tensor,
  263. sampling_metadata: SamplingMetadata,
  264. ) -> Optional[SamplerOutput]:
  265. next_tokens = self.sampler(logits, sampling_metadata)
  266. return next_tokens
  267. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  268. stacked_params_mapping = [
  269. # (param_name, shard_name, shard_id)
  270. ("qkv_proj", "q_proj", "q"),
  271. ("qkv_proj", "k_proj", "k"),
  272. ("qkv_proj", "v_proj", "v"),
  273. ("gate_up_proj", "gate_proj", 0),
  274. ("gate_up_proj", "up_proj", 1),
  275. ]
  276. params_dict = dict(self.named_parameters())
  277. weights_list = list(weights)
  278. for name, loaded_weight in progress_bar(weights_list,
  279. desc="Loading modules..."):
  280. if "rotary_emb.inv_freq" in name:
  281. continue
  282. if ("rotary_emb.cos_cached" in name
  283. or "rotary_emb.sin_cached" in name):
  284. # Models trained using ColossalAI may include these tensors in
  285. # the checkpoint. Skip them.
  286. continue
  287. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  288. if weight_name not in name:
  289. continue
  290. name = name.replace(weight_name, param_name)
  291. # Skip loading extra bias for GPTQ models.
  292. if name.endswith(".bias") and name not in params_dict:
  293. continue
  294. param = params_dict[name]
  295. weight_loader = param.weight_loader
  296. weight_loader(param, loaded_weight, shard_id)
  297. break
  298. else:
  299. # Skip loading extra bias for GPTQ models.
  300. if name.endswith(".bias") and name not in params_dict:
  301. continue
  302. param = params_dict[name]
  303. weight_loader = getattr(param, "weight_loader",
  304. default_weight_loader)
  305. weight_loader(param, loaded_weight)