xverse.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. # coding=utf-8
  2. # Adapted from
  3. # https://huggingface.co/xverse/XVERSE-7B/blob/main/modeling_xverse.py
  4. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  5. #
  6. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  7. # and OPT implementations in this library. It has been modified from its
  8. # original forms to accommodate minor architectural differences compared
  9. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  10. #
  11. # Licensed under the Apache License, Version 2.0 (the "License");
  12. # you may not use this file except in compliance with the License.
  13. # You may obtain a copy of the License at
  14. #
  15. # http://www.apache.org/licenses/LICENSE-2.0
  16. #
  17. # Unless required by applicable law or agreed to in writing, software
  18. # distributed under the License is distributed on an "AS IS" BASIS,
  19. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  20. # See the License for the specific language governing permissions and
  21. # limitations under the License.
  22. """Inference-only Xverse model compatible with HuggingFace weights."""
  23. from typing import Any, Dict, Iterable, List, Optional, Tuple
  24. import torch
  25. from torch import nn
  26. from transformers import PretrainedConfig
  27. from aphrodite.attention import Attention, AttentionMetadata
  28. from aphrodite.common.config import CacheConfig, LoRAConfig
  29. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  30. from aphrodite.common.utils import progress_bar
  31. from aphrodite.distributed import get_tensor_model_parallel_world_size
  32. from aphrodite.modeling.layers.activation import SiluAndMul
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear)
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. ParallelLMHead, VocabParallelEmbedding)
  42. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  43. from aphrodite.modeling.models.interfaces import SupportsLoRA
  44. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  45. from aphrodite.quantization.base_config import QuantizationConfig
  46. class XverseMLP(nn.Module):
  47. def __init__(
  48. self,
  49. hidden_size: int,
  50. intermediate_size: int,
  51. hidden_act: str,
  52. quant_config: Optional[QuantizationConfig] = None,
  53. ) -> None:
  54. super().__init__()
  55. self.gate_up_proj = MergedColumnParallelLinear(
  56. hidden_size, [intermediate_size] * 2,
  57. bias=False,
  58. quant_config=quant_config)
  59. self.down_proj = RowParallelLinear(intermediate_size,
  60. hidden_size,
  61. bias=False,
  62. quant_config=quant_config)
  63. if hidden_act != "silu":
  64. raise ValueError(f"Unsupported activation: {hidden_act}. "
  65. "Only silu is supported for now.")
  66. self.act_fn = SiluAndMul()
  67. def forward(self, x):
  68. gate, _ = self.gate_up_proj(x)
  69. x = self.act_fn(gate)
  70. x, _ = self.down_proj(x)
  71. return x
  72. class XverseAttention(nn.Module):
  73. def __init__(
  74. self,
  75. hidden_size: int,
  76. num_heads: int,
  77. num_kv_heads: int,
  78. rope_theta: float = 10000,
  79. rope_scaling: Optional[Dict[str, Any]] = None,
  80. max_position_embeddings: int = 8192,
  81. quant_config: Optional[QuantizationConfig] = None,
  82. bias: bool = False,
  83. cache_config: Optional[CacheConfig] = None,
  84. ) -> None:
  85. super().__init__()
  86. self.hidden_size = hidden_size
  87. tp_size = get_tensor_model_parallel_world_size()
  88. self.total_num_heads = num_heads
  89. assert self.total_num_heads % tp_size == 0
  90. self.num_heads = self.total_num_heads // tp_size
  91. self.total_num_kv_heads = num_kv_heads
  92. # partition the KV heads across multiple tensor parallel GPUs.
  93. assert self.total_num_kv_heads % tp_size == 0
  94. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  95. self.head_dim = hidden_size // self.total_num_heads
  96. self.q_size = self.num_heads * self.head_dim
  97. self.kv_size = self.num_kv_heads * self.head_dim
  98. self.scaling = self.head_dim**-0.5
  99. self.rope_theta = rope_theta
  100. self.max_position_embeddings = max_position_embeddings
  101. self.qkv_proj = QKVParallelLinear(
  102. hidden_size,
  103. self.head_dim,
  104. self.total_num_heads,
  105. self.total_num_kv_heads,
  106. bias=bias,
  107. quant_config=quant_config,
  108. )
  109. self.o_proj = RowParallelLinear(
  110. self.total_num_heads * self.head_dim,
  111. hidden_size,
  112. bias=bias,
  113. quant_config=quant_config,
  114. )
  115. self.rotary_emb = get_rope(
  116. self.head_dim,
  117. rotary_dim=self.head_dim,
  118. max_position=max_position_embeddings,
  119. base=rope_theta,
  120. rope_scaling=rope_scaling,
  121. )
  122. self.attn = Attention(self.num_heads,
  123. self.head_dim,
  124. self.scaling,
  125. num_kv_heads=self.num_kv_heads,
  126. cache_config=cache_config,
  127. quant_config=quant_config)
  128. def forward(
  129. self,
  130. positions: torch.Tensor,
  131. hidden_states: torch.Tensor,
  132. kv_cache: torch.Tensor,
  133. attn_metadata: AttentionMetadata,
  134. ) -> torch.Tensor:
  135. qkv, _ = self.qkv_proj(hidden_states)
  136. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  137. q, k = self.rotary_emb(positions, q, k)
  138. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  139. output, _ = self.o_proj(attn_output)
  140. return output
  141. class XverseDecoderLayer(nn.Module):
  142. def __init__(
  143. self,
  144. config: PretrainedConfig,
  145. cache_config: Optional[CacheConfig] = None,
  146. quant_config: Optional[QuantizationConfig] = None,
  147. ) -> None:
  148. super().__init__()
  149. self.hidden_size = config.hidden_size
  150. rope_theta = getattr(config, "rope_theta", 10000)
  151. rope_scaling = getattr(config, "rope_scaling", None)
  152. max_position_embeddings = getattr(config, "max_position_embeddings",
  153. 8192)
  154. self.self_attn = XverseAttention(
  155. hidden_size=self.hidden_size,
  156. num_heads=config.num_attention_heads,
  157. num_kv_heads=getattr(config, "num_key_value_heads",
  158. config.num_attention_heads),
  159. rope_theta=rope_theta,
  160. rope_scaling=rope_scaling,
  161. max_position_embeddings=max_position_embeddings,
  162. quant_config=quant_config,
  163. bias=getattr(config, "bias", False),
  164. cache_config=cache_config,
  165. )
  166. self.mlp = XverseMLP(
  167. hidden_size=self.hidden_size,
  168. intermediate_size=config.intermediate_size,
  169. hidden_act=config.hidden_act,
  170. quant_config=quant_config,
  171. )
  172. self.input_layernorm = RMSNorm(config.hidden_size,
  173. eps=config.rms_norm_eps)
  174. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  175. eps=config.rms_norm_eps)
  176. def forward(
  177. self,
  178. positions: torch.Tensor,
  179. hidden_states: torch.Tensor,
  180. kv_cache: torch.Tensor,
  181. attn_metadata: AttentionMetadata,
  182. residual: Optional[torch.Tensor],
  183. ) -> Tuple[torch.Tensor, torch.Tensor]:
  184. # Self Attention
  185. if residual is None:
  186. residual = hidden_states
  187. hidden_states = self.input_layernorm(hidden_states)
  188. else:
  189. hidden_states, residual = self.input_layernorm(
  190. hidden_states, residual)
  191. hidden_states = self.self_attn(
  192. positions=positions,
  193. hidden_states=hidden_states,
  194. kv_cache=kv_cache,
  195. attn_metadata=attn_metadata,
  196. )
  197. # Fully Connected
  198. hidden_states, residual = self.post_attention_layernorm(
  199. hidden_states, residual)
  200. hidden_states = self.mlp(hidden_states)
  201. return hidden_states, residual
  202. class XverseModel(nn.Module):
  203. def __init__(
  204. self,
  205. config: PretrainedConfig,
  206. cache_config: Optional[CacheConfig] = None,
  207. quant_config: Optional[QuantizationConfig] = None,
  208. lora_config: Optional[LoRAConfig] = None,
  209. ) -> None:
  210. super().__init__()
  211. self.config = config
  212. self.padding_idx = config.pad_token_id
  213. lora_vocab = (lora_config.lora_extra_vocab_size *
  214. (lora_config.max_loras or 1)) if lora_config else 0
  215. self.vocab_size = config.vocab_size + lora_vocab
  216. self.org_vocab_size = config.vocab_size
  217. self.embed_tokens = VocabParallelEmbedding(
  218. self.vocab_size,
  219. config.hidden_size,
  220. org_num_embeddings=config.vocab_size,
  221. )
  222. self.layers = nn.ModuleList([
  223. XverseDecoderLayer(config, cache_config, quant_config)
  224. for _ in range(config.num_hidden_layers)
  225. ])
  226. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  227. def forward(
  228. self,
  229. input_ids: torch.Tensor,
  230. positions: torch.Tensor,
  231. kv_caches: List[torch.Tensor],
  232. attn_metadata: AttentionMetadata,
  233. ) -> torch.Tensor:
  234. hidden_states = self.embed_tokens(input_ids)
  235. residual = None
  236. for i in range(len(self.layers)):
  237. layer = self.layers[i]
  238. hidden_states, residual = layer(
  239. positions,
  240. hidden_states,
  241. kv_caches[i],
  242. attn_metadata,
  243. residual,
  244. )
  245. hidden_states, _ = self.norm(hidden_states, residual)
  246. return hidden_states
  247. class XverseForCausalLM(nn.Module, SupportsLoRA):
  248. packed_modules_mapping = {
  249. "qkv_proj": [
  250. "q_proj",
  251. "k_proj",
  252. "v_proj",
  253. ],
  254. "gate_up_proj": [
  255. "gate_proj",
  256. "up_proj",
  257. ],
  258. }
  259. # LoRA specific attributes
  260. supported_lora_modules = [
  261. "qkv_proj",
  262. "o_proj",
  263. "gate_up_proj",
  264. "down_proj",
  265. "embed_tokens",
  266. "lm_head",
  267. ]
  268. embedding_modules = {
  269. "embed_tokens": "input_embeddings",
  270. "lm_head": "output_embeddings",
  271. }
  272. embedding_padding_modules = ["lm_head"]
  273. def __init__(
  274. self,
  275. config: PretrainedConfig,
  276. cache_config: Optional[CacheConfig] = None,
  277. quant_config: Optional[QuantizationConfig] = None,
  278. lora_config: Optional[LoRAConfig] = None,
  279. ) -> None:
  280. super().__init__()
  281. self.config = config
  282. self.lora_config = lora_config
  283. self.quant_config = quant_config
  284. self.model = XverseModel(config, cache_config, quant_config)
  285. self.lm_head = ParallelLMHead(config.vocab_size,
  286. config.hidden_size,
  287. quant_config=quant_config)
  288. self.logits_processor = LogitsProcessor(config.vocab_size)
  289. self.sampler = Sampler()
  290. def forward(
  291. self,
  292. input_ids: torch.Tensor,
  293. positions: torch.Tensor,
  294. kv_caches: List[torch.Tensor],
  295. attn_metadata: AttentionMetadata,
  296. intermediate_tensors: Optional[IntermediateTensors] = None,
  297. ) -> torch.Tensor:
  298. hidden_states = self.model(input_ids, positions, kv_caches,
  299. attn_metadata)
  300. return hidden_states
  301. def compute_logits(
  302. self,
  303. hidden_states: torch.Tensor,
  304. sampling_metadata: SamplingMetadata,
  305. ) -> Optional[torch.Tensor]:
  306. logits = self.logits_processor(self.lm_head, hidden_states,
  307. sampling_metadata)
  308. return logits
  309. def sample(
  310. self,
  311. logits: torch.Tensor,
  312. sampling_metadata: SamplingMetadata,
  313. ) -> Optional[SamplerOutput]:
  314. next_tokens = self.sampler(logits, sampling_metadata)
  315. return next_tokens
  316. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  317. stacked_params_mapping = [
  318. ("qkv_proj", "q_proj", "q"),
  319. ("qkv_proj", "k_proj", "k"),
  320. ("qkv_proj", "v_proj", "v"),
  321. ("gate_up_proj", "gate_proj", 0),
  322. ("gate_up_proj", "up_proj", 1),
  323. ]
  324. params_dict = dict(self.named_parameters())
  325. weights_list = list(weights)
  326. for name, loaded_weight in progress_bar(weights_list,
  327. desc="Loading modules..."):
  328. if ("rotary_emb.inv_freq" in name
  329. or "rotary_emb.cos_cached" in name
  330. or "rotary_emb.sin_cached" in name):
  331. continue
  332. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  333. if weight_name not in name:
  334. continue
  335. name = name.replace(weight_name, param_name)
  336. # Skip loading extra bias for GPTQ models.
  337. if name.endswith(".bias") and name not in params_dict:
  338. continue
  339. param = params_dict[name]
  340. weight_loader = param.weight_loader
  341. weight_loader(param, loaded_weight, shard_id)
  342. break
  343. else:
  344. # Skip loading extra bias for GPTQ models.
  345. if name.endswith(".bias") and name not in params_dict:
  346. continue
  347. param = params_dict[name]
  348. weight_loader = getattr(param, "weight_loader",
  349. default_weight_loader)
  350. weight_loader(param, loaded_weight)