xverse.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/xverse/XVERSE-7B/blob/main/modeling_xverse.py
  4. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  5. #
  6. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  7. # and OPT implementations in this library. It has been modified from its
  8. # original forms to accommodate minor architectural differences compared
  9. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. """Inference-only Xverse model compatible with HuggingFace weights."""
  23. from typing import Any, Dict, Iterable, List, Optional, Tuple
  24. import torch
  25. from torch import nn
  26. from transformers import PretrainedConfig
  27. from aphrodite.attention import Attention, AttentionMetadata
  28. from aphrodite.common.config import CacheConfig, LoRAConfig
  29. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  30. from aphrodite.distributed import get_tensor_model_parallel_world_size
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.layernorm import RMSNorm
  33. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  34. QKVParallelLinear,
  35. RowParallelLinear)
  36. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.sampler import Sampler
  39. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  40. ParallelLMHead, VocabParallelEmbedding)
  41. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  42. from aphrodite.modeling.models.interfaces import SupportsLoRA
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.quantization.base_config import QuantizationConfig
  45. class XverseMLP(nn.Module):
  46. def __init__(
  47. self,
  48. hidden_size: int,
  49. intermediate_size: int,
  50. hidden_act: str,
  51. quant_config: Optional[QuantizationConfig] = None,
  52. ) -> None:
  53. super().__init__()
  54. self.gate_up_proj = MergedColumnParallelLinear(
  55. hidden_size, [intermediate_size] * 2,
  56. bias=False,
  57. quant_config=quant_config)
  58. self.down_proj = RowParallelLinear(intermediate_size,
  59. hidden_size,
  60. bias=False,
  61. quant_config=quant_config)
  62. if hidden_act != "silu":
  63. raise ValueError(f"Unsupported activation: {hidden_act}. "
  64. "Only silu is supported for now.")
  65. self.act_fn = SiluAndMul()
  66. def forward(self, x):
  67. gate, _ = self.gate_up_proj(x)
  68. x = self.act_fn(gate)
  69. x, _ = self.down_proj(x)
  70. return x
  71. class XverseAttention(nn.Module):
  72. def __init__(
  73. self,
  74. hidden_size: int,
  75. num_heads: int,
  76. num_kv_heads: int,
  77. rope_theta: float = 10000,
  78. rope_scaling: Optional[Dict[str, Any]] = None,
  79. max_position_embeddings: int = 8192,
  80. quant_config: Optional[QuantizationConfig] = None,
  81. bias: bool = False,
  82. cache_config: Optional[CacheConfig] = None,
  83. ) -> None:
  84. super().__init__()
  85. self.hidden_size = hidden_size
  86. tp_size = get_tensor_model_parallel_world_size()
  87. self.total_num_heads = num_heads
  88. assert self.total_num_heads % tp_size == 0
  89. self.num_heads = self.total_num_heads // tp_size
  90. self.total_num_kv_heads = num_kv_heads
  91. # partition the KV heads across multiple tensor parallel GPUs.
  92. assert self.total_num_kv_heads % tp_size == 0
  93. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  94. self.head_dim = hidden_size // self.total_num_heads
  95. self.q_size = self.num_heads * self.head_dim
  96. self.kv_size = self.num_kv_heads * self.head_dim
  97. self.scaling = self.head_dim**-0.5
  98. self.rope_theta = rope_theta
  99. self.max_position_embeddings = max_position_embeddings
  100. self.qkv_proj = QKVParallelLinear(
  101. hidden_size,
  102. self.head_dim,
  103. self.total_num_heads,
  104. self.total_num_kv_heads,
  105. bias=bias,
  106. quant_config=quant_config,
  107. )
  108. self.o_proj = RowParallelLinear(
  109. self.total_num_heads * self.head_dim,
  110. hidden_size,
  111. bias=bias,
  112. quant_config=quant_config,
  113. )
  114. self.rotary_emb = get_rope(
  115. self.head_dim,
  116. rotary_dim=self.head_dim,
  117. max_position=max_position_embeddings,
  118. base=rope_theta,
  119. rope_scaling=rope_scaling,
  120. )
  121. self.attn = Attention(self.num_heads,
  122. self.head_dim,
  123. self.scaling,
  124. num_kv_heads=self.num_kv_heads,
  125. cache_config=cache_config,
  126. quant_config=quant_config)
  127. def forward(
  128. self,
  129. positions: torch.Tensor,
  130. hidden_states: torch.Tensor,
  131. kv_cache: torch.Tensor,
  132. attn_metadata: AttentionMetadata,
  133. ) -> torch.Tensor:
  134. qkv, _ = self.qkv_proj(hidden_states)
  135. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  136. q, k = self.rotary_emb(positions, q, k)
  137. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  138. output, _ = self.o_proj(attn_output)
  139. return output
  140. class XverseDecoderLayer(nn.Module):
  141. def __init__(
  142. self,
  143. config: PretrainedConfig,
  144. cache_config: Optional[CacheConfig] = None,
  145. quant_config: Optional[QuantizationConfig] = None,
  146. ) -> None:
  147. super().__init__()
  148. self.hidden_size = config.hidden_size
  149. rope_theta = getattr(config, "rope_theta", 10000)
  150. rope_scaling = getattr(config, "rope_scaling", None)
  151. max_position_embeddings = getattr(config, "max_position_embeddings",
  152. 8192)
  153. self.self_attn = XverseAttention(
  154. hidden_size=self.hidden_size,
  155. num_heads=config.num_attention_heads,
  156. num_kv_heads=getattr(config, "num_key_value_heads",
  157. config.num_attention_heads),
  158. rope_theta=rope_theta,
  159. rope_scaling=rope_scaling,
  160. max_position_embeddings=max_position_embeddings,
  161. quant_config=quant_config,
  162. bias=getattr(config, "bias", False),
  163. cache_config=cache_config,
  164. )
  165. self.mlp = XverseMLP(
  166. hidden_size=self.hidden_size,
  167. intermediate_size=config.intermediate_size,
  168. hidden_act=config.hidden_act,
  169. quant_config=quant_config,
  170. )
  171. self.input_layernorm = RMSNorm(config.hidden_size,
  172. eps=config.rms_norm_eps)
  173. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  174. eps=config.rms_norm_eps)
  175. def forward(
  176. self,
  177. positions: torch.Tensor,
  178. hidden_states: torch.Tensor,
  179. kv_cache: torch.Tensor,
  180. attn_metadata: AttentionMetadata,
  181. residual: Optional[torch.Tensor],
  182. ) -> Tuple[torch.Tensor, torch.Tensor]:
  183. # Self Attention
  184. if residual is None:
  185. residual = hidden_states
  186. hidden_states = self.input_layernorm(hidden_states)
  187. else:
  188. hidden_states, residual = self.input_layernorm(
  189. hidden_states, residual)
  190. hidden_states = self.self_attn(
  191. positions=positions,
  192. hidden_states=hidden_states,
  193. kv_cache=kv_cache,
  194. attn_metadata=attn_metadata,
  195. )
  196. # Fully Connected
  197. hidden_states, residual = self.post_attention_layernorm(
  198. hidden_states, residual)
  199. hidden_states = self.mlp(hidden_states)
  200. return hidden_states, residual
  201. class XverseModel(nn.Module):
  202. def __init__(
  203. self,
  204. config: PretrainedConfig,
  205. cache_config: Optional[CacheConfig] = None,
  206. quant_config: Optional[QuantizationConfig] = None,
  207. lora_config: Optional[LoRAConfig] = None,
  208. ) -> None:
  209. super().__init__()
  210. self.config = config
  211. self.padding_idx = config.pad_token_id
  212. lora_vocab = (lora_config.lora_extra_vocab_size *
  213. (lora_config.max_loras or 1)) if lora_config else 0
  214. self.vocab_size = config.vocab_size + lora_vocab
  215. self.org_vocab_size = config.vocab_size
  216. self.embed_tokens = VocabParallelEmbedding(
  217. self.vocab_size,
  218. config.hidden_size,
  219. org_num_embeddings=config.vocab_size,
  220. )
  221. self.layers = nn.ModuleList([
  222. XverseDecoderLayer(config, cache_config, quant_config)
  223. for _ in range(config.num_hidden_layers)
  224. ])
  225. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  226. def forward(
  227. self,
  228. input_ids: torch.Tensor,
  229. positions: torch.Tensor,
  230. kv_caches: List[torch.Tensor],
  231. attn_metadata: AttentionMetadata,
  232. ) -> torch.Tensor:
  233. hidden_states = self.embed_tokens(input_ids)
  234. residual = None
  235. for i in range(len(self.layers)):
  236. layer = self.layers[i]
  237. hidden_states, residual = layer(
  238. positions,
  239. hidden_states,
  240. kv_caches[i],
  241. attn_metadata,
  242. residual,
  243. )
  244. hidden_states, _ = self.norm(hidden_states, residual)
  245. return hidden_states
  246. class XverseForCausalLM(nn.Module, SupportsLoRA):
  247. packed_modules_mapping = {
  248. "qkv_proj": [
  249. "q_proj",
  250. "k_proj",
  251. "v_proj",
  252. ],
  253. "gate_up_proj": [
  254. "gate_proj",
  255. "up_proj",
  256. ],
  257. }
  258. # LoRA specific attributes
  259. supported_lora_modules = [
  260. "qkv_proj",
  261. "o_proj",
  262. "gate_up_proj",
  263. "down_proj",
  264. "embed_tokens",
  265. "lm_head",
  266. ]
  267. embedding_modules = {
  268. "embed_tokens": "input_embeddings",
  269. "lm_head": "output_embeddings",
  270. }
  271. embedding_padding_modules = ["lm_head"]
  272. def __init__(
  273. self,
  274. config: PretrainedConfig,
  275. cache_config: Optional[CacheConfig] = None,
  276. quant_config: Optional[QuantizationConfig] = None,
  277. lora_config: Optional[LoRAConfig] = None,
  278. ) -> None:
  279. super().__init__()
  280. self.config = config
  281. self.lora_config = lora_config
  282. self.quant_config = quant_config
  283. self.model = XverseModel(config, cache_config, quant_config)
  284. self.lm_head = ParallelLMHead(config.vocab_size,
  285. config.hidden_size,
  286. quant_config=quant_config)
  287. self.logits_processor = LogitsProcessor(config.vocab_size)
  288. self.sampler = Sampler()
  289. def forward(
  290. self,
  291. input_ids: torch.Tensor,
  292. positions: torch.Tensor,
  293. kv_caches: List[torch.Tensor],
  294. attn_metadata: AttentionMetadata,
  295. intermediate_tensors: Optional[IntermediateTensors] = None,
  296. ) -> torch.Tensor:
  297. hidden_states = self.model(input_ids, positions, kv_caches,
  298. attn_metadata)
  299. return hidden_states
  300. def compute_logits(
  301. self,
  302. hidden_states: torch.Tensor,
  303. sampling_metadata: SamplingMetadata,
  304. ) -> Optional[torch.Tensor]:
  305. logits = self.logits_processor(self.lm_head, hidden_states,
  306. sampling_metadata)
  307. return logits
  308. def sample(
  309. self,
  310. logits: torch.Tensor,
  311. sampling_metadata: SamplingMetadata,
  312. ) -> Optional[SamplerOutput]:
  313. next_tokens = self.sampler(logits, sampling_metadata)
  314. return next_tokens
  315. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  316. stacked_params_mapping = [
  317. ("qkv_proj", "q_proj", "q"),
  318. ("qkv_proj", "k_proj", "k"),
  319. ("qkv_proj", "v_proj", "v"),
  320. ("gate_up_proj", "gate_proj", 0),
  321. ("gate_up_proj", "up_proj", 1),
  322. ]
  323. params_dict = dict(self.named_parameters())
  324. for name, loaded_weight in weights:
  325. if ("rotary_emb.inv_freq" in name
  326. or "rotary_emb.cos_cached" in name
  327. or "rotary_emb.sin_cached" in name):
  328. continue
  329. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  330. if weight_name not in name:
  331. continue
  332. name = name.replace(weight_name, param_name)
  333. # Skip loading extra bias for GPTQ models.
  334. if name.endswith(".bias") and name not in params_dict:
  335. continue
  336. param = params_dict[name]
  337. weight_loader = param.weight_loader
  338. weight_loader(param, loaded_weight, shard_id)
  339. break
  340. else:
  341. # Skip loading extra bias for GPTQ models.
  342. if name.endswith(".bias") and name not in params_dict:
  343. continue
  344. param = params_dict[name]
  345. weight_loader = getattr(param, "weight_loader",
  346. default_weight_loader)
  347. weight_loader(param, loaded_weight)