1
0

internlm2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # -*- coding: utf-8 -*-
  2. from typing import Any, Dict, Iterable, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.attention import Attention, AttentionMetadata
  7. from aphrodite.common.config import CacheConfig
  8. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  9. from aphrodite.distributed import get_tensor_model_parallel_world_size
  10. from aphrodite.modeling.layers.activation import SiluAndMul
  11. from aphrodite.modeling.layers.layernorm import RMSNorm
  12. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  13. QKVParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  16. from aphrodite.modeling.layers.rotary_embedding import get_rope
  17. from aphrodite.modeling.layers.sampler import Sampler
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. ParallelLMHead, VocabParallelEmbedding)
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. from aphrodite.quantization.base_config import QuantizationConfig
  23. class InternLM2MLP(nn.Module):
  24. def __init__(
  25. self,
  26. hidden_size: int,
  27. intermediate_size: int,
  28. hidden_act: str,
  29. quant_config: Optional[QuantizationConfig] = None,
  30. ) -> None:
  31. super().__init__()
  32. self.gate_up_proj = MergedColumnParallelLinear(
  33. hidden_size, [intermediate_size] * 2,
  34. bias=False,
  35. quant_config=quant_config)
  36. self.w2 = RowParallelLinear(intermediate_size,
  37. hidden_size,
  38. bias=False,
  39. quant_config=quant_config)
  40. if hidden_act != "silu":
  41. raise ValueError(f"Unsupported activation: {hidden_act}. "
  42. "Only silu is supported for now.")
  43. self.act_fn = SiluAndMul()
  44. def forward(self, x):
  45. gate_up, _ = self.gate_up_proj(x)
  46. x = self.act_fn(gate_up)
  47. x, _ = self.w2(x)
  48. return x
  49. class InternLM2Attention(nn.Module):
  50. def __init__(
  51. self,
  52. hidden_size: int,
  53. num_heads: int,
  54. num_kv_heads: int,
  55. rope_theta: float = 10000,
  56. rope_scaling: Optional[Dict[str, Any]] = None,
  57. max_position_embeddings: int = 8192,
  58. cache_config: Optional[CacheConfig] = None,
  59. quant_config: Optional[QuantizationConfig] = None,
  60. ) -> None:
  61. super().__init__()
  62. self.hidden_size = hidden_size
  63. tp_size = get_tensor_model_parallel_world_size()
  64. self.total_num_heads = num_heads
  65. assert self.total_num_heads % tp_size == 0
  66. self.num_heads = self.total_num_heads // tp_size
  67. self.total_num_kv_heads = num_kv_heads
  68. if self.total_num_kv_heads >= tp_size:
  69. # Number of KV heads is greater than TP size, so we partition
  70. # the KV heads across multiple tensor parallel GPUs.
  71. assert self.total_num_kv_heads % tp_size == 0
  72. else:
  73. # Number of KV heads is less than TP size, so we replicate
  74. # the KV heads across multiple tensor parallel GPUs.
  75. assert tp_size % self.total_num_kv_heads == 0
  76. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  77. self.head_dim = hidden_size // self.total_num_heads
  78. self.q_size = self.num_heads * self.head_dim
  79. self.kv_size = self.num_kv_heads * self.head_dim
  80. self.scaling = self.head_dim**-0.5
  81. self.rope_theta = rope_theta
  82. self.max_position_embeddings = max_position_embeddings
  83. self.wqkv = QKVParallelLinear(
  84. hidden_size,
  85. self.head_dim,
  86. self.total_num_heads,
  87. self.total_num_kv_heads,
  88. bias=False,
  89. quant_config=quant_config,
  90. )
  91. self.wo = RowParallelLinear(
  92. self.total_num_heads * self.head_dim,
  93. hidden_size,
  94. bias=False,
  95. quant_config=quant_config,
  96. )
  97. self.rotary_emb = get_rope(
  98. self.head_dim,
  99. rotary_dim=self.head_dim,
  100. max_position=max_position_embeddings,
  101. base=rope_theta,
  102. rope_scaling=rope_scaling,
  103. )
  104. self.attn = Attention(self.num_heads,
  105. self.head_dim,
  106. self.scaling,
  107. num_kv_heads=self.num_kv_heads,
  108. cache_config=cache_config,
  109. quant_config=quant_config)
  110. def forward(
  111. self,
  112. positions: torch.Tensor,
  113. hidden_states: torch.Tensor,
  114. kv_cache: torch.Tensor,
  115. attn_metadata: AttentionMetadata,
  116. ) -> torch.Tensor:
  117. qkv, _ = self.wqkv(hidden_states)
  118. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  119. q, k = self.rotary_emb(positions, q, k)
  120. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  121. output, _ = self.wo(attn_output)
  122. return output
  123. class InternLMDecoderLayer(nn.Module):
  124. def __init__(
  125. self,
  126. config: PretrainedConfig,
  127. cache_config: Optional[CacheConfig] = None,
  128. quant_config: Optional[QuantizationConfig] = None,
  129. ) -> None:
  130. super().__init__()
  131. self.hidden_size = config.hidden_size
  132. rope_theta = getattr(config, "rope_theta", 10000)
  133. rope_scaling = getattr(config, "rope_scaling", None)
  134. max_position_embeddings = getattr(config, "max_position_embeddings",
  135. 8192)
  136. self.attention = InternLM2Attention(
  137. hidden_size=self.hidden_size,
  138. num_heads=config.num_attention_heads,
  139. num_kv_heads=config.num_key_value_heads,
  140. rope_theta=rope_theta,
  141. rope_scaling=rope_scaling,
  142. max_position_embeddings=max_position_embeddings,
  143. cache_config=cache_config,
  144. quant_config=quant_config,
  145. )
  146. self.feed_forward = InternLM2MLP(
  147. hidden_size=self.hidden_size,
  148. intermediate_size=config.intermediate_size,
  149. hidden_act=config.hidden_act,
  150. quant_config=quant_config,
  151. )
  152. self.attention_norm = RMSNorm(config.hidden_size,
  153. eps=config.rms_norm_eps)
  154. self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  155. def forward(
  156. self,
  157. positions: torch.Tensor,
  158. hidden_states: torch.Tensor,
  159. kv_cache: torch.Tensor,
  160. attn_metadata: AttentionMetadata,
  161. residual: Optional[torch.Tensor],
  162. ) -> Tuple[torch.Tensor, torch.Tensor]:
  163. # Self Attention
  164. if residual is None:
  165. residual = hidden_states
  166. hidden_states = self.attention_norm(hidden_states)
  167. else:
  168. hidden_states, residual = self.attention_norm(
  169. hidden_states, residual)
  170. hidden_states = self.attention(
  171. positions=positions,
  172. hidden_states=hidden_states,
  173. kv_cache=kv_cache,
  174. attn_metadata=attn_metadata,
  175. )
  176. # Fully Connected
  177. hidden_states, residual = self.ffn_norm(hidden_states, residual)
  178. hidden_states = self.feed_forward(hidden_states)
  179. return hidden_states, residual
  180. class InternLM2Model(nn.Module):
  181. def __init__(
  182. self,
  183. config: PretrainedConfig,
  184. cache_config: Optional[CacheConfig] = None,
  185. quant_config: Optional[QuantizationConfig] = None,
  186. ) -> None:
  187. super().__init__()
  188. self.config = config
  189. self.padding_idx = config.pad_token_id
  190. self.vocab_size = config.vocab_size
  191. self.tok_embeddings = VocabParallelEmbedding(
  192. config.vocab_size,
  193. config.hidden_size,
  194. )
  195. self.layers = nn.ModuleList([
  196. InternLMDecoderLayer(config, cache_config, quant_config)
  197. for _ in range(config.num_hidden_layers)
  198. ])
  199. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  200. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  201. return self.tok_embeddings(input_ids)
  202. def forward(
  203. self,
  204. input_ids: torch.Tensor,
  205. positions: torch.Tensor,
  206. kv_caches: List[torch.Tensor],
  207. attn_metadata: AttentionMetadata,
  208. intermediate_tensors: Optional[IntermediateTensors] = None,
  209. inputs_embeds: Optional[torch.Tensor] = None,
  210. ) -> torch.Tensor:
  211. if inputs_embeds is not None:
  212. hidden_states = inputs_embeds
  213. else:
  214. hidden_states = self.tok_embeddings(input_ids)
  215. residual = None
  216. for i in range(len(self.layers)):
  217. layer = self.layers[i]
  218. hidden_states, residual = layer(
  219. positions,
  220. hidden_states,
  221. kv_caches[i],
  222. attn_metadata,
  223. residual,
  224. )
  225. hidden_states, _ = self.norm(hidden_states, residual)
  226. return hidden_states
  227. class InternLM2ForCausalLM(nn.Module):
  228. def __init__(
  229. self,
  230. config: PretrainedConfig,
  231. cache_config: Optional[CacheConfig] = None,
  232. quant_config: Optional[QuantizationConfig] = None,
  233. ) -> None:
  234. super().__init__()
  235. self.config = config
  236. self.quant_config = quant_config
  237. self.model = InternLM2Model(config, cache_config, quant_config)
  238. self.output = ParallelLMHead(config.vocab_size,
  239. config.hidden_size,
  240. quant_config=quant_config)
  241. self.logits_processor = LogitsProcessor(config.vocab_size)
  242. self.sampler = Sampler()
  243. def forward(
  244. self,
  245. input_ids: torch.Tensor,
  246. positions: torch.Tensor,
  247. kv_caches: List[torch.Tensor],
  248. attn_metadata: AttentionMetadata,
  249. intermediate_tensors: IntermediateTensors,
  250. ) -> torch.Tensor:
  251. hidden_states = self.model(input_ids, positions, kv_caches,
  252. attn_metadata)
  253. return hidden_states
  254. def compute_logits(self, hidden_states: torch.Tensor,
  255. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  256. logits = self.logits_processor(self.output, hidden_states,
  257. sampling_metadata)
  258. return logits
  259. def sample(
  260. self,
  261. logits: torch.Tensor,
  262. sampling_metadata: SamplingMetadata,
  263. ) -> Optional[SamplerOutput]:
  264. next_tokens = self.sampler(logits, sampling_metadata)
  265. return next_tokens
  266. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  267. stacked_params_mapping = [
  268. # (param_name, shard_name, shard_id)
  269. ("gate_up_proj", "w1", 0),
  270. ("gate_up_proj", "w3", 1),
  271. ]
  272. params_dict = dict(self.named_parameters())
  273. for name, loaded_weight in weights:
  274. if "rotary_emb.inv_freq" in name:
  275. continue
  276. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  277. if weight_name not in name:
  278. continue
  279. name = name.replace(weight_name, param_name)
  280. # Skip loading extra bias for GPTQ models.
  281. if name.endswith(".bias") and name not in params_dict:
  282. continue
  283. param = params_dict[name]
  284. weight_loader = param.weight_loader
  285. weight_loader(param, loaded_weight, shard_id)
  286. break
  287. else:
  288. # Skip loading extra bias for GPTQ models.
  289. if name.endswith(".bias") and name not in params_dict:
  290. continue
  291. param = params_dict[name]
  292. if "wqkv" in name:
  293. config = self.config
  294. kv_groups = (config.num_attention_heads //
  295. config.num_key_value_heads)
  296. head_dim = config.hidden_size // config.num_attention_heads
  297. loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
  298. head_dim,
  299. loaded_weight.shape[-1])
  300. wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
  301. dim=1)
  302. wq = wq.reshape(-1, wq.shape[-1])
  303. wk = wk.reshape(-1, wk.shape[-1])
  304. wv = wv.reshape(-1, wv.shape[-1])
  305. weight_loader = param.weight_loader
  306. weight_loader(param, wq, 'q')
  307. weight_loader(param, wk, 'k')
  308. weight_loader(param, wv, 'v')
  309. else:
  310. weight_loader = getattr(param, "weight_loader",
  311. default_weight_loader)
  312. weight_loader(param, loaded_weight)