internlm2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # -*- coding: utf-8 -*-
  2. from typing import Any, Dict, Iterable, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.attention import Attention, AttentionMetadata
  7. from aphrodite.common.config import CacheConfig
  8. from aphrodite.common.sequence import SamplerOutput
  9. from aphrodite.distributed import get_tensor_model_parallel_world_size
  10. from aphrodite.modeling.layers.activation import SiluAndMul
  11. from aphrodite.modeling.layers.layernorm import RMSNorm
  12. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  13. QKVParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  16. from aphrodite.modeling.layers.rotary_embedding import get_rope
  17. from aphrodite.modeling.layers.sampler import Sampler
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. ParallelLMHead, VocabParallelEmbedding)
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. from aphrodite.quantization.base_config import QuantizationConfig
  23. class InternLM2MLP(nn.Module):
  24. def __init__(
  25. self,
  26. hidden_size: int,
  27. intermediate_size: int,
  28. hidden_act: str,
  29. quant_config: Optional[QuantizationConfig] = None,
  30. ) -> None:
  31. super().__init__()
  32. self.gate_up_proj = MergedColumnParallelLinear(
  33. hidden_size, [intermediate_size] * 2,
  34. bias=False,
  35. quant_config=quant_config)
  36. self.w2 = RowParallelLinear(intermediate_size,
  37. hidden_size,
  38. bias=False,
  39. quant_config=quant_config)
  40. if hidden_act != "silu":
  41. raise ValueError(f"Unsupported activation: {hidden_act}. "
  42. "Only silu is supported for now.")
  43. self.act_fn = SiluAndMul()
  44. def forward(self, x):
  45. gate_up, _ = self.gate_up_proj(x)
  46. x = self.act_fn(gate_up)
  47. x, _ = self.w2(x)
  48. return x
  49. class InternLM2Attention(nn.Module):
  50. def __init__(
  51. self,
  52. hidden_size: int,
  53. num_heads: int,
  54. num_kv_heads: int,
  55. rope_theta: float = 10000,
  56. rope_scaling: Optional[Dict[str, Any]] = None,
  57. max_position_embeddings: int = 8192,
  58. cache_config: Optional[CacheConfig] = None,
  59. quant_config: Optional[QuantizationConfig] = None,
  60. ) -> None:
  61. super().__init__()
  62. self.hidden_size = hidden_size
  63. tp_size = get_tensor_model_parallel_world_size()
  64. self.total_num_heads = num_heads
  65. assert self.total_num_heads % tp_size == 0
  66. self.num_heads = self.total_num_heads // tp_size
  67. self.total_num_kv_heads = num_kv_heads
  68. if self.total_num_kv_heads >= tp_size:
  69. # Number of KV heads is greater than TP size, so we partition
  70. # the KV heads across multiple tensor parallel GPUs.
  71. assert self.total_num_kv_heads % tp_size == 0
  72. else:
  73. # Number of KV heads is less than TP size, so we replicate
  74. # the KV heads across multiple tensor parallel GPUs.
  75. assert tp_size % self.total_num_kv_heads == 0
  76. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  77. self.head_dim = hidden_size // self.total_num_heads
  78. self.q_size = self.num_heads * self.head_dim
  79. self.kv_size = self.num_kv_heads * self.head_dim
  80. self.scaling = self.head_dim**-0.5
  81. self.rope_theta = rope_theta
  82. self.max_position_embeddings = max_position_embeddings
  83. self.wqkv = QKVParallelLinear(
  84. hidden_size,
  85. self.head_dim,
  86. self.total_num_heads,
  87. self.total_num_kv_heads,
  88. bias=False,
  89. quant_config=quant_config,
  90. )
  91. self.wo = RowParallelLinear(
  92. self.total_num_heads * self.head_dim,
  93. hidden_size,
  94. bias=False,
  95. quant_config=quant_config,
  96. )
  97. self.rotary_emb = get_rope(
  98. self.head_dim,
  99. rotary_dim=self.head_dim,
  100. max_position=max_position_embeddings,
  101. base=rope_theta,
  102. rope_scaling=rope_scaling,
  103. )
  104. self.attn = Attention(self.num_heads,
  105. self.head_dim,
  106. self.scaling,
  107. num_kv_heads=self.num_kv_heads,
  108. cache_config=cache_config,
  109. quant_config=quant_config)
  110. def forward(
  111. self,
  112. positions: torch.Tensor,
  113. hidden_states: torch.Tensor,
  114. kv_cache: torch.Tensor,
  115. attn_metadata: AttentionMetadata,
  116. ) -> torch.Tensor:
  117. qkv, _ = self.wqkv(hidden_states)
  118. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  119. q, k = self.rotary_emb(positions, q, k)
  120. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  121. output, _ = self.wo(attn_output)
  122. return output
  123. class InternLMDecoderLayer(nn.Module):
  124. def __init__(
  125. self,
  126. config: PretrainedConfig,
  127. cache_config: Optional[CacheConfig] = None,
  128. quant_config: Optional[QuantizationConfig] = None,
  129. ) -> None:
  130. super().__init__()
  131. self.hidden_size = config.hidden_size
  132. rope_theta = getattr(config, "rope_theta", 10000)
  133. rope_scaling = getattr(config, "rope_scaling", None)
  134. max_position_embeddings = getattr(config, "max_position_embeddings",
  135. 8192)
  136. self.attention = InternLM2Attention(
  137. hidden_size=self.hidden_size,
  138. num_heads=config.num_attention_heads,
  139. num_kv_heads=config.num_key_value_heads,
  140. rope_theta=rope_theta,
  141. rope_scaling=rope_scaling,
  142. max_position_embeddings=max_position_embeddings,
  143. cache_config=cache_config,
  144. quant_config=quant_config,
  145. )
  146. self.feed_forward = InternLM2MLP(
  147. hidden_size=self.hidden_size,
  148. intermediate_size=config.intermediate_size,
  149. hidden_act=config.hidden_act,
  150. quant_config=quant_config,
  151. )
  152. self.attention_norm = RMSNorm(config.hidden_size,
  153. eps=config.rms_norm_eps)
  154. self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  155. def forward(
  156. self,
  157. positions: torch.Tensor,
  158. hidden_states: torch.Tensor,
  159. kv_cache: torch.Tensor,
  160. attn_metadata: AttentionMetadata,
  161. residual: Optional[torch.Tensor],
  162. ) -> Tuple[torch.Tensor, torch.Tensor]:
  163. # Self Attention
  164. if residual is None:
  165. residual = hidden_states
  166. hidden_states = self.attention_norm(hidden_states)
  167. else:
  168. hidden_states, residual = self.attention_norm(
  169. hidden_states, residual)
  170. hidden_states = self.attention(
  171. positions=positions,
  172. hidden_states=hidden_states,
  173. kv_cache=kv_cache,
  174. attn_metadata=attn_metadata,
  175. )
  176. # Fully Connected
  177. hidden_states, residual = self.ffn_norm(hidden_states, residual)
  178. hidden_states = self.feed_forward(hidden_states)
  179. return hidden_states, residual
  180. class InternLM2Model(nn.Module):
  181. def __init__(
  182. self,
  183. config: PretrainedConfig,
  184. cache_config: Optional[CacheConfig] = None,
  185. quant_config: Optional[QuantizationConfig] = None,
  186. ) -> None:
  187. super().__init__()
  188. self.config = config
  189. self.padding_idx = config.pad_token_id
  190. self.vocab_size = config.vocab_size
  191. self.tok_embeddings = VocabParallelEmbedding(
  192. config.vocab_size,
  193. config.hidden_size,
  194. )
  195. self.layers = nn.ModuleList([
  196. InternLMDecoderLayer(config, cache_config, quant_config)
  197. for _ in range(config.num_hidden_layers)
  198. ])
  199. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  200. def forward(
  201. self,
  202. input_ids: torch.Tensor,
  203. positions: torch.Tensor,
  204. kv_caches: List[torch.Tensor],
  205. attn_metadata: AttentionMetadata,
  206. ) -> torch.Tensor:
  207. hidden_states = self.tok_embeddings(input_ids)
  208. residual = None
  209. for i in range(len(self.layers)):
  210. layer = self.layers[i]
  211. hidden_states, residual = layer(
  212. positions,
  213. hidden_states,
  214. kv_caches[i],
  215. attn_metadata,
  216. residual,
  217. )
  218. hidden_states, _ = self.norm(hidden_states, residual)
  219. return hidden_states
  220. class InternLM2ForCausalLM(nn.Module):
  221. def __init__(
  222. self,
  223. config: PretrainedConfig,
  224. cache_config: Optional[CacheConfig] = None,
  225. quant_config: Optional[QuantizationConfig] = None,
  226. ) -> None:
  227. super().__init__()
  228. self.config = config
  229. self.quant_config = quant_config
  230. self.model = InternLM2Model(config, cache_config, quant_config)
  231. self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
  232. self.logits_processor = LogitsProcessor(config.vocab_size)
  233. self.sampler = Sampler()
  234. def forward(
  235. self,
  236. input_ids: torch.Tensor,
  237. positions: torch.Tensor,
  238. kv_caches: List[torch.Tensor],
  239. attn_metadata: AttentionMetadata,
  240. ) -> torch.Tensor:
  241. hidden_states = self.model(input_ids, positions, kv_caches,
  242. attn_metadata)
  243. return hidden_states
  244. def compute_logits(self, hidden_states: torch.Tensor,
  245. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  246. logits = self.logits_processor(self.output.weight, hidden_states,
  247. sampling_metadata)
  248. return logits
  249. def sample(
  250. self,
  251. logits: torch.Tensor,
  252. sampling_metadata: SamplingMetadata,
  253. ) -> Optional[SamplerOutput]:
  254. next_tokens = self.sampler(logits, sampling_metadata)
  255. return next_tokens
  256. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  257. stacked_params_mapping = [
  258. # (param_name, shard_name, shard_id)
  259. ("gate_up_proj", "w1", 0),
  260. ("gate_up_proj", "w3", 1),
  261. ]
  262. params_dict = dict(self.named_parameters())
  263. for name, loaded_weight in weights:
  264. if "rotary_emb.inv_freq" in name:
  265. continue
  266. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  267. if weight_name not in name:
  268. continue
  269. name = name.replace(weight_name, param_name)
  270. # Skip loading extra bias for GPTQ models.
  271. if name.endswith(".bias") and name not in params_dict:
  272. continue
  273. param = params_dict[name]
  274. weight_loader = param.weight_loader
  275. weight_loader(param, loaded_weight, shard_id)
  276. break
  277. else:
  278. # Skip loading extra bias for GPTQ models.
  279. if name.endswith(".bias") and name not in params_dict:
  280. continue
  281. param = params_dict[name]
  282. if "wqkv" in name:
  283. config = self.config
  284. kv_groups = (config.num_attention_heads //
  285. config.num_key_value_heads)
  286. head_dim = config.hidden_size // config.num_attention_heads
  287. loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
  288. head_dim,
  289. loaded_weight.shape[-1])
  290. wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
  291. dim=1)
  292. wq = wq.reshape(-1, wq.shape[-1])
  293. wk = wk.reshape(-1, wk.shape[-1])
  294. wv = wv.reshape(-1, wv.shape[-1])
  295. weight_loader = param.weight_loader
  296. weight_loader(param, wq, 'q')
  297. weight_loader(param, wk, 'k')
  298. weight_loader(param, wv, 'v')
  299. else:
  300. weight_loader = getattr(param, "weight_loader",
  301. default_weight_loader)
  302. weight_loader(param, loaded_weight)