minicpm.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only MiniCPM model compatible with HuggingFace weights."""
  24. import math
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import LoRAConfig
  30. from aphrodite.common.sequence import SamplerOutput
  31. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  32. get_tensor_model_parallel_world_size,
  33. tensor_model_parallel_all_reduce)
  34. from aphrodite.modeling.layers.activation import SiluAndMul
  35. from aphrodite.modeling.layers.fused_moe import fused_moe
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  38. QKVParallelLinear,
  39. ReplicatedLinear,
  40. RowParallelLinear)
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.rotary_embedding import get_rope
  43. from aphrodite.modeling.layers.sampler import Sampler
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  45. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  46. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. from aphrodite.modeling.utils import set_weight_attrs
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. class MiniCPMMoE(nn.Module):
  51. """A tensor-parallel MoE implementation that shards each expert
  52. across all ranks.
  53. Each expert's weights are sharded across all ranks and a fused MoE
  54. kernel is used for the forward pass, and finally we reduce the outputs
  55. across ranks.
  56. """
  57. def __init__(
  58. self,
  59. num_experts: int,
  60. top_k: int,
  61. hidden_size: int,
  62. intermediate_size: int,
  63. params_dtype: Optional[torch.dtype] = None,
  64. tp_size: Optional[int] = None,
  65. ):
  66. super().__init__()
  67. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  68. self.num_total_experts = num_experts
  69. self.top_k = top_k
  70. self.hidden_size = hidden_size
  71. self.intermediate_size = intermediate_size // self.tp_size
  72. if params_dtype is None:
  73. params_dtype = torch.get_default_dtype()
  74. self.params_dtype = params_dtype
  75. self.gate = ReplicatedLinear(self.hidden_size,
  76. self.num_total_experts,
  77. bias=False,
  78. params_dtype=self.params_dtype,
  79. quant_config=None)
  80. self.ws = nn.Parameter(
  81. torch.empty(self.num_total_experts,
  82. 2 * self.intermediate_size,
  83. self.hidden_size,
  84. device="cuda",
  85. dtype=self.params_dtype))
  86. self.w2s = nn.Parameter(
  87. torch.empty(self.num_total_experts,
  88. self.hidden_size,
  89. self.intermediate_size,
  90. device="cuda",
  91. dtype=self.params_dtype))
  92. set_weight_attrs(self.ws, {
  93. "weight_loader": self.weight_loader,
  94. })
  95. set_weight_attrs(self.w2s, {
  96. "weight_loader": self.weight_loader,
  97. })
  98. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  99. weight_name: str, expert_id: int):
  100. tp_rank = get_tensor_model_parallel_rank()
  101. param_data = param.data
  102. shard_size = self.intermediate_size
  103. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  104. if weight_name.endswith("w1.weight"):
  105. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  106. if weight_name.endswith("w3.weight"):
  107. param_data[expert_id,
  108. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  109. if weight_name.endswith("w2.weight"):
  110. param_data[expert_id, :, :] = loaded_weight[:, shard]
  111. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  112. num_tokens, hidden_size = hidden_states.shape
  113. hidden_states = hidden_states.view(-1, self.hidden_size)
  114. # router_logits: (num_tokens, n_experts)
  115. router_logits, _ = self.gate(hidden_states)
  116. final_hidden_states = fused_moe(hidden_states,
  117. self.ws,
  118. self.w2s,
  119. router_logits,
  120. self.top_k,
  121. renormalize=True,
  122. inplace=True)
  123. if self.tp_size > 1:
  124. final_hidden_states = tensor_model_parallel_all_reduce(
  125. final_hidden_states)
  126. return final_hidden_states.view(num_tokens, hidden_size)
  127. class MiniCPMMLP(nn.Module):
  128. def __init__(
  129. self,
  130. hidden_size: int,
  131. intermediate_size: int,
  132. hidden_act: str,
  133. quant_config: Optional[QuantizationConfig] = None,
  134. ) -> None:
  135. super().__init__()
  136. self.gate_up_proj = MergedColumnParallelLinear(
  137. hidden_size, [intermediate_size] * 2,
  138. bias=False,
  139. quant_config=quant_config)
  140. self.down_proj = RowParallelLinear(intermediate_size,
  141. hidden_size,
  142. bias=False,
  143. quant_config=quant_config)
  144. if hidden_act != "silu":
  145. raise ValueError(f"Unsupported activation: {hidden_act}. "
  146. "Only silu is supported for now.")
  147. self.act_fn = SiluAndMul()
  148. def forward(self, x):
  149. gate_up, _ = self.gate_up_proj(x)
  150. x = self.act_fn(gate_up)
  151. x, _ = self.down_proj(x)
  152. return x
  153. class MiniCPMAttention(nn.Module):
  154. def __init__(
  155. self,
  156. hidden_size: int,
  157. num_heads: int,
  158. num_kv_heads: int,
  159. rope_theta: float = 10000,
  160. rope_scaling: Optional[Dict[str, Any]] = None,
  161. max_position_embeddings: int = 8192,
  162. quant_config: Optional[QuantizationConfig] = None,
  163. ) -> None:
  164. super().__init__()
  165. self.hidden_size = hidden_size
  166. tp_size = get_tensor_model_parallel_world_size()
  167. self.total_num_heads = num_heads
  168. assert self.total_num_heads % tp_size == 0
  169. self.num_heads = self.total_num_heads // tp_size
  170. self.total_num_kv_heads = num_kv_heads
  171. if self.total_num_kv_heads >= tp_size:
  172. # Number of KV heads is greater than TP size, so we partition
  173. # the KV heads across multiple tensor parallel GPUs.
  174. assert self.total_num_kv_heads % tp_size == 0
  175. else:
  176. # Number of KV heads is less than TP size, so we replicate
  177. # the KV heads across multiple tensor parallel GPUs.
  178. assert tp_size % self.total_num_kv_heads == 0
  179. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  180. self.head_dim = hidden_size // self.total_num_heads
  181. self.q_size = self.num_heads * self.head_dim
  182. self.kv_size = self.num_kv_heads * self.head_dim
  183. self.scaling = self.head_dim**-0.5
  184. self.rope_theta = rope_theta
  185. self.max_position_embeddings = max_position_embeddings
  186. self.qkv_proj = QKVParallelLinear(
  187. hidden_size,
  188. self.head_dim,
  189. self.total_num_heads,
  190. self.total_num_kv_heads,
  191. bias=False,
  192. quant_config=quant_config,
  193. )
  194. self.o_proj = RowParallelLinear(
  195. self.total_num_heads * self.head_dim,
  196. hidden_size,
  197. bias=False,
  198. quant_config=quant_config,
  199. )
  200. self.rotary_emb = get_rope(
  201. self.head_dim,
  202. rotary_dim=self.head_dim,
  203. max_position=max_position_embeddings,
  204. base=rope_theta,
  205. rope_scaling=rope_scaling,
  206. )
  207. # set rope as fp32 instead of bf16
  208. self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
  209. )
  210. self.attn = Attention(self.num_heads,
  211. self.head_dim,
  212. self.scaling,
  213. num_kv_heads=self.num_kv_heads)
  214. def forward(
  215. self,
  216. positions: torch.Tensor,
  217. hidden_states: torch.Tensor,
  218. kv_cache: torch.Tensor,
  219. attn_metadata: AttentionMetadata,
  220. ) -> torch.Tensor:
  221. qkv, _ = self.qkv_proj(hidden_states)
  222. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  223. orig_dtype = q.dtype
  224. q, k = q.float(), k.float()
  225. q, k = self.rotary_emb(positions, q, k)
  226. q, k = q.to(orig_dtype), k.to(orig_dtype)
  227. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  228. output, _ = self.o_proj(attn_output)
  229. return output
  230. class MiniCPMDecoderLayer(nn.Module):
  231. def __init__(
  232. self,
  233. config,
  234. quant_config: Optional[QuantizationConfig] = None,
  235. ) -> None:
  236. super().__init__()
  237. self.config = config
  238. self.hidden_size = config.hidden_size
  239. rope_theta = getattr(config, "rope_theta", 10000)
  240. rope_scaling = getattr(config, "rope_scaling", None)
  241. max_position_embeddings = getattr(config, "max_position_embeddings",
  242. 8192)
  243. self.self_attn = MiniCPMAttention(
  244. hidden_size=self.hidden_size,
  245. num_heads=config.num_attention_heads,
  246. num_kv_heads=config.num_key_value_heads,
  247. rope_theta=rope_theta,
  248. rope_scaling=rope_scaling,
  249. max_position_embeddings=max_position_embeddings,
  250. quant_config=quant_config,
  251. )
  252. self.num_experts = getattr(self.config, "num_experts", 0)
  253. if self.num_experts == 0:
  254. self.mlp = MiniCPMMLP(
  255. hidden_size=self.hidden_size,
  256. intermediate_size=config.intermediate_size,
  257. hidden_act=config.hidden_act,
  258. quant_config=quant_config,
  259. )
  260. else:
  261. self.mlp = MiniCPMMoE(num_experts=config.num_experts,
  262. top_k=config.num_experts_per_tok,
  263. hidden_size=config.hidden_size,
  264. intermediate_size=config.intermediate_size)
  265. self.input_layernorm = RMSNorm(config.hidden_size,
  266. eps=config.rms_norm_eps)
  267. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  268. eps=config.rms_norm_eps)
  269. def forward(
  270. self,
  271. positions: torch.Tensor,
  272. hidden_states: torch.Tensor,
  273. kv_cache: torch.Tensor,
  274. attn_metadata: AttentionMetadata,
  275. residual: Optional[torch.Tensor],
  276. ) -> Tuple[torch.Tensor, torch.Tensor]:
  277. # Self Attention
  278. residual = hidden_states
  279. hidden_states = self.input_layernorm(hidden_states)
  280. hidden_states = self.self_attn(
  281. positions=positions,
  282. hidden_states=hidden_states,
  283. kv_cache=kv_cache,
  284. attn_metadata=attn_metadata,
  285. )
  286. hidden_states = residual + hidden_states * \
  287. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  288. # Fully Connected
  289. residual = hidden_states
  290. hidden_states = self.post_attention_layernorm(hidden_states)
  291. hidden_states = self.mlp(hidden_states)
  292. hidden_states = residual + hidden_states * \
  293. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  294. return hidden_states, None
  295. class MiniCPMModel(nn.Module):
  296. def __init__(
  297. self,
  298. config,
  299. quant_config: Optional[QuantizationConfig] = None,
  300. lora_config: Optional[LoRAConfig] = None,
  301. ) -> None:
  302. super().__init__()
  303. self.config = config
  304. self.padding_idx = config.pad_token_id
  305. lora_vocab = (lora_config.lora_extra_vocab_size *
  306. (lora_config.max_loras or 1)) if lora_config else 0
  307. self.vocab_size = config.vocab_size + lora_vocab
  308. self.org_vocab_size = config.vocab_size
  309. self.embed_tokens = VocabParallelEmbedding(
  310. self.vocab_size,
  311. config.hidden_size,
  312. org_num_embeddings=config.vocab_size,
  313. )
  314. self.layers = nn.ModuleList([
  315. MiniCPMDecoderLayer(config, quant_config)
  316. for _ in range(config.num_hidden_layers)
  317. ])
  318. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  319. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  320. embedding = self.embed_tokens(input_ids)
  321. return embedding * self.config.scale_emb
  322. def forward(
  323. self,
  324. input_ids: torch.Tensor,
  325. positions: torch.Tensor,
  326. kv_caches: List[torch.Tensor],
  327. attn_metadata: AttentionMetadata,
  328. inputs_embeds: Optional[torch.Tensor] = None,
  329. ) -> torch.Tensor:
  330. if inputs_embeds is not None:
  331. hidden_states = inputs_embeds
  332. else:
  333. hidden_states = self.get_input_embeddings(input_ids)
  334. residual = None
  335. for i in range(len(self.layers)):
  336. layer = self.layers[i]
  337. hidden_states, residual = layer(
  338. positions,
  339. hidden_states,
  340. kv_caches[i],
  341. attn_metadata,
  342. residual,
  343. )
  344. hidden_states = self.norm(hidden_states)
  345. return hidden_states
  346. class MiniCPMForCausalLM(nn.Module):
  347. packed_modules_mapping = {
  348. "qkv_proj": [
  349. "q_proj",
  350. "k_proj",
  351. "v_proj",
  352. ],
  353. "gate_up_proj": [
  354. "gate_proj",
  355. "up_proj",
  356. ],
  357. }
  358. # LoRA specific attributes
  359. supported_lora_modules = [
  360. "qkv_proj",
  361. "o_proj",
  362. "gate_up_proj",
  363. "down_proj",
  364. "embed_tokens",
  365. "lm_head",
  366. ]
  367. embedding_modules = {
  368. "embed_tokens": "input_embeddings",
  369. "lm_head": "output_embeddings",
  370. }
  371. embedding_padding_modules = ["lm_head"]
  372. def __init__(
  373. self,
  374. config,
  375. quant_config: Optional[QuantizationConfig] = None,
  376. lora_config: Optional[LoRAConfig] = None,
  377. ) -> None:
  378. super().__init__()
  379. self.config = config
  380. self.num_experts = getattr(self.config, "num_experts", 0)
  381. self.quant_config = quant_config
  382. self.model = MiniCPMModel(config,
  383. quant_config,
  384. lora_config=lora_config)
  385. unpadded_vocab_size = config.vocab_size
  386. if lora_config:
  387. unpadded_vocab_size += lora_config.lora_extra_vocab_size
  388. if not self.config.tie_word_embeddings:
  389. self.lm_head = ParallelLMHead(
  390. unpadded_vocab_size,
  391. config.hidden_size,
  392. org_num_embeddings=config.vocab_size,
  393. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  394. # We need bigger padding if using lora for kernel
  395. # compatibility
  396. if not lora_config else lora_config.lora_vocab_padding_size,
  397. )
  398. self.scale_width = self.config.hidden_size / self.config.dim_model_base
  399. self.logits_processor = LogitsProcessor(unpadded_vocab_size,
  400. config.vocab_size)
  401. self.sampler = Sampler()
  402. def forward(
  403. self,
  404. input_ids: torch.Tensor,
  405. positions: torch.Tensor,
  406. kv_caches: List[torch.Tensor],
  407. attn_metadata: AttentionMetadata,
  408. ) -> torch.Tensor:
  409. hidden_states = self.model(input_ids, positions, kv_caches,
  410. attn_metadata)
  411. return hidden_states
  412. def compute_logits(self, hidden_states: torch.Tensor,
  413. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  414. hidden_states = hidden_states / self.scale_width
  415. if self.config.tie_word_embeddings:
  416. lm_head_weight = self.model.embed_tokens.weight
  417. else:
  418. lm_head_weight = self.lm_head.weight
  419. logits = self.logits_processor(lm_head_weight, hidden_states,
  420. sampling_metadata)
  421. return logits
  422. def sample(
  423. self,
  424. logits: torch.Tensor,
  425. sampling_metadata: SamplingMetadata,
  426. ) -> Optional[SamplerOutput]:
  427. next_tokens = self.sampler(logits, sampling_metadata)
  428. return next_tokens
  429. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  430. stacked_params_mapping = [
  431. # (param_name, shard_name, shard_id)
  432. ("qkv_proj", "q_proj", "q"),
  433. ("qkv_proj", "k_proj", "k"),
  434. ("qkv_proj", "v_proj", "v"),
  435. ("gate_up_proj", "gate_proj", 0),
  436. ("gate_up_proj", "up_proj", 1),
  437. ]
  438. expert_params_mapping = [
  439. # (param_name, weight_name, expert_id)
  440. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  441. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  442. for expert_id in range(self.num_experts)
  443. for weight_name in ["w1", "w2", "w3"]
  444. ]
  445. params_dict = dict(self.named_parameters())
  446. for name, loaded_weight in weights:
  447. if "rotary_emb.inv_freq" in name:
  448. continue
  449. if ("rotary_emb.cos_cached" in name
  450. or "rotary_emb.sin_cached" in name):
  451. # Models trained using ColossalAI may include these tensors in
  452. # the checkpoint. Skip them.
  453. continue
  454. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  455. if weight_name not in name:
  456. continue
  457. name = name.replace(weight_name, param_name)
  458. # Skip loading extra bias for GPTQ models.
  459. if name.endswith(".bias") and name not in params_dict:
  460. continue
  461. param = params_dict[name]
  462. weight_loader = param.weight_loader
  463. weight_loader(param, loaded_weight, shard_id)
  464. break
  465. else:
  466. for param_name, weight_name, expert_id in expert_params_mapping:
  467. if weight_name not in name:
  468. continue
  469. name = name.replace(weight_name, param_name)
  470. param = params_dict[name]
  471. weight_loader = param.weight_loader
  472. weight_loader(param,
  473. loaded_weight,
  474. weight_name,
  475. expert_id=expert_id)
  476. break
  477. else:
  478. # Skip loading extra bias for GPTQ models.
  479. if name.endswith(".bias") and name not in params_dict:
  480. continue
  481. param = params_dict[name]
  482. weight_loader = getattr(param, "weight_loader",
  483. default_weight_loader)
  484. weight_loader(param, loaded_weight)