internlm2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. # -*- coding: utf-8 -*-
  2. from typing import Any, Dict, Iterable, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.attention import Attention, AttentionMetadata
  7. from aphrodite.common.sequence import SamplerOutput
  8. from aphrodite.distributed import get_tensor_model_parallel_world_size
  9. from aphrodite.modeling.layers.activation import SiluAndMul
  10. from aphrodite.modeling.layers.layernorm import RMSNorm
  11. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  12. MergedColumnParallelLinear,
  13. QKVParallelLinear,
  14. RowParallelLinear)
  15. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  16. from aphrodite.modeling.layers.rotary_embedding import get_rope
  17. from aphrodite.modeling.layers.sampler import Sampler
  18. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  19. ParallelLMHead, VocabParallelEmbedding)
  20. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  21. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  22. class InternLM2MLP(nn.Module):
  23. def __init__(
  24. self,
  25. hidden_size: int,
  26. intermediate_size: int,
  27. hidden_act: str,
  28. linear_method: Optional[LinearMethodBase] = None,
  29. ) -> None:
  30. super().__init__()
  31. self.gate_up_proj = MergedColumnParallelLinear(
  32. hidden_size, [intermediate_size] * 2,
  33. bias=False,
  34. linear_method=linear_method)
  35. self.w2 = RowParallelLinear(intermediate_size,
  36. hidden_size,
  37. bias=False,
  38. linear_method=linear_method)
  39. if hidden_act != "silu":
  40. raise ValueError(f"Unsupported activation: {hidden_act}. "
  41. "Only silu is supported for now.")
  42. self.act_fn = SiluAndMul()
  43. def forward(self, x):
  44. gate_up, _ = self.gate_up_proj(x)
  45. x = self.act_fn(gate_up)
  46. x, _ = self.w2(x)
  47. return x
  48. class InternLM2Attention(nn.Module):
  49. def __init__(
  50. self,
  51. hidden_size: int,
  52. num_heads: int,
  53. num_kv_heads: int,
  54. rope_theta: float = 10000,
  55. rope_scaling: Optional[Dict[str, Any]] = None,
  56. max_position_embeddings: int = 8192,
  57. linear_method: Optional[LinearMethodBase] = None,
  58. ) -> None:
  59. super().__init__()
  60. self.hidden_size = hidden_size
  61. tp_size = get_tensor_model_parallel_world_size()
  62. self.total_num_heads = num_heads
  63. assert self.total_num_heads % tp_size == 0
  64. self.num_heads = self.total_num_heads // tp_size
  65. self.total_num_kv_heads = num_kv_heads
  66. if self.total_num_kv_heads >= tp_size:
  67. # Number of KV heads is greater than TP size, so we partition
  68. # the KV heads across multiple tensor parallel GPUs.
  69. assert self.total_num_kv_heads % tp_size == 0
  70. else:
  71. # Number of KV heads is less than TP size, so we replicate
  72. # the KV heads across multiple tensor parallel GPUs.
  73. assert tp_size % self.total_num_kv_heads == 0
  74. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  75. self.head_dim = hidden_size // self.total_num_heads
  76. self.q_size = self.num_heads * self.head_dim
  77. self.kv_size = self.num_kv_heads * self.head_dim
  78. self.scaling = self.head_dim**-0.5
  79. self.rope_theta = rope_theta
  80. self.max_position_embeddings = max_position_embeddings
  81. self.wqkv = QKVParallelLinear(
  82. hidden_size,
  83. self.head_dim,
  84. self.total_num_heads,
  85. self.total_num_kv_heads,
  86. bias=False,
  87. linear_method=linear_method,
  88. )
  89. self.wo = RowParallelLinear(
  90. self.total_num_heads * self.head_dim,
  91. hidden_size,
  92. bias=False,
  93. linear_method=linear_method,
  94. )
  95. self.rotary_emb = get_rope(
  96. self.head_dim,
  97. rotary_dim=self.head_dim,
  98. max_position=max_position_embeddings,
  99. base=rope_theta,
  100. rope_scaling=rope_scaling,
  101. )
  102. self.attn = Attention(self.num_heads,
  103. self.head_dim,
  104. self.scaling,
  105. num_kv_heads=self.num_kv_heads)
  106. def forward(
  107. self,
  108. positions: torch.Tensor,
  109. hidden_states: torch.Tensor,
  110. kv_cache: torch.Tensor,
  111. attn_metadata: AttentionMetadata,
  112. ) -> torch.Tensor:
  113. qkv, _ = self.wqkv(hidden_states)
  114. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  115. q, k = self.rotary_emb(positions, q, k)
  116. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  117. output, _ = self.wo(attn_output)
  118. return output
  119. class InternLMDecoderLayer(nn.Module):
  120. def __init__(
  121. self,
  122. config: PretrainedConfig,
  123. linear_method: Optional[LinearMethodBase] = None,
  124. ) -> None:
  125. super().__init__()
  126. self.hidden_size = config.hidden_size
  127. rope_theta = getattr(config, "rope_theta", 10000)
  128. rope_scaling = getattr(config, "rope_scaling", None)
  129. max_position_embeddings = getattr(config, "max_position_embeddings",
  130. 8192)
  131. self.attention = InternLM2Attention(
  132. hidden_size=self.hidden_size,
  133. num_heads=config.num_attention_heads,
  134. num_kv_heads=config.num_key_value_heads,
  135. rope_theta=rope_theta,
  136. rope_scaling=rope_scaling,
  137. max_position_embeddings=max_position_embeddings,
  138. linear_method=linear_method,
  139. )
  140. self.feed_forward = InternLM2MLP(
  141. hidden_size=self.hidden_size,
  142. intermediate_size=config.intermediate_size,
  143. hidden_act=config.hidden_act,
  144. linear_method=linear_method,
  145. )
  146. self.attention_norm = RMSNorm(config.hidden_size,
  147. eps=config.rms_norm_eps)
  148. self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  149. def forward(
  150. self,
  151. positions: torch.Tensor,
  152. hidden_states: torch.Tensor,
  153. kv_cache: torch.Tensor,
  154. attn_metadata: AttentionMetadata,
  155. residual: Optional[torch.Tensor],
  156. ) -> Tuple[torch.Tensor, torch.Tensor]:
  157. # Self Attention
  158. if residual is None:
  159. residual = hidden_states
  160. hidden_states = self.attention_norm(hidden_states)
  161. else:
  162. hidden_states, residual = self.attention_norm(
  163. hidden_states, residual)
  164. hidden_states = self.attention(
  165. positions=positions,
  166. hidden_states=hidden_states,
  167. kv_cache=kv_cache,
  168. attn_metadata=attn_metadata,
  169. )
  170. # Fully Connected
  171. hidden_states, residual = self.ffn_norm(hidden_states, residual)
  172. hidden_states = self.feed_forward(hidden_states)
  173. return hidden_states, residual
  174. class InternLM2Model(nn.Module):
  175. def __init__(
  176. self,
  177. config: PretrainedConfig,
  178. linear_method: Optional[LinearMethodBase] = None,
  179. ) -> None:
  180. super().__init__()
  181. self.config = config
  182. self.padding_idx = config.pad_token_id
  183. self.vocab_size = config.vocab_size
  184. self.tok_embeddings = VocabParallelEmbedding(
  185. config.vocab_size,
  186. config.hidden_size,
  187. )
  188. self.layers = nn.ModuleList([
  189. InternLMDecoderLayer(config, linear_method)
  190. for _ in range(config.num_hidden_layers)
  191. ])
  192. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  193. def forward(
  194. self,
  195. input_ids: torch.Tensor,
  196. positions: torch.Tensor,
  197. kv_caches: List[torch.Tensor],
  198. attn_metadata: AttentionMetadata,
  199. ) -> torch.Tensor:
  200. hidden_states = self.tok_embeddings(input_ids)
  201. residual = None
  202. for i in range(len(self.layers)):
  203. layer = self.layers[i]
  204. hidden_states, residual = layer(
  205. positions,
  206. hidden_states,
  207. kv_caches[i],
  208. attn_metadata,
  209. residual,
  210. )
  211. hidden_states, _ = self.norm(hidden_states, residual)
  212. return hidden_states
  213. class InternLM2ForCausalLM(nn.Module):
  214. def __init__(
  215. self,
  216. config: PretrainedConfig,
  217. linear_method: Optional[LinearMethodBase] = None,
  218. ) -> None:
  219. super().__init__()
  220. self.config = config
  221. self.linear_method = linear_method
  222. self.model = InternLM2Model(config, linear_method)
  223. self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
  224. self.logits_processor = LogitsProcessor(config.vocab_size)
  225. self.sampler = Sampler()
  226. def forward(
  227. self,
  228. input_ids: torch.Tensor,
  229. positions: torch.Tensor,
  230. kv_caches: List[torch.Tensor],
  231. attn_metadata: AttentionMetadata,
  232. ) -> torch.Tensor:
  233. hidden_states = self.model(input_ids, positions, kv_caches,
  234. attn_metadata)
  235. return hidden_states
  236. def compute_logits(self, hidden_states: torch.Tensor,
  237. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  238. logits = self.logits_processor(self.output.weight, hidden_states,
  239. sampling_metadata)
  240. return logits
  241. def sample(
  242. self,
  243. logits: torch.Tensor,
  244. sampling_metadata: SamplingMetadata,
  245. ) -> Optional[SamplerOutput]:
  246. next_tokens = self.sampler(logits, sampling_metadata)
  247. return next_tokens
  248. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  249. stacked_params_mapping = [
  250. # (param_name, shard_name, shard_id)
  251. ("gate_up_proj", "w1", 0),
  252. ("gate_up_proj", "w3", 1),
  253. ]
  254. params_dict = dict(self.named_parameters())
  255. for name, loaded_weight in weights:
  256. if "rotary_emb.inv_freq" in name:
  257. continue
  258. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  259. if weight_name not in name:
  260. continue
  261. name = name.replace(weight_name, param_name)
  262. # Skip loading extra bias for GPTQ models.
  263. if name.endswith(".bias") and name not in params_dict:
  264. continue
  265. param = params_dict[name]
  266. weight_loader = param.weight_loader
  267. weight_loader(param, loaded_weight, shard_id)
  268. break
  269. else:
  270. # Skip loading extra bias for GPTQ models.
  271. if name.endswith(".bias") and name not in params_dict:
  272. continue
  273. param = params_dict[name]
  274. if "wqkv" in name:
  275. config = self.config
  276. kv_groups = (config.num_attention_heads //
  277. config.num_key_value_heads)
  278. head_dim = config.hidden_size // config.num_attention_heads
  279. loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
  280. head_dim,
  281. loaded_weight.shape[-1])
  282. wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
  283. dim=1)
  284. wq = wq.reshape(-1, wq.shape[-1])
  285. wk = wk.reshape(-1, wk.shape[-1])
  286. wv = wv.reshape(-1, wv.shape[-1])
  287. weight_loader = param.weight_loader
  288. weight_loader(param, wq, 'q')
  289. weight_loader(param, wk, 'k')
  290. weight_loader(param, wv, 'v')
  291. else:
  292. weight_loader = getattr(param, "weight_loader",
  293. default_weight_loader)
  294. weight_loader(param, loaded_weight)