minicpm.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only MiniCPM model compatible with HuggingFace weights."""
  25. import math
  26. from typing import Any, Dict, Iterable, List, Optional, Tuple
  27. import torch
  28. from torch import nn
  29. from transformers import PretrainedConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import CacheConfig, LoRAConfig
  32. from aphrodite.common.sequence import IntermediateTensors
  33. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size,
  35. tensor_model_parallel_all_reduce)
  36. from aphrodite.modeling.layers.activation import SiluAndMul
  37. from aphrodite.modeling.layers.fused_moe import fused_moe
  38. from aphrodite.modeling.layers.layernorm import RMSNorm
  39. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  40. QKVParallelLinear,
  41. ReplicatedLinear,
  42. RowParallelLinear)
  43. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  44. from aphrodite.modeling.layers.rotary_embedding import get_rope
  45. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  46. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  47. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  48. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  49. from aphrodite.modeling.models.interfaces import SupportsLoRA
  50. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  51. from aphrodite.modeling.utils import set_weight_attrs
  52. from aphrodite.quantization.base_config import QuantizationConfig
  53. class MiniCPMMoE(nn.Module):
  54. """A tensor-parallel MoE implementation that shards each expert
  55. across all ranks.
  56. Each expert's weights are sharded across all ranks and a fused MoE
  57. kernel is used for the forward pass, and finally we reduce the outputs
  58. across ranks.
  59. """
  60. def __init__(
  61. self,
  62. num_experts: int,
  63. top_k: int,
  64. hidden_size: int,
  65. intermediate_size: int,
  66. params_dtype: Optional[torch.dtype] = None,
  67. tp_size: Optional[int] = None,
  68. ):
  69. super().__init__()
  70. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  71. self.num_total_experts = num_experts
  72. self.top_k = top_k
  73. self.hidden_size = hidden_size
  74. self.intermediate_size = intermediate_size // self.tp_size
  75. if params_dtype is None:
  76. params_dtype = torch.get_default_dtype()
  77. self.params_dtype = params_dtype
  78. self.gate = ReplicatedLinear(self.hidden_size,
  79. self.num_total_experts,
  80. bias=False,
  81. params_dtype=self.params_dtype,
  82. quant_config=None)
  83. self.ws = nn.Parameter(
  84. torch.empty(self.num_total_experts,
  85. 2 * self.intermediate_size,
  86. self.hidden_size,
  87. device="cuda",
  88. dtype=self.params_dtype))
  89. self.w2s = nn.Parameter(
  90. torch.empty(self.num_total_experts,
  91. self.hidden_size,
  92. self.intermediate_size,
  93. device="cuda",
  94. dtype=self.params_dtype))
  95. set_weight_attrs(self.ws, {
  96. "weight_loader": self.weight_loader,
  97. })
  98. set_weight_attrs(self.w2s, {
  99. "weight_loader": self.weight_loader,
  100. })
  101. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  102. weight_name: str, expert_id: int):
  103. tp_rank = get_tensor_model_parallel_rank()
  104. param_data = param.data
  105. shard_size = self.intermediate_size
  106. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  107. if weight_name.endswith("w1.weight"):
  108. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  109. if weight_name.endswith("w3.weight"):
  110. param_data[expert_id,
  111. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  112. if weight_name.endswith("w2.weight"):
  113. param_data[expert_id, :, :] = loaded_weight[:, shard]
  114. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  115. num_tokens, hidden_size = hidden_states.shape
  116. hidden_states = hidden_states.view(-1, self.hidden_size)
  117. # router_logits: (num_tokens, n_experts)
  118. router_logits, _ = self.gate(hidden_states)
  119. final_hidden_states = fused_moe(hidden_states,
  120. self.ws,
  121. self.w2s,
  122. router_logits,
  123. self.top_k,
  124. renormalize=True,
  125. inplace=True)
  126. if self.tp_size > 1:
  127. final_hidden_states = tensor_model_parallel_all_reduce(
  128. final_hidden_states)
  129. return final_hidden_states.view(num_tokens, hidden_size)
  130. class MiniCPMMLP(nn.Module):
  131. def __init__(
  132. self,
  133. hidden_size: int,
  134. intermediate_size: int,
  135. hidden_act: str,
  136. quant_config: Optional[QuantizationConfig] = None,
  137. ) -> None:
  138. super().__init__()
  139. self.gate_up_proj = MergedColumnParallelLinear(
  140. hidden_size, [intermediate_size] * 2,
  141. bias=False,
  142. quant_config=quant_config)
  143. self.down_proj = RowParallelLinear(intermediate_size,
  144. hidden_size,
  145. bias=False,
  146. quant_config=quant_config)
  147. if hidden_act != "silu":
  148. raise ValueError(f"Unsupported activation: {hidden_act}. "
  149. "Only silu is supported for now.")
  150. self.act_fn = SiluAndMul()
  151. def forward(self, x):
  152. gate_up, _ = self.gate_up_proj(x)
  153. x = self.act_fn(gate_up)
  154. x, _ = self.down_proj(x)
  155. return x
  156. class MiniCPMAttention(nn.Module):
  157. def __init__(
  158. self,
  159. hidden_size: int,
  160. num_heads: int,
  161. num_kv_heads: int,
  162. rope_theta: float = 10000,
  163. rope_scaling: Optional[Dict[str, Any]] = None,
  164. max_position_embeddings: int = 8192,
  165. cache_config: Optional[CacheConfig] = None,
  166. quant_config: Optional[QuantizationConfig] = None,
  167. ) -> None:
  168. super().__init__()
  169. self.hidden_size = hidden_size
  170. tp_size = get_tensor_model_parallel_world_size()
  171. self.total_num_heads = num_heads
  172. assert self.total_num_heads % tp_size == 0
  173. self.num_heads = self.total_num_heads // tp_size
  174. self.total_num_kv_heads = num_kv_heads
  175. if self.total_num_kv_heads >= tp_size:
  176. # Number of KV heads is greater than TP size, so we partition
  177. # the KV heads across multiple tensor parallel GPUs.
  178. assert self.total_num_kv_heads % tp_size == 0
  179. else:
  180. # Number of KV heads is less than TP size, so we replicate
  181. # the KV heads across multiple tensor parallel GPUs.
  182. assert tp_size % self.total_num_kv_heads == 0
  183. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  184. self.head_dim = hidden_size // self.total_num_heads
  185. self.q_size = self.num_heads * self.head_dim
  186. self.kv_size = self.num_kv_heads * self.head_dim
  187. self.scaling = self.head_dim**-0.5
  188. self.rope_theta = rope_theta
  189. self.max_position_embeddings = max_position_embeddings
  190. self.qkv_proj = QKVParallelLinear(
  191. hidden_size,
  192. self.head_dim,
  193. self.total_num_heads,
  194. self.total_num_kv_heads,
  195. bias=False,
  196. quant_config=quant_config,
  197. )
  198. self.o_proj = RowParallelLinear(
  199. self.total_num_heads * self.head_dim,
  200. hidden_size,
  201. bias=False,
  202. quant_config=quant_config,
  203. )
  204. self.rotary_emb = get_rope(
  205. self.head_dim,
  206. rotary_dim=self.head_dim,
  207. max_position=max_position_embeddings,
  208. base=rope_theta,
  209. rope_scaling=rope_scaling,
  210. )
  211. # set rope as fp32 instead of bf16
  212. self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache(
  213. )
  214. self.attn = Attention(self.num_heads,
  215. self.head_dim,
  216. self.scaling,
  217. num_kv_heads=self.num_kv_heads,
  218. cache_config=cache_config,
  219. quant_config=quant_config)
  220. def forward(
  221. self,
  222. positions: torch.Tensor,
  223. hidden_states: torch.Tensor,
  224. kv_cache: torch.Tensor,
  225. attn_metadata: AttentionMetadata,
  226. ) -> torch.Tensor:
  227. qkv, _ = self.qkv_proj(hidden_states)
  228. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  229. orig_dtype = q.dtype
  230. q, k = q.float(), k.float()
  231. q, k = self.rotary_emb(positions, q, k)
  232. q, k = q.to(orig_dtype), k.to(orig_dtype)
  233. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  234. output, _ = self.o_proj(attn_output)
  235. return output
  236. class MiniCPMDecoderLayer(nn.Module):
  237. def __init__(
  238. self,
  239. config,
  240. cache_config: Optional[CacheConfig] = None,
  241. quant_config: Optional[QuantizationConfig] = None,
  242. ) -> None:
  243. super().__init__()
  244. self.config = config
  245. self.cache_config = cache_config
  246. self.quant_config = quant_config
  247. self.hidden_size = config.hidden_size
  248. self.rope_theta = getattr(config, "rope_theta", 10000)
  249. self.rope_scaling = getattr(config, "rope_scaling", None)
  250. self.max_position_embeddings = getattr(config,
  251. "max_position_embeddings", 8192)
  252. self._init_attn_block()
  253. self._init_ffn_block()
  254. def _init_attn_block(self):
  255. self.input_layernorm = RMSNorm(self.config.hidden_size,
  256. eps=self.config.rms_norm_eps)
  257. self.self_attn = MiniCPMAttention(
  258. hidden_size=self.hidden_size,
  259. num_heads=self.config.num_attention_heads,
  260. num_kv_heads=self.config.num_key_value_heads,
  261. rope_theta=self.rope_theta,
  262. rope_scaling=self.rope_scaling,
  263. max_position_embeddings=self.max_position_embeddings,
  264. cache_config=self.cache_config,
  265. quant_config=self.quant_config,
  266. )
  267. def _init_ffn_block(self):
  268. self.post_attention_layernorm = RMSNorm(self.config.hidden_size,
  269. eps=self.config.rms_norm_eps)
  270. self.num_experts = getattr(self.config, "num_experts", 0)
  271. if self.num_experts == 0:
  272. self.mlp = MiniCPMMLP(
  273. hidden_size=self.hidden_size,
  274. intermediate_size=self.config.intermediate_size,
  275. hidden_act=self.config.hidden_act,
  276. quant_config=self.quant_config,
  277. )
  278. else:
  279. self.mlp = MiniCPMMoE(
  280. num_experts=self.config.num_experts,
  281. top_k=self.config.num_experts_per_tok,
  282. hidden_size=self.config.hidden_size,
  283. intermediate_size=self.config.intermediate_size)
  284. def forward(
  285. self,
  286. positions: torch.Tensor,
  287. hidden_states: torch.Tensor,
  288. kv_cache: torch.Tensor,
  289. attn_metadata: AttentionMetadata,
  290. residual: Optional[torch.Tensor],
  291. ) -> Tuple[torch.Tensor, torch.Tensor]:
  292. # Self Attention
  293. residual = hidden_states
  294. hidden_states = self.input_layernorm(hidden_states)
  295. hidden_states = self.self_attn(
  296. positions=positions,
  297. hidden_states=hidden_states,
  298. kv_cache=kv_cache,
  299. attn_metadata=attn_metadata,
  300. )
  301. hidden_states = residual + hidden_states * \
  302. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  303. # Fully Connected
  304. residual = hidden_states
  305. hidden_states = self.post_attention_layernorm(hidden_states)
  306. hidden_states = self.mlp(hidden_states)
  307. hidden_states = residual + hidden_states * \
  308. (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers))
  309. return hidden_states, None
  310. class MiniCPMModel(nn.Module):
  311. def __init__(
  312. self,
  313. config,
  314. cache_config: Optional[CacheConfig] = None,
  315. quant_config: Optional[QuantizationConfig] = None,
  316. lora_config: Optional[LoRAConfig] = None,
  317. ) -> None:
  318. super().__init__()
  319. self.config = config
  320. self.cache_config = cache_config
  321. self.quant_config = quant_config
  322. self.padding_idx = config.pad_token_id
  323. lora_vocab = (lora_config.lora_extra_vocab_size *
  324. (lora_config.max_loras or 1)) if lora_config else 0
  325. self.vocab_size = config.vocab_size + lora_vocab
  326. self.org_vocab_size = config.vocab_size
  327. self.embed_tokens = VocabParallelEmbedding(
  328. self.vocab_size,
  329. config.hidden_size,
  330. org_num_embeddings=config.vocab_size,
  331. )
  332. self._init_layers()
  333. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  334. def _init_layers(self):
  335. self.layers = nn.ModuleList([
  336. MiniCPMDecoderLayer(self.config, self.cache_config,
  337. self.quant_config)
  338. for _ in range(self.config.num_hidden_layers)
  339. ])
  340. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  341. embedding = self.embed_tokens(input_ids)
  342. return embedding * self.config.scale_emb
  343. def forward(
  344. self,
  345. input_ids: torch.Tensor,
  346. positions: torch.Tensor,
  347. kv_caches: List[torch.Tensor],
  348. attn_metadata: AttentionMetadata,
  349. intermediate_tensors: Optional[IntermediateTensors] = None,
  350. inputs_embeds: Optional[torch.Tensor] = None,
  351. ) -> torch.Tensor:
  352. if inputs_embeds is not None:
  353. hidden_states = inputs_embeds
  354. else:
  355. hidden_states = self.get_input_embeddings(input_ids)
  356. residual = None
  357. for i in range(len(self.layers)):
  358. layer = self.layers[i]
  359. hidden_states, residual = layer(
  360. positions,
  361. hidden_states,
  362. kv_caches[i],
  363. attn_metadata,
  364. residual,
  365. )
  366. hidden_states = self.norm(hidden_states)
  367. return hidden_states
  368. class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
  369. packed_modules_mapping = {
  370. "qkv_proj": [
  371. "q_proj",
  372. "k_proj",
  373. "v_proj",
  374. ],
  375. "gate_up_proj": [
  376. "gate_proj",
  377. "up_proj",
  378. ],
  379. }
  380. # LoRA specific attributes
  381. supported_lora_modules = [
  382. "qkv_proj",
  383. "o_proj",
  384. "gate_up_proj",
  385. "down_proj",
  386. "embed_tokens",
  387. "lm_head",
  388. ]
  389. embedding_modules = {
  390. "embed_tokens": "input_embeddings",
  391. "lm_head": "output_embeddings",
  392. }
  393. embedding_padding_modules = ["lm_head"]
  394. def __init__(
  395. self,
  396. config: PretrainedConfig,
  397. cache_config: Optional[CacheConfig] = None,
  398. quant_config: Optional[QuantizationConfig] = None,
  399. lora_config: Optional[LoRAConfig] = None,
  400. ) -> None:
  401. super().__init__()
  402. self.config = config
  403. self.lora_config = lora_config
  404. self.cache_config = cache_config
  405. self.quant_config = quant_config
  406. self.num_experts = getattr(self.config, "num_experts", 0)
  407. self._init_model()
  408. unpadded_vocab_size = config.vocab_size
  409. if lora_config:
  410. unpadded_vocab_size += lora_config.lora_extra_vocab_size
  411. if not self.config.tie_word_embeddings:
  412. self.lm_head = ParallelLMHead(
  413. unpadded_vocab_size,
  414. config.hidden_size,
  415. org_num_embeddings=config.vocab_size,
  416. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  417. # We need bigger padding if using lora for kernel
  418. # compatibility
  419. if not lora_config else lora_config.lora_vocab_padding_size,
  420. quant_config=quant_config,
  421. )
  422. self.scale_width = self.config.hidden_size / self.config.dim_model_base
  423. self.logits_processor = LogitsProcessor(unpadded_vocab_size,
  424. config.vocab_size)
  425. self.sampler = Sampler()
  426. def _init_model(self):
  427. self.model = MiniCPMModel(config=self.config,
  428. cache_config=self.cache_config,
  429. quant_config=self.quant_config,
  430. lora_config=self.lora_config)
  431. def forward(
  432. self,
  433. input_ids: torch.Tensor,
  434. positions: torch.Tensor,
  435. kv_caches: List[torch.Tensor],
  436. attn_metadata: AttentionMetadata,
  437. intermediate_tensors: Optional[IntermediateTensors] = None,
  438. ) -> torch.Tensor:
  439. hidden_states = self.model(input_ids, positions, kv_caches,
  440. attn_metadata, intermediate_tensors)
  441. return hidden_states
  442. def compute_logits(
  443. self,
  444. hidden_states: torch.Tensor,
  445. sampling_metadata: SamplingMetadata,
  446. ) -> Optional[torch.Tensor]:
  447. hidden_states = hidden_states / self.scale_width
  448. if self.config.tie_word_embeddings:
  449. lm_head = self.model.embed_tokens
  450. else:
  451. lm_head = self.lm_head
  452. logits = self.logits_processor(lm_head, hidden_states,
  453. sampling_metadata)
  454. return logits
  455. def sample(
  456. self,
  457. logits: torch.Tensor,
  458. sampling_metadata: SamplingMetadata,
  459. ) -> Optional[SamplerOutput]:
  460. next_tokens = self.sampler(logits, sampling_metadata)
  461. return next_tokens
  462. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  463. stacked_params_mapping = [
  464. # (param_name, shard_name, shard_id)
  465. ("qkv_proj", "q_proj", "q"),
  466. ("qkv_proj", "k_proj", "k"),
  467. ("qkv_proj", "v_proj", "v"),
  468. ("gate_up_proj", "gate_proj", 0),
  469. ("gate_up_proj", "up_proj", 1),
  470. ]
  471. expert_params_mapping = [
  472. # (param_name, weight_name, expert_id)
  473. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  474. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  475. for expert_id in range(self.num_experts)
  476. for weight_name in ["w1", "w2", "w3"]
  477. ]
  478. params_dict = dict(self.named_parameters())
  479. for name, loaded_weight in weights:
  480. if "rotary_emb.inv_freq" in name:
  481. continue
  482. if ("rotary_emb.cos_cached" in name
  483. or "rotary_emb.sin_cached" in name):
  484. # Models trained using ColossalAI may include these tensors in
  485. # the checkpoint. Skip them.
  486. continue
  487. # With tie_word_embeddings, we can skip lm_head.weight
  488. # The weight might appear unnecessarily in the files if the model is
  489. # processed with quantization, LoRA, fine-tuning, etc.
  490. if self.config.tie_word_embeddings and "lm_head.weight" in name:
  491. continue
  492. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  493. if weight_name not in name:
  494. continue
  495. name = name.replace(weight_name, param_name)
  496. # Skip loading extra bias for GPTQ models.
  497. if name.endswith(".bias") and name not in params_dict:
  498. continue
  499. param = params_dict[name]
  500. weight_loader = param.weight_loader
  501. weight_loader(param, loaded_weight, shard_id)
  502. break
  503. else:
  504. for param_name, weight_name, expert_id in expert_params_mapping:
  505. if weight_name not in name:
  506. continue
  507. name = name.replace(weight_name, param_name)
  508. param = params_dict[name]
  509. weight_loader = param.weight_loader
  510. weight_loader(param,
  511. loaded_weight,
  512. weight_name,
  513. expert_id=expert_id)
  514. break
  515. else:
  516. # Skip loading extra bias for GPTQ models.
  517. if name.endswith(".bias") and name not in params_dict:
  518. continue
  519. param = params_dict[name]
  520. weight_loader = getattr(param, "weight_loader",
  521. default_weight_loader)
  522. weight_loader(param, loaded_weight)