baichuan.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # coding=utf-8
  2. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. """Inference-only BaiChuan model compatible with HuggingFace weights."""
  21. import math
  22. from typing import List, Optional, Tuple
  23. import torch
  24. from torch import nn
  25. from aphrodite.modeling.metadata import InputMetadata
  26. from aphrodite.modeling.layers.activation import SiluAndMul
  27. from aphrodite.modeling.layers.attention import PagedAttention
  28. from aphrodite.modeling.layers.layernorm import RMSNorm
  29. from aphrodite.modeling.layers.linear import (
  30. LinearMethodBase,
  31. MergedColumnParallelLinear,
  32. QKVParallelLinear,
  33. RowParallelLinear,
  34. ColumnParallelLinear,
  35. )
  36. from aphrodite.modeling.layers.rotary_embedding import get_rope
  37. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  38. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  39. VocabParallelEmbedding,
  40. ParallelLMHead,
  41. )
  42. from aphrodite.modeling.megatron.parallel_state import (
  43. get_tensor_model_parallel_rank,
  44. get_tensor_model_parallel_world_size,
  45. )
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.modeling.hf_downloader import (
  48. default_weight_loader,
  49. hf_model_weights_iterator,
  50. )
  51. from aphrodite.common.sequence import SamplerOutput
  52. from aphrodite.transformers_utils.configs.baichuan import BaiChuanConfig
  53. KVCache = Tuple[torch.Tensor, torch.Tensor]
  54. def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
  55. closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
  56. base = torch.tensor(
  57. 2**(-(2**-(math.log2(closest_power_of_2) - 3))),
  58. dtype=torch.float32,
  59. )
  60. powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
  61. slopes = torch.pow(base, powers)
  62. if closest_power_of_2 != total_num_heads:
  63. extra_base = torch.tensor(
  64. 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
  65. dtype=torch.float32,
  66. )
  67. num_remaining_heads = min(closest_power_of_2,
  68. total_num_heads - closest_power_of_2)
  69. extra_powers = torch.arange(start=1,
  70. end=1 + 2 * num_remaining_heads,
  71. step=2,
  72. dtype=torch.int32)
  73. slopes = torch.cat(
  74. [slopes, torch.pow(extra_base, extra_powers)], dim=0)
  75. return slopes
  76. class BaiChuanMLP(nn.Module):
  77. def __init__(
  78. self,
  79. hidden_size: int,
  80. intermediate_size: int,
  81. hidden_act: str,
  82. linear_method: Optional[LinearMethodBase] = None,
  83. ):
  84. super().__init__()
  85. if (linear_method is not None
  86. and not linear_method.quant_config.merge_weight()):
  87. self.merge_weight = False
  88. self.gate_proj = ColumnParallelLinear(
  89. hidden_size,
  90. intermediate_size,
  91. bias=False,
  92. linear_method=linear_method,
  93. )
  94. self.up_proj = ColumnParallelLinear(
  95. hidden_size,
  96. intermediate_size,
  97. bias=False,
  98. linear_method=linear_method,
  99. )
  100. else:
  101. self.merge_weight = True
  102. self.gate_up_proj = MergedColumnParallelLinear(
  103. hidden_size,
  104. [intermediate_size] * 2,
  105. bias=False,
  106. linear_method=linear_method,
  107. )
  108. self.down_proj = RowParallelLinear(
  109. intermediate_size,
  110. hidden_size,
  111. bias=False,
  112. linear_method=linear_method,
  113. )
  114. if hidden_act != "silu":
  115. raise ValueError(f"Unsupported activation: {hidden_act}. "
  116. "Only silu is supported for now.")
  117. self.act_fn = SiluAndMul()
  118. def forward(self, x):
  119. if self.merge_weight:
  120. gate_up, _ = self.gate_up_proj(x)
  121. else:
  122. up, _ = self.up_proj(x)
  123. gate, _ = self.gate_proj(x)
  124. gate_up = torch.cat([gate, up], dim=-1)
  125. x = self.act_fn(gate_up)
  126. x, _ = self.down_proj(x)
  127. return x
  128. class BaiChuanAttention(nn.Module):
  129. """Multi-headed attention from 'Attention Is All You Need' paper"""
  130. def __init__(
  131. self,
  132. hidden_size: int,
  133. num_heads: int,
  134. position_embedding: str,
  135. rope_theta: float = 10000,
  136. max_position_embeddings: int = 8192,
  137. linear_method: Optional[LinearMethodBase] = None,
  138. ):
  139. super().__init__()
  140. self.hidden_size = hidden_size
  141. tensor_model_parallel_world_size = (
  142. get_tensor_model_parallel_world_size())
  143. self.total_num_heads = num_heads
  144. assert self.total_num_heads % tensor_model_parallel_world_size == 0
  145. self.num_heads = (self.total_num_heads //
  146. tensor_model_parallel_world_size)
  147. self.head_dim = hidden_size // self.total_num_heads
  148. self.postion_embedding = position_embedding
  149. self.rope_theta = rope_theta
  150. self.max_position_embeddings = max_position_embeddings
  151. # pylint: disable=invalid-name
  152. self.W_pack = QKVParallelLinear(
  153. hidden_size,
  154. self.head_dim,
  155. self.total_num_heads,
  156. self.total_num_heads,
  157. bias=False,
  158. linear_method=linear_method,
  159. )
  160. self.o_proj = RowParallelLinear(
  161. self.total_num_heads * self.head_dim,
  162. hidden_size,
  163. bias=False,
  164. linear_method=linear_method,
  165. )
  166. # Create the alibi slopes and slice them.
  167. if self.postion_embedding == "ALIBI":
  168. tp_rank = get_tensor_model_parallel_rank()
  169. head_start = tp_rank * self.num_heads
  170. head_end = (tp_rank + 1) * self.num_heads
  171. alibi_slopes = _get_alibi_slopes(self.total_num_heads)
  172. alibi_slopes = alibi_slopes[head_start:head_end].tolist()
  173. scaling = self.head_dim**-0.5
  174. self.attn = PagedAttention(
  175. self.num_heads,
  176. self.head_dim,
  177. scaling,
  178. alibi_slopes=alibi_slopes,
  179. )
  180. else:
  181. is_neox_style = (True if linear_method is None
  182. or linear_method.quant_config.rope_style() is None
  183. else linear_method.quant_config.rope_style())
  184. self.rotary_emb = get_rope(
  185. self.head_dim,
  186. rotary_dim=self.head_dim,
  187. max_position=self.max_position_embeddings,
  188. base=self.rope_theta,
  189. is_neox_style=is_neox_style,
  190. )
  191. self.scaling = self.head_dim**-0.5
  192. self.attn = PagedAttention(self.num_heads, self.head_dim,
  193. self.scaling)
  194. def forward(
  195. self,
  196. positions: torch.Tensor,
  197. hidden_states: torch.Tensor,
  198. kv_cache: KVCache,
  199. input_metadata: InputMetadata,
  200. ) -> torch.Tensor:
  201. qkv, _ = self.W_pack(hidden_states)
  202. q, k, v = qkv.chunk(chunks=3, dim=-1)
  203. if self.postion_embedding != "ALIBI":
  204. q, k = self.rotary_emb(positions, q, k)
  205. k_cache, v_cache = kv_cache
  206. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  207. output, _ = self.o_proj(attn_output)
  208. return output
  209. class BaiChuanDecoderLayer(nn.Module):
  210. def __init__(
  211. self,
  212. config: BaiChuanConfig,
  213. position_embedding: str,
  214. linear_method: Optional[LinearMethodBase] = None,
  215. ):
  216. super().__init__()
  217. self.hidden_size = config.hidden_size
  218. rope_theta = getattr(config, "rope_theta", 10000)
  219. max_position_embeddings = getattr(config, "max_position_embeddings",
  220. 8192)
  221. self.self_attn = BaiChuanAttention(
  222. hidden_size=self.hidden_size,
  223. num_heads=config.num_attention_heads,
  224. position_embedding=position_embedding,
  225. rope_theta=rope_theta,
  226. max_position_embeddings=max_position_embeddings,
  227. linear_method=linear_method,
  228. )
  229. self.mlp = BaiChuanMLP(
  230. hidden_size=self.hidden_size,
  231. intermediate_size=config.intermediate_size,
  232. hidden_act=config.hidden_act,
  233. linear_method=linear_method,
  234. )
  235. self.input_layernorm = RMSNorm(config.hidden_size,
  236. eps=config.rms_norm_eps)
  237. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  238. eps=config.rms_norm_eps)
  239. def forward(
  240. self,
  241. positions: torch.Tensor,
  242. hidden_states: torch.Tensor,
  243. kv_cache: KVCache,
  244. input_metadata: InputMetadata,
  245. residual: Optional[torch.Tensor],
  246. ) -> Tuple[torch.Tensor, torch.Tensor]:
  247. # Self Attention
  248. if residual is None:
  249. residual = hidden_states
  250. hidden_states = self.input_layernorm(hidden_states)
  251. else:
  252. hidden_states, residual = self.input_layernorm(
  253. hidden_states, residual)
  254. hidden_states = self.self_attn(
  255. positions=positions,
  256. hidden_states=hidden_states,
  257. kv_cache=kv_cache,
  258. input_metadata=input_metadata,
  259. )
  260. # Fully Connected
  261. hidden_states, residual = self.post_attention_layernorm(
  262. hidden_states, residual)
  263. hidden_states = self.mlp(hidden_states)
  264. return hidden_states, residual
  265. class BaiChuanModel(nn.Module):
  266. def __init__(
  267. self,
  268. config: BaiChuanConfig,
  269. position_embedding: str,
  270. linear_method: Optional[LinearMethodBase] = None,
  271. ):
  272. super().__init__()
  273. self.config = config
  274. self.padding_idx = config.pad_token_id
  275. self.vocab_size = config.vocab_size
  276. self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
  277. config.hidden_size,
  278. linear_method=linear_method)
  279. self.layers = nn.ModuleList([
  280. BaiChuanDecoderLayer(config, position_embedding, linear_method)
  281. for _ in range(config.num_hidden_layers)
  282. ])
  283. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  284. def forward(
  285. self,
  286. input_ids: torch.Tensor,
  287. positions: torch.Tensor,
  288. kv_caches: List[KVCache],
  289. input_metadata: InputMetadata,
  290. ) -> torch.Tensor:
  291. hidden_states = self.embed_tokens(input_ids)
  292. residual = None
  293. for i in range(len(self.layers)):
  294. layer = self.layers[i]
  295. hidden_states, residual = layer(
  296. positions,
  297. hidden_states,
  298. kv_caches[i],
  299. input_metadata,
  300. residual,
  301. )
  302. hidden_states, _ = self.norm(hidden_states, residual)
  303. return hidden_states
  304. class BaiChuanBaseForCausalLM(nn.Module):
  305. def __init__(
  306. self,
  307. config,
  308. position_embedding: str,
  309. linear_method: Optional[LinearMethodBase] = None,
  310. ):
  311. super().__init__()
  312. self.config = config
  313. self.linear_method = linear_method
  314. self.model = BaiChuanModel(config, position_embedding, linear_method)
  315. self.lm_head = ParallelLMHead(config.vocab_size,
  316. config.hidden_size,
  317. linear_method=linear_method)
  318. self.sampler = Sampler(config.vocab_size)
  319. self.quant_sampler = QuantSampler(config.vocab_size)
  320. def forward(
  321. self,
  322. input_ids: torch.Tensor,
  323. positions: torch.Tensor,
  324. kv_caches: List[KVCache],
  325. input_metadata: InputMetadata,
  326. ) -> torch.Tensor:
  327. hidden_states = self.model(input_ids, positions, kv_caches,
  328. input_metadata)
  329. return hidden_states
  330. def sample(
  331. self,
  332. hidden_states: torch.Tensor,
  333. sampling_metadata: SamplingMetadata,
  334. ) -> Optional[SamplerOutput]:
  335. if (self.linear_method is not None
  336. and not self.linear_method.quant_config.merge_weight()):
  337. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  338. sampling_metadata)
  339. else:
  340. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  341. sampling_metadata)
  342. return next_tokens
  343. def load_weights(
  344. self,
  345. model_name_or_path: str,
  346. cache_dir: Optional[str] = None,
  347. load_format: str = "auto",
  348. revision: Optional[str] = None,
  349. ):
  350. stacked_params_mapping = [
  351. # (param_name, shard_name, shard_id)
  352. ("gate_up_proj", "gate_proj", 0),
  353. ("gate_up_proj", "up_proj", 1),
  354. ]
  355. if (self.linear_method is not None
  356. and not self.linear_method.quant_config.merge_weight()):
  357. stacked_params_mapping = []
  358. params_dict = dict(self.named_parameters())
  359. for name, loaded_weight in hf_model_weights_iterator(
  360. model_name_or_path, cache_dir, load_format, revision,
  361. self.config):
  362. if "rotary_emb.inv_freq" in name:
  363. continue
  364. if name == "lm_head.weight":
  365. # Unlike Baichuan, Baichuan2 normalizes the head weights. Ref.:
  366. # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
  367. # Distinguish between Baichuan and Baichuan2 by checking the
  368. # vocab size.
  369. is_baichuan2 = self.config.vocab_size == 125696
  370. if is_baichuan2:
  371. loaded_weight = torch.nn.functional.normalize(
  372. loaded_weight)
  373. for param_name, weight_name, shard_id in stacked_params_mapping:
  374. if weight_name not in name:
  375. continue
  376. name = name.replace(weight_name, param_name)
  377. # Skip loading extra bias for GPTQ models.
  378. if name.endswith(".bias") and name not in params_dict:
  379. continue
  380. param = params_dict[name]
  381. weight_loader = param.weight_loader
  382. weight_loader(param, loaded_weight, shard_id)
  383. break
  384. else:
  385. # Skip loading extra bias for GPTQ models.
  386. if name.endswith(".bias") and name not in params_dict:
  387. continue
  388. param = params_dict[name]
  389. weight_loader = getattr(param, "weight_loader",
  390. default_weight_loader)
  391. weight_loader(param, loaded_weight)
  392. class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
  393. """Baichuan 13B and Baichuan2 7B/13B."""
  394. def __init__(self,
  395. config,
  396. linear_method: Optional[LinearMethodBase] = None):
  397. if config.hidden_size == 4096: # baichuan2 7b
  398. super().__init__(config, "ROPE", linear_method)
  399. else: # baichuan 13b, baichuan2 13b
  400. super().__init__(config, "ALIBI", linear_method)
  401. class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
  402. """Baichuan 7B."""
  403. def __init__(self,
  404. config,
  405. linear_method: Optional[LinearMethodBase] = None):
  406. super().__init__(config, "ROPE", linear_method)