1
0

internlm2.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # -*- coding: utf-8 -*-
  2. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  3. import torch
  4. from torch import nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.attention import Attention, AttentionMetadata
  7. from aphrodite.common.config import CacheConfig
  8. from aphrodite.common.sequence import IntermediateTensors
  9. from aphrodite.distributed import (get_pp_group,
  10. get_tensor_model_parallel_world_size)
  11. from aphrodite.modeling.layers.activation import SiluAndMul
  12. from aphrodite.modeling.layers.layernorm import RMSNorm
  13. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  14. QKVParallelLinear,
  15. RowParallelLinear)
  16. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  17. from aphrodite.modeling.layers.rotary_embedding import get_rope
  18. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  19. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  20. ParallelLMHead, VocabParallelEmbedding)
  21. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  22. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  23. from aphrodite.quantization.base_config import QuantizationConfig
  24. from .utils import (is_pp_missing_parameter,
  25. make_empty_intermediate_tensors_factory, make_layers)
  26. class InternLM2MLP(nn.Module):
  27. def __init__(
  28. self,
  29. hidden_size: int,
  30. intermediate_size: int,
  31. hidden_act: str,
  32. quant_config: Optional[QuantizationConfig] = None,
  33. ) -> None:
  34. super().__init__()
  35. self.gate_up_proj = MergedColumnParallelLinear(
  36. hidden_size, [intermediate_size] * 2,
  37. bias=False,
  38. quant_config=quant_config)
  39. self.w2 = RowParallelLinear(intermediate_size,
  40. hidden_size,
  41. bias=False,
  42. quant_config=quant_config)
  43. if hidden_act != "silu":
  44. raise ValueError(f"Unsupported activation: {hidden_act}. "
  45. "Only silu is supported for now.")
  46. self.act_fn = SiluAndMul()
  47. def forward(self, x):
  48. gate_up, _ = self.gate_up_proj(x)
  49. x = self.act_fn(gate_up)
  50. x, _ = self.w2(x)
  51. return x
  52. class InternLM2Attention(nn.Module):
  53. def __init__(
  54. self,
  55. hidden_size: int,
  56. num_heads: int,
  57. num_kv_heads: int,
  58. rope_theta: float = 10000,
  59. rope_scaling: Optional[Dict[str, Any]] = None,
  60. max_position_embeddings: int = 8192,
  61. cache_config: Optional[CacheConfig] = None,
  62. quant_config: Optional[QuantizationConfig] = None,
  63. ) -> None:
  64. super().__init__()
  65. self.hidden_size = hidden_size
  66. tp_size = get_tensor_model_parallel_world_size()
  67. self.total_num_heads = num_heads
  68. assert self.total_num_heads % tp_size == 0
  69. self.num_heads = self.total_num_heads // tp_size
  70. self.total_num_kv_heads = num_kv_heads
  71. if self.total_num_kv_heads >= tp_size:
  72. # Number of KV heads is greater than TP size, so we partition
  73. # the KV heads across multiple tensor parallel GPUs.
  74. assert self.total_num_kv_heads % tp_size == 0
  75. else:
  76. # Number of KV heads is less than TP size, so we replicate
  77. # the KV heads across multiple tensor parallel GPUs.
  78. assert tp_size % self.total_num_kv_heads == 0
  79. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  80. self.head_dim = hidden_size // self.total_num_heads
  81. self.q_size = self.num_heads * self.head_dim
  82. self.kv_size = self.num_kv_heads * self.head_dim
  83. self.key_value_groups = int(self.num_heads / self.num_kv_heads)
  84. self.scaling = self.head_dim**-0.5
  85. self.rope_theta = rope_theta
  86. self.max_position_embeddings = max_position_embeddings
  87. self.wqkv = QKVParallelLinear(
  88. hidden_size,
  89. self.head_dim,
  90. self.total_num_heads,
  91. self.total_num_kv_heads,
  92. bias=False,
  93. quant_config=quant_config,
  94. )
  95. self.wo = RowParallelLinear(
  96. self.total_num_heads * self.head_dim,
  97. hidden_size,
  98. bias=False,
  99. quant_config=quant_config,
  100. )
  101. self.rotary_emb = get_rope(
  102. self.head_dim,
  103. rotary_dim=self.head_dim,
  104. max_position=max_position_embeddings,
  105. base=rope_theta,
  106. rope_scaling=rope_scaling,
  107. )
  108. self.attn = Attention(self.num_heads,
  109. self.head_dim,
  110. self.scaling,
  111. num_kv_heads=self.num_kv_heads,
  112. cache_config=cache_config,
  113. quant_config=quant_config)
  114. def split_qkv(self, qkv: torch.Tensor):
  115. qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
  116. q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
  117. q = q.reshape(-1, self.q_size)
  118. k = k.reshape(-1, self.kv_size)
  119. v = v.reshape(-1, self.kv_size)
  120. return q, k, v
  121. def forward(
  122. self,
  123. positions: torch.Tensor,
  124. hidden_states: torch.Tensor,
  125. kv_cache: torch.Tensor,
  126. attn_metadata: AttentionMetadata,
  127. ) -> torch.Tensor:
  128. qkv, _ = self.wqkv(hidden_states)
  129. q, k, v = self.split_qkv(qkv)
  130. q, k = self.rotary_emb(positions, q, k)
  131. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  132. output, _ = self.wo(attn_output)
  133. return output
  134. class InternLMDecoderLayer(nn.Module):
  135. def __init__(
  136. self,
  137. config: PretrainedConfig,
  138. cache_config: Optional[CacheConfig] = None,
  139. quant_config: Optional[QuantizationConfig] = None,
  140. ) -> None:
  141. super().__init__()
  142. self.hidden_size = config.hidden_size
  143. rope_theta = getattr(config, "rope_theta", 10000)
  144. rope_scaling = getattr(config, "rope_scaling", None)
  145. max_position_embeddings = getattr(config, "max_position_embeddings",
  146. 8192)
  147. self.attention = InternLM2Attention(
  148. hidden_size=self.hidden_size,
  149. num_heads=config.num_attention_heads,
  150. num_kv_heads=config.num_key_value_heads,
  151. rope_theta=rope_theta,
  152. rope_scaling=rope_scaling,
  153. max_position_embeddings=max_position_embeddings,
  154. cache_config=cache_config,
  155. quant_config=quant_config,
  156. )
  157. self.feed_forward = InternLM2MLP(
  158. hidden_size=self.hidden_size,
  159. intermediate_size=config.intermediate_size,
  160. hidden_act=config.hidden_act,
  161. quant_config=quant_config,
  162. )
  163. self.attention_norm = RMSNorm(config.hidden_size,
  164. eps=config.rms_norm_eps)
  165. self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  166. def forward(
  167. self,
  168. positions: torch.Tensor,
  169. hidden_states: torch.Tensor,
  170. kv_cache: torch.Tensor,
  171. attn_metadata: AttentionMetadata,
  172. residual: Optional[torch.Tensor],
  173. ) -> Tuple[torch.Tensor, torch.Tensor]:
  174. # Self Attention
  175. if residual is None:
  176. residual = hidden_states
  177. hidden_states = self.attention_norm(hidden_states)
  178. else:
  179. hidden_states, residual = self.attention_norm(
  180. hidden_states, residual)
  181. hidden_states = self.attention(
  182. positions=positions,
  183. hidden_states=hidden_states,
  184. kv_cache=kv_cache,
  185. attn_metadata=attn_metadata,
  186. )
  187. # Fully Connected
  188. hidden_states, residual = self.ffn_norm(hidden_states, residual)
  189. hidden_states = self.feed_forward(hidden_states)
  190. return hidden_states, residual
  191. class InternLM2Model(nn.Module):
  192. def __init__(
  193. self,
  194. config: PretrainedConfig,
  195. cache_config: Optional[CacheConfig] = None,
  196. quant_config: Optional[QuantizationConfig] = None,
  197. prefix: str = "",
  198. ) -> None:
  199. super().__init__()
  200. self.config = config
  201. self.padding_idx = config.pad_token_id
  202. self.vocab_size = config.vocab_size
  203. self.tok_embeddings = VocabParallelEmbedding(
  204. config.vocab_size,
  205. config.hidden_size,
  206. )
  207. self.start_layer, self.end_layer, self.layers = make_layers(
  208. config.num_hidden_layers,
  209. lambda prefix: InternLMDecoderLayer(config, cache_config,
  210. quant_config),
  211. prefix=f"{prefix}.layers")
  212. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  213. self.make_empty_intermediate_tensors = (
  214. make_empty_intermediate_tensors_factory(
  215. ["hidden_states", "residual"], config.hidden_size))
  216. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  217. return self.tok_embeddings(input_ids)
  218. def forward(
  219. self,
  220. input_ids: torch.Tensor,
  221. positions: torch.Tensor,
  222. kv_caches: List[torch.Tensor],
  223. attn_metadata: AttentionMetadata,
  224. intermediate_tensors: IntermediateTensors = None,
  225. inputs_embeds: Optional[torch.Tensor] = None,
  226. ) -> Union[torch.Tensor, IntermediateTensors]:
  227. if get_pp_group().is_first_rank:
  228. if inputs_embeds is not None:
  229. hidden_states = inputs_embeds
  230. else:
  231. hidden_states = self.tok_embeddings(input_ids)
  232. residual = None
  233. else:
  234. assert intermediate_tensors is not None
  235. hidden_states = intermediate_tensors["hidden_states"]
  236. residual = intermediate_tensors["residual"]
  237. for i in range(self.start_layer, self.end_layer):
  238. layer = self.layers[i]
  239. hidden_states, residual = layer(
  240. positions,
  241. hidden_states,
  242. kv_caches[i - self.start_layer],
  243. attn_metadata,
  244. residual,
  245. )
  246. if not get_pp_group().is_last_rank:
  247. return IntermediateTensors({
  248. "hidden_states": hidden_states,
  249. "residual": residual
  250. })
  251. hidden_states, _ = self.norm(hidden_states, residual)
  252. return hidden_states
  253. class InternLM2ForCausalLM(nn.Module):
  254. def __init__(
  255. self,
  256. config: PretrainedConfig,
  257. cache_config: Optional[CacheConfig] = None,
  258. quant_config: Optional[QuantizationConfig] = None,
  259. ) -> None:
  260. super().__init__()
  261. self.config = config
  262. self.quant_config = quant_config
  263. self.model = InternLM2Model(config, cache_config, quant_config)
  264. self.output = ParallelLMHead(config.vocab_size,
  265. config.hidden_size,
  266. quant_config=quant_config)
  267. if self.config.tie_word_embeddings:
  268. self.output.weight = self.model.tok_embeddings.weight
  269. self.logits_processor = LogitsProcessor(config.vocab_size)
  270. self.sampler = Sampler()
  271. self.make_empty_intermediate_tensors = (
  272. self.model.make_empty_intermediate_tensors)
  273. def forward(
  274. self,
  275. input_ids: torch.Tensor,
  276. positions: torch.Tensor,
  277. kv_caches: List[torch.Tensor],
  278. attn_metadata: AttentionMetadata,
  279. intermediate_tensors: IntermediateTensors,
  280. ) -> torch.Tensor:
  281. hidden_states = self.model(input_ids, positions, kv_caches,
  282. attn_metadata, intermediate_tensors)
  283. return hidden_states
  284. def compute_logits(
  285. self,
  286. hidden_states: torch.Tensor,
  287. sampling_metadata: SamplingMetadata,
  288. ) -> Optional[torch.Tensor]:
  289. logits = self.logits_processor(self.output, hidden_states,
  290. sampling_metadata)
  291. return logits
  292. def sample(
  293. self,
  294. logits: torch.Tensor,
  295. sampling_metadata: SamplingMetadata,
  296. ) -> Optional[SamplerOutput]:
  297. next_tokens = self.sampler(logits, sampling_metadata)
  298. return next_tokens
  299. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  300. stacked_params_mapping = [
  301. # (param_name, shard_name, shard_id)
  302. ("gate_up_proj", "w1", 0),
  303. ("gate_up_proj", "w3", 1),
  304. ]
  305. params_dict = dict(self.named_parameters())
  306. for name, loaded_weight in weights:
  307. if "rotary_emb.inv_freq" in name:
  308. continue
  309. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  310. if weight_name not in name:
  311. continue
  312. name = name.replace(weight_name, param_name)
  313. # Skip loading extra bias for GPTQ models.
  314. if name.endswith(".bias") and name not in params_dict:
  315. continue
  316. if is_pp_missing_parameter(name, self):
  317. continue
  318. param = params_dict[name]
  319. weight_loader = param.weight_loader
  320. weight_loader(param, loaded_weight, shard_id)
  321. break
  322. else:
  323. # Skip loading extra bias for GPTQ models.
  324. if name.endswith(".bias") and name not in params_dict:
  325. continue
  326. if is_pp_missing_parameter(name, self):
  327. continue
  328. param = params_dict[name]
  329. weight_loader = getattr(param, "weight_loader",
  330. default_weight_loader)
  331. weight_loader(param, loaded_weight)