minicpm.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only MiniCPM model compatible with HuggingFace weights."""
  24. import math
  25. from typing import Any, Dict, Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import PretrainedConfig
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig, LoRAConfig
  31. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  32. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  33. get_tensor_model_parallel_world_size,
  34. tensor_model_parallel_all_reduce)
  35. from aphrodite.modeling.layers.activation import SiluAndMul
  36. from aphrodite.modeling.layers.fused_moe import fused_moe
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  39. QKVParallelLinear,
  40. ReplicatedLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.modeling.layers.rotary_embedding import get_rope
  44. from aphrodite.modeling.layers.sampler import Sampler
  45. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  46. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  47. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  48. from aphrodite.modeling.models.interfaces import SupportsLoRA
  49. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  50. from aphrodite.modeling.utils import set_weight_attrs
  51. from aphrodite.quantization.base_config import QuantizationConfig
  52. class MiniCPMMoE(nn.Module):
  53. """A tensor-parallel MoE implementation that shards each expert
  54. across all ranks.
  55. Each expert's weights are sharded across all ranks and a fused MoE
  56. kernel is used for the forward pass, and finally we reduce the outputs
  57. across ranks.
  58. """
  59. def __init__(
  60. self,
  61. num_experts: int,
  62. top_k: int,
  63. hidden_size: int,
  64. intermediate_size: int,
  65. params_dtype: Optional[torch.dtype] = None,
  66. tp_size: Optional[int] = None,
  67. ):
  68. super().__init__()
  69. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  70. self.num_total_experts = num_experts
  71. self.top_k = top_k
  72. self.hidden_size = hidden_size
  73. self.intermediate_size = intermediate_size // self.tp_size
  74. if params_dtype is None:
  75. params_dtype = torch.get_default_dtype()
  76. self.params_dtype = params_dtype
  77. self.gate = ReplicatedLinear(self.hidden_size,
  78. self.num_total_experts,
  79. bias=False,
  80. params_dtype=self.params_dtype,
  81. quant_config=None)
  82. self.ws = nn.Parameter(
  83. torch.empty(self.num_total_experts,
  84. 2 * self.intermediate_size,
  85. self.hidden_size,
  86. device="cuda",
  87. dtype=self.params_dtype))
  88. self.w2s = nn.Parameter(
  89. torch.empty(self.num_total_experts,
  90. self.hidden_size,
  91. self.intermediate_size,
  92. device="cuda",
  93. dtype=self.params_dtype))
  94. set_weight_attrs(self.ws, {
  95. "weight_loader": self.weight_loader,
  96. })
  97. set_weight_attrs(self.w2s, {
  98. "weight_loader": self.weight_loader,
  99. })
  100. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  101. weight_name: str, expert_id: int):
  102. tp_rank = get_tensor_model_parallel_rank()
  103. param_data = param.data
  104. shard_size = self.intermediate_size
  105. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  106. if weight_name.endswith("w1.weight"):
  107. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  108. if weight_name.endswith("w3.weight"):
  109. param_data[expert_id,
  110. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  111. if weight_name.endswith("w2.weight"):
  112. param_data[expert_id, :, :] = loaded_weight[:, shard]
  113. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  114. num_tokens, hidden_size = hidden_states.shape
  115. hidden_states = hidden_states.view(-1, self.hidden_size)
  116. # router_logits: (num_tokens, n_experts)
  117. router_logits, _ = self.gate(hidden_states)
  118. final_hidden_states = fused_moe(hidden_states,
  119. self.ws,
  120. self.w2s,
  121. router_logits,
  122. self.top_k,
  123. renormalize=True,
  124. inplace=True)
  125. if self.tp_size > 1:
  126. final_hidden_states = tensor_model_parallel_all_reduce(
  127. final_hidden_states)
  128. return final_hidden_states.view(num_tokens, hidden_size)
  129. class MiniCPMMLP(nn.Module):
  130. def __init__(
  131. self,
  132. hidden_size: int,
  133. intermediate_size: int,
  134. hidden_act: str,
  135. quant_config: Optional[QuantizationConfig] = None,
  136. ) -> None:
  137. super().__init__()
  138. self.gate_up_proj = MergedColumnParallelLinear(
  139. hidden_size, [intermediate_size] * 2,
  140. bias=False,
  141. quant_config=quant_config)
  142. self.down_proj = RowParallelLinear(intermediate_size,
  143. hidden_size,
  144. bias=False,
  145. quant_config=quant_config)
  146. if hidden_act != "silu":
  147. raise ValueError(f"Unsupported activation: {hidden_act}. "
  148. "Only silu is supported for now.")
  149. self.act_fn = SiluAndMul()
  150. def forward(self, x):
  151. gate_up, _ = self.gate_up_proj(x)
  152. x = self.act_fn(gate_up)
  153. x, _ = self.down_proj(x)
  154. return x
  155. class MiniCPMAttention(nn.Module):
  156. def __init__(
  157. self,
  158. hidden_size: int,
  159. num_heads: int,
  160. num_kv_heads: int,
  161. rope_theta: float = 10000,
  162. rope_scaling: Optional[Dict[str, Any]] = None,
  163. max_position_embeddings: int = 8192,
  164. cache_config: Optional[CacheConfig] = None,
  165. quant_config: Optional[QuantizationConfig] = None,
  166. ) -> None:
  167. super().__init__()
  168. self.hidden_size = hidden_size
  169. tp_size = get_tensor_model_parallel_world_size()
  170. self.total_num_heads = num_heads
  171. assert self.total_num_heads % tp_size == 0
  172. self.num_heads = self.total_num_heads // tp_size
  173. self.total_num_kv_heads = num_kv_heads
  174. if self.total_num_kv_heads >= tp_size:
  175. # Number of KV heads is greater than TP size, so we partition
  176. # the KV heads across multiple tensor parallel GPUs.
  177. assert self.total_num_kv_heads % tp_size == 0
  178. else:
  179. # Number of KV heads is less than TP size, so we replicate
  180. # the KV heads across multiple tensor parallel GPUs.
  181. assert tp_size % self.total_num_kv_heads == 0
  182. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  183. self.head_dim = hidden_size // self.total_num_heads
  184. self.q_size = self.num_heads * self.head_dim
  185. self.kv_size = self.num_kv_heads * self.head_dim
  186. self.scaling = self.head_dim**-0.5
  187. self.rope_theta = rope_theta
  188. self.max_position_embeddings = max_position_embeddings
  189. self.qkv_proj = QKVParallelLinear(
  190. hidden_size,
  191. self.head_dim,
  192. self.total_num_heads,
  193. self.total_num_kv_heads,
  194. bias=False,
  195. quant_config=quant_config,
  196. )
  197. self.o_proj = RowParallelLinear(
  198. self.total_num_heads * self.head_dim,
  199. hidden_size,
  200. bias=False,
  201. quant_config=quant_config,
  202. )
  203. self.rotary_emb = get_rope(
  204. self.head_dim,
  205. rotary_dim=self.head_dim,
  206. max_position=max_position_embeddings,
  207. base=rope_theta,
  208. rope_scaling=rope_scaling,
  209. )
  210. # set rope as fp32 instead of bf16
  211. self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
  212. )
  213. self.attn = Attention(self.num_heads,
  214. self.head_dim,
  215. self.scaling,
  216. num_kv_heads=self.num_kv_heads,
  217. cache_config=cache_config,
  218. quant_config=quant_config)
  219. def forward(
  220. self,
  221. positions: torch.Tensor,
  222. hidden_states: torch.Tensor,
  223. kv_cache: torch.Tensor,
  224. attn_metadata: AttentionMetadata,
  225. ) -> torch.Tensor:
  226. qkv, _ = self.qkv_proj(hidden_states)
  227. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  228. orig_dtype = q.dtype
  229. q, k = q.float(), k.float()
  230. q, k = self.rotary_emb(positions, q, k)
  231. q, k = q.to(orig_dtype), k.to(orig_dtype)
  232. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  233. output, _ = self.o_proj(attn_output)
  234. return output
  235. class MiniCPMDecoderLayer(nn.Module):
  236. def __init__(
  237. self,
  238. config,
  239. cache_config: Optional[CacheConfig] = None,
  240. quant_config: Optional[QuantizationConfig] = None,
  241. ) -> None:
  242. super().__init__()
  243. self.config = config
  244. self.hidden_size = config.hidden_size
  245. rope_theta = getattr(config, "rope_theta", 10000)
  246. rope_scaling = getattr(config, "rope_scaling", None)
  247. max_position_embeddings = getattr(config, "max_position_embeddings",
  248. 8192)
  249. self.self_attn = MiniCPMAttention(
  250. hidden_size=self.hidden_size,
  251. num_heads=config.num_attention_heads,
  252. num_kv_heads=config.num_key_value_heads,
  253. rope_theta=rope_theta,
  254. rope_scaling=rope_scaling,
  255. max_position_embeddings=max_position_embeddings,
  256. cache_config=cache_config,
  257. quant_config=quant_config,
  258. )
  259. self.num_experts = getattr(self.config, "num_experts", 0)
  260. if self.num_experts == 0:
  261. self.mlp = MiniCPMMLP(
  262. hidden_size=self.hidden_size,
  263. intermediate_size=config.intermediate_size,
  264. hidden_act=config.hidden_act,
  265. quant_config=quant_config,
  266. )
  267. else:
  268. self.mlp = MiniCPMMoE(num_experts=config.num_experts,
  269. top_k=config.num_experts_per_tok,
  270. hidden_size=config.hidden_size,
  271. intermediate_size=config.intermediate_size)
  272. self.input_layernorm = RMSNorm(config.hidden_size,
  273. eps=config.rms_norm_eps)
  274. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  275. eps=config.rms_norm_eps)
  276. def forward(
  277. self,
  278. positions: torch.Tensor,
  279. hidden_states: torch.Tensor,
  280. kv_cache: torch.Tensor,
  281. attn_metadata: AttentionMetadata,
  282. residual: Optional[torch.Tensor],
  283. ) -> Tuple[torch.Tensor, torch.Tensor]:
  284. # Self Attention
  285. residual = hidden_states
  286. hidden_states = self.input_layernorm(hidden_states)
  287. hidden_states = self.self_attn(
  288. positions=positions,
  289. hidden_states=hidden_states,
  290. kv_cache=kv_cache,
  291. attn_metadata=attn_metadata,
  292. )
  293. hidden_states = residual + hidden_states * \
  294. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  295. # Fully Connected
  296. residual = hidden_states
  297. hidden_states = self.post_attention_layernorm(hidden_states)
  298. hidden_states = self.mlp(hidden_states)
  299. hidden_states = residual + hidden_states * \
  300. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  301. return hidden_states, None
  302. class MiniCPMModel(nn.Module):
  303. def __init__(
  304. self,
  305. config,
  306. cache_config: Optional[CacheConfig] = None,
  307. quant_config: Optional[QuantizationConfig] = None,
  308. lora_config: Optional[LoRAConfig] = None,
  309. ) -> None:
  310. super().__init__()
  311. self.config = config
  312. self.padding_idx = config.pad_token_id
  313. lora_vocab = (lora_config.lora_extra_vocab_size *
  314. (lora_config.max_loras or 1)) if lora_config else 0
  315. self.vocab_size = config.vocab_size + lora_vocab
  316. self.org_vocab_size = config.vocab_size
  317. self.embed_tokens = VocabParallelEmbedding(
  318. self.vocab_size,
  319. config.hidden_size,
  320. org_num_embeddings=config.vocab_size,
  321. )
  322. self.layers = nn.ModuleList([
  323. MiniCPMDecoderLayer(config, cache_config, quant_config)
  324. for _ in range(config.num_hidden_layers)
  325. ])
  326. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  327. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  328. embedding = self.embed_tokens(input_ids)
  329. return embedding * self.config.scale_emb
  330. def forward(
  331. self,
  332. input_ids: torch.Tensor,
  333. positions: torch.Tensor,
  334. kv_caches: List[torch.Tensor],
  335. attn_metadata: AttentionMetadata,
  336. inputs_embeds: Optional[torch.Tensor] = None,
  337. ) -> torch.Tensor:
  338. if inputs_embeds is not None:
  339. hidden_states = inputs_embeds
  340. else:
  341. hidden_states = self.get_input_embeddings(input_ids)
  342. residual = None
  343. for i in range(len(self.layers)):
  344. layer = self.layers[i]
  345. hidden_states, residual = layer(
  346. positions,
  347. hidden_states,
  348. kv_caches[i],
  349. attn_metadata,
  350. residual,
  351. )
  352. hidden_states = self.norm(hidden_states)
  353. return hidden_states
  354. class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
  355. packed_modules_mapping = {
  356. "qkv_proj": [
  357. "q_proj",
  358. "k_proj",
  359. "v_proj",
  360. ],
  361. "gate_up_proj": [
  362. "gate_proj",
  363. "up_proj",
  364. ],
  365. }
  366. # LoRA specific attributes
  367. supported_lora_modules = [
  368. "qkv_proj",
  369. "o_proj",
  370. "gate_up_proj",
  371. "down_proj",
  372. "embed_tokens",
  373. "lm_head",
  374. ]
  375. embedding_modules = {
  376. "embed_tokens": "input_embeddings",
  377. "lm_head": "output_embeddings",
  378. }
  379. embedding_padding_modules = ["lm_head"]
  380. def __init__(
  381. self,
  382. config: PretrainedConfig,
  383. cache_config: Optional[CacheConfig] = None,
  384. quant_config: Optional[QuantizationConfig] = None,
  385. lora_config: Optional[LoRAConfig] = None,
  386. ) -> None:
  387. super().__init__()
  388. self.config = config
  389. self.lora_config = lora_config
  390. self.num_experts = getattr(self.config, "num_experts", 0)
  391. self.quant_config = quant_config
  392. self.model = MiniCPMModel(config,
  393. cache_config,
  394. quant_config,
  395. lora_config=lora_config)
  396. unpadded_vocab_size = config.vocab_size
  397. if lora_config:
  398. unpadded_vocab_size += lora_config.lora_extra_vocab_size
  399. if not self.config.tie_word_embeddings:
  400. self.lm_head = ParallelLMHead(
  401. unpadded_vocab_size,
  402. config.hidden_size,
  403. org_num_embeddings=config.vocab_size,
  404. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  405. # We need bigger padding if using lora for kernel
  406. # compatibility
  407. if not lora_config else lora_config.lora_vocab_padding_size,
  408. quant_config=quant_config,
  409. )
  410. self.scale_width = self.config.hidden_size / self.config.dim_model_base
  411. self.logits_processor = LogitsProcessor(unpadded_vocab_size,
  412. config.vocab_size)
  413. self.sampler = Sampler()
  414. def forward(
  415. self,
  416. input_ids: torch.Tensor,
  417. positions: torch.Tensor,
  418. kv_caches: List[torch.Tensor],
  419. attn_metadata: AttentionMetadata,
  420. intermediate_tensors: Optional[IntermediateTensors] = None,
  421. ) -> torch.Tensor:
  422. hidden_states = self.model(input_ids, positions, kv_caches,
  423. attn_metadata)
  424. return hidden_states
  425. def compute_logits(self, hidden_states: torch.Tensor,
  426. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  427. hidden_states = hidden_states / self.scale_width
  428. if self.config.tie_word_embeddings:
  429. lm_head = self.model.embed_tokens
  430. else:
  431. lm_head = self.lm_head
  432. logits = self.logits_processor(lm_head, hidden_states,
  433. sampling_metadata)
  434. return logits
  435. def sample(
  436. self,
  437. logits: torch.Tensor,
  438. sampling_metadata: SamplingMetadata,
  439. ) -> Optional[SamplerOutput]:
  440. next_tokens = self.sampler(logits, sampling_metadata)
  441. return next_tokens
  442. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  443. stacked_params_mapping = [
  444. # (param_name, shard_name, shard_id)
  445. ("qkv_proj", "q_proj", "q"),
  446. ("qkv_proj", "k_proj", "k"),
  447. ("qkv_proj", "v_proj", "v"),
  448. ("gate_up_proj", "gate_proj", 0),
  449. ("gate_up_proj", "up_proj", 1),
  450. ]
  451. expert_params_mapping = [
  452. # (param_name, weight_name, expert_id)
  453. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  454. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  455. for expert_id in range(self.num_experts)
  456. for weight_name in ["w1", "w2", "w3"]
  457. ]
  458. params_dict = dict(self.named_parameters())
  459. for name, loaded_weight in weights:
  460. if "rotary_emb.inv_freq" in name:
  461. continue
  462. if ("rotary_emb.cos_cached" in name
  463. or "rotary_emb.sin_cached" in name):
  464. # Models trained using ColossalAI may include these tensors in
  465. # the checkpoint. Skip them.
  466. continue
  467. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  468. if weight_name not in name:
  469. continue
  470. name = name.replace(weight_name, param_name)
  471. # Skip loading extra bias for GPTQ models.
  472. if name.endswith(".bias") and name not in params_dict:
  473. continue
  474. param = params_dict[name]
  475. weight_loader = param.weight_loader
  476. weight_loader(param, loaded_weight, shard_id)
  477. break
  478. else:
  479. for param_name, weight_name, expert_id in expert_params_mapping:
  480. if weight_name not in name:
  481. continue
  482. name = name.replace(weight_name, param_name)
  483. param = params_dict[name]
  484. weight_loader = param.weight_loader
  485. weight_loader(param,
  486. loaded_weight,
  487. weight_name,
  488. expert_id=expert_id)
  489. break
  490. else:
  491. # Skip loading extra bias for GPTQ models.
  492. if name.endswith(".bias") and name not in params_dict:
  493. continue
  494. param = params_dict[name]
  495. weight_loader = getattr(param, "weight_loader",
  496. default_weight_loader)
  497. weight_loader(param, loaded_weight)