minicpm.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only MiniCPM model compatible with HuggingFace weights."""
  25. import math
  26. from typing import Any, Dict, Iterable, List, Optional, Tuple
  27. import torch
  28. from torch import nn
  29. from transformers import PretrainedConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import CacheConfig, LoRAConfig
  32. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  33. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size,
  35. tensor_model_parallel_all_reduce)
  36. from aphrodite.modeling.layers.activation import SiluAndMul
  37. from aphrodite.modeling.layers.fused_moe import fused_moe
  38. from aphrodite.modeling.layers.layernorm import RMSNorm
  39. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  40. QKVParallelLinear,
  41. ReplicatedLinear,
  42. RowParallelLinear)
  43. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  44. from aphrodite.modeling.layers.rotary_embedding import get_rope
  45. from aphrodite.modeling.layers.sampler import Sampler
  46. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  47. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  48. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  49. from aphrodite.modeling.models.interfaces import SupportsLoRA
  50. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  51. from aphrodite.modeling.utils import set_weight_attrs
  52. from aphrodite.quantization.base_config import QuantizationConfig
  53. class MiniCPMMoE(nn.Module):
  54. """A tensor-parallel MoE implementation that shards each expert
  55. across all ranks.
  56. Each expert's weights are sharded across all ranks and a fused MoE
  57. kernel is used for the forward pass, and finally we reduce the outputs
  58. across ranks.
  59. """
  60. def __init__(
  61. self,
  62. num_experts: int,
  63. top_k: int,
  64. hidden_size: int,
  65. intermediate_size: int,
  66. params_dtype: Optional[torch.dtype] = None,
  67. tp_size: Optional[int] = None,
  68. ):
  69. super().__init__()
  70. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  71. self.num_total_experts = num_experts
  72. self.top_k = top_k
  73. self.hidden_size = hidden_size
  74. self.intermediate_size = intermediate_size // self.tp_size
  75. if params_dtype is None:
  76. params_dtype = torch.get_default_dtype()
  77. self.params_dtype = params_dtype
  78. self.gate = ReplicatedLinear(self.hidden_size,
  79. self.num_total_experts,
  80. bias=False,
  81. params_dtype=self.params_dtype,
  82. quant_config=None)
  83. self.ws = nn.Parameter(
  84. torch.empty(self.num_total_experts,
  85. 2 * self.intermediate_size,
  86. self.hidden_size,
  87. device="cuda",
  88. dtype=self.params_dtype))
  89. self.w2s = nn.Parameter(
  90. torch.empty(self.num_total_experts,
  91. self.hidden_size,
  92. self.intermediate_size,
  93. device="cuda",
  94. dtype=self.params_dtype))
  95. set_weight_attrs(self.ws, {
  96. "weight_loader": self.weight_loader,
  97. })
  98. set_weight_attrs(self.w2s, {
  99. "weight_loader": self.weight_loader,
  100. })
  101. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  102. weight_name: str, expert_id: int):
  103. tp_rank = get_tensor_model_parallel_rank()
  104. param_data = param.data
  105. shard_size = self.intermediate_size
  106. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  107. if weight_name.endswith("w1.weight"):
  108. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  109. if weight_name.endswith("w3.weight"):
  110. param_data[expert_id,
  111. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  112. if weight_name.endswith("w2.weight"):
  113. param_data[expert_id, :, :] = loaded_weight[:, shard]
  114. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  115. num_tokens, hidden_size = hidden_states.shape
  116. hidden_states = hidden_states.view(-1, self.hidden_size)
  117. # router_logits: (num_tokens, n_experts)
  118. router_logits, _ = self.gate(hidden_states)
  119. final_hidden_states = fused_moe(hidden_states,
  120. self.ws,
  121. self.w2s,
  122. router_logits,
  123. self.top_k,
  124. renormalize=True,
  125. inplace=True)
  126. if self.tp_size > 1:
  127. final_hidden_states = tensor_model_parallel_all_reduce(
  128. final_hidden_states)
  129. return final_hidden_states.view(num_tokens, hidden_size)
  130. class MiniCPMMLP(nn.Module):
  131. def __init__(
  132. self,
  133. hidden_size: int,
  134. intermediate_size: int,
  135. hidden_act: str,
  136. quant_config: Optional[QuantizationConfig] = None,
  137. ) -> None:
  138. super().__init__()
  139. self.gate_up_proj = MergedColumnParallelLinear(
  140. hidden_size, [intermediate_size] * 2,
  141. bias=False,
  142. quant_config=quant_config)
  143. self.down_proj = RowParallelLinear(intermediate_size,
  144. hidden_size,
  145. bias=False,
  146. quant_config=quant_config)
  147. if hidden_act != "silu":
  148. raise ValueError(f"Unsupported activation: {hidden_act}. "
  149. "Only silu is supported for now.")
  150. self.act_fn = SiluAndMul()
  151. def forward(self, x):
  152. gate_up, _ = self.gate_up_proj(x)
  153. x = self.act_fn(gate_up)
  154. x, _ = self.down_proj(x)
  155. return x
  156. class MiniCPMAttention(nn.Module):
  157. def __init__(
  158. self,
  159. hidden_size: int,
  160. num_heads: int,
  161. num_kv_heads: int,
  162. rope_theta: float = 10000,
  163. rope_scaling: Optional[Dict[str, Any]] = None,
  164. max_position_embeddings: int = 8192,
  165. cache_config: Optional[CacheConfig] = None,
  166. quant_config: Optional[QuantizationConfig] = None,
  167. ) -> None:
  168. super().__init__()
  169. self.hidden_size = hidden_size
  170. tp_size = get_tensor_model_parallel_world_size()
  171. self.total_num_heads = num_heads
  172. assert self.total_num_heads % tp_size == 0
  173. self.num_heads = self.total_num_heads // tp_size
  174. self.total_num_kv_heads = num_kv_heads
  175. if self.total_num_kv_heads >= tp_size:
  176. # Number of KV heads is greater than TP size, so we partition
  177. # the KV heads across multiple tensor parallel GPUs.
  178. assert self.total_num_kv_heads % tp_size == 0
  179. else:
  180. # Number of KV heads is less than TP size, so we replicate
  181. # the KV heads across multiple tensor parallel GPUs.
  182. assert tp_size % self.total_num_kv_heads == 0
  183. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  184. self.head_dim = hidden_size // self.total_num_heads
  185. self.q_size = self.num_heads * self.head_dim
  186. self.kv_size = self.num_kv_heads * self.head_dim
  187. self.scaling = self.head_dim**-0.5
  188. self.rope_theta = rope_theta
  189. self.max_position_embeddings = max_position_embeddings
  190. self.qkv_proj = QKVParallelLinear(
  191. hidden_size,
  192. self.head_dim,
  193. self.total_num_heads,
  194. self.total_num_kv_heads,
  195. bias=False,
  196. quant_config=quant_config,
  197. )
  198. self.o_proj = RowParallelLinear(
  199. self.total_num_heads * self.head_dim,
  200. hidden_size,
  201. bias=False,
  202. quant_config=quant_config,
  203. )
  204. self.rotary_emb = get_rope(
  205. self.head_dim,
  206. rotary_dim=self.head_dim,
  207. max_position=max_position_embeddings,
  208. base=rope_theta,
  209. rope_scaling=rope_scaling,
  210. )
  211. # set rope as fp32 instead of bf16
  212. self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
  213. )
  214. self.attn = Attention(self.num_heads,
  215. self.head_dim,
  216. self.scaling,
  217. num_kv_heads=self.num_kv_heads,
  218. cache_config=cache_config,
  219. quant_config=quant_config)
  220. def forward(
  221. self,
  222. positions: torch.Tensor,
  223. hidden_states: torch.Tensor,
  224. kv_cache: torch.Tensor,
  225. attn_metadata: AttentionMetadata,
  226. ) -> torch.Tensor:
  227. qkv, _ = self.qkv_proj(hidden_states)
  228. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  229. orig_dtype = q.dtype
  230. q, k = q.float(), k.float()
  231. q, k = self.rotary_emb(positions, q, k)
  232. q, k = q.to(orig_dtype), k.to(orig_dtype)
  233. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  234. output, _ = self.o_proj(attn_output)
  235. return output
  236. class MiniCPMDecoderLayer(nn.Module):
  237. def __init__(
  238. self,
  239. config,
  240. cache_config: Optional[CacheConfig] = None,
  241. quant_config: Optional[QuantizationConfig] = None,
  242. ) -> None:
  243. super().__init__()
  244. self.config = config
  245. self.hidden_size = config.hidden_size
  246. rope_theta = getattr(config, "rope_theta", 10000)
  247. rope_scaling = getattr(config, "rope_scaling", None)
  248. max_position_embeddings = getattr(config, "max_position_embeddings",
  249. 8192)
  250. self.self_attn = MiniCPMAttention(
  251. hidden_size=self.hidden_size,
  252. num_heads=config.num_attention_heads,
  253. num_kv_heads=config.num_key_value_heads,
  254. rope_theta=rope_theta,
  255. rope_scaling=rope_scaling,
  256. max_position_embeddings=max_position_embeddings,
  257. cache_config=cache_config,
  258. quant_config=quant_config,
  259. )
  260. self.num_experts = getattr(self.config, "num_experts", 0)
  261. if self.num_experts == 0:
  262. self.mlp = MiniCPMMLP(
  263. hidden_size=self.hidden_size,
  264. intermediate_size=config.intermediate_size,
  265. hidden_act=config.hidden_act,
  266. quant_config=quant_config,
  267. )
  268. else:
  269. self.mlp = MiniCPMMoE(num_experts=config.num_experts,
  270. top_k=config.num_experts_per_tok,
  271. hidden_size=config.hidden_size,
  272. intermediate_size=config.intermediate_size)
  273. self.input_layernorm = RMSNorm(config.hidden_size,
  274. eps=config.rms_norm_eps)
  275. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  276. eps=config.rms_norm_eps)
  277. def forward(
  278. self,
  279. positions: torch.Tensor,
  280. hidden_states: torch.Tensor,
  281. kv_cache: torch.Tensor,
  282. attn_metadata: AttentionMetadata,
  283. residual: Optional[torch.Tensor],
  284. ) -> Tuple[torch.Tensor, torch.Tensor]:
  285. # Self Attention
  286. residual = hidden_states
  287. hidden_states = self.input_layernorm(hidden_states)
  288. hidden_states = self.self_attn(
  289. positions=positions,
  290. hidden_states=hidden_states,
  291. kv_cache=kv_cache,
  292. attn_metadata=attn_metadata,
  293. )
  294. hidden_states = residual + hidden_states * \
  295. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  296. # Fully Connected
  297. residual = hidden_states
  298. hidden_states = self.post_attention_layernorm(hidden_states)
  299. hidden_states = self.mlp(hidden_states)
  300. hidden_states = residual + hidden_states * \
  301. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  302. return hidden_states, None
  303. class MiniCPMModel(nn.Module):
  304. def __init__(
  305. self,
  306. config,
  307. cache_config: Optional[CacheConfig] = None,
  308. quant_config: Optional[QuantizationConfig] = None,
  309. lora_config: Optional[LoRAConfig] = None,
  310. ) -> None:
  311. super().__init__()
  312. self.config = config
  313. self.padding_idx = config.pad_token_id
  314. lora_vocab = (lora_config.lora_extra_vocab_size *
  315. (lora_config.max_loras or 1)) if lora_config else 0
  316. self.vocab_size = config.vocab_size + lora_vocab
  317. self.org_vocab_size = config.vocab_size
  318. self.embed_tokens = VocabParallelEmbedding(
  319. self.vocab_size,
  320. config.hidden_size,
  321. org_num_embeddings=config.vocab_size,
  322. )
  323. self.layers = nn.ModuleList([
  324. MiniCPMDecoderLayer(config, cache_config, quant_config)
  325. for _ in range(config.num_hidden_layers)
  326. ])
  327. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  328. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  329. embedding = self.embed_tokens(input_ids)
  330. return embedding * self.config.scale_emb
  331. def forward(
  332. self,
  333. input_ids: torch.Tensor,
  334. positions: torch.Tensor,
  335. kv_caches: List[torch.Tensor],
  336. attn_metadata: AttentionMetadata,
  337. intermediate_tensors: Optional[IntermediateTensors] = None,
  338. inputs_embeds: Optional[torch.Tensor] = None,
  339. ) -> torch.Tensor:
  340. if inputs_embeds is not None:
  341. hidden_states = inputs_embeds
  342. else:
  343. hidden_states = self.get_input_embeddings(input_ids)
  344. residual = None
  345. for i in range(len(self.layers)):
  346. layer = self.layers[i]
  347. hidden_states, residual = layer(
  348. positions,
  349. hidden_states,
  350. kv_caches[i],
  351. attn_metadata,
  352. residual,
  353. )
  354. hidden_states = self.norm(hidden_states)
  355. return hidden_states
  356. class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
  357. packed_modules_mapping = {
  358. "qkv_proj": [
  359. "q_proj",
  360. "k_proj",
  361. "v_proj",
  362. ],
  363. "gate_up_proj": [
  364. "gate_proj",
  365. "up_proj",
  366. ],
  367. }
  368. # LoRA specific attributes
  369. supported_lora_modules = [
  370. "qkv_proj",
  371. "o_proj",
  372. "gate_up_proj",
  373. "down_proj",
  374. "embed_tokens",
  375. "lm_head",
  376. ]
  377. embedding_modules = {
  378. "embed_tokens": "input_embeddings",
  379. "lm_head": "output_embeddings",
  380. }
  381. embedding_padding_modules = ["lm_head"]
  382. def __init__(
  383. self,
  384. config: PretrainedConfig,
  385. cache_config: Optional[CacheConfig] = None,
  386. quant_config: Optional[QuantizationConfig] = None,
  387. lora_config: Optional[LoRAConfig] = None,
  388. ) -> None:
  389. super().__init__()
  390. self.config = config
  391. self.lora_config = lora_config
  392. self.num_experts = getattr(self.config, "num_experts", 0)
  393. self.quant_config = quant_config
  394. self.model = MiniCPMModel(config,
  395. cache_config,
  396. quant_config,
  397. lora_config=lora_config)
  398. unpadded_vocab_size = config.vocab_size
  399. if lora_config:
  400. unpadded_vocab_size += lora_config.lora_extra_vocab_size
  401. if not self.config.tie_word_embeddings:
  402. self.lm_head = ParallelLMHead(
  403. unpadded_vocab_size,
  404. config.hidden_size,
  405. org_num_embeddings=config.vocab_size,
  406. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  407. # We need bigger padding if using lora for kernel
  408. # compatibility
  409. if not lora_config else lora_config.lora_vocab_padding_size,
  410. quant_config=quant_config,
  411. )
  412. self.scale_width = self.config.hidden_size / self.config.dim_model_base
  413. self.logits_processor = LogitsProcessor(unpadded_vocab_size,
  414. config.vocab_size)
  415. self.sampler = Sampler()
  416. def forward(
  417. self,
  418. input_ids: torch.Tensor,
  419. positions: torch.Tensor,
  420. kv_caches: List[torch.Tensor],
  421. attn_metadata: AttentionMetadata,
  422. intermediate_tensors: Optional[IntermediateTensors] = None,
  423. ) -> torch.Tensor:
  424. hidden_states = self.model(input_ids, positions, kv_caches,
  425. attn_metadata, intermediate_tensors)
  426. return hidden_states
  427. def compute_logits(
  428. self,
  429. hidden_states: torch.Tensor,
  430. sampling_metadata: SamplingMetadata,
  431. ) -> Optional[torch.Tensor]:
  432. hidden_states = hidden_states / self.scale_width
  433. if self.config.tie_word_embeddings:
  434. lm_head = self.model.embed_tokens
  435. else:
  436. lm_head = self.lm_head
  437. logits = self.logits_processor(lm_head, hidden_states,
  438. sampling_metadata)
  439. return logits
  440. def sample(
  441. self,
  442. logits: torch.Tensor,
  443. sampling_metadata: SamplingMetadata,
  444. ) -> Optional[SamplerOutput]:
  445. next_tokens = self.sampler(logits, sampling_metadata)
  446. return next_tokens
  447. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  448. stacked_params_mapping = [
  449. # (param_name, shard_name, shard_id)
  450. ("qkv_proj", "q_proj", "q"),
  451. ("qkv_proj", "k_proj", "k"),
  452. ("qkv_proj", "v_proj", "v"),
  453. ("gate_up_proj", "gate_proj", 0),
  454. ("gate_up_proj", "up_proj", 1),
  455. ]
  456. expert_params_mapping = [
  457. # (param_name, weight_name, expert_id)
  458. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  459. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  460. for expert_id in range(self.num_experts)
  461. for weight_name in ["w1", "w2", "w3"]
  462. ]
  463. params_dict = dict(self.named_parameters())
  464. for name, loaded_weight in weights:
  465. if "rotary_emb.inv_freq" in name:
  466. continue
  467. if ("rotary_emb.cos_cached" in name
  468. or "rotary_emb.sin_cached" in name):
  469. # Models trained using ColossalAI may include these tensors in
  470. # the checkpoint. Skip them.
  471. continue
  472. # With tie_word_embeddings, we can skip lm_head.weight
  473. # The weight might appear unnecessarily in the files if the model is
  474. # processed with quantization, LoRA, fine-tuning, etc.
  475. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  476. continue
  477. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  478. if weight_name not in name:
  479. continue
  480. name = name.replace(weight_name, param_name)
  481. # Skip loading extra bias for GPTQ models.
  482. if name.endswith(".bias") and name not in params_dict:
  483. continue
  484. param = params_dict[name]
  485. weight_loader = param.weight_loader
  486. weight_loader(param, loaded_weight, shard_id)
  487. break
  488. else:
  489. for param_name, weight_name, expert_id in expert_params_mapping:
  490. if weight_name not in name:
  491. continue
  492. name = name.replace(weight_name, param_name)
  493. param = params_dict[name]
  494. weight_loader = param.weight_loader
  495. weight_loader(param,
  496. loaded_weight,
  497. weight_name,
  498. expert_id=expert_id)
  499. break
  500. else:
  501. # Skip loading extra bias for GPTQ models.
  502. if name.endswith(".bias") and name not in params_dict:
  503. continue
  504. param = params_dict[name]
  505. weight_loader = getattr(param, "weight_loader",
  506. default_weight_loader)
  507. weight_loader(param, loaded_weight)