internlm2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # -*- coding: utf-8 -*-
  2. from typing import Any, Dict, Iterable, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.attention import Attention, AttentionMetadata
  7. from aphrodite.common.config import CacheConfig
  8. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  9. from aphrodite.common.utils import progress_bar
  10. from aphrodite.distributed import get_tensor_model_parallel_world_size
  11. from aphrodite.modeling.layers.activation import SiluAndMul
  12. from aphrodite.modeling.layers.layernorm import RMSNorm
  13. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  14. QKVParallelLinear,
  15. RowParallelLinear)
  16. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  17. from aphrodite.modeling.layers.rotary_embedding import get_rope
  18. from aphrodite.modeling.layers.sampler import Sampler
  19. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  20. ParallelLMHead, VocabParallelEmbedding)
  21. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  22. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  23. from aphrodite.quantization.base_config import QuantizationConfig
  24. class InternLM2MLP(nn.Module):
  25. def __init__(
  26. self,
  27. hidden_size: int,
  28. intermediate_size: int,
  29. hidden_act: str,
  30. quant_config: Optional[QuantizationConfig] = None,
  31. ) -> None:
  32. super().__init__()
  33. self.gate_up_proj = MergedColumnParallelLinear(
  34. hidden_size, [intermediate_size] * 2,
  35. bias=False,
  36. quant_config=quant_config)
  37. self.w2 = RowParallelLinear(intermediate_size,
  38. hidden_size,
  39. bias=False,
  40. quant_config=quant_config)
  41. if hidden_act != "silu":
  42. raise ValueError(f"Unsupported activation: {hidden_act}. "
  43. "Only silu is supported for now.")
  44. self.act_fn = SiluAndMul()
  45. def forward(self, x):
  46. gate_up, _ = self.gate_up_proj(x)
  47. x = self.act_fn(gate_up)
  48. x, _ = self.w2(x)
  49. return x
  50. class InternLM2Attention(nn.Module):
  51. def __init__(
  52. self,
  53. hidden_size: int,
  54. num_heads: int,
  55. num_kv_heads: int,
  56. rope_theta: float = 10000,
  57. rope_scaling: Optional[Dict[str, Any]] = None,
  58. max_position_embeddings: int = 8192,
  59. cache_config: Optional[CacheConfig] = None,
  60. quant_config: Optional[QuantizationConfig] = None,
  61. ) -> None:
  62. super().__init__()
  63. self.hidden_size = hidden_size
  64. tp_size = get_tensor_model_parallel_world_size()
  65. self.total_num_heads = num_heads
  66. assert self.total_num_heads % tp_size == 0
  67. self.num_heads = self.total_num_heads // tp_size
  68. self.total_num_kv_heads = num_kv_heads
  69. if self.total_num_kv_heads >= tp_size:
  70. # Number of KV heads is greater than TP size, so we partition
  71. # the KV heads across multiple tensor parallel GPUs.
  72. assert self.total_num_kv_heads % tp_size == 0
  73. else:
  74. # Number of KV heads is less than TP size, so we replicate
  75. # the KV heads across multiple tensor parallel GPUs.
  76. assert tp_size % self.total_num_kv_heads == 0
  77. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  78. self.head_dim = hidden_size // self.total_num_heads
  79. self.q_size = self.num_heads * self.head_dim
  80. self.kv_size = self.num_kv_heads * self.head_dim
  81. self.scaling = self.head_dim**-0.5
  82. self.rope_theta = rope_theta
  83. self.max_position_embeddings = max_position_embeddings
  84. self.wqkv = QKVParallelLinear(
  85. hidden_size,
  86. self.head_dim,
  87. self.total_num_heads,
  88. self.total_num_kv_heads,
  89. bias=False,
  90. quant_config=quant_config,
  91. )
  92. self.wo = RowParallelLinear(
  93. self.total_num_heads * self.head_dim,
  94. hidden_size,
  95. bias=False,
  96. quant_config=quant_config,
  97. )
  98. self.rotary_emb = get_rope(
  99. self.head_dim,
  100. rotary_dim=self.head_dim,
  101. max_position=max_position_embeddings,
  102. base=rope_theta,
  103. rope_scaling=rope_scaling,
  104. )
  105. self.attn = Attention(self.num_heads,
  106. self.head_dim,
  107. self.scaling,
  108. num_kv_heads=self.num_kv_heads,
  109. cache_config=cache_config,
  110. quant_config=quant_config)
  111. def forward(
  112. self,
  113. positions: torch.Tensor,
  114. hidden_states: torch.Tensor,
  115. kv_cache: torch.Tensor,
  116. attn_metadata: AttentionMetadata,
  117. ) -> torch.Tensor:
  118. qkv, _ = self.wqkv(hidden_states)
  119. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  120. q, k = self.rotary_emb(positions, q, k)
  121. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  122. output, _ = self.wo(attn_output)
  123. return output
  124. class InternLMDecoderLayer(nn.Module):
  125. def __init__(
  126. self,
  127. config: PretrainedConfig,
  128. cache_config: Optional[CacheConfig] = None,
  129. quant_config: Optional[QuantizationConfig] = None,
  130. ) -> None:
  131. super().__init__()
  132. self.hidden_size = config.hidden_size
  133. rope_theta = getattr(config, "rope_theta", 10000)
  134. rope_scaling = getattr(config, "rope_scaling", None)
  135. max_position_embeddings = getattr(config, "max_position_embeddings",
  136. 8192)
  137. self.attention = InternLM2Attention(
  138. hidden_size=self.hidden_size,
  139. num_heads=config.num_attention_heads,
  140. num_kv_heads=config.num_key_value_heads,
  141. rope_theta=rope_theta,
  142. rope_scaling=rope_scaling,
  143. max_position_embeddings=max_position_embeddings,
  144. cache_config=cache_config,
  145. quant_config=quant_config,
  146. )
  147. self.feed_forward = InternLM2MLP(
  148. hidden_size=self.hidden_size,
  149. intermediate_size=config.intermediate_size,
  150. hidden_act=config.hidden_act,
  151. quant_config=quant_config,
  152. )
  153. self.attention_norm = RMSNorm(config.hidden_size,
  154. eps=config.rms_norm_eps)
  155. self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  156. def forward(
  157. self,
  158. positions: torch.Tensor,
  159. hidden_states: torch.Tensor,
  160. kv_cache: torch.Tensor,
  161. attn_metadata: AttentionMetadata,
  162. residual: Optional[torch.Tensor],
  163. ) -> Tuple[torch.Tensor, torch.Tensor]:
  164. # Self Attention
  165. if residual is None:
  166. residual = hidden_states
  167. hidden_states = self.attention_norm(hidden_states)
  168. else:
  169. hidden_states, residual = self.attention_norm(
  170. hidden_states, residual)
  171. hidden_states = self.attention(
  172. positions=positions,
  173. hidden_states=hidden_states,
  174. kv_cache=kv_cache,
  175. attn_metadata=attn_metadata,
  176. )
  177. # Fully Connected
  178. hidden_states, residual = self.ffn_norm(hidden_states, residual)
  179. hidden_states = self.feed_forward(hidden_states)
  180. return hidden_states, residual
  181. class InternLM2Model(nn.Module):
  182. def __init__(
  183. self,
  184. config: PretrainedConfig,
  185. cache_config: Optional[CacheConfig] = None,
  186. quant_config: Optional[QuantizationConfig] = None,
  187. ) -> None:
  188. super().__init__()
  189. self.config = config
  190. self.padding_idx = config.pad_token_id
  191. self.vocab_size = config.vocab_size
  192. self.tok_embeddings = VocabParallelEmbedding(
  193. config.vocab_size,
  194. config.hidden_size,
  195. )
  196. self.layers = nn.ModuleList([
  197. InternLMDecoderLayer(config, cache_config, quant_config)
  198. for _ in range(config.num_hidden_layers)
  199. ])
  200. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  201. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  202. return self.tok_embeddings(input_ids)
  203. def forward(
  204. self,
  205. input_ids: torch.Tensor,
  206. positions: torch.Tensor,
  207. kv_caches: List[torch.Tensor],
  208. attn_metadata: AttentionMetadata,
  209. intermediate_tensors: Optional[IntermediateTensors] = None,
  210. inputs_embeds: Optional[torch.Tensor] = None,
  211. ) -> torch.Tensor:
  212. if inputs_embeds is not None:
  213. hidden_states = inputs_embeds
  214. else:
  215. hidden_states = self.tok_embeddings(input_ids)
  216. residual = None
  217. for i in range(len(self.layers)):
  218. layer = self.layers[i]
  219. hidden_states, residual = layer(
  220. positions,
  221. hidden_states,
  222. kv_caches[i],
  223. attn_metadata,
  224. residual,
  225. )
  226. hidden_states, _ = self.norm(hidden_states, residual)
  227. return hidden_states
  228. class InternLM2ForCausalLM(nn.Module):
  229. def __init__(
  230. self,
  231. config: PretrainedConfig,
  232. cache_config: Optional[CacheConfig] = None,
  233. quant_config: Optional[QuantizationConfig] = None,
  234. ) -> None:
  235. super().__init__()
  236. self.config = config
  237. self.quant_config = quant_config
  238. self.model = InternLM2Model(config, cache_config, quant_config)
  239. self.output = ParallelLMHead(config.vocab_size,
  240. config.hidden_size,
  241. quant_config=quant_config)
  242. self.logits_processor = LogitsProcessor(config.vocab_size)
  243. self.sampler = Sampler()
  244. def forward(
  245. self,
  246. input_ids: torch.Tensor,
  247. positions: torch.Tensor,
  248. kv_caches: List[torch.Tensor],
  249. attn_metadata: AttentionMetadata,
  250. intermediate_tensors: IntermediateTensors,
  251. ) -> torch.Tensor:
  252. hidden_states = self.model(input_ids, positions, kv_caches,
  253. attn_metadata)
  254. return hidden_states
  255. def compute_logits(
  256. self,
  257. hidden_states: torch.Tensor,
  258. sampling_metadata: SamplingMetadata,
  259. ) -> Optional[torch.Tensor]:
  260. logits = self.logits_processor(self.output, hidden_states,
  261. sampling_metadata)
  262. return logits
  263. def sample(
  264. self,
  265. logits: torch.Tensor,
  266. sampling_metadata: SamplingMetadata,
  267. ) -> Optional[SamplerOutput]:
  268. next_tokens = self.sampler(logits, sampling_metadata)
  269. return next_tokens
  270. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  271. stacked_params_mapping = [
  272. # (param_name, shard_name, shard_id)
  273. ("gate_up_proj", "w1", 0),
  274. ("gate_up_proj", "w3", 1),
  275. ]
  276. params_dict = dict(self.named_parameters())
  277. weights_list = list(weights)
  278. for name, loaded_weight in progress_bar(weights_list,
  279. desc="Loading modules..."):
  280. if "rotary_emb.inv_freq" in name:
  281. continue
  282. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  283. if weight_name not in name:
  284. continue
  285. name = name.replace(weight_name, param_name)
  286. # Skip loading extra bias for GPTQ models.
  287. if name.endswith(".bias") and name not in params_dict:
  288. continue
  289. param = params_dict[name]
  290. weight_loader = param.weight_loader
  291. weight_loader(param, loaded_weight, shard_id)
  292. break
  293. else:
  294. # Skip loading extra bias for GPTQ models.
  295. if name.endswith(".bias") and name not in params_dict:
  296. continue
  297. param = params_dict[name]
  298. if "wqkv" in name:
  299. config = self.config
  300. kv_groups = (config.num_attention_heads //
  301. config.num_key_value_heads)
  302. head_dim = config.hidden_size // config.num_attention_heads
  303. loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
  304. head_dim,
  305. loaded_weight.shape[-1])
  306. wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
  307. dim=1)
  308. wq = wq.reshape(-1, wq.shape[-1])
  309. wk = wk.reshape(-1, wk.shape[-1])
  310. wv = wv.reshape(-1, wv.shape[-1])
  311. weight_loader = param.weight_loader
  312. weight_loader(param, wq, 'q')
  313. weight_loader(param, wk, 'k')
  314. weight_loader(param, wv, 'v')
  315. else:
  316. weight_loader = getattr(param, "weight_loader",
  317. default_weight_loader)
  318. weight_loader(param, loaded_weight)