baichuan.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # coding=utf-8
  2. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only BaiChuan model compatible with HuggingFace weights."""
  21. import math
  22. from typing import List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from aphrodite.attention import Attention, AttentionMetadata
  26. from aphrodite.modeling.layers.activation import SiluAndMul
  27. from aphrodite.modeling.layers.layernorm import RMSNorm
  28. from aphrodite.modeling.layers.linear import (
  29. LinearMethodBase,
  30. MergedColumnParallelLinear,
  31. QKVParallelLinear,
  32. RowParallelLinear,
  33. ColumnParallelLinear,
  34. )
  35. from aphrodite.modeling.layers.rotary_embedding import get_rope
  36. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  37. from aphrodite.modeling.layers.sampler import Sampler
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  39. VocabParallelEmbedding,
  40. ParallelLMHead,
  41. )
  42. from aphrodite.distributed import (
  43. get_tensor_model_parallel_rank,
  44. get_tensor_model_parallel_world_size,
  45. )
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.modeling.hf_downloader import (
  48. default_weight_loader,
  49. hf_model_weights_iterator,
  50. )
  51. from aphrodite.common.sequence import SamplerOutput
  52. from aphrodite.transformers_utils.configs.baichuan import BaiChuanConfig
  53. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  54. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  55. base = torch.tensor(
  56. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  57. dtype=torch.float32,
  58. )
  59. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  60. slopes = torch.pow(base, powers)
  61. if closest_power_of_2 != total_num_heads:
  62. extra_base = torch.tensor(
  63. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  64. dtype=torch.float32,
  65. )
  66. num_remaining_heads = min(closest_power_of_2,
  67. total_num_heads - closest_power_of_2)
  68. extra_powers = torch.arange(start=1,
  69. end=1 + 2 * num_remaining_heads,
  70. step=2,
  71. dtype=torch.int32)
  72. slopes = torch.cat(
  73. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  74. return slopes
  75. class BaiChuanMLP(nn.Module):
  76. def __init__(
  77. self,
  78. hidden_size: int,
  79. intermediate_size: int,
  80. hidden_act: str,
  81. linear_method: Optional[LinearMethodBase] = None,
  82. ):
  83. super().__init__()
  84. if (linear_method is not None
  85. and not linear_method.quant_config.merge_weight()):
  86. self.merge_weight = False
  87. self.gate_proj = ColumnParallelLinear(
  88. hidden_size,
  89. intermediate_size,
  90. bias=False,
  91. linear_method=linear_method,
  92. )
  93. self.up_proj = ColumnParallelLinear(
  94. hidden_size,
  95. intermediate_size,
  96. bias=False,
  97. linear_method=linear_method,
  98. )
  99. else:
  100. self.merge_weight = True
  101. self.gate_up_proj = MergedColumnParallelLinear(
  102. hidden_size,
  103. [intermediate_size] * 2,
  104. bias=False,
  105. linear_method=linear_method,
  106. )
  107. self.down_proj = RowParallelLinear(
  108. intermediate_size,
  109. hidden_size,
  110. bias=False,
  111. linear_method=linear_method,
  112. )
  113. if hidden_act != "silu":
  114. raise ValueError(f"Unsupported activation: {hidden_act}. "
  115. "Only silu is supported for now.")
  116. self.act_fn = SiluAndMul()
  117. def forward(self, x):
  118. if self.merge_weight:
  119. gate_up, _ = self.gate_up_proj(x)
  120. else:
  121. up, _ = self.up_proj(x)
  122. gate, _ = self.gate_proj(x)
  123. gate_up = torch.cat([gate, up], dim=-1)
  124. x = self.act_fn(gate_up)
  125. x, _ = self.down_proj(x)
  126. return x
  127. class BaiChuanAttention(nn.Module):
  128. """Multi-headed attention from 'Attention Is All You Need' paper"""
  129. def __init__(
  130. self,
  131. hidden_size: int,
  132. num_heads: int,
  133. position_embedding: str,
  134. rope_theta: float = 10000,
  135. max_position_embeddings: int = 8192,
  136. linear_method: Optional[LinearMethodBase] = None,
  137. ):
  138. super().__init__()
  139. self.hidden_size = hidden_size
  140. tensor_model_parallel_world_size = (
  141. get_tensor_model_parallel_world_size())
  142. self.total_num_heads = num_heads
  143. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  144. self.num_heads = (self.total_num_heads //
  145. tensor_model_parallel_world_size)
  146. self.head_dim = hidden_size // self.total_num_heads
  147. self.postion_embedding = position_embedding
  148. self.rope_theta = rope_theta
  149. self.max_position_embeddings = max_position_embeddings
  150. # pylint: disable=invalid-name
  151. self.W_pack = QKVParallelLinear(
  152. hidden_size,
  153. self.head_dim,
  154. self.total_num_heads,
  155. self.total_num_heads,
  156. bias=False,
  157. linear_method=linear_method,
  158. )
  159. self.o_proj = RowParallelLinear(
  160. self.total_num_heads * self.head_dim,
  161. hidden_size,
  162. bias=False,
  163. linear_method=linear_method,
  164. )
  165. # Create the alibi slopes and slice them.
  166. if self.postion_embedding == "ALIBI":
  167. tp_rank = get_tensor_model_parallel_rank()
  168. head_start = tp_rank * self.num_heads
  169. head_end = (tp_rank + 1) * self.num_heads
  170. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  171. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  172. scaling = self.head_dim**-0.5
  173. self.attn = Attention(
  174. self.num_heads,
  175. self.head_dim,
  176. scaling,
  177. alibi_slopes=alibi_slopes,
  178. )
  179. else:
  180. is_neox_style = (True if linear_method is None
  181. or linear_method.quant_config.rope_style() is None
  182. else linear_method.quant_config.rope_style())
  183. self.rotary_emb = get_rope(
  184. self.head_dim,
  185. rotary_dim=self.head_dim,
  186. max_position=self.max_position_embeddings,
  187. base=self.rope_theta,
  188. is_neox_style=is_neox_style,
  189. )
  190. self.scaling = self.head_dim**-0.5
  191. self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
  192. def forward(
  193. self,
  194. positions: torch.Tensor,
  195. hidden_states: torch.Tensor,
  196. kv_cache: torch.Tensor,
  197. attn_metadata: AttentionMetadata,
  198. ) -> torch.Tensor:
  199. qkv, _ = self.W_pack(hidden_states)
  200. q, k, v = qkv.chunk(chunks=3, dim=-1)
  201. if self.postion_embedding != "ALIBI":
  202. q, k = self.rotary_emb(positions, q, k)
  203. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  204. output, _ = self.o_proj(attn_output)
  205. return output
  206. class BaiChuanDecoderLayer(nn.Module):
  207. def __init__(
  208. self,
  209. config: BaiChuanConfig,
  210. position_embedding: str,
  211. linear_method: Optional[LinearMethodBase] = None,
  212. ):
  213. super().__init__()
  214. self.hidden_size = config.hidden_size
  215. rope_theta = getattr(config, "rope_theta", 10000)
  216. max_position_embeddings = getattr(config, "max_position_embeddings",
  217. 8192)
  218. self.self_attn = BaiChuanAttention(
  219. hidden_size=self.hidden_size,
  220. num_heads=config.num_attention_heads,
  221. position_embedding=position_embedding,
  222. rope_theta=rope_theta,
  223. max_position_embeddings=max_position_embeddings,
  224. linear_method=linear_method,
  225. )
  226. self.mlp = BaiChuanMLP(
  227. hidden_size=self.hidden_size,
  228. intermediate_size=config.intermediate_size,
  229. hidden_act=config.hidden_act,
  230. linear_method=linear_method,
  231. )
  232. self.input_layernorm = RMSNorm(config.hidden_size,
  233. eps=config.rms_norm_eps)
  234. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  235. eps=config.rms_norm_eps)
  236. def forward(
  237. self,
  238. positions: torch.Tensor,
  239. hidden_states: torch.Tensor,
  240. kv_cache: torch.Tensor,
  241. attn_metadata: AttentionMetadata,
  242. residual: Optional[torch.Tensor],
  243. ) -> Tuple[torch.Tensor, torch.Tensor]:
  244. # Self Attention
  245. if residual is None:
  246. residual = hidden_states
  247. hidden_states = self.input_layernorm(hidden_states)
  248. else:
  249. hidden_states, residual = self.input_layernorm(
  250. hidden_states, residual)
  251. hidden_states = self.self_attn(
  252. positions=positions,
  253. hidden_states=hidden_states,
  254. kv_cache=kv_cache,
  255. attn_metadata=attn_metadata,
  256. )
  257. # Fully Connected
  258. hidden_states, residual = self.post_attention_layernorm(
  259. hidden_states, residual)
  260. hidden_states = self.mlp(hidden_states)
  261. return hidden_states, residual
  262. class BaiChuanModel(nn.Module):
  263. def __init__(
  264. self,
  265. config: BaiChuanConfig,
  266. position_embedding: str,
  267. linear_method: Optional[LinearMethodBase] = None,
  268. ):
  269. super().__init__()
  270. self.config = config
  271. self.padding_idx = config.pad_token_id
  272. self.vocab_size = config.vocab_size
  273. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  274. config.hidden_size,
  275. linear_method=linear_method)
  276. self.layers = nn.ModuleList([
  277. BaiChuanDecoderLayer(config, position_embedding, linear_method)
  278. for _ in range(config.num_hidden_layers)
  279. ])
  280. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  281. def forward(
  282. self,
  283. input_ids: torch.Tensor,
  284. positions: torch.Tensor,
  285. kv_caches: List[torch.Tensor],
  286. attn_metadata: AttentionMetadata,
  287. ) -> torch.Tensor:
  288. hidden_states = self.embed_tokens(input_ids)
  289. residual = None
  290. for i in range(len(self.layers)):
  291. layer = self.layers[i]
  292. hidden_states, residual = layer(
  293. positions,
  294. hidden_states,
  295. kv_caches[i],
  296. attn_metadata,
  297. residual,
  298. )
  299. hidden_states, _ = self.norm(hidden_states, residual)
  300. return hidden_states
  301. class BaiChuanBaseForCausalLM(nn.Module):
  302. def __init__(
  303. self,
  304. config,
  305. position_embedding: str,
  306. linear_method: Optional[LinearMethodBase] = None,
  307. ):
  308. super().__init__()
  309. self.config = config
  310. self.linear_method = linear_method
  311. self.model = BaiChuanModel(config, position_embedding, linear_method)
  312. self.lm_head = ParallelLMHead(config.vocab_size,
  313. config.hidden_size,
  314. linear_method=linear_method)
  315. self.logits_processor = LogitsProcessor(config.vocab_size,
  316. config.tokenizer_vocab_size)
  317. self.sampler = Sampler()
  318. def forward(
  319. self,
  320. input_ids: torch.Tensor,
  321. positions: torch.Tensor,
  322. kv_caches: List[torch.Tensor],
  323. attn_metadata: AttentionMetadata,
  324. ) -> torch.Tensor:
  325. hidden_states = self.model(input_ids, positions, kv_caches,
  326. attn_metadata)
  327. return hidden_states
  328. def compute_logits(self, hidden_states: torch.Tensor,
  329. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  330. logits = self.logits_processor(self.lm_head, hidden_states,
  331. sampling_metadata)
  332. return logits
  333. def sample(
  334. self,
  335. logits: torch.Tensor,
  336. sampling_metadata: SamplingMetadata,
  337. ) -> Optional[SamplerOutput]:
  338. next_tokens = self.sampler(logits, sampling_metadata)
  339. return next_tokens
  340. def load_weights(
  341. self,
  342. model_name_or_path: str,
  343. cache_dir: Optional[str] = None,
  344. load_format: str = "auto",
  345. revision: Optional[str] = None,
  346. ):
  347. stacked_params_mapping = [
  348. # (param_name, shard_name, shard_id)
  349. ("gate_up_proj", "gate_proj", 0),
  350. ("gate_up_proj", "up_proj", 1),
  351. ]
  352. if (self.linear_method is not None
  353. and not self.linear_method.quant_config.merge_weight()):
  354. stacked_params_mapping = []
  355. params_dict = dict(self.named_parameters())
  356. for name, loaded_weight in hf_model_weights_iterator(
  357. model_name_or_path, cache_dir, load_format, revision,
  358. self.config):
  359. if "rotary_emb.inv_freq" in name:
  360. continue
  361. if name == "lm_head.weight":
  362. # Unlike Baichuan, Baichuan2 normalizes the head weights. Ref.:
  363. # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
  364. # Distinguish between Baichuan and Baichuan2 by checking the
  365. # vocab size.
  366. is_baichuan2 = self.config.vocab_size == 125696
  367. if is_baichuan2:
  368. loaded_weight = torch.nn.functional.normalize(
  369. loaded_weight)
  370. for param_name, weight_name, shard_id in stacked_params_mapping:
  371. if weight_name not in name:
  372. continue
  373. name = name.replace(weight_name, param_name)
  374. # Skip loading extra bias for GPTQ models.
  375. if name.endswith(".bias") and name not in params_dict:
  376. continue
  377. param = params_dict[name]
  378. weight_loader = param.weight_loader
  379. weight_loader(param, loaded_weight, shard_id)
  380. break
  381. else:
  382. # Skip loading extra bias for GPTQ models.
  383. if name.endswith(".bias") and name not in params_dict:
  384. continue
  385. param = params_dict[name]
  386. weight_loader = getattr(param, "weight_loader",
  387. default_weight_loader)
  388. weight_loader(param, loaded_weight)
  389. class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
  390. """Baichuan 13B and Baichuan2 7B/13B."""
  391. def __init__(self,
  392. config,
  393. linear_method: Optional[LinearMethodBase] = None):
  394. if config.hidden_size == 4096: # baichuan2 7b
  395. super().__init__(config, "ROPE", linear_method)
  396. else: # baichuan 13b, baichuan2 13b
  397. super().__init__(config, "ALIBI", linear_method)
  398. class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
  399. """Baichuan 7B."""
  400. def __init__(self,
  401. config,
  402. linear_method: Optional[LinearMethodBase] = None):
  403. super().__init__(config, "ROPE", linear_method)