baichuan.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. # coding=utf-8
  2. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only BaiChuan model compatible with HuggingFace weights."""
  21. import math
  22. from typing import Iterable, List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from transformers import PretrainedConfig
  26. from aphrodite.attention import Attention, AttentionMetadata
  27. from aphrodite.common.config import CacheConfig, LoRAConfig
  28. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  29. from aphrodite.common.utils import progress_bar
  30. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  31. get_tensor_model_parallel_world_size)
  32. from aphrodite.modeling.layers.activation import SiluAndMul
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear)
  37. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  38. from aphrodite.modeling.layers.rotary_embedding import get_rope
  39. from aphrodite.modeling.layers.sampler import Sampler
  40. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  41. ParallelLMHead, VocabParallelEmbedding)
  42. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  43. from aphrodite.modeling.models.interfaces import SupportsLoRA
  44. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  45. from aphrodite.quantization.base_config import QuantizationConfig
  46. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  47. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  48. base = torch.tensor(
  49. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  50. dtype=torch.float32,
  51. )
  52. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  53. slopes = torch.pow(base, powers)
  54. if closest_power_of_2 != total_num_heads:
  55. extra_base = torch.tensor(
  56. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  57. dtype=torch.float32,
  58. )
  59. num_remaining_heads = min(closest_power_of_2,
  60. total_num_heads - closest_power_of_2)
  61. extra_powers = torch.arange(start=1,
  62. end=1 + 2 * num_remaining_heads,
  63. step=2,
  64. dtype=torch.int32)
  65. slopes = torch.cat(
  66. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  67. return slopes
  68. class BaiChuanMLP(nn.Module):
  69. def __init__(
  70. self,
  71. hidden_size: int,
  72. intermediate_size: int,
  73. hidden_act: str,
  74. quant_config: Optional[QuantizationConfig] = None,
  75. ):
  76. super().__init__()
  77. self.gate_up_proj = MergedColumnParallelLinear(
  78. hidden_size, [intermediate_size] * 2,
  79. bias=False,
  80. quant_config=quant_config)
  81. self.down_proj = RowParallelLinear(intermediate_size,
  82. hidden_size,
  83. bias=False,
  84. quant_config=quant_config)
  85. if hidden_act != "silu":
  86. raise ValueError(f"Unsupported activation: {hidden_act}. "
  87. "Only silu is supported for now.")
  88. self.act_fn = SiluAndMul()
  89. def forward(self, x):
  90. gate_up, _ = self.gate_up_proj(x)
  91. x = self.act_fn(gate_up)
  92. x, _ = self.down_proj(x)
  93. return x
  94. class BaiChuanAttention(nn.Module):
  95. """Multi-headed attention from 'Attention Is All You Need' paper"""
  96. def __init__(
  97. self,
  98. hidden_size: int,
  99. num_heads: int,
  100. position_embedding: str,
  101. rope_theta: float = 10000,
  102. max_position_embeddings: int = 8192,
  103. cache_config: Optional[CacheConfig] = None,
  104. quant_config: Optional[QuantizationConfig] = None,
  105. ):
  106. super().__init__()
  107. self.hidden_size = hidden_size
  108. tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
  109. )
  110. self.total_num_heads = num_heads
  111. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  112. self.num_heads = (self.total_num_heads //
  113. tensor_model_parallel_world_size)
  114. self.head_dim = hidden_size // self.total_num_heads
  115. self.postion_embedding = position_embedding
  116. self.rope_theta = rope_theta
  117. self.max_position_embeddings = max_position_embeddings
  118. # pylint: disable=invalid-name
  119. self.W_pack = QKVParallelLinear(
  120. hidden_size,
  121. self.head_dim,
  122. self.total_num_heads,
  123. self.total_num_heads,
  124. bias=False,
  125. quant_config=quant_config,
  126. )
  127. self.o_proj = RowParallelLinear(
  128. self.total_num_heads * self.head_dim,
  129. hidden_size,
  130. bias=False,
  131. quant_config=quant_config,
  132. )
  133. # Create the alibi slopes and slice them.
  134. if self.postion_embedding == "ALIBI":
  135. tp_rank = get_tensor_model_parallel_rank()
  136. head_start = tp_rank * self.num_heads
  137. head_end = (tp_rank + 1) * self.num_heads
  138. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  139. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  140. scaling = self.head_dim**-0.5
  141. self.attn = Attention(self.num_heads,
  142. self.head_dim,
  143. scaling,
  144. alibi_slopes=alibi_slopes,
  145. quant_config=quant_config)
  146. else:
  147. self.rotary_emb = get_rope(
  148. self.head_dim,
  149. rotary_dim=self.head_dim,
  150. max_position=self.max_position_embeddings,
  151. base=self.rope_theta,
  152. )
  153. self.scaling = self.head_dim**-0.5
  154. self.attn = Attention(self.num_heads,
  155. self.head_dim,
  156. self.scaling,
  157. cache_config=cache_config,
  158. quant_config=quant_config)
  159. def forward(
  160. self,
  161. positions: torch.Tensor,
  162. hidden_states: torch.Tensor,
  163. kv_cache: torch.Tensor,
  164. attn_metadata: AttentionMetadata,
  165. ) -> torch.Tensor:
  166. qkv, _ = self.W_pack(hidden_states)
  167. q, k, v = qkv.chunk(chunks=3, dim=-1)
  168. if self.postion_embedding != "ALIBI":
  169. q, k = self.rotary_emb(positions, q, k)
  170. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  171. output, _ = self.o_proj(attn_output)
  172. return output
  173. class BaiChuanDecoderLayer(nn.Module):
  174. def __init__(self,
  175. config: PretrainedConfig,
  176. position_embedding: str,
  177. cache_config: Optional[CacheConfig] = None,
  178. quant_config: Optional[QuantizationConfig] = None):
  179. super().__init__()
  180. self.hidden_size = config.hidden_size
  181. rope_theta = getattr(config, "rope_theta", 10000)
  182. max_position_embeddings = getattr(config, "max_position_embeddings",
  183. 8192)
  184. self.self_attn = BaiChuanAttention(
  185. hidden_size=self.hidden_size,
  186. num_heads=config.num_attention_heads,
  187. position_embedding=position_embedding,
  188. rope_theta=rope_theta,
  189. max_position_embeddings=max_position_embeddings,
  190. cache_config=cache_config,
  191. quant_config=quant_config,
  192. )
  193. self.mlp = BaiChuanMLP(
  194. hidden_size=self.hidden_size,
  195. intermediate_size=config.intermediate_size,
  196. hidden_act=config.hidden_act,
  197. quant_config=quant_config,
  198. )
  199. self.input_layernorm = RMSNorm(config.hidden_size,
  200. eps=config.rms_norm_eps)
  201. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  202. eps=config.rms_norm_eps)
  203. def forward(
  204. self,
  205. positions: torch.Tensor,
  206. hidden_states: torch.Tensor,
  207. kv_cache: torch.Tensor,
  208. attn_metadata: AttentionMetadata,
  209. residual: Optional[torch.Tensor],
  210. ) -> Tuple[torch.Tensor, torch.Tensor]:
  211. # Self Attention
  212. if residual is None:
  213. residual = hidden_states
  214. hidden_states = self.input_layernorm(hidden_states)
  215. else:
  216. hidden_states, residual = self.input_layernorm(
  217. hidden_states, residual)
  218. hidden_states = self.self_attn(
  219. positions=positions,
  220. hidden_states=hidden_states,
  221. kv_cache=kv_cache,
  222. attn_metadata=attn_metadata,
  223. )
  224. # Fully Connected
  225. hidden_states, residual = self.post_attention_layernorm(
  226. hidden_states, residual)
  227. hidden_states = self.mlp(hidden_states)
  228. return hidden_states, residual
  229. class BaiChuanModel(nn.Module):
  230. def __init__(self,
  231. config: PretrainedConfig,
  232. position_embedding: str,
  233. cache_config: Optional[CacheConfig] = None,
  234. quant_config: Optional[QuantizationConfig] = None):
  235. super().__init__()
  236. self.config = config
  237. self.padding_idx = config.pad_token_id
  238. self.vocab_size = config.vocab_size
  239. self.embed_tokens = VocabParallelEmbedding(
  240. config.vocab_size,
  241. config.hidden_size,
  242. )
  243. self.layers = nn.ModuleList([
  244. BaiChuanDecoderLayer(config, position_embedding, cache_config,
  245. quant_config)
  246. for _ in range(config.num_hidden_layers)
  247. ])
  248. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  249. def forward(
  250. self,
  251. input_ids: torch.Tensor,
  252. positions: torch.Tensor,
  253. kv_caches: List[torch.Tensor],
  254. attn_metadata: AttentionMetadata,
  255. ) -> torch.Tensor:
  256. hidden_states = self.embed_tokens(input_ids)
  257. residual = None
  258. for i in range(len(self.layers)):
  259. layer = self.layers[i]
  260. hidden_states, residual = layer(
  261. positions,
  262. hidden_states,
  263. kv_caches[i],
  264. attn_metadata,
  265. residual,
  266. )
  267. hidden_states, _ = self.norm(hidden_states, residual)
  268. return hidden_states
  269. class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
  270. packed_modules_mapping = {
  271. "W_pack": ["W_pack"],
  272. "gate_up_proj": [
  273. "gate_proj",
  274. "up_proj",
  275. ],
  276. }
  277. # LoRA specific attributes
  278. supported_lora_modules = [
  279. "W_pack",
  280. "o_proj",
  281. "gate_up_proj",
  282. "down_proj",
  283. ]
  284. embedding_modules = {}
  285. embedding_padding_modules = []
  286. def __init__(
  287. self,
  288. config: PretrainedConfig,
  289. position_embedding: str,
  290. cache_config: Optional[CacheConfig] = None,
  291. quant_config: Optional[QuantizationConfig] = None,
  292. lora_config: Optional[LoRAConfig] = None,
  293. ):
  294. super().__init__()
  295. self.config = config
  296. self.lora_config: Optional[LoRAConfig] = lora_config
  297. self.quant_config = quant_config
  298. self.model = BaiChuanModel(config, position_embedding, cache_config,
  299. quant_config)
  300. self.lm_head = ParallelLMHead(config.vocab_size,
  301. config.hidden_size,
  302. quant_config=quant_config)
  303. self.logits_processor = LogitsProcessor(config.vocab_size)
  304. self.sampler = Sampler()
  305. def forward(
  306. self,
  307. input_ids: torch.Tensor,
  308. positions: torch.Tensor,
  309. kv_caches: List[torch.Tensor],
  310. attn_metadata: AttentionMetadata,
  311. intermediate_tensors: Optional[IntermediateTensors] = None,
  312. ) -> torch.Tensor:
  313. hidden_states = self.model(input_ids, positions, kv_caches,
  314. attn_metadata)
  315. return hidden_states
  316. def compute_logits(
  317. self,
  318. hidden_states: torch.Tensor,
  319. sampling_metadata: SamplingMetadata,
  320. ) -> Optional[torch.Tensor]:
  321. logits = self.logits_processor(self.lm_head, hidden_states,
  322. sampling_metadata)
  323. return logits
  324. def sample(
  325. self,
  326. logits: torch.Tensor,
  327. sampling_metadata: SamplingMetadata,
  328. ) -> Optional[SamplerOutput]:
  329. next_tokens = self.sampler(logits, sampling_metadata)
  330. return next_tokens
  331. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  332. stacked_params_mapping = [
  333. # (param_name, shard_name, shard_id)
  334. ("gate_up_proj", "gate_proj", 0),
  335. ("gate_up_proj", "up_proj", 1),
  336. ]
  337. params_dict = dict(self.named_parameters())
  338. weights_list = list(weights)
  339. for name, loaded_weight in progress_bar(weights_list,
  340. desc="Loading modules..."):
  341. if "rotary_emb.inv_freq" in name:
  342. continue
  343. if name == "lm_head.weight":
  344. # Unlike Baichuan, Baichuan2 normalizes the head weights.
  345. # Refer to:
  346. # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
  347. # Distinguish between Baichuan and Baichuan2 by checking the
  348. # vocab size. This is suggested by
  349. # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
  350. is_baichuan2 = self.config.vocab_size == 125696
  351. if is_baichuan2:
  352. loaded_weight = torch.nn.functional.normalize(
  353. loaded_weight)
  354. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  355. if weight_name not in name:
  356. continue
  357. name = name.replace(weight_name, param_name)
  358. # Skip loading extra bias for GPTQ models.
  359. if name.endswith(".bias") and name not in params_dict:
  360. continue
  361. param = params_dict[name]
  362. weight_loader = param.weight_loader
  363. weight_loader(param, loaded_weight, shard_id)
  364. break
  365. else:
  366. # Skip loading extra bias for GPTQ models.
  367. if name.endswith(".bias") and name not in params_dict:
  368. continue
  369. param = params_dict[name]
  370. weight_loader = getattr(param, "weight_loader",
  371. default_weight_loader)
  372. weight_loader(param, loaded_weight)
  373. class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
  374. """Baichuan 13B and Baichuan2 7B/13B."""
  375. def __init__(
  376. self,
  377. config,
  378. cache_config: Optional[CacheConfig] = None,
  379. quant_config: Optional[QuantizationConfig] = None,
  380. lora_config: Optional[LoRAConfig] = None,
  381. ):
  382. if config.hidden_size == 4096: # baichuan2 7b
  383. super().__init__(config, "ROPE", cache_config, quant_config,
  384. lora_config)
  385. else: # baichuan 13b, baichuan2 13b
  386. super().__init__(config, "ALIBI", cache_config, quant_config,
  387. lora_config)
  388. class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
  389. """Baichuan 7B."""
  390. def __init__(
  391. self,
  392. config,
  393. cache_config: Optional[CacheConfig] = None,
  394. quant_config: Optional[QuantizationConfig] = None,
  395. lora_config: Optional[LoRAConfig] = None,
  396. ):
  397. super().__init__(config, "ROPE", cache_config, quant_config,
  398. lora_config)