minicpm3.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2024 The ModelBest team.
  5. # Copyright 2023 The PygmalionAI team.
  6. # Copyright 2023 The vLLM team.
  7. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  10. # and OPT implementations in this library. It has been modified from its
  11. # original forms to accommodate minor architectural differences compared
  12. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  13. #
  14. # Licensed under the Apache License, Version 2.0 (the "License");
  15. # you may not use this file except in compliance with the License.
  16. # You may obtain a copy of the License at
  17. #
  18. # http://www.apache.org/licenses/LICENSE-2.0
  19. #
  20. # Unless required by applicable law or agreed to in writing, software
  21. # distributed under the License is distributed on an "AS IS" BASIS,
  22. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. # See the License for the specific language governing permissions and
  24. # limitations under the License.
  25. """Inference-only MiniCPM3 model compatible with HuggingFace weights."""
  26. from typing import Any, Dict, Optional
  27. import torch
  28. from torch import nn
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import CacheConfig
  31. from aphrodite.distributed import get_tensor_model_parallel_world_size
  32. from aphrodite.modeling.layers.layernorm import RMSNorm
  33. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  34. ReplicatedLinear,
  35. RowParallelLinear)
  36. from aphrodite.modeling.layers.rotary_embedding import get_rope
  37. from aphrodite.modeling.models.minicpm import (MiniCPMDecoderLayer,
  38. MiniCPMForCausalLM,
  39. MiniCPMModel)
  40. from aphrodite.quantization.base_config import QuantizationConfig
  41. class MiniCPM3Attention(nn.Module):
  42. def __init__(
  43. self,
  44. config,
  45. hidden_size: int,
  46. num_heads: int,
  47. qk_nope_head_dim: int,
  48. qk_rope_head_dim: int,
  49. v_head_dim: int,
  50. q_lora_rank: int,
  51. kv_lora_rank: int,
  52. rope_theta: float = 10000,
  53. rope_scaling: Optional[Dict[str, Any]] = None,
  54. max_position_embeddings: int = 8192,
  55. cache_config: Optional[CacheConfig] = None,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. ) -> None:
  58. super().__init__()
  59. self.hidden_size = hidden_size
  60. self.qk_nope_head_dim = qk_nope_head_dim
  61. self.qk_rope_head_dim = qk_rope_head_dim
  62. self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
  63. self.v_head_dim = v_head_dim
  64. self.q_lora_rank = q_lora_rank
  65. self.kv_lora_rank = kv_lora_rank
  66. self.num_heads = num_heads
  67. tp_size = get_tensor_model_parallel_world_size()
  68. assert self.num_heads % tp_size == 0
  69. self.num_local_heads = num_heads // tp_size
  70. self.scaling = self.qk_head_dim**-0.5
  71. self.rope_theta = rope_theta
  72. self.max_position_embeddings = max_position_embeddings
  73. self.q_a_proj = ReplicatedLinear(self.hidden_size,
  74. self.q_lora_rank,
  75. bias=False,
  76. quant_config=quant_config)
  77. self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
  78. self.q_b_proj = ColumnParallelLinear(q_lora_rank,
  79. self.num_heads * self.qk_head_dim,
  80. bias=False,
  81. quant_config=quant_config)
  82. self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size,
  83. self.kv_lora_rank +
  84. self.qk_rope_head_dim,
  85. bias=False,
  86. quant_config=quant_config)
  87. self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
  88. eps=config.rms_norm_eps)
  89. self.kv_b_proj = ColumnParallelLinear(
  90. self.kv_lora_rank,
  91. self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
  92. bias=False,
  93. quant_config=quant_config)
  94. # O projection.
  95. self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
  96. self.hidden_size,
  97. bias=False,
  98. quant_config=quant_config)
  99. self.rotary_emb = get_rope(
  100. self.qk_rope_head_dim,
  101. rotary_dim=self.qk_rope_head_dim,
  102. max_position=max_position_embeddings,
  103. base=rope_theta,
  104. rope_scaling=rope_scaling,
  105. )
  106. self.attn = Attention(self.num_local_heads,
  107. self.qk_head_dim,
  108. self.scaling,
  109. num_kv_heads=self.num_local_heads,
  110. cache_config=cache_config,
  111. quant_config=quant_config)
  112. def forward(
  113. self,
  114. positions: torch.Tensor,
  115. hidden_states: torch.Tensor,
  116. kv_cache: torch.Tensor,
  117. attn_metadata: AttentionMetadata,
  118. ) -> torch.Tensor:
  119. q, _ = self.q_a_proj(hidden_states)
  120. q = self.q_a_layernorm(q)
  121. q, _ = self.q_b_proj(q)
  122. q = q.view(-1, self.num_local_heads, self.qk_head_dim)
  123. _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
  124. dim=-1)
  125. latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states)
  126. kv_a, _ = latent_cache.split(
  127. [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
  128. latent_cache = latent_cache.unsqueeze(1)
  129. kv_a = self.kv_a_layernorm(kv_a.contiguous())
  130. kv, _ = self.kv_b_proj(kv_a)
  131. kv = kv.view(-1, self.num_local_heads,
  132. self.qk_nope_head_dim + self.v_head_dim)
  133. k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
  134. k_pe = latent_cache[:, :, self.kv_lora_rank:]
  135. q_pe, k_pe = self.rotary_emb(
  136. positions,
  137. q_pe.reshape(-1, self.num_local_heads * self.qk_rope_head_dim),
  138. k_pe.reshape(-1, self.qk_rope_head_dim))
  139. q_pe = q_pe.view(-1, self.num_local_heads, self.qk_rope_head_dim)
  140. k_pe = k_pe.view(-1, 1, self.qk_rope_head_dim)
  141. q[..., self.qk_nope_head_dim:] = q_pe
  142. k = torch.empty_like(q)
  143. k[..., :self.qk_nope_head_dim] = k_nope
  144. k[..., self.qk_nope_head_dim:] = k_pe
  145. q = q.reshape(-1, self.num_local_heads * self.qk_head_dim)
  146. k = k.view(-1, self.num_local_heads * self.qk_head_dim)
  147. v = torch.nn.functional.pad(
  148. v, [0, self.qk_head_dim - self.v_head_dim],
  149. value=0).view(-1, self.num_local_heads * self.qk_head_dim)
  150. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  151. attn_output = attn_output.view(
  152. -1, self.num_local_heads,
  153. self.qk_head_dim)[..., :self.v_head_dim].reshape(
  154. -1, self.num_local_heads * self.v_head_dim)
  155. output, _ = self.o_proj(attn_output)
  156. return output
  157. class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
  158. def _init_attn_block(self):
  159. self.input_layernorm = RMSNorm(self.config.hidden_size,
  160. eps=self.config.rms_norm_eps)
  161. self.self_attn = MiniCPM3Attention(
  162. config=self.config,
  163. hidden_size=self.hidden_size,
  164. num_heads=self.config.num_attention_heads,
  165. qk_nope_head_dim=self.config.qk_nope_head_dim,
  166. qk_rope_head_dim=self.config.qk_rope_head_dim,
  167. v_head_dim=self.config.v_head_dim,
  168. q_lora_rank=self.config.q_lora_rank,
  169. kv_lora_rank=self.config.kv_lora_rank,
  170. rope_theta=self.rope_theta,
  171. rope_scaling=self.rope_scaling,
  172. max_position_embeddings=self.max_position_embeddings,
  173. cache_config=self.cache_config,
  174. quant_config=self.quant_config,
  175. )
  176. class MiniCPM3Model(MiniCPMModel):
  177. def _init_layers(self):
  178. self.layers = nn.ModuleList([
  179. MiniCPM3DecoderLayer(self.config, self.cache_config,
  180. self.quant_config)
  181. for _ in range(self.config.num_hidden_layers)
  182. ])
  183. class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
  184. def _init_model(self):
  185. self.model = MiniCPM3Model(config=self.config,
  186. cache_config=self.cache_config,
  187. quant_config=self.quant_config,
  188. lora_config=self.lora_config)