baichuan.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. # coding=utf-8
  2. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only BaiChuan model compatible with HuggingFace weights."""
  21. import math
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from transformers import PretrainedConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.common.config import CacheConfig, LoRAConfig
  28. from aphrodite.common.sequence import SamplerOutput
  29. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  30. get_tensor_model_parallel_world_size)
  31. from aphrodite.modeling.layers.activation import SiluAndMul
  32. from aphrodite.modeling.layers.layernorm import RMSNorm
  33. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  34. QKVParallelLinear,
  35. RowParallelLinear)
  36. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.sampler import Sampler
  39. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  40. ParallelLMHead, VocabParallelEmbedding)
  41. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  42. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  43. from aphrodite.quantization.base_config import QuantizationConfig
  44. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  45. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  46. base = torch.tensor(
  47. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  48. dtype=torch.float32,
  49. )
  50. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  51. slopes = torch.pow(base, powers)
  52. if closest_power_of_2 != total_num_heads:
  53. extra_base = torch.tensor(
  54. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  55. dtype=torch.float32,
  56. )
  57. num_remaining_heads = min(closest_power_of_2,
  58. total_num_heads - closest_power_of_2)
  59. extra_powers = torch.arange(start=1,
  60. end=1 + 2 * num_remaining_heads,
  61. step=2,
  62. dtype=torch.int32)
  63. slopes = torch.cat(
  64. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  65. return slopes
  66. class BaiChuanMLP(nn.Module):
  67. def __init__(
  68. self,
  69. hidden_size: int,
  70. intermediate_size: int,
  71. hidden_act: str,
  72. quant_config: Optional[QuantizationConfig] = None,
  73. ):
  74. super().__init__()
  75. self.gate_up_proj = MergedColumnParallelLinear(
  76. hidden_size, [intermediate_size] * 2,
  77. bias=False,
  78. quant_config=quant_config)
  79. self.down_proj = RowParallelLinear(intermediate_size,
  80. hidden_size,
  81. bias=False,
  82. quant_config=quant_config)
  83. if hidden_act != "silu":
  84. raise ValueError(f"Unsupported activation: {hidden_act}. "
  85. "Only silu is supported for now.")
  86. self.act_fn = SiluAndMul()
  87. def forward(self, x):
  88. gate_up, _ = self.gate_up_proj(x)
  89. x = self.act_fn(gate_up)
  90. x, _ = self.down_proj(x)
  91. return x
  92. class BaiChuanAttention(nn.Module):
  93. """Multi-headed attention from 'Attention Is All You Need' paper"""
  94. def __init__(
  95. self,
  96. hidden_size: int,
  97. num_heads: int,
  98. position_embedding: str,
  99. rope_theta: float = 10000,
  100. max_position_embeddings: int = 8192,
  101. cache_config: Optional[CacheConfig] = None,
  102. quant_config: Optional[QuantizationConfig] = None,
  103. ):
  104. super().__init__()
  105. self.hidden_size = hidden_size
  106. tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
  107. )
  108. self.total_num_heads = num_heads
  109. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  110. self.num_heads = (self.total_num_heads //
  111. tensor_model_parallel_world_size)
  112. self.head_dim = hidden_size // self.total_num_heads
  113. self.postion_embedding = position_embedding
  114. self.rope_theta = rope_theta
  115. self.max_position_embeddings = max_position_embeddings
  116. # pylint: disable=invalid-name
  117. self.W_pack = QKVParallelLinear(
  118. hidden_size,
  119. self.head_dim,
  120. self.total_num_heads,
  121. self.total_num_heads,
  122. bias=False,
  123. quant_config=quant_config,
  124. )
  125. self.o_proj = RowParallelLinear(
  126. self.total_num_heads * self.head_dim,
  127. hidden_size,
  128. bias=False,
  129. quant_config=quant_config,
  130. )
  131. # Create the alibi slopes and slice them.
  132. if self.postion_embedding == "ALIBI":
  133. tp_rank = get_tensor_model_parallel_rank()
  134. head_start = tp_rank * self.num_heads
  135. head_end = (tp_rank + 1) * self.num_heads
  136. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  137. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  138. scaling = self.head_dim**-0.5
  139. self.attn = Attention(self.num_heads,
  140. self.head_dim,
  141. scaling,
  142. alibi_slopes=alibi_slopes,
  143. quant_config=quant_config)
  144. else:
  145. self.rotary_emb = get_rope(
  146. self.head_dim,
  147. rotary_dim=self.head_dim,
  148. max_position=self.max_position_embeddings,
  149. base=self.rope_theta,
  150. )
  151. self.scaling = self.head_dim**-0.5
  152. self.attn = Attention(self.num_heads,
  153. self.head_dim,
  154. self.scaling,
  155. cache_config=cache_config,
  156. quant_config=quant_config)
  157. def forward(
  158. self,
  159. positions: torch.Tensor,
  160. hidden_states: torch.Tensor,
  161. kv_cache: torch.Tensor,
  162. attn_metadata: AttentionMetadata,
  163. ) -> torch.Tensor:
  164. qkv, _ = self.W_pack(hidden_states)
  165. q, k, v = qkv.chunk(chunks=3, dim=-1)
  166. if self.postion_embedding != "ALIBI":
  167. q, k = self.rotary_emb(positions, q, k)
  168. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  169. output, _ = self.o_proj(attn_output)
  170. return output
  171. class BaiChuanDecoderLayer(nn.Module):
  172. def __init__(self,
  173. config: PretrainedConfig,
  174. position_embedding: str,
  175. cache_config: Optional[CacheConfig] = None,
  176. quant_config: Optional[QuantizationConfig] = None):
  177. super().__init__()
  178. self.hidden_size = config.hidden_size
  179. rope_theta = getattr(config, "rope_theta", 10000)
  180. max_position_embeddings = getattr(config, "max_position_embeddings",
  181. 8192)
  182. self.self_attn = BaiChuanAttention(
  183. hidden_size=self.hidden_size,
  184. num_heads=config.num_attention_heads,
  185. position_embedding=position_embedding,
  186. rope_theta=rope_theta,
  187. max_position_embeddings=max_position_embeddings,
  188. cache_config=cache_config,
  189. quant_config=quant_config,
  190. )
  191. self.mlp = BaiChuanMLP(
  192. hidden_size=self.hidden_size,
  193. intermediate_size=config.intermediate_size,
  194. hidden_act=config.hidden_act,
  195. quant_config=quant_config,
  196. )
  197. self.input_layernorm = RMSNorm(config.hidden_size,
  198. eps=config.rms_norm_eps)
  199. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  200. eps=config.rms_norm_eps)
  201. def forward(
  202. self,
  203. positions: torch.Tensor,
  204. hidden_states: torch.Tensor,
  205. kv_cache: torch.Tensor,
  206. attn_metadata: AttentionMetadata,
  207. residual: Optional[torch.Tensor],
  208. ) -> Tuple[torch.Tensor, torch.Tensor]:
  209. # Self Attention
  210. if residual is None:
  211. residual = hidden_states
  212. hidden_states = self.input_layernorm(hidden_states)
  213. else:
  214. hidden_states, residual = self.input_layernorm(
  215. hidden_states, residual)
  216. hidden_states = self.self_attn(
  217. positions=positions,
  218. hidden_states=hidden_states,
  219. kv_cache=kv_cache,
  220. attn_metadata=attn_metadata,
  221. )
  222. # Fully Connected
  223. hidden_states, residual = self.post_attention_layernorm(
  224. hidden_states, residual)
  225. hidden_states = self.mlp(hidden_states)
  226. return hidden_states, residual
  227. class BaiChuanModel(nn.Module):
  228. def __init__(self,
  229. config: PretrainedConfig,
  230. position_embedding: str,
  231. cache_config: Optional[CacheConfig] = None,
  232. quant_config: Optional[QuantizationConfig] = None):
  233. super().__init__()
  234. self.config = config
  235. self.padding_idx = config.pad_token_id
  236. self.vocab_size = config.vocab_size
  237. self.embed_tokens = VocabParallelEmbedding(
  238. config.vocab_size,
  239. config.hidden_size,
  240. )
  241. self.layers = nn.ModuleList([
  242. BaiChuanDecoderLayer(config, position_embedding, cache_config,
  243. quant_config)
  244. for _ in range(config.num_hidden_layers)
  245. ])
  246. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  247. def forward(
  248. self,
  249. input_ids: torch.Tensor,
  250. positions: torch.Tensor,
  251. kv_caches: List[torch.Tensor],
  252. attn_metadata: AttentionMetadata,
  253. ) -> torch.Tensor:
  254. hidden_states = self.embed_tokens(input_ids)
  255. residual = None
  256. for i in range(len(self.layers)):
  257. layer = self.layers[i]
  258. hidden_states, residual = layer(
  259. positions,
  260. hidden_states,
  261. kv_caches[i],
  262. attn_metadata,
  263. residual,
  264. )
  265. hidden_states, _ = self.norm(hidden_states, residual)
  266. return hidden_states
  267. class BaiChuanBaseForCausalLM(nn.Module):
  268. packed_modules_mapping = {
  269. "W_pack": ["W_pack"],
  270. "gate_up_proj": [
  271. "gate_proj",
  272. "up_proj",
  273. ],
  274. }
  275. # LoRA specific attributes
  276. supported_lora_modules = [
  277. "W_pack",
  278. "o_proj",
  279. "gate_up_proj",
  280. "down_proj",
  281. ]
  282. embedding_modules = {}
  283. embedding_padding_modules = []
  284. def __init__(
  285. self,
  286. config,
  287. position_embedding: str,
  288. cache_config: Optional[CacheConfig] = None,
  289. quant_config: Optional[QuantizationConfig] = None,
  290. lora_config: Optional[LoRAConfig] = None,
  291. ):
  292. super().__init__()
  293. self.config = config
  294. self.quant_config = quant_config
  295. self.model = BaiChuanModel(config, position_embedding, cache_config,
  296. quant_config)
  297. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  298. self.logits_processor = LogitsProcessor(config.vocab_size)
  299. self.sampler = Sampler()
  300. def forward(
  301. self,
  302. input_ids: torch.Tensor,
  303. positions: torch.Tensor,
  304. kv_caches: List[torch.Tensor],
  305. attn_metadata: AttentionMetadata,
  306. ) -> torch.Tensor:
  307. hidden_states = self.model(input_ids, positions, kv_caches,
  308. attn_metadata)
  309. return hidden_states
  310. def compute_logits(self, hidden_states: torch.Tensor,
  311. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  312. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  313. sampling_metadata)
  314. return logits
  315. def sample(
  316. self,
  317. logits: torch.Tensor,
  318. sampling_metadata: SamplingMetadata,
  319. ) -> Optional[SamplerOutput]:
  320. next_tokens = self.sampler(logits, sampling_metadata)
  321. return next_tokens
  322. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  323. stacked_params_mapping = [
  324. # (param_name, shard_name, shard_id)
  325. ("gate_up_proj", "gate_proj", 0),
  326. ("gate_up_proj", "up_proj", 1),
  327. ]
  328. params_dict = dict(self.named_parameters())
  329. for name, loaded_weight in weights:
  330. if "rotary_emb.inv_freq" in name:
  331. continue
  332. if name == "lm_head.weight":
  333. # Unlike Baichuan, Baichuan2 normalizes the head weights.
  334. # Refer to:
  335. # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
  336. # Distinguish between Baichuan and Baichuan2 by checking the
  337. # vocab size. This is suggested by
  338. # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
  339. is_baichuan2 = self.config.vocab_size == 125696
  340. if is_baichuan2:
  341. loaded_weight = torch.nn.functional.normalize(
  342. loaded_weight)
  343. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  344. if weight_name not in name:
  345. continue
  346. name = name.replace(weight_name, param_name)
  347. # Skip loading extra bias for GPTQ models.
  348. if name.endswith(".bias") and name not in params_dict:
  349. continue
  350. param = params_dict[name]
  351. weight_loader = param.weight_loader
  352. weight_loader(param, loaded_weight, shard_id)
  353. break
  354. else:
  355. # Skip loading extra bias for GPTQ models.
  356. if name.endswith(".bias") and name not in params_dict:
  357. continue
  358. param = params_dict[name]
  359. weight_loader = getattr(param, "weight_loader",
  360. default_weight_loader)
  361. weight_loader(param, loaded_weight)
  362. class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
  363. """Baichuan 13B and Baichuan2 7B/13B."""
  364. def __init__(
  365. self,
  366. config,
  367. cache_config: Optional[CacheConfig] = None,
  368. quant_config: Optional[QuantizationConfig] = None,
  369. lora_config: Optional[LoRAConfig] = None,
  370. ):
  371. if config.hidden_size == 4096: # baichuan2 7b
  372. super().__init__(config, "ROPE", cache_config, quant_config,
  373. lora_config)
  374. else: # baichuan 13b, baichuan2 13b
  375. super().__init__(config, "ALIBI", cache_config, quant_config,
  376. lora_config)
  377. class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
  378. """Baichuan 7B."""
  379. def __init__(
  380. self,
  381. config,
  382. cache_config: Optional[CacheConfig] = None,
  383. quant_config: Optional[QuantizationConfig] = None,
  384. lora_config: Optional[LoRAConfig] = None,
  385. ):
  386. super().__init__(config, "ROPE", cache_config, quant_config,
  387. lora_config)