minicpm.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only MiniCPM model compatible with HuggingFace weights."""
  24. import math
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import SamplerOutput
  31. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  32. get_tensor_model_parallel_world_size,
  33. tensor_model_parallel_all_reduce)
  34. from aphrodite.modeling.layers.activation import SiluAndMul
  35. from aphrodite.modeling.layers.fused_moe import fused_moe
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  38. QKVParallelLinear,
  39. ReplicatedLinear,
  40. RowParallelLinear)
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.rotary_embedding import get_rope
  43. from aphrodite.modeling.layers.sampler import Sampler
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  45. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  46. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. from aphrodite.modeling.utils import set_weight_attrs
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. class MiniCPMMoE(nn.Module):
  51. """A tensor-parallel MoE implementation that shards each expert
  52. across all ranks.
  53. Each expert's weights are sharded across all ranks and a fused MoE
  54. kernel is used for the forward pass, and finally we reduce the outputs
  55. across ranks.
  56. """
  57. def __init__(
  58. self,
  59. num_experts: int,
  60. top_k: int,
  61. hidden_size: int,
  62. intermediate_size: int,
  63. params_dtype: Optional[torch.dtype] = None,
  64. tp_size: Optional[int] = None,
  65. ):
  66. super().__init__()
  67. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  68. self.num_total_experts = num_experts
  69. self.top_k = top_k
  70. self.hidden_size = hidden_size
  71. self.intermediate_size = intermediate_size // self.tp_size
  72. if params_dtype is None:
  73. params_dtype = torch.get_default_dtype()
  74. self.params_dtype = params_dtype
  75. self.gate = ReplicatedLinear(self.hidden_size,
  76. self.num_total_experts,
  77. bias=False,
  78. params_dtype=self.params_dtype,
  79. quant_config=None)
  80. self.ws = nn.Parameter(
  81. torch.empty(self.num_total_experts,
  82. 2 * self.intermediate_size,
  83. self.hidden_size,
  84. device="cuda",
  85. dtype=self.params_dtype))
  86. self.w2s = nn.Parameter(
  87. torch.empty(self.num_total_experts,
  88. self.hidden_size,
  89. self.intermediate_size,
  90. device="cuda",
  91. dtype=self.params_dtype))
  92. set_weight_attrs(self.ws, {
  93. "weight_loader": self.weight_loader,
  94. })
  95. set_weight_attrs(self.w2s, {
  96. "weight_loader": self.weight_loader,
  97. })
  98. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  99. weight_name: str, expert_id: int):
  100. tp_rank = get_tensor_model_parallel_rank()
  101. param_data = param.data
  102. shard_size = self.intermediate_size
  103. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  104. if weight_name.endswith("w1.weight"):
  105. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  106. if weight_name.endswith("w3.weight"):
  107. param_data[expert_id,
  108. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  109. if weight_name.endswith("w2.weight"):
  110. param_data[expert_id, :, :] = loaded_weight[:, shard]
  111. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  112. num_tokens, hidden_size = hidden_states.shape
  113. hidden_states = hidden_states.view(-1, self.hidden_size)
  114. # router_logits: (num_tokens, n_experts)
  115. router_logits, _ = self.gate(hidden_states)
  116. final_hidden_states = fused_moe(hidden_states,
  117. self.ws,
  118. self.w2s,
  119. router_logits,
  120. self.top_k,
  121. renormalize=True,
  122. inplace=True)
  123. if self.tp_size > 1:
  124. final_hidden_states = tensor_model_parallel_all_reduce(
  125. final_hidden_states)
  126. return final_hidden_states.view(num_tokens, hidden_size)
  127. class MiniCPMMLP(nn.Module):
  128. def __init__(
  129. self,
  130. hidden_size: int,
  131. intermediate_size: int,
  132. hidden_act: str,
  133. quant_config: Optional[QuantizationConfig] = None,
  134. ) -> None:
  135. super().__init__()
  136. self.gate_up_proj = MergedColumnParallelLinear(
  137. hidden_size, [intermediate_size] * 2,
  138. bias=False,
  139. quant_config=quant_config)
  140. self.down_proj = RowParallelLinear(intermediate_size,
  141. hidden_size,
  142. bias=False,
  143. quant_config=quant_config)
  144. if hidden_act != "silu":
  145. raise ValueError(f"Unsupported activation: {hidden_act}. "
  146. "Only silu is supported for now.")
  147. self.act_fn = SiluAndMul()
  148. def forward(self, x):
  149. gate_up, _ = self.gate_up_proj(x)
  150. x = self.act_fn(gate_up)
  151. x, _ = self.down_proj(x)
  152. return x
  153. class MiniCPMAttention(nn.Module):
  154. def __init__(
  155. self,
  156. hidden_size: int,
  157. num_heads: int,
  158. num_kv_heads: int,
  159. rope_theta: float = 10000,
  160. rope_scaling: Optional[Dict[str, Any]] = None,
  161. max_position_embeddings: int = 8192,
  162. cache_config: Optional[CacheConfig] = None,
  163. quant_config: Optional[QuantizationConfig] = None,
  164. ) -> None:
  165. super().__init__()
  166. self.hidden_size = hidden_size
  167. tp_size = get_tensor_model_parallel_world_size()
  168. self.total_num_heads = num_heads
  169. assert self.total_num_heads % tp_size == 0
  170. self.num_heads = self.total_num_heads // tp_size
  171. self.total_num_kv_heads = num_kv_heads
  172. if self.total_num_kv_heads >= tp_size:
  173. # Number of KV heads is greater than TP size, so we partition
  174. # the KV heads across multiple tensor parallel GPUs.
  175. assert self.total_num_kv_heads % tp_size == 0
  176. else:
  177. # Number of KV heads is less than TP size, so we replicate
  178. # the KV heads across multiple tensor parallel GPUs.
  179. assert tp_size % self.total_num_kv_heads == 0
  180. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  181. self.head_dim = hidden_size // self.total_num_heads
  182. self.q_size = self.num_heads * self.head_dim
  183. self.kv_size = self.num_kv_heads * self.head_dim
  184. self.scaling = self.head_dim**-0.5
  185. self.rope_theta = rope_theta
  186. self.max_position_embeddings = max_position_embeddings
  187. self.qkv_proj = QKVParallelLinear(
  188. hidden_size,
  189. self.head_dim,
  190. self.total_num_heads,
  191. self.total_num_kv_heads,
  192. bias=False,
  193. quant_config=quant_config,
  194. )
  195. self.o_proj = RowParallelLinear(
  196. self.total_num_heads * self.head_dim,
  197. hidden_size,
  198. bias=False,
  199. quant_config=quant_config,
  200. )
  201. self.rotary_emb = get_rope(
  202. self.head_dim,
  203. rotary_dim=self.head_dim,
  204. max_position=max_position_embeddings,
  205. base=rope_theta,
  206. rope_scaling=rope_scaling,
  207. )
  208. # set rope as fp32 instead of bf16
  209. self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
  210. )
  211. self.attn = Attention(self.num_heads,
  212. self.head_dim,
  213. self.scaling,
  214. num_kv_heads=self.num_kv_heads,
  215. cache_config=cache_config,
  216. quant_config=quant_config)
  217. def forward(
  218. self,
  219. positions: torch.Tensor,
  220. hidden_states: torch.Tensor,
  221. kv_cache: torch.Tensor,
  222. attn_metadata: AttentionMetadata,
  223. ) -> torch.Tensor:
  224. qkv, _ = self.qkv_proj(hidden_states)
  225. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  226. orig_dtype = q.dtype
  227. q, k = q.float(), k.float()
  228. q, k = self.rotary_emb(positions, q, k)
  229. q, k = q.to(orig_dtype), k.to(orig_dtype)
  230. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  231. output, _ = self.o_proj(attn_output)
  232. return output
  233. class MiniCPMDecoderLayer(nn.Module):
  234. def __init__(
  235. self,
  236. config,
  237. cache_config: Optional[CacheConfig] = None,
  238. quant_config: Optional[QuantizationConfig] = None,
  239. ) -> None:
  240. super().__init__()
  241. self.config = config
  242. self.hidden_size = config.hidden_size
  243. rope_theta = getattr(config, "rope_theta", 10000)
  244. rope_scaling = getattr(config, "rope_scaling", None)
  245. max_position_embeddings = getattr(config, "max_position_embeddings",
  246. 8192)
  247. self.self_attn = MiniCPMAttention(
  248. hidden_size=self.hidden_size,
  249. num_heads=config.num_attention_heads,
  250. num_kv_heads=config.num_key_value_heads,
  251. rope_theta=rope_theta,
  252. rope_scaling=rope_scaling,
  253. max_position_embeddings=max_position_embeddings,
  254. cache_config=cache_config,
  255. quant_config=quant_config,
  256. )
  257. self.num_experts = getattr(self.config, "num_experts", 0)
  258. if self.num_experts == 0:
  259. self.mlp = MiniCPMMLP(
  260. hidden_size=self.hidden_size,
  261. intermediate_size=config.intermediate_size,
  262. hidden_act=config.hidden_act,
  263. quant_config=quant_config,
  264. )
  265. else:
  266. self.mlp = MiniCPMMoE(num_experts=config.num_experts,
  267. top_k=config.num_experts_per_tok,
  268. hidden_size=config.hidden_size,
  269. intermediate_size=config.intermediate_size)
  270. self.input_layernorm = RMSNorm(config.hidden_size,
  271. eps=config.rms_norm_eps)
  272. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  273. eps=config.rms_norm_eps)
  274. def forward(
  275. self,
  276. positions: torch.Tensor,
  277. hidden_states: torch.Tensor,
  278. kv_cache: torch.Tensor,
  279. attn_metadata: AttentionMetadata,
  280. residual: Optional[torch.Tensor],
  281. ) -> Tuple[torch.Tensor, torch.Tensor]:
  282. # Self Attention
  283. residual = hidden_states
  284. hidden_states = self.input_layernorm(hidden_states)
  285. hidden_states = self.self_attn(
  286. positions=positions,
  287. hidden_states=hidden_states,
  288. kv_cache=kv_cache,
  289. attn_metadata=attn_metadata,
  290. )
  291. hidden_states = residual + hidden_states * \
  292. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  293. # Fully Connected
  294. residual = hidden_states
  295. hidden_states = self.post_attention_layernorm(hidden_states)
  296. hidden_states = self.mlp(hidden_states)
  297. hidden_states = residual + hidden_states * \
  298. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  299. return hidden_states, None
  300. class MiniCPMModel(nn.Module):
  301. def __init__(
  302. self,
  303. config,
  304. cache_config: Optional[CacheConfig] = None,
  305. quant_config: Optional[QuantizationConfig] = None,
  306. lora_config: Optional[LoRAConfig] = None,
  307. ) -> None:
  308. super().__init__()
  309. self.config = config
  310. self.padding_idx = config.pad_token_id
  311. lora_vocab = (lora_config.lora_extra_vocab_size *
  312. (lora_config.max_loras or 1)) if lora_config else 0
  313. self.vocab_size = config.vocab_size + lora_vocab
  314. self.org_vocab_size = config.vocab_size
  315. self.embed_tokens = VocabParallelEmbedding(
  316. self.vocab_size,
  317. config.hidden_size,
  318. org_num_embeddings=config.vocab_size,
  319. )
  320. self.layers = nn.ModuleList([
  321. MiniCPMDecoderLayer(config, cache_config, quant_config)
  322. for _ in range(config.num_hidden_layers)
  323. ])
  324. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  325. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  326. embedding = self.embed_tokens(input_ids)
  327. return embedding * self.config.scale_emb
  328. def forward(
  329. self,
  330. input_ids: torch.Tensor,
  331. positions: torch.Tensor,
  332. kv_caches: List[torch.Tensor],
  333. attn_metadata: AttentionMetadata,
  334. inputs_embeds: Optional[torch.Tensor] = None,
  335. ) -> torch.Tensor:
  336. if inputs_embeds is not None:
  337. hidden_states = inputs_embeds
  338. else:
  339. hidden_states = self.get_input_embeddings(input_ids)
  340. residual = None
  341. for i in range(len(self.layers)):
  342. layer = self.layers[i]
  343. hidden_states, residual = layer(
  344. positions,
  345. hidden_states,
  346. kv_caches[i],
  347. attn_metadata,
  348. residual,
  349. )
  350. hidden_states = self.norm(hidden_states)
  351. return hidden_states
  352. class MiniCPMForCausalLM(nn.Module):
  353. packed_modules_mapping = {
  354. "qkv_proj": [
  355. "q_proj",
  356. "k_proj",
  357. "v_proj",
  358. ],
  359. "gate_up_proj": [
  360. "gate_proj",
  361. "up_proj",
  362. ],
  363. }
  364. # LoRA specific attributes
  365. supported_lora_modules = [
  366. "qkv_proj",
  367. "o_proj",
  368. "gate_up_proj",
  369. "down_proj",
  370. "embed_tokens",
  371. "lm_head",
  372. ]
  373. embedding_modules = {
  374. "embed_tokens": "input_embeddings",
  375. "lm_head": "output_embeddings",
  376. }
  377. embedding_padding_modules = ["lm_head"]
  378. def __init__(
  379. self,
  380. config,
  381. cache_config: Optional[CacheConfig] = None,
  382. quant_config: Optional[QuantizationConfig] = None,
  383. lora_config: Optional[LoRAConfig] = None,
  384. ) -> None:
  385. super().__init__()
  386. self.config = config
  387. self.num_experts = getattr(self.config, "num_experts", 0)
  388. self.quant_config = quant_config
  389. self.model = MiniCPMModel(config,
  390. cache_config,
  391. quant_config,
  392. lora_config=lora_config)
  393. unpadded_vocab_size = config.vocab_size
  394. if lora_config:
  395. unpadded_vocab_size += lora_config.lora_extra_vocab_size
  396. if not self.config.tie_word_embeddings:
  397. self.lm_head = ParallelLMHead(
  398. unpadded_vocab_size,
  399. config.hidden_size,
  400. org_num_embeddings=config.vocab_size,
  401. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  402. # We need bigger padding if using lora for kernel
  403. # compatibility
  404. if not lora_config else lora_config.lora_vocab_padding_size,
  405. )
  406. self.scale_width = self.config.hidden_size / self.config.dim_model_base
  407. self.logits_processor = LogitsProcessor(unpadded_vocab_size,
  408. config.vocab_size)
  409. self.sampler = Sampler()
  410. def forward(
  411. self,
  412. input_ids: torch.Tensor,
  413. positions: torch.Tensor,
  414. kv_caches: List[torch.Tensor],
  415. attn_metadata: AttentionMetadata,
  416. ) -> torch.Tensor:
  417. hidden_states = self.model(input_ids, positions, kv_caches,
  418. attn_metadata)
  419. return hidden_states
  420. def compute_logits(self, hidden_states: torch.Tensor,
  421. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  422. hidden_states = hidden_states / self.scale_width
  423. if self.config.tie_word_embeddings:
  424. lm_head_weight = self.model.embed_tokens.weight
  425. else:
  426. lm_head_weight = self.lm_head.weight
  427. logits = self.logits_processor(lm_head_weight, hidden_states,
  428. sampling_metadata)
  429. return logits
  430. def sample(
  431. self,
  432. logits: torch.Tensor,
  433. sampling_metadata: SamplingMetadata,
  434. ) -> Optional[SamplerOutput]:
  435. next_tokens = self.sampler(logits, sampling_metadata)
  436. return next_tokens
  437. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  438. stacked_params_mapping = [
  439. # (param_name, shard_name, shard_id)
  440. ("qkv_proj", "q_proj", "q"),
  441. ("qkv_proj", "k_proj", "k"),
  442. ("qkv_proj", "v_proj", "v"),
  443. ("gate_up_proj", "gate_proj", 0),
  444. ("gate_up_proj", "up_proj", 1),
  445. ]
  446. expert_params_mapping = [
  447. # (param_name, weight_name, expert_id)
  448. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  449. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  450. for expert_id in range(self.num_experts)
  451. for weight_name in ["w1", "w2", "w3"]
  452. ]
  453. params_dict = dict(self.named_parameters())
  454. for name, loaded_weight in weights:
  455. if "rotary_emb.inv_freq" in name:
  456. continue
  457. if ("rotary_emb.cos_cached" in name
  458. or "rotary_emb.sin_cached" in name):
  459. # Models trained using ColossalAI may include these tensors in
  460. # the checkpoint. Skip them.
  461. continue
  462. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  463. if weight_name not in name:
  464. continue
  465. name = name.replace(weight_name, param_name)
  466. # Skip loading extra bias for GPTQ models.
  467. if name.endswith(".bias") and name not in params_dict:
  468. continue
  469. param = params_dict[name]
  470. weight_loader = param.weight_loader
  471. weight_loader(param, loaded_weight, shard_id)
  472. break
  473. else:
  474. for param_name, weight_name, expert_id in expert_params_mapping:
  475. if weight_name not in name:
  476. continue
  477. name = name.replace(weight_name, param_name)
  478. param = params_dict[name]
  479. weight_loader = param.weight_loader
  480. weight_loader(param,
  481. loaded_weight,
  482. weight_name,
  483. expert_id=expert_id)
  484. break
  485. else:
  486. # Skip loading extra bias for GPTQ models.
  487. if name.endswith(".bias") and name not in params_dict:
  488. continue
  489. param = params_dict[name]
  490. weight_loader = getattr(param, "weight_loader",
  491. default_weight_loader)
  492. weight_loader(param, loaded_weight)