minicpm.py 22 KB


  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only MiniCPM model compatible with HuggingFace weights."""
  25. import math
  26. from typing import Any, Dict, Iterable, List, Optional, Tuple
  27. import torch
  28. from torch import nn
  29. from transformers import PretrainedConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import CacheConfig, LoRAConfig
  32. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  33. from aphrodite.common.utils import progress_bar
  34. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  35. get_tensor_model_parallel_world_size,
  36. tensor_model_parallel_all_reduce)
  37. from aphrodite.modeling.layers.activation import SiluAndMul
  38. from aphrodite.modeling.layers.fused_moe import fused_moe
  39. from aphrodite.modeling.layers.layernorm import RMSNorm
  40. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  41. QKVParallelLinear,
  42. ReplicatedLinear,
  43. RowParallelLinear)
  44. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  45. from aphrodite.modeling.layers.rotary_embedding import get_rope
  46. from aphrodite.modeling.layers.sampler import Sampler
  47. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  48. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  49. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  50. from aphrodite.modeling.models.interfaces import SupportsLoRA
  51. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  52. from aphrodite.modeling.utils import set_weight_attrs
  53. from aphrodite.quantization.base_config import QuantizationConfig
  54. class MiniCPMMoE(nn.Module):
  55. """A tensor-parallel MoE implementation that shards each expert
  56. across all ranks.
  57. Each expert's weights are sharded across all ranks and a fused MoE
  58. kernel is used for the forward pass, and finally we reduce the outputs
  59. across ranks.
  60. """
  61. def __init__(
  62. self,
  63. num_experts: int,
  64. top_k: int,
  65. hidden_size: int,
  66. intermediate_size: int,
  67. params_dtype: Optional[torch.dtype] = None,
  68. tp_size: Optional[int] = None,
  69. ):
  70. super().__init__()
  71. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  72. self.num_total_experts = num_experts
  73. self.top_k = top_k
  74. self.hidden_size = hidden_size
  75. self.intermediate_size = intermediate_size // self.tp_size
  76. if params_dtype is None:
  77. params_dtype = torch.get_default_dtype()
  78. self.params_dtype = params_dtype
  79. self.gate = ReplicatedLinear(self.hidden_size,
  80. self.num_total_experts,
  81. bias=False,
  82. params_dtype=self.params_dtype,
  83. quant_config=None)
  84. self.ws = nn.Parameter(
  85. torch.empty(self.num_total_experts,
  86. 2 * self.intermediate_size,
  87. self.hidden_size,
  88. device="cuda",
  89. dtype=self.params_dtype))
  90. self.w2s = nn.Parameter(
  91. torch.empty(self.num_total_experts,
  92. self.hidden_size,
  93. self.intermediate_size,
  94. device="cuda",
  95. dtype=self.params_dtype))
  96. set_weight_attrs(self.ws, {
  97. "weight_loader": self.weight_loader,
  98. })
  99. set_weight_attrs(self.w2s, {
  100. "weight_loader": self.weight_loader,
  101. })
  102. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  103. weight_name: str, expert_id: int):
  104. tp_rank = get_tensor_model_parallel_rank()
  105. param_data = param.data
  106. shard_size = self.intermediate_size
  107. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  108. if weight_name.endswith("w1.weight"):
  109. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  110. if weight_name.endswith("w3.weight"):
  111. param_data[expert_id,
  112. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  113. if weight_name.endswith("w2.weight"):
  114. param_data[expert_id, :, :] = loaded_weight[:, shard]
  115. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  116. num_tokens, hidden_size = hidden_states.shape
  117. hidden_states = hidden_states.view(-1, self.hidden_size)
  118. # router_logits: (num_tokens, n_experts)
  119. router_logits, _ = self.gate(hidden_states)
  120. final_hidden_states = fused_moe(hidden_states,
  121. self.ws,
  122. self.w2s,
  123. router_logits,
  124. self.top_k,
  125. renormalize=True,
  126. inplace=True)
  127. if self.tp_size > 1:
  128. final_hidden_states = tensor_model_parallel_all_reduce(
  129. final_hidden_states)
  130. return final_hidden_states.view(num_tokens, hidden_size)
  131. class MiniCPMMLP(nn.Module):
  132. def __init__(
  133. self,
  134. hidden_size: int,
  135. intermediate_size: int,
  136. hidden_act: str,
  137. quant_config: Optional[QuantizationConfig] = None,
  138. ) -> None:
  139. super().__init__()
  140. self.gate_up_proj = MergedColumnParallelLinear(
  141. hidden_size, [intermediate_size] * 2,
  142. bias=False,
  143. quant_config=quant_config)
  144. self.down_proj = RowParallelLinear(intermediate_size,
  145. hidden_size,
  146. bias=False,
  147. quant_config=quant_config)
  148. if hidden_act != "silu":
  149. raise ValueError(f"Unsupported activation: {hidden_act}. "
  150. "Only silu is supported for now.")
  151. self.act_fn = SiluAndMul()
  152. def forward(self, x):
  153. gate_up, _ = self.gate_up_proj(x)
  154. x = self.act_fn(gate_up)
  155. x, _ = self.down_proj(x)
  156. return x
  157. class MiniCPMAttention(nn.Module):
  158. def __init__(
  159. self,
  160. hidden_size: int,
  161. num_heads: int,
  162. num_kv_heads: int,
  163. rope_theta: float = 10000,
  164. rope_scaling: Optional[Dict[str, Any]] = None,
  165. max_position_embeddings: int = 8192,
  166. cache_config: Optional[CacheConfig] = None,
  167. quant_config: Optional[QuantizationConfig] = None,
  168. ) -> None:
  169. super().__init__()
  170. self.hidden_size = hidden_size
  171. tp_size = get_tensor_model_parallel_world_size()
  172. self.total_num_heads = num_heads
  173. assert self.total_num_heads % tp_size == 0
  174. self.num_heads = self.total_num_heads // tp_size
  175. self.total_num_kv_heads = num_kv_heads
  176. if self.total_num_kv_heads >= tp_size:
  177. # Number of KV heads is greater than TP size, so we partition
  178. # the KV heads across multiple tensor parallel GPUs.
  179. assert self.total_num_kv_heads % tp_size == 0
  180. else:
  181. # Number of KV heads is less than TP size, so we replicate
  182. # the KV heads across multiple tensor parallel GPUs.
  183. assert tp_size % self.total_num_kv_heads == 0
  184. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  185. self.head_dim = hidden_size // self.total_num_heads
  186. self.q_size = self.num_heads * self.head_dim
  187. self.kv_size = self.num_kv_heads * self.head_dim
  188. self.scaling = self.head_dim**-0.5
  189. self.rope_theta = rope_theta
  190. self.max_position_embeddings = max_position_embeddings
  191. self.qkv_proj = QKVParallelLinear(
  192. hidden_size,
  193. self.head_dim,
  194. self.total_num_heads,
  195. self.total_num_kv_heads,
  196. bias=False,
  197. quant_config=quant_config,
  198. )
  199. self.o_proj = RowParallelLinear(
  200. self.total_num_heads * self.head_dim,
  201. hidden_size,
  202. bias=False,
  203. quant_config=quant_config,
  204. )
  205. self.rotary_emb = get_rope(
  206. self.head_dim,
  207. rotary_dim=self.head_dim,
  208. max_position=max_position_embeddings,
  209. base=rope_theta,
  210. rope_scaling=rope_scaling,
  211. )
  212. # set rope as fp32 instead of bf16
  213. self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
  214. )
  215. self.attn = Attention(self.num_heads,
  216. self.head_dim,
  217. self.scaling,
  218. num_kv_heads=self.num_kv_heads,
  219. cache_config=cache_config,
  220. quant_config=quant_config)
  221. def forward(
  222. self,
  223. positions: torch.Tensor,
  224. hidden_states: torch.Tensor,
  225. kv_cache: torch.Tensor,
  226. attn_metadata: AttentionMetadata,
  227. ) -> torch.Tensor:
  228. qkv, _ = self.qkv_proj(hidden_states)
  229. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  230. orig_dtype = q.dtype
  231. q, k = q.float(), k.float()
  232. q, k = self.rotary_emb(positions, q, k)
  233. q, k = q.to(orig_dtype), k.to(orig_dtype)
  234. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  235. output, _ = self.o_proj(attn_output)
  236. return output
  237. class MiniCPMDecoderLayer(nn.Module):
  238. def __init__(
  239. self,
  240. config,
  241. cache_config: Optional[CacheConfig] = None,
  242. quant_config: Optional[QuantizationConfig] = None,
  243. ) -> None:
  244. super().__init__()
  245. self.config = config
  246. self.hidden_size = config.hidden_size
  247. rope_theta = getattr(config, "rope_theta", 10000)
  248. rope_scaling = getattr(config, "rope_scaling", None)
  249. max_position_embeddings = getattr(config, "max_position_embeddings",
  250. 8192)
  251. self.self_attn = MiniCPMAttention(
  252. hidden_size=self.hidden_size,
  253. num_heads=config.num_attention_heads,
  254. num_kv_heads=config.num_key_value_heads,
  255. rope_theta=rope_theta,
  256. rope_scaling=rope_scaling,
  257. max_position_embeddings=max_position_embeddings,
  258. cache_config=cache_config,
  259. quant_config=quant_config,
  260. )
  261. self.num_experts = getattr(self.config, "num_experts", 0)
  262. if self.num_experts == 0:
  263. self.mlp = MiniCPMMLP(
  264. hidden_size=self.hidden_size,
  265. intermediate_size=config.intermediate_size,
  266. hidden_act=config.hidden_act,
  267. quant_config=quant_config,
  268. )
  269. else:
  270. self.mlp = MiniCPMMoE(num_experts=config.num_experts,
  271. top_k=config.num_experts_per_tok,
  272. hidden_size=config.hidden_size,
  273. intermediate_size=config.intermediate_size)
  274. self.input_layernorm = RMSNorm(config.hidden_size,
  275. eps=config.rms_norm_eps)
  276. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  277. eps=config.rms_norm_eps)
  278. def forward(
  279. self,
  280. positions: torch.Tensor,
  281. hidden_states: torch.Tensor,
  282. kv_cache: torch.Tensor,
  283. attn_metadata: AttentionMetadata,
  284. residual: Optional[torch.Tensor],
  285. ) -> Tuple[torch.Tensor, torch.Tensor]:
  286. # Self Attention
  287. residual = hidden_states
  288. hidden_states = self.input_layernorm(hidden_states)
  289. hidden_states = self.self_attn(
  290. positions=positions,
  291. hidden_states=hidden_states,
  292. kv_cache=kv_cache,
  293. attn_metadata=attn_metadata,
  294. )
  295. hidden_states = residual + hidden_states * \
  296. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  297. # Fully Connected
  298. residual = hidden_states
  299. hidden_states = self.post_attention_layernorm(hidden_states)
  300. hidden_states = self.mlp(hidden_states)
  301. hidden_states = residual + hidden_states * \
  302. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  303. return hidden_states, None
  304. class MiniCPMModel(nn.Module):
  305. def __init__(
  306. self,
  307. config,
  308. cache_config: Optional[CacheConfig] = None,
  309. quant_config: Optional[QuantizationConfig] = None,
  310. lora_config: Optional[LoRAConfig] = None,
  311. ) -> None:
  312. super().__init__()
  313. self.config = config
  314. self.padding_idx = config.pad_token_id
  315. lora_vocab = (lora_config.lora_extra_vocab_size *
  316. (lora_config.max_loras or 1)) if lora_config else 0
  317. self.vocab_size = config.vocab_size + lora_vocab
  318. self.org_vocab_size = config.vocab_size
  319. self.embed_tokens = VocabParallelEmbedding(
  320. self.vocab_size,
  321. config.hidden_size,
  322. org_num_embeddings=config.vocab_size,
  323. )
  324. self.layers = nn.ModuleList([
  325. MiniCPMDecoderLayer(config, cache_config, quant_config)
  326. for _ in range(config.num_hidden_layers)
  327. ])
  328. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  329. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  330. embedding = self.embed_tokens(input_ids)
  331. return embedding * self.config.scale_emb
  332. def forward(
  333. self,
  334. input_ids: torch.Tensor,
  335. positions: torch.Tensor,
  336. kv_caches: List[torch.Tensor],
  337. attn_metadata: AttentionMetadata,
  338. intermediate_tensors: Optional[IntermediateTensors] = None,
  339. inputs_embeds: Optional[torch.Tensor] = None,
  340. ) -> torch.Tensor:
  341. if inputs_embeds is not None:
  342. hidden_states = inputs_embeds
  343. else:
  344. hidden_states = self.get_input_embeddings(input_ids)
  345. residual = None
  346. for i in range(len(self.layers)):
  347. layer = self.layers[i]
  348. hidden_states, residual = layer(
  349. positions,
  350. hidden_states,
  351. kv_caches[i],
  352. attn_metadata,
  353. residual,
  354. )
  355. hidden_states = self.norm(hidden_states)
  356. return hidden_states
  357. class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
  358. packed_modules_mapping = {
  359. "qkv_proj": [
  360. "q_proj",
  361. "k_proj",
  362. "v_proj",
  363. ],
  364. "gate_up_proj": [
  365. "gate_proj",
  366. "up_proj",
  367. ],
  368. }
  369. # LoRA specific attributes
  370. supported_lora_modules = [
  371. "qkv_proj",
  372. "o_proj",
  373. "gate_up_proj",
  374. "down_proj",
  375. "embed_tokens",
  376. "lm_head",
  377. ]
  378. embedding_modules = {
  379. "embed_tokens": "input_embeddings",
  380. "lm_head": "output_embeddings",
  381. }
  382. embedding_padding_modules = ["lm_head"]
  383. def __init__(
  384. self,
  385. config: PretrainedConfig,
  386. cache_config: Optional[CacheConfig] = None,
  387. quant_config: Optional[QuantizationConfig] = None,
  388. lora_config: Optional[LoRAConfig] = None,
  389. ) -> None:
  390. super().__init__()
  391. self.config = config
  392. self.lora_config = lora_config
  393. self.num_experts = getattr(self.config, "num_experts", 0)
  394. self.quant_config = quant_config
  395. self.model = MiniCPMModel(config,
  396. cache_config,
  397. quant_config,
  398. lora_config=lora_config)
  399. unpadded_vocab_size = config.vocab_size
  400. if lora_config:
  401. unpadded_vocab_size += lora_config.lora_extra_vocab_size
  402. if not self.config.tie_word_embeddings:
  403. self.lm_head = ParallelLMHead(
  404. unpadded_vocab_size,
  405. config.hidden_size,
  406. org_num_embeddings=config.vocab_size,
  407. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  408. # We need bigger padding if using lora for kernel
  409. # compatibility
  410. if not lora_config else lora_config.lora_vocab_padding_size,
  411. quant_config=quant_config,
  412. )
  413. self.scale_width = self.config.hidden_size / self.config.dim_model_base
  414. self.logits_processor = LogitsProcessor(unpadded_vocab_size,
  415. config.vocab_size)
  416. self.sampler = Sampler()
  417. def forward(
  418. self,
  419. input_ids: torch.Tensor,
  420. positions: torch.Tensor,
  421. kv_caches: List[torch.Tensor],
  422. attn_metadata: AttentionMetadata,
  423. intermediate_tensors: Optional[IntermediateTensors] = None,
  424. ) -> torch.Tensor:
  425. hidden_states = self.model(input_ids, positions, kv_caches,
  426. attn_metadata, intermediate_tensors)
  427. return hidden_states
  428. def compute_logits(
  429. self,
  430. hidden_states: torch.Tensor,
  431. sampling_metadata: SamplingMetadata,
  432. ) -> Optional[torch.Tensor]:
  433. hidden_states = hidden_states / self.scale_width
  434. if self.config.tie_word_embeddings:
  435. lm_head = self.model.embed_tokens
  436. else:
  437. lm_head = self.lm_head
  438. logits = self.logits_processor(lm_head, hidden_states,
  439. sampling_metadata)
  440. return logits
  441. def sample(
  442. self,
  443. logits: torch.Tensor,
  444. sampling_metadata: SamplingMetadata,
  445. ) -> Optional[SamplerOutput]:
  446. next_tokens = self.sampler(logits, sampling_metadata)
  447. return next_tokens
  448. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  449. stacked_params_mapping = [
  450. # (param_name, shard_name, shard_id)
  451. ("qkv_proj", "q_proj", "q"),
  452. ("qkv_proj", "k_proj", "k"),
  453. ("qkv_proj", "v_proj", "v"),
  454. ("gate_up_proj", "gate_proj", 0),
  455. ("gate_up_proj", "up_proj", 1),
  456. ]
  457. expert_params_mapping = [
  458. # (param_name, weight_name, expert_id)
  459. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  460. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  461. for expert_id in range(self.num_experts)
  462. for weight_name in ["w1", "w2", "w3"]
  463. ]
  464. params_dict = dict(self.named_parameters())
  465. weights_list = list(weights)
  466. for name, loaded_weight in progress_bar(weights_list,
  467. desc="Loading modules..."):
  468. if "rotary_emb.inv_freq" in name:
  469. continue
  470. if ("rotary_emb.cos_cached" in name
  471. or "rotary_emb.sin_cached" in name):
  472. # Models trained using ColossalAI may include these tensors in
  473. # the checkpoint. Skip them.
  474. continue
  475. # With tie_word_embeddings, we can skip lm_head.weight
  476. # The weight might appear unnecessarily in the files if the model is
  477. # processed with quantization, LoRA, fine-tuning, etc.
  478. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  479. continue
  480. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  481. if weight_name not in name:
  482. continue
  483. name = name.replace(weight_name, param_name)
  484. # Skip loading extra bias for GPTQ models.
  485. if name.endswith(".bias") and name not in params_dict:
  486. continue
  487. param = params_dict[name]
  488. weight_loader = param.weight_loader
  489. weight_loader(param, loaded_weight, shard_id)
  490. break
  491. else:
  492. for param_name, weight_name, expert_id in expert_params_mapping:
  493. if weight_name not in name:
  494. continue
  495. name = name.replace(weight_name, param_name)
  496. param = params_dict[name]
  497. weight_loader = param.weight_loader
  498. weight_loader(param,
  499. loaded_weight,
  500. weight_name,
  501. expert_id=expert_id)
  502. break
  503. else:
  504. # Skip loading extra bias for GPTQ models.
  505. if name.endswith(".bias") and name not in params_dict:
  506. continue
  507. param = params_dict[name]
  508. weight_loader = getattr(param, "weight_loader",
  509. default_weight_loader)
  510. weight_loader(param, loaded_weight)