baichuan.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. # coding=utf-8
  2. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only BaiChuan model compatible with HuggingFace weights."""
  21. import math
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from transformers import PretrainedConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.common.config import LoRAConfig
  28. from aphrodite.common.sequence import SamplerOutput
  29. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  30. get_tensor_model_parallel_world_size)
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.layernorm import RMSNorm
  33. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  34. QKVParallelLinear,
  35. RowParallelLinear)
  36. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.sampler import Sampler
  39. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  40. ParallelLMHead, VocabParallelEmbedding)
  41. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  42. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  43. from aphrodite.quantization.base_config import QuantizationConfig
  44. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  45. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  46. base = torch.tensor(
  47. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  48. dtype=torch.float32,
  49. )
  50. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  51. slopes = torch.pow(base, powers)
  52. if closest_power_of_2 != total_num_heads:
  53. extra_base = torch.tensor(
  54. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  55. dtype=torch.float32,
  56. )
  57. num_remaining_heads = min(closest_power_of_2,
  58. total_num_heads - closest_power_of_2)
  59. extra_powers = torch.arange(start=1,
  60. end=1 + 2 * num_remaining_heads,
  61. step=2,
  62. dtype=torch.int32)
  63. slopes = torch.cat(
  64. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  65. return slopes
  66. class BaiChuanMLP(nn.Module):
  67. def __init__(
  68. self,
  69. hidden_size: int,
  70. intermediate_size: int,
  71. hidden_act: str,
  72. quant_config: Optional[QuantizationConfig] = None,
  73. ):
  74. super().__init__()
  75. self.gate_up_proj = MergedColumnParallelLinear(
  76. hidden_size, [intermediate_size] * 2,
  77. bias=False,
  78. quant_config=quant_config)
  79. self.down_proj = RowParallelLinear(intermediate_size,
  80. hidden_size,
  81. bias=False,
  82. quant_config=quant_config)
  83. if hidden_act != "silu":
  84. raise ValueError(f"Unsupported activation: {hidden_act}. "
  85. "Only silu is supported for now.")
  86. self.act_fn = SiluAndMul()
  87. def forward(self, x):
  88. gate_up, _ = self.gate_up_proj(x)
  89. x = self.act_fn(gate_up)
  90. x, _ = self.down_proj(x)
  91. return x
  92. class BaiChuanAttention(nn.Module):
  93. """Multi-headed attention from 'Attention Is All You Need' paper"""
  94. def __init__(
  95. self,
  96. hidden_size: int,
  97. num_heads: int,
  98. position_embedding: str,
  99. rope_theta: float = 10000,
  100. max_position_embeddings: int = 8192,
  101. quant_config: Optional[QuantizationConfig] = None,
  102. ):
  103. super().__init__()
  104. self.hidden_size = hidden_size
  105. tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
  106. )
  107. self.total_num_heads = num_heads
  108. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  109. self.num_heads = (self.total_num_heads //
  110. tensor_model_parallel_world_size)
  111. self.head_dim = hidden_size // self.total_num_heads
  112. self.postion_embedding = position_embedding
  113. self.rope_theta = rope_theta
  114. self.max_position_embeddings = max_position_embeddings
  115. # pylint: disable=invalid-name
  116. self.W_pack = QKVParallelLinear(
  117. hidden_size,
  118. self.head_dim,
  119. self.total_num_heads,
  120. self.total_num_heads,
  121. bias=False,
  122. quant_config=quant_config,
  123. )
  124. self.o_proj = RowParallelLinear(
  125. self.total_num_heads * self.head_dim,
  126. hidden_size,
  127. bias=False,
  128. quant_config=quant_config,
  129. )
  130. # Create the alibi slopes and slice them.
  131. if self.postion_embedding == "ALIBI":
  132. tp_rank = get_tensor_model_parallel_rank()
  133. head_start = tp_rank * self.num_heads
  134. head_end = (tp_rank + 1) * self.num_heads
  135. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  136. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  137. scaling = self.head_dim**-0.5
  138. self.attn = Attention(self.num_heads,
  139. self.head_dim,
  140. scaling,
  141. alibi_slopes=alibi_slopes)
  142. else:
  143. self.rotary_emb = get_rope(
  144. self.head_dim,
  145. rotary_dim=self.head_dim,
  146. max_position=self.max_position_embeddings,
  147. base=self.rope_theta,
  148. )
  149. self.scaling = self.head_dim**-0.5
  150. self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
  151. def forward(
  152. self,
  153. positions: torch.Tensor,
  154. hidden_states: torch.Tensor,
  155. kv_cache: torch.Tensor,
  156. attn_metadata: AttentionMetadata,
  157. ) -> torch.Tensor:
  158. qkv, _ = self.W_pack(hidden_states)
  159. q, k, v = qkv.chunk(chunks=3, dim=-1)
  160. if self.postion_embedding != "ALIBI":
  161. q, k = self.rotary_emb(positions, q, k)
  162. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  163. output, _ = self.o_proj(attn_output)
  164. return output
  165. class BaiChuanDecoderLayer(nn.Module):
  166. def __init__(self,
  167. config: PretrainedConfig,
  168. position_embedding: str,
  169. quant_config: Optional[QuantizationConfig] = None):
  170. super().__init__()
  171. self.hidden_size = config.hidden_size
  172. rope_theta = getattr(config, "rope_theta", 10000)
  173. max_position_embeddings = getattr(config, "max_position_embeddings",
  174. 8192)
  175. self.self_attn = BaiChuanAttention(
  176. hidden_size=self.hidden_size,
  177. num_heads=config.num_attention_heads,
  178. position_embedding=position_embedding,
  179. rope_theta=rope_theta,
  180. max_position_embeddings=max_position_embeddings,
  181. quant_config=quant_config,
  182. )
  183. self.mlp = BaiChuanMLP(
  184. hidden_size=self.hidden_size,
  185. intermediate_size=config.intermediate_size,
  186. hidden_act=config.hidden_act,
  187. quant_config=quant_config,
  188. )
  189. self.input_layernorm = RMSNorm(config.hidden_size,
  190. eps=config.rms_norm_eps)
  191. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  192. eps=config.rms_norm_eps)
  193. def forward(
  194. self,
  195. positions: torch.Tensor,
  196. hidden_states: torch.Tensor,
  197. kv_cache: torch.Tensor,
  198. attn_metadata: AttentionMetadata,
  199. residual: Optional[torch.Tensor],
  200. ) -> Tuple[torch.Tensor, torch.Tensor]:
  201. # Self Attention
  202. if residual is None:
  203. residual = hidden_states
  204. hidden_states = self.input_layernorm(hidden_states)
  205. else:
  206. hidden_states, residual = self.input_layernorm(
  207. hidden_states, residual)
  208. hidden_states = self.self_attn(
  209. positions=positions,
  210. hidden_states=hidden_states,
  211. kv_cache=kv_cache,
  212. attn_metadata=attn_metadata,
  213. )
  214. # Fully Connected
  215. hidden_states, residual = self.post_attention_layernorm(
  216. hidden_states, residual)
  217. hidden_states = self.mlp(hidden_states)
  218. return hidden_states, residual
  219. class BaiChuanModel(nn.Module):
  220. def __init__(self,
  221. config: PretrainedConfig,
  222. position_embedding: str,
  223. quant_config: Optional[QuantizationConfig] = None):
  224. super().__init__()
  225. self.config = config
  226. self.padding_idx = config.pad_token_id
  227. self.vocab_size = config.vocab_size
  228. self.embed_tokens = VocabParallelEmbedding(
  229. config.vocab_size,
  230. config.hidden_size,
  231. )
  232. self.layers = nn.ModuleList([
  233. BaiChuanDecoderLayer(config, position_embedding, quant_config)
  234. for _ in range(config.num_hidden_layers)
  235. ])
  236. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  237. def forward(
  238. self,
  239. input_ids: torch.Tensor,
  240. positions: torch.Tensor,
  241. kv_caches: List[torch.Tensor],
  242. attn_metadata: AttentionMetadata,
  243. ) -> torch.Tensor:
  244. hidden_states = self.embed_tokens(input_ids)
  245. residual = None
  246. for i in range(len(self.layers)):
  247. layer = self.layers[i]
  248. hidden_states, residual = layer(
  249. positions,
  250. hidden_states,
  251. kv_caches[i],
  252. attn_metadata,
  253. residual,
  254. )
  255. hidden_states, _ = self.norm(hidden_states, residual)
  256. return hidden_states
  257. class BaiChuanBaseForCausalLM(nn.Module):
  258. packed_modules_mapping = {
  259. "W_pack": ["W_pack"],
  260. "gate_up_proj": [
  261. "gate_proj",
  262. "up_proj",
  263. ],
  264. }
  265. # LoRA specific attributes
  266. supported_lora_modules = [
  267. "W_pack",
  268. "o_proj",
  269. "gate_up_proj",
  270. "down_proj",
  271. ]
  272. embedding_modules = {}
  273. embedding_padding_modules = []
  274. def __init__(
  275. self,
  276. config,
  277. position_embedding: str,
  278. quant_config: Optional[QuantizationConfig] = None,
  279. lora_config: Optional[LoRAConfig] = None,
  280. ):
  281. super().__init__()
  282. self.config = config
  283. self.quant_config = quant_config
  284. self.model = BaiChuanModel(config, position_embedding, quant_config)
  285. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  286. self.logits_processor = LogitsProcessor(config.vocab_size)
  287. self.sampler = Sampler()
  288. def forward(
  289. self,
  290. input_ids: torch.Tensor,
  291. positions: torch.Tensor,
  292. kv_caches: List[torch.Tensor],
  293. attn_metadata: AttentionMetadata,
  294. ) -> torch.Tensor:
  295. hidden_states = self.model(input_ids, positions, kv_caches,
  296. attn_metadata)
  297. return hidden_states
  298. def compute_logits(self, hidden_states: torch.Tensor,
  299. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  300. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  301. sampling_metadata)
  302. return logits
  303. def sample(
  304. self,
  305. logits: torch.Tensor,
  306. sampling_metadata: SamplingMetadata,
  307. ) -> Optional[SamplerOutput]:
  308. next_tokens = self.sampler(logits, sampling_metadata)
  309. return next_tokens
  310. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  311. stacked_params_mapping = [
  312. # (param_name, shard_name, shard_id)
  313. ("gate_up_proj", "gate_proj", 0),
  314. ("gate_up_proj", "up_proj", 1),
  315. ]
  316. params_dict = dict(self.named_parameters())
  317. for name, loaded_weight in weights:
  318. if "rotary_emb.inv_freq" in name:
  319. continue
  320. if name == "lm_head.weight":
  321. # Unlike Baichuan, Baichuan2 normalizes the head weights.
  322. # Refer to:
  323. # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
  324. # Distinguish between Baichuan and Baichuan2 by checking the
  325. # vocab size. This is suggested by
  326. # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
  327. is_baichuan2 = self.config.vocab_size == 125696
  328. if is_baichuan2:
  329. loaded_weight = torch.nn.functional.normalize(
  330. loaded_weight)
  331. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  332. if weight_name not in name:
  333. continue
  334. name = name.replace(weight_name, param_name)
  335. # Skip loading extra bias for GPTQ models.
  336. if name.endswith(".bias") and name not in params_dict:
  337. continue
  338. param = params_dict[name]
  339. weight_loader = param.weight_loader
  340. weight_loader(param, loaded_weight, shard_id)
  341. break
  342. else:
  343. # Skip loading extra bias for GPTQ models.
  344. if name.endswith(".bias") and name not in params_dict:
  345. continue
  346. param = params_dict[name]
  347. weight_loader = getattr(param, "weight_loader",
  348. default_weight_loader)
  349. weight_loader(param, loaded_weight)
  350. class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
  351. """Baichuan 13B and Baichuan2 7B/13B."""
  352. def __init__(
  353. self,
  354. config,
  355. quant_config: Optional[QuantizationConfig] = None,
  356. lora_config: Optional[LoRAConfig] = None,
  357. ):
  358. if config.hidden_size == 4096: # baichuan2 7b
  359. super().__init__(config, "ROPE", quant_config, lora_config)
  360. else: # baichuan 13b, baichuan2 13b
  361. super().__init__(config, "ALIBI", quant_config, lora_config)
  362. class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
  363. """Baichuan 7B."""
  364. def __init__(
  365. self,
  366. config,
  367. quant_config: Optional[QuantizationConfig] = None,
  368. lora_config: Optional[LoRAConfig] = None,
  369. ):
  370. super().__init__(config, "ROPE", quant_config, lora_config)