internlm2.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # -*- coding: utf-8 -*-
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import torch
  4. from torch import nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.modeling.metadata import InputMetadata
  7. from aphrodite.modeling.layers.activation import SiluAndMul
  8. from aphrodite.modeling.layers.attention import PagedAttention
  9. from aphrodite.modeling.layers.layernorm import RMSNorm
  10. from aphrodite.modeling.layers.linear import (
  11. LinearMethodBase,
  12. ColumnParallelLinear,
  13. MergedColumnParallelLinear,
  14. QKVParallelLinear,
  15. RowParallelLinear,
  16. )
  17. from aphrodite.modeling.layers.rotary_embedding import get_rope
  18. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  19. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  20. VocabParallelEmbedding,
  21. ParallelLMHead,
  22. )
  23. from aphrodite.modeling.megatron.parallel_state import (
  24. get_tensor_model_parallel_world_size, )
  25. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  26. from aphrodite.modeling.hf_downloader import (
  27. default_weight_loader,
  28. hf_model_weights_iterator,
  29. )
  30. from aphrodite.common.sequence import SamplerOutput
  31. KVCache = Tuple[torch.Tensor, torch.Tensor]
  32. class InternLM2MLP(nn.Module):
  33. def __init__(
  34. self,
  35. hidden_size: int,
  36. intermediate_size: int,
  37. hidden_act: str,
  38. linear_method: Optional[LinearMethodBase] = None,
  39. ) -> None:
  40. super().__init__()
  41. if (linear_method is not None
  42. and not linear_method.quant_config.merge_weight()):
  43. self.merge_weight = False
  44. self.w1 = ColumnParallelLinear(
  45. hidden_size,
  46. intermediate_size,
  47. bias=False,
  48. linear_method=linear_method,
  49. )
  50. self.w3 = ColumnParallelLinear(
  51. hidden_size,
  52. intermediate_size,
  53. bias=False,
  54. linear_method=linear_method,
  55. )
  56. else:
  57. self.merge_weight = True
  58. self.gate_up_proj = MergedColumnParallelLinear(
  59. hidden_size,
  60. [intermediate_size] * 2,
  61. bias=False,
  62. linear_method=linear_method,
  63. )
  64. self.w2 = RowParallelLinear(
  65. intermediate_size,
  66. hidden_size,
  67. bias=False,
  68. linear_method=linear_method,
  69. )
  70. if hidden_act != "silu":
  71. raise ValueError(f"Unsupported activation: {hidden_act}. "
  72. "Only silu is supported for now.")
  73. self.act_fn = SiluAndMul()
  74. def forward(self, x):
  75. if self.merge_weight:
  76. gate_up, _ = self.gate_up_proj(x)
  77. else:
  78. up, _ = self.up_proj(x)
  79. gate, _ = self.gate_proj(x)
  80. gate_up = torch.cat([gate, up], dim=-1)
  81. x = self.act_fn(gate_up)
  82. x, _ = self.w2(x)
  83. return x
  84. class InternLM2Attention(nn.Module):
  85. def __init__(
  86. self,
  87. hidden_size: int,
  88. num_heads: int,
  89. num_kv_heads: int,
  90. rope_theta: float = 10000,
  91. rope_scaling: Optional[Dict[str, Any]] = None,
  92. max_position_embeddings: int = 8192,
  93. linear_method: Optional[LinearMethodBase] = None,
  94. ) -> None:
  95. super().__init__()
  96. self.hidden_size = hidden_size
  97. tp_size = get_tensor_model_parallel_world_size()
  98. self.total_num_heads = num_heads
  99. assert self.total_num_heads % tp_size == 0
  100. self.num_heads = self.total_num_heads // tp_size
  101. self.total_num_kv_heads = num_kv_heads
  102. if self.total_num_kv_heads >= tp_size:
  103. # Number of KV heads is greater than TP size, so we partition
  104. # the KV heads across multiple tensor parallel GPUs.
  105. assert self.total_num_kv_heads % tp_size == 0
  106. else:
  107. # Number of KV heads is less than TP size, so we replicate
  108. # the KV heads across multiple tensor parallel GPUs.
  109. assert tp_size % self.total_num_kv_heads == 0
  110. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  111. self.head_dim = hidden_size // self.total_num_heads
  112. self.q_size = self.num_heads * self.head_dim
  113. self.kv_size = self.num_kv_heads * self.head_dim
  114. self.scaling = self.head_dim**-0.5
  115. self.rope_theta = rope_theta
  116. self.max_position_embeddings = max_position_embeddings
  117. self.wqkv = QKVParallelLinear(
  118. hidden_size,
  119. self.head_dim,
  120. self.total_num_heads,
  121. self.total_num_kv_heads,
  122. bias=False,
  123. linear_method=linear_method,
  124. )
  125. self.wo = RowParallelLinear(
  126. self.total_num_heads * self.head_dim,
  127. hidden_size,
  128. bias=False,
  129. linear_method=linear_method,
  130. )
  131. self.rotary_emb = get_rope(
  132. self.head_dim,
  133. rotary_dim=self.head_dim,
  134. max_position=max_position_embeddings,
  135. base=rope_theta,
  136. rope_scaling=rope_scaling,
  137. )
  138. self.attn = PagedAttention(
  139. self.num_heads,
  140. self.head_dim,
  141. self.scaling,
  142. num_kv_heads=self.num_kv_heads,
  143. )
  144. def forward(
  145. self,
  146. positions: torch.Tensor,
  147. hidden_states: torch.Tensor,
  148. kv_cache: KVCache,
  149. input_metadata: InputMetadata,
  150. ) -> torch.Tensor:
  151. qkv, _ = self.wqkv(hidden_states)
  152. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  153. q, k = self.rotary_emb(positions, q, k)
  154. k_cache, v_cache = kv_cache
  155. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  156. output, _ = self.wo(attn_output)
  157. return output
  158. class InternLMDecoderLayer(nn.Module):
  159. def __init__(
  160. self,
  161. config: PretrainedConfig,
  162. linear_method: Optional[LinearMethodBase] = None,
  163. ) -> None:
  164. super().__init__()
  165. self.hidden_size = config.hidden_size
  166. rope_theta = getattr(config, "rope_theta", 10000)
  167. rope_scaling = getattr(config, "rope_scaling", None)
  168. max_position_embeddings = getattr(config, "max_position_embeddings",
  169. 8192)
  170. self.attention = InternLM2Attention(
  171. hidden_size=self.hidden_size,
  172. num_heads=config.num_attention_heads,
  173. num_kv_heads=config.num_key_value_heads,
  174. rope_theta=rope_theta,
  175. rope_scaling=rope_scaling,
  176. max_position_embeddings=max_position_embeddings,
  177. linear_method=linear_method,
  178. )
  179. self.feed_forward = InternLM2MLP(
  180. hidden_size=self.hidden_size,
  181. intermediate_size=config.intermediate_size,
  182. hidden_act=config.hidden_act,
  183. linear_method=linear_method,
  184. )
  185. self.attention_norm = RMSNorm(config.hidden_size,
  186. eps=config.rms_norm_eps)
  187. self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  188. def forward(
  189. self,
  190. positions: torch.Tensor,
  191. hidden_states: torch.Tensor,
  192. kv_cache: KVCache,
  193. input_metadata: InputMetadata,
  194. residual: Optional[torch.Tensor],
  195. ) -> Tuple[torch.Tensor, torch.Tensor]:
  196. # Self Attention
  197. if residual is None:
  198. residual = hidden_states
  199. hidden_states = self.attention_norm(hidden_states)
  200. else:
  201. hidden_states, residual = self.attention_norm(
  202. hidden_states, residual)
  203. hidden_states = self.attention(
  204. positions=positions,
  205. hidden_states=hidden_states,
  206. kv_cache=kv_cache,
  207. input_metadata=input_metadata,
  208. )
  209. # Fully Connected
  210. hidden_states, residual = self.ffn_norm(hidden_states, residual)
  211. hidden_states = self.feed_forward(hidden_states)
  212. return hidden_states, residual
  213. class InternLM2Model(nn.Module):
  214. def __init__(
  215. self,
  216. config: PretrainedConfig,
  217. linear_method: Optional[LinearMethodBase] = None,
  218. ) -> None:
  219. super().__init__()
  220. self.config = config
  221. self.padding_idx = config.pad_token_id
  222. self.vocab_size = config.vocab_size
  223. self.tok_embeddings = VocabParallelEmbedding(
  224. config.vocab_size,
  225. config.hidden_size,
  226. linear_method=linear_method,
  227. )
  228. self.layers = nn.ModuleList([
  229. InternLMDecoderLayer(config, linear_method)
  230. for _ in range(config.num_hidden_layers)
  231. ])
  232. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  233. def forward(
  234. self,
  235. input_ids: torch.Tensor,
  236. positions: torch.Tensor,
  237. kv_caches: List[KVCache],
  238. input_metadata: InputMetadata,
  239. ) -> torch.Tensor:
  240. hidden_states = self.tok_embeddings(input_ids)
  241. residual = None
  242. for i in range(len(self.layers)):
  243. layer = self.layers[i]
  244. hidden_states, residual = layer(
  245. positions,
  246. hidden_states,
  247. kv_caches[i],
  248. input_metadata,
  249. residual,
  250. )
  251. hidden_states, _ = self.norm(hidden_states, residual)
  252. return hidden_states
  253. class InternLM2ForCausalLM(nn.Module):
  254. def __init__(
  255. self,
  256. config: PretrainedConfig,
  257. linear_method: Optional[LinearMethodBase] = None,
  258. ) -> None:
  259. super().__init__()
  260. self.config = config
  261. self.linear_method = linear_method
  262. self.model = InternLM2Model(config, linear_method)
  263. self.output = ParallelLMHead(
  264. config.vocab_size,
  265. config.hidden_size,
  266. linear_method=linear_method,
  267. )
  268. self.sampler = Sampler(config.vocab_size)
  269. self.quant_sampler = QuantSampler(config.vocab_size)
  270. def forward(
  271. self,
  272. input_ids: torch.Tensor,
  273. positions: torch.Tensor,
  274. kv_caches: List[KVCache],
  275. input_metadata: InputMetadata,
  276. ) -> torch.Tensor:
  277. hidden_states = self.model(input_ids, positions, kv_caches,
  278. input_metadata)
  279. return hidden_states
  280. def sample(
  281. self,
  282. hidden_states: torch.Tensor,
  283. sampling_metadata: SamplingMetadata,
  284. ) -> Optional[SamplerOutput]:
  285. if (self.linear_method is not None
  286. and not self.linear_method.quant_config.merge_weight()):
  287. next_tokens = self.quant_sampler(self.output(hidden_states),
  288. sampling_metadata)
  289. else:
  290. next_tokens = self.sampler(self.output.weight, hidden_states,
  291. sampling_metadata)
  292. return next_tokens
  293. def load_weights(
  294. self,
  295. model_name_or_path: str,
  296. cache_dir: Optional[str] = None,
  297. load_format: str = "auto",
  298. revision: Optional[str] = None,
  299. ):
  300. stacked_params_mapping = [
  301. # (param_name, shard_name, shard_id)
  302. ("gate_up_proj", "w1", 0),
  303. ("gate_up_proj", "w3", 1),
  304. ]
  305. if (self.linear_method is not None
  306. and not self.linear_method.quant_config.merge_weight()):
  307. stacked_params_mapping = []
  308. params_dict = dict(self.named_parameters())
  309. for name, loaded_weight in hf_model_weights_iterator(
  310. model_name_or_path, cache_dir, load_format, revision):
  311. if "rotary_emb.inv_freq" in name:
  312. continue
  313. for param_name, weight_name, shard_id in stacked_params_mapping:
  314. if weight_name not in name:
  315. continue
  316. name = name.replace(weight_name, param_name)
  317. # Skip loading extra bias for GPTQ models.
  318. if name.endswith(".bias") and name not in params_dict:
  319. continue
  320. param = params_dict[name]
  321. weight_loader = param.weight_loader
  322. weight_loader(param, loaded_weight, shard_id)
  323. break
  324. else:
  325. # Skip loading extra bias for GPTQ models.
  326. if name.endswith(".bias") and name not in params_dict:
  327. continue
  328. param = params_dict[name]
  329. if "wqkv" in name:
  330. config = self.config
  331. kv_groups = (config.num_attention_heads //
  332. config.num_key_value_heads)
  333. head_dim = config.hidden_size // config.num_attention_heads
  334. loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
  335. head_dim,
  336. loaded_weight.shape[-1])
  337. wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
  338. dim=1)
  339. wq = wq.reshape(-1, wq.shape[-1])
  340. wk = wk.reshape(-1, wk.shape[-1])
  341. wv = wv.reshape(-1, wv.shape[-1])
  342. weight_loader = param.weight_loader
  343. weight_loader(param, wq, "q")
  344. weight_loader(param, wk, "k")
  345. weight_loader(param, wv, "v")
  346. else:
  347. weight_loader = getattr(param, "weight_loader",
  348. default_weight_loader)
  349. weight_loader(param, loaded_weight)