1
0

internlm2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # -*- coding: utf-8 -*-
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.attention import Attention, AttentionMetadata
  7. from aphrodite.modeling.layers.activation import SiluAndMul
  8. from aphrodite.modeling.layers.layernorm import RMSNorm
  9. from aphrodite.modeling.layers.linear import (
  10. LinearMethodBase,
  11. ColumnParallelLinear,
  12. MergedColumnParallelLinear,
  13. QKVParallelLinear,
  14. RowParallelLinear,
  15. )
  16. from aphrodite.modeling.layers.rotary_embedding import get_rope
  17. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  18. from aphrodite.modeling.layers.sampler import Sampler
  19. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  20. VocabParallelEmbedding,
  21. ParallelLMHead,
  22. )
  23. from aphrodite.distributed import (
  24. get_tensor_model_parallel_world_size, )
  25. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  26. from aphrodite.modeling.hf_downloader import (
  27. default_weight_loader,
  28. hf_model_weights_iterator,
  29. )
  30. from aphrodite.common.sequence import SamplerOutput
  31. class InternLM2MLP(nn.Module):
  32. def __init__(
  33. self,
  34. hidden_size: int,
  35. intermediate_size: int,
  36. hidden_act: str,
  37. linear_method: Optional[LinearMethodBase] = None,
  38. ) -> None:
  39. super().__init__()
  40. if (linear_method is not None
  41. and not linear_method.quant_config.merge_weight()):
  42. self.merge_weight = False
  43. self.w1 = ColumnParallelLinear(
  44. hidden_size,
  45. intermediate_size,
  46. bias=False,
  47. linear_method=linear_method,
  48. )
  49. self.w3 = ColumnParallelLinear(
  50. hidden_size,
  51. intermediate_size,
  52. bias=False,
  53. linear_method=linear_method,
  54. )
  55. else:
  56. self.merge_weight = True
  57. self.gate_up_proj = MergedColumnParallelLinear(
  58. hidden_size,
  59. [intermediate_size] * 2,
  60. bias=False,
  61. linear_method=linear_method,
  62. )
  63. self.w2 = RowParallelLinear(
  64. intermediate_size,
  65. hidden_size,
  66. bias=False,
  67. linear_method=linear_method,
  68. )
  69. if hidden_act != "silu":
  70. raise ValueError(f"Unsupported activation: {hidden_act}. "
  71. "Only silu is supported for now.")
  72. self.act_fn = SiluAndMul()
  73. def forward(self, x):
  74. if self.merge_weight:
  75. gate_up, _ = self.gate_up_proj(x)
  76. else:
  77. up, _ = self.up_proj(x)
  78. gate, _ = self.gate_proj(x)
  79. gate_up = torch.cat([gate, up], dim=-1)
  80. x = self.act_fn(gate_up)
  81. x, _ = self.w2(x)
  82. return x
  83. class InternLM2Attention(nn.Module):
  84. def __init__(
  85. self,
  86. hidden_size: int,
  87. num_heads: int,
  88. num_kv_heads: int,
  89. rope_theta: float = 10000,
  90. rope_scaling: Optional[Dict[str, Any]] = None,
  91. max_position_embeddings: int = 8192,
  92. linear_method: Optional[LinearMethodBase] = None,
  93. ) -> None:
  94. super().__init__()
  95. self.hidden_size = hidden_size
  96. tp_size = get_tensor_model_parallel_world_size()
  97. self.total_num_heads = num_heads
  98. assert self.total_num_heads % tp_size == 0
  99. self.num_heads = self.total_num_heads // tp_size
  100. self.total_num_kv_heads = num_kv_heads
  101. if self.total_num_kv_heads >= tp_size:
  102. # Number of KV heads is greater than TP size, so we partition
  103. # the KV heads across multiple tensor parallel GPUs.
  104. assert self.total_num_kv_heads % tp_size == 0
  105. else:
  106. # Number of KV heads is less than TP size, so we replicate
  107. # the KV heads across multiple tensor parallel GPUs.
  108. assert tp_size % self.total_num_kv_heads == 0
  109. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  110. self.head_dim = hidden_size // self.total_num_heads
  111. self.q_size = self.num_heads * self.head_dim
  112. self.kv_size = self.num_kv_heads * self.head_dim
  113. self.scaling = self.head_dim**-0.5
  114. self.rope_theta = rope_theta
  115. self.max_position_embeddings = max_position_embeddings
  116. self.wqkv = QKVParallelLinear(
  117. hidden_size,
  118. self.head_dim,
  119. self.total_num_heads,
  120. self.total_num_kv_heads,
  121. bias=False,
  122. linear_method=linear_method,
  123. )
  124. self.wo = RowParallelLinear(
  125. self.total_num_heads * self.head_dim,
  126. hidden_size,
  127. bias=False,
  128. linear_method=linear_method,
  129. )
  130. self.rotary_emb = get_rope(
  131. self.head_dim,
  132. rotary_dim=self.head_dim,
  133. max_position=max_position_embeddings,
  134. base=rope_theta,
  135. rope_scaling=rope_scaling,
  136. )
  137. self.attn = Attention(
  138. self.num_heads,
  139. self.head_dim,
  140. self.scaling,
  141. num_kv_heads=self.num_kv_heads,
  142. )
  143. def forward(
  144. self,
  145. positions: torch.Tensor,
  146. hidden_states: torch.Tensor,
  147. kv_cache: torch.Tensor,
  148. attn_metadata: AttentionMetadata,
  149. ) -> torch.Tensor:
  150. qkv, _ = self.wqkv(hidden_states)
  151. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  152. q, k = self.rotary_emb(positions, q, k)
  153. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  154. output, _ = self.wo(attn_output)
  155. return output
  156. class InternLMDecoderLayer(nn.Module):
  157. def __init__(
  158. self,
  159. config: PretrainedConfig,
  160. linear_method: Optional[LinearMethodBase] = None,
  161. ) -> None:
  162. super().__init__()
  163. self.hidden_size = config.hidden_size
  164. rope_theta = getattr(config, "rope_theta", 10000)
  165. rope_scaling = getattr(config, "rope_scaling", None)
  166. max_position_embeddings = getattr(config, "max_position_embeddings",
  167. 8192)
  168. self.attention = InternLM2Attention(
  169. hidden_size=self.hidden_size,
  170. num_heads=config.num_attention_heads,
  171. num_kv_heads=config.num_key_value_heads,
  172. rope_theta=rope_theta,
  173. rope_scaling=rope_scaling,
  174. max_position_embeddings=max_position_embeddings,
  175. linear_method=linear_method,
  176. )
  177. self.feed_forward = InternLM2MLP(
  178. hidden_size=self.hidden_size,
  179. intermediate_size=config.intermediate_size,
  180. hidden_act=config.hidden_act,
  181. linear_method=linear_method,
  182. )
  183. self.attention_norm = RMSNorm(config.hidden_size,
  184. eps=config.rms_norm_eps)
  185. self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  186. def forward(
  187. self,
  188. positions: torch.Tensor,
  189. hidden_states: torch.Tensor,
  190. kv_cache: torch.Tensor,
  191. attn_metadata: AttentionMetadata,
  192. residual: Optional[torch.Tensor],
  193. ) -> Tuple[torch.Tensor, torch.Tensor]:
  194. # Self Attention
  195. if residual is None:
  196. residual = hidden_states
  197. hidden_states = self.attention_norm(hidden_states)
  198. else:
  199. hidden_states, residual = self.attention_norm(
  200. hidden_states, residual)
  201. hidden_states = self.attention(
  202. positions=positions,
  203. hidden_states=hidden_states,
  204. kv_cache=kv_cache,
  205. attn_metadata=attn_metadata,
  206. )
  207. # Fully Connected
  208. hidden_states, residual = self.ffn_norm(hidden_states, residual)
  209. hidden_states = self.feed_forward(hidden_states)
  210. return hidden_states, residual
  211. class InternLM2Model(nn.Module):
  212. def __init__(
  213. self,
  214. config: PretrainedConfig,
  215. linear_method: Optional[LinearMethodBase] = None,
  216. ) -> None:
  217. super().__init__()
  218. self.config = config
  219. self.padding_idx = config.pad_token_id
  220. self.vocab_size = config.vocab_size
  221. self.tok_embeddings = VocabParallelEmbedding(
  222. config.vocab_size,
  223. config.hidden_size,
  224. linear_method=linear_method,
  225. )
  226. self.layers = nn.ModuleList([
  227. InternLMDecoderLayer(config, linear_method)
  228. for _ in range(config.num_hidden_layers)
  229. ])
  230. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  231. def forward(
  232. self,
  233. input_ids: torch.Tensor,
  234. positions: torch.Tensor,
  235. kv_caches: List[torch.Tensor],
  236. attn_metadata: AttentionMetadata,
  237. ) -> torch.Tensor:
  238. hidden_states = self.tok_embeddings(input_ids)
  239. residual = None
  240. for i in range(len(self.layers)):
  241. layer = self.layers[i]
  242. hidden_states, residual = layer(
  243. positions,
  244. hidden_states,
  245. kv_caches[i],
  246. attn_metadata,
  247. residual,
  248. )
  249. hidden_states, _ = self.norm(hidden_states, residual)
  250. return hidden_states
  251. class InternLM2ForCausalLM(nn.Module):
  252. def __init__(
  253. self,
  254. config: PretrainedConfig,
  255. linear_method: Optional[LinearMethodBase] = None,
  256. ) -> None:
  257. super().__init__()
  258. self.config = config
  259. self.linear_method = linear_method
  260. self.model = InternLM2Model(config, linear_method)
  261. self.output = ParallelLMHead(
  262. config.vocab_size,
  263. config.hidden_size,
  264. linear_method=linear_method,
  265. )
  266. self.logits_processor = LogitsProcessor(config.vocab_size,
  267. config.tokenizer_vocab_size)
  268. self.sampler = Sampler()
  269. def forward(
  270. self,
  271. input_ids: torch.Tensor,
  272. positions: torch.Tensor,
  273. kv_caches: List[torch.Tensor],
  274. attn_metadata: AttentionMetadata,
  275. ) -> torch.Tensor:
  276. hidden_states = self.model(input_ids, positions, kv_caches,
  277. attn_metadata)
  278. return hidden_states
  279. def compute_logits(self, hidden_states: torch.Tensor,
  280. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  281. logits = self.logits_processor(self.output, hidden_states,
  282. sampling_metadata)
  283. return logits
  284. def sample(
  285. self,
  286. logits: torch.Tensor,
  287. sampling_metadata: SamplingMetadata,
  288. ) -> Optional[SamplerOutput]:
  289. next_tokens = self.sampler(logits, sampling_metadata)
  290. return next_tokens
  291. def load_weights(
  292. self,
  293. model_name_or_path: str,
  294. cache_dir: Optional[str] = None,
  295. load_format: str = "auto",
  296. revision: Optional[str] = None,
  297. ):
  298. stacked_params_mapping = [
  299. # (param_name, shard_name, shard_id)
  300. ("gate_up_proj", "w1", 0),
  301. ("gate_up_proj", "w3", 1),
  302. ]
  303. if (self.linear_method is not None
  304. and not self.linear_method.quant_config.merge_weight()):
  305. stacked_params_mapping = []
  306. params_dict = dict(self.named_parameters())
  307. for name, loaded_weight in hf_model_weights_iterator(
  308. model_name_or_path, cache_dir, load_format, revision):
  309. if "rotary_emb.inv_freq" in name:
  310. continue
  311. for param_name, weight_name, shard_id in stacked_params_mapping:
  312. if weight_name not in name:
  313. continue
  314. name = name.replace(weight_name, param_name)
  315. # Skip loading extra bias for GPTQ models.
  316. if name.endswith(".bias") and name not in params_dict:
  317. continue
  318. param = params_dict[name]
  319. weight_loader = param.weight_loader
  320. weight_loader(param, loaded_weight, shard_id)
  321. break
  322. else:
  323. # Skip loading extra bias for GPTQ models.
  324. if name.endswith(".bias") and name not in params_dict:
  325. continue
  326. param = params_dict[name]
  327. if "wqkv" in name:
  328. config = self.config
  329. kv_groups = (config.num_attention_heads //
  330. config.num_key_value_heads)
  331. head_dim = config.hidden_size // config.num_attention_heads
  332. loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
  333. head_dim,
  334. loaded_weight.shape[-1])
  335. wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
  336. dim=1)
  337. wq = wq.reshape(-1, wq.shape[-1])
  338. wk = wk.reshape(-1, wk.shape[-1])
  339. wv = wv.reshape(-1, wv.shape[-1])
  340. weight_loader = param.weight_loader
  341. weight_loader(param, wq, "q")
  342. weight_loader(param, wk, "k")
  343. weight_loader(param, wv, "v")
  344. else:
  345. weight_loader = getattr(param, "weight_loader",
  346. default_weight_loader)
  347. weight_loader(param, loaded_weight)