internlm2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # -*- coding: utf-8 -*-
  2. from typing import Any, Dict, Iterable, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.attention import Attention, AttentionMetadata
  7. from aphrodite.common.config import CacheConfig
  8. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  9. from aphrodite.distributed import get_tensor_model_parallel_world_size
  10. from aphrodite.modeling.layers.activation import SiluAndMul
  11. from aphrodite.modeling.layers.layernorm import RMSNorm
  12. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  13. QKVParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  16. from aphrodite.modeling.layers.rotary_embedding import get_rope
  17. from aphrodite.modeling.layers.sampler import Sampler
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. ParallelLMHead, VocabParallelEmbedding)
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. from aphrodite.quantization.base_config import QuantizationConfig
  23. class InternLM2MLP(nn.Module):
  24. def __init__(
  25. self,
  26. hidden_size: int,
  27. intermediate_size: int,
  28. hidden_act: str,
  29. quant_config: Optional[QuantizationConfig] = None,
  30. ) -> None:
  31. super().__init__()
  32. self.gate_up_proj = MergedColumnParallelLinear(
  33. hidden_size, [intermediate_size] * 2,
  34. bias=False,
  35. quant_config=quant_config)
  36. self.w2 = RowParallelLinear(intermediate_size,
  37. hidden_size,
  38. bias=False,
  39. quant_config=quant_config)
  40. if hidden_act != "silu":
  41. raise ValueError(f"Unsupported activation: {hidden_act}. "
  42. "Only silu is supported for now.")
  43. self.act_fn = SiluAndMul()
  44. def forward(self, x):
  45. gate_up, _ = self.gate_up_proj(x)
  46. x = self.act_fn(gate_up)
  47. x, _ = self.w2(x)
  48. return x
  49. class InternLM2Attention(nn.Module):
  50. def __init__(
  51. self,
  52. hidden_size: int,
  53. num_heads: int,
  54. num_kv_heads: int,
  55. rope_theta: float = 10000,
  56. rope_scaling: Optional[Dict[str, Any]] = None,
  57. max_position_embeddings: int = 8192,
  58. cache_config: Optional[CacheConfig] = None,
  59. quant_config: Optional[QuantizationConfig] = None,
  60. ) -> None:
  61. super().__init__()
  62. self.hidden_size = hidden_size
  63. tp_size = get_tensor_model_parallel_world_size()
  64. self.total_num_heads = num_heads
  65. assert self.total_num_heads % tp_size == 0
  66. self.num_heads = self.total_num_heads // tp_size
  67. self.total_num_kv_heads = num_kv_heads
  68. if self.total_num_kv_heads >= tp_size:
  69. # Number of KV heads is greater than TP size, so we partition
  70. # the KV heads across multiple tensor parallel GPUs.
  71. assert self.total_num_kv_heads % tp_size == 0
  72. else:
  73. # Number of KV heads is less than TP size, so we replicate
  74. # the KV heads across multiple tensor parallel GPUs.
  75. assert tp_size % self.total_num_kv_heads == 0
  76. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  77. self.head_dim = hidden_size // self.total_num_heads
  78. self.q_size = self.num_heads * self.head_dim
  79. self.kv_size = self.num_kv_heads * self.head_dim
  80. self.key_value_groups = int(self.num_heads / self.num_kv_heads)
  81. self.scaling = self.head_dim**-0.5
  82. self.rope_theta = rope_theta
  83. self.max_position_embeddings = max_position_embeddings
  84. self.wqkv = QKVParallelLinear(
  85. hidden_size,
  86. self.head_dim,
  87. self.total_num_heads,
  88. self.total_num_kv_heads,
  89. bias=False,
  90. quant_config=quant_config,
  91. )
  92. self.wo = RowParallelLinear(
  93. self.total_num_heads * self.head_dim,
  94. hidden_size,
  95. bias=False,
  96. quant_config=quant_config,
  97. )
  98. self.rotary_emb = get_rope(
  99. self.head_dim,
  100. rotary_dim=self.head_dim,
  101. max_position=max_position_embeddings,
  102. base=rope_theta,
  103. rope_scaling=rope_scaling,
  104. )
  105. self.attn = Attention(self.num_heads,
  106. self.head_dim,
  107. self.scaling,
  108. num_kv_heads=self.num_kv_heads,
  109. cache_config=cache_config,
  110. quant_config=quant_config)
  111. def split_qkv(self, qkv: torch.Tensor):
  112. qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
  113. q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
  114. q = q.reshape(-1, self.q_size)
  115. k = k.reshape(-1, self.kv_size)
  116. v = v.reshape(-1, self.kv_size)
  117. return q, k, v
  118. def forward(
  119. self,
  120. positions: torch.Tensor,
  121. hidden_states: torch.Tensor,
  122. kv_cache: torch.Tensor,
  123. attn_metadata: AttentionMetadata,
  124. ) -> torch.Tensor:
  125. qkv, _ = self.wqkv(hidden_states)
  126. q, k, v = self.split_qkv(qkv)
  127. q, k = self.rotary_emb(positions, q, k)
  128. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  129. output, _ = self.wo(attn_output)
  130. return output
  131. class InternLMDecoderLayer(nn.Module):
  132. def __init__(
  133. self,
  134. config: PretrainedConfig,
  135. cache_config: Optional[CacheConfig] = None,
  136. quant_config: Optional[QuantizationConfig] = None,
  137. ) -> None:
  138. super().__init__()
  139. self.hidden_size = config.hidden_size
  140. rope_theta = getattr(config, "rope_theta", 10000)
  141. rope_scaling = getattr(config, "rope_scaling", None)
  142. max_position_embeddings = getattr(config, "max_position_embeddings",
  143. 8192)
  144. self.attention = InternLM2Attention(
  145. hidden_size=self.hidden_size,
  146. num_heads=config.num_attention_heads,
  147. num_kv_heads=config.num_key_value_heads,
  148. rope_theta=rope_theta,
  149. rope_scaling=rope_scaling,
  150. max_position_embeddings=max_position_embeddings,
  151. cache_config=cache_config,
  152. quant_config=quant_config,
  153. )
  154. self.feed_forward = InternLM2MLP(
  155. hidden_size=self.hidden_size,
  156. intermediate_size=config.intermediate_size,
  157. hidden_act=config.hidden_act,
  158. quant_config=quant_config,
  159. )
  160. self.attention_norm = RMSNorm(config.hidden_size,
  161. eps=config.rms_norm_eps)
  162. self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  163. def forward(
  164. self,
  165. positions: torch.Tensor,
  166. hidden_states: torch.Tensor,
  167. kv_cache: torch.Tensor,
  168. attn_metadata: AttentionMetadata,
  169. residual: Optional[torch.Tensor],
  170. ) -> Tuple[torch.Tensor, torch.Tensor]:
  171. # Self Attention
  172. if residual is None:
  173. residual = hidden_states
  174. hidden_states = self.attention_norm(hidden_states)
  175. else:
  176. hidden_states, residual = self.attention_norm(
  177. hidden_states, residual)
  178. hidden_states = self.attention(
  179. positions=positions,
  180. hidden_states=hidden_states,
  181. kv_cache=kv_cache,
  182. attn_metadata=attn_metadata,
  183. )
  184. # Fully Connected
  185. hidden_states, residual = self.ffn_norm(hidden_states, residual)
  186. hidden_states = self.feed_forward(hidden_states)
  187. return hidden_states, residual
  188. class InternLM2Model(nn.Module):
  189. def __init__(
  190. self,
  191. config: PretrainedConfig,
  192. cache_config: Optional[CacheConfig] = None,
  193. quant_config: Optional[QuantizationConfig] = None,
  194. ) -> None:
  195. super().__init__()
  196. self.config = config
  197. self.padding_idx = config.pad_token_id
  198. self.vocab_size = config.vocab_size
  199. self.tok_embeddings = VocabParallelEmbedding(
  200. config.vocab_size,
  201. config.hidden_size,
  202. )
  203. self.layers = nn.ModuleList([
  204. InternLMDecoderLayer(config, cache_config, quant_config)
  205. for _ in range(config.num_hidden_layers)
  206. ])
  207. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  208. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  209. return self.tok_embeddings(input_ids)
  210. def forward(
  211. self,
  212. input_ids: torch.Tensor,
  213. positions: torch.Tensor,
  214. kv_caches: List[torch.Tensor],
  215. attn_metadata: AttentionMetadata,
  216. intermediate_tensors: Optional[IntermediateTensors] = None,
  217. inputs_embeds: Optional[torch.Tensor] = None,
  218. ) -> torch.Tensor:
  219. if inputs_embeds is not None:
  220. hidden_states = inputs_embeds
  221. else:
  222. hidden_states = self.tok_embeddings(input_ids)
  223. residual = None
  224. for i in range(len(self.layers)):
  225. layer = self.layers[i]
  226. hidden_states, residual = layer(
  227. positions,
  228. hidden_states,
  229. kv_caches[i],
  230. attn_metadata,
  231. residual,
  232. )
  233. hidden_states, _ = self.norm(hidden_states, residual)
  234. return hidden_states
  235. class InternLM2ForCausalLM(nn.Module):
  236. def __init__(
  237. self,
  238. config: PretrainedConfig,
  239. cache_config: Optional[CacheConfig] = None,
  240. quant_config: Optional[QuantizationConfig] = None,
  241. ) -> None:
  242. super().__init__()
  243. self.config = config
  244. self.quant_config = quant_config
  245. self.model = InternLM2Model(config, cache_config, quant_config)
  246. self.output = ParallelLMHead(config.vocab_size,
  247. config.hidden_size,
  248. quant_config=quant_config)
  249. self.logits_processor = LogitsProcessor(config.vocab_size)
  250. self.sampler = Sampler()
  251. def forward(
  252. self,
  253. input_ids: torch.Tensor,
  254. positions: torch.Tensor,
  255. kv_caches: List[torch.Tensor],
  256. attn_metadata: AttentionMetadata,
  257. intermediate_tensors: IntermediateTensors,
  258. ) -> torch.Tensor:
  259. hidden_states = self.model(input_ids, positions, kv_caches,
  260. attn_metadata)
  261. return hidden_states
  262. def compute_logits(
  263. self,
  264. hidden_states: torch.Tensor,
  265. sampling_metadata: SamplingMetadata,
  266. ) -> Optional[torch.Tensor]:
  267. logits = self.logits_processor(self.output, hidden_states,
  268. sampling_metadata)
  269. return logits
  270. def sample(
  271. self,
  272. logits: torch.Tensor,
  273. sampling_metadata: SamplingMetadata,
  274. ) -> Optional[SamplerOutput]:
  275. next_tokens = self.sampler(logits, sampling_metadata)
  276. return next_tokens
  277. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  278. stacked_params_mapping = [
  279. # (param_name, shard_name, shard_id)
  280. ("gate_up_proj", "w1", 0),
  281. ("gate_up_proj", "w3", 1),
  282. ]
  283. params_dict = dict(self.named_parameters())
  284. for name, loaded_weight in weights:
  285. if "rotary_emb.inv_freq" in name:
  286. continue
  287. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  288. if weight_name not in name:
  289. continue
  290. name = name.replace(weight_name, param_name)
  291. # Skip loading extra bias for GPTQ models.
  292. if name.endswith(".bias") and name not in params_dict:
  293. continue
  294. param = params_dict[name]
  295. weight_loader = param.weight_loader
  296. weight_loader(param, loaded_weight, shard_id)
  297. break
  298. else:
  299. # Skip loading extra bias for GPTQ models.
  300. if name.endswith(".bias") and name not in params_dict:
  301. continue
  302. param = params_dict[name]
  303. weight_loader = getattr(param, "weight_loader",
  304. default_weight_loader)
  305. weight_loader(param, loaded_weight)