internlm2.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # -*- coding: utf-8 -*-
  2. from functools import partial
  3. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
  4. import torch
  5. from torch import nn
  6. from transformers import PretrainedConfig
  7. from aphrodite.attention import Attention, AttentionMetadata
  8. from aphrodite.common.config import CacheConfig
  9. from aphrodite.common.sequence import IntermediateTensors
  10. from aphrodite.distributed import (get_pp_group,
  11. get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather)
  15. from aphrodite.modeling.layers.activation import SiluAndMul
  16. from aphrodite.modeling.layers.layernorm import RMSNorm
  17. from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
  18. QKVParallelLinear,
  19. RowParallelLinear)
  20. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  21. from aphrodite.modeling.layers.rotary_embedding import get_rope
  22. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  23. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  24. ParallelLMHead, VocabParallelEmbedding)
  25. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  26. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  27. from aphrodite.quantization.base_config import QuantizationConfig
  28. from .utils import (is_pp_missing_parameter,
  29. make_empty_intermediate_tensors_factory, make_layers)
  30. class InternLM2MLP(nn.Module):
  31. def __init__(
  32. self,
  33. hidden_size: int,
  34. intermediate_size: int,
  35. hidden_act: str,
  36. quant_config: Optional[QuantizationConfig] = None,
  37. ) -> None:
  38. super().__init__()
  39. self.gate_up_proj = MergedColumnParallelLinear(
  40. hidden_size, [intermediate_size] * 2,
  41. bias=False,
  42. quant_config=quant_config)
  43. self.w2 = RowParallelLinear(intermediate_size,
  44. hidden_size,
  45. bias=False,
  46. quant_config=quant_config)
  47. if hidden_act != "silu":
  48. raise ValueError(f"Unsupported activation: {hidden_act}. "
  49. "Only silu is supported for now.")
  50. self.act_fn = SiluAndMul()
  51. def forward(self, x):
  52. gate_up, _ = self.gate_up_proj(x)
  53. x = self.act_fn(gate_up)
  54. x, _ = self.w2(x)
  55. return x
  56. class InternLM2Attention(nn.Module):
  57. def __init__(
  58. self,
  59. hidden_size: int,
  60. num_heads: int,
  61. num_kv_heads: int,
  62. rope_theta: float = 10000,
  63. rope_scaling: Optional[Dict[str, Any]] = None,
  64. max_position_embeddings: int = 8192,
  65. cache_config: Optional[CacheConfig] = None,
  66. quant_config: Optional[QuantizationConfig] = None,
  67. ) -> None:
  68. super().__init__()
  69. self.hidden_size = hidden_size
  70. self.tp_size = get_tensor_model_parallel_world_size()
  71. self.tp_rank = get_tensor_model_parallel_rank()
  72. self.total_num_heads = num_heads
  73. assert self.total_num_heads % self.tp_size == 0
  74. self.num_heads = self.total_num_heads // self.tp_size
  75. self.total_num_kv_heads = num_kv_heads
  76. if self.total_num_kv_heads >= self.tp_size:
  77. # Number of KV heads is greater than TP size, so we partition
  78. # the KV heads across multiple tensor parallel GPUs.
  79. assert self.total_num_kv_heads % self.tp_size == 0
  80. else:
  81. # Number of KV heads is less than TP size, so we replicate
  82. # the KV heads across multiple tensor parallel GPUs.
  83. assert self.tp_size % self.total_num_kv_heads == 0
  84. self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
  85. self.head_dim = hidden_size // self.total_num_heads
  86. self.q_size = self.num_heads * self.head_dim
  87. self.kv_size = self.num_kv_heads * self.head_dim
  88. self.key_value_groups = int(self.num_heads / self.num_kv_heads)
  89. self.scaling = self.head_dim**-0.5
  90. self.rope_theta = rope_theta
  91. self.max_position_embeddings = max_position_embeddings
  92. self.wqkv = QKVParallelLinear(
  93. hidden_size,
  94. self.head_dim,
  95. self.total_num_heads,
  96. self.total_num_kv_heads,
  97. bias=False,
  98. quant_config=quant_config,
  99. )
  100. self.wo = RowParallelLinear(
  101. self.total_num_heads * self.head_dim,
  102. hidden_size,
  103. bias=False,
  104. quant_config=quant_config,
  105. )
  106. self.rotary_emb = get_rope(
  107. self.head_dim,
  108. rotary_dim=self.head_dim,
  109. max_position=max_position_embeddings,
  110. base=rope_theta,
  111. rope_scaling=rope_scaling,
  112. )
  113. self.attn = Attention(self.num_heads,
  114. self.head_dim,
  115. self.scaling,
  116. num_kv_heads=self.num_kv_heads,
  117. cache_config=cache_config,
  118. quant_config=quant_config)
  119. def split_qkv(self, qkv: torch.Tensor):
  120. seq_len = qkv.shape[0]
  121. if self.tp_size > 1:
  122. qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
  123. qkv = tensor_model_parallel_all_gather(qkv)
  124. qkv = torch.split(qkv, qkv_map, dim=-1)
  125. qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
  126. qkv = torch.cat(qkv, dim=-1)
  127. qkv = qkv.view(seq_len, self.total_num_kv_heads,
  128. self.key_value_groups + 2, self.head_dim)
  129. q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
  130. q = q.reshape(seq_len, self.q_size * self.tp_size)
  131. k = k.reshape(seq_len, self.kv_size * self.tp_size)
  132. v = v.reshape(seq_len, self.kv_size * self.tp_size)
  133. if self.tp_size > 1:
  134. splitter = partial(split_tensor_along_last_dim,
  135. num_partitions=self.tp_size)
  136. q = splitter(q)[self.tp_rank]
  137. k = splitter(k)[self.tp_rank]
  138. v = splitter(v)[self.tp_rank]
  139. return q, k, v
  140. def forward(
  141. self,
  142. positions: torch.Tensor,
  143. hidden_states: torch.Tensor,
  144. kv_cache: torch.Tensor,
  145. attn_metadata: AttentionMetadata,
  146. ) -> torch.Tensor:
  147. qkv, _ = self.wqkv(hidden_states)
  148. q, k, v = self.split_qkv(qkv)
  149. q, k = self.rotary_emb(positions, q, k)
  150. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  151. output, _ = self.wo(attn_output)
  152. return output
  153. class InternLMDecoderLayer(nn.Module):
  154. def __init__(
  155. self,
  156. config: PretrainedConfig,
  157. cache_config: Optional[CacheConfig] = None,
  158. quant_config: Optional[QuantizationConfig] = None,
  159. ) -> None:
  160. super().__init__()
  161. self.hidden_size = config.hidden_size
  162. rope_theta = getattr(config, "rope_theta", 10000)
  163. rope_scaling = getattr(config, "rope_scaling", None)
  164. max_position_embeddings = getattr(config, "max_position_embeddings",
  165. 8192)
  166. self.attention = InternLM2Attention(
  167. hidden_size=self.hidden_size,
  168. num_heads=config.num_attention_heads,
  169. num_kv_heads=config.num_key_value_heads,
  170. rope_theta=rope_theta,
  171. rope_scaling=rope_scaling,
  172. max_position_embeddings=max_position_embeddings,
  173. cache_config=cache_config,
  174. quant_config=quant_config,
  175. )
  176. self.feed_forward = InternLM2MLP(
  177. hidden_size=self.hidden_size,
  178. intermediate_size=config.intermediate_size,
  179. hidden_act=config.hidden_act,
  180. quant_config=quant_config,
  181. )
  182. self.attention_norm = RMSNorm(config.hidden_size,
  183. eps=config.rms_norm_eps)
  184. self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  185. def forward(
  186. self,
  187. positions: torch.Tensor,
  188. hidden_states: torch.Tensor,
  189. kv_cache: torch.Tensor,
  190. attn_metadata: AttentionMetadata,
  191. residual: Optional[torch.Tensor],
  192. ) -> Tuple[torch.Tensor, torch.Tensor]:
  193. # Self Attention
  194. if residual is None:
  195. residual = hidden_states
  196. hidden_states = self.attention_norm(hidden_states)
  197. else:
  198. hidden_states, residual = self.attention_norm(
  199. hidden_states, residual)
  200. hidden_states = self.attention(
  201. positions=positions,
  202. hidden_states=hidden_states,
  203. kv_cache=kv_cache,
  204. attn_metadata=attn_metadata,
  205. )
  206. # Fully Connected
  207. hidden_states, residual = self.ffn_norm(hidden_states, residual)
  208. hidden_states = self.feed_forward(hidden_states)
  209. return hidden_states, residual
  210. class InternLM2Model(nn.Module):
  211. def __init__(
  212. self,
  213. config: PretrainedConfig,
  214. cache_config: Optional[CacheConfig] = None,
  215. quant_config: Optional[QuantizationConfig] = None,
  216. prefix: str = "",
  217. ) -> None:
  218. super().__init__()
  219. self.config = config
  220. self.padding_idx = config.pad_token_id
  221. self.vocab_size = config.vocab_size
  222. self.tok_embeddings = VocabParallelEmbedding(
  223. config.vocab_size,
  224. config.hidden_size,
  225. )
  226. self.start_layer, self.end_layer, self.layers = make_layers(
  227. config.num_hidden_layers,
  228. lambda prefix: InternLMDecoderLayer(config, cache_config,
  229. quant_config),
  230. prefix=f"{prefix}.layers")
  231. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  232. self.make_empty_intermediate_tensors = (
  233. make_empty_intermediate_tensors_factory(
  234. ["hidden_states", "residual"], config.hidden_size))
  235. def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
  236. return self.tok_embeddings(input_ids)
  237. def forward(
  238. self,
  239. input_ids: torch.Tensor,
  240. positions: torch.Tensor,
  241. kv_caches: List[torch.Tensor],
  242. attn_metadata: AttentionMetadata,
  243. intermediate_tensors: IntermediateTensors = None,
  244. inputs_embeds: Optional[torch.Tensor] = None,
  245. ) -> Union[torch.Tensor, IntermediateTensors]:
  246. if get_pp_group().is_first_rank:
  247. if inputs_embeds is not None:
  248. hidden_states = inputs_embeds
  249. else:
  250. hidden_states = self.tok_embeddings(input_ids)
  251. residual = None
  252. else:
  253. assert intermediate_tensors is not None
  254. hidden_states = intermediate_tensors["hidden_states"]
  255. residual = intermediate_tensors["residual"]
  256. for i in range(self.start_layer, self.end_layer):
  257. layer = self.layers[i]
  258. hidden_states, residual = layer(
  259. positions,
  260. hidden_states,
  261. kv_caches[i - self.start_layer],
  262. attn_metadata,
  263. residual,
  264. )
  265. if not get_pp_group().is_last_rank:
  266. return IntermediateTensors({
  267. "hidden_states": hidden_states,
  268. "residual": residual
  269. })
  270. hidden_states, _ = self.norm(hidden_states, residual)
  271. return hidden_states
  272. class InternLM2ForCausalLM(nn.Module):
  273. def __init__(
  274. self,
  275. config: PretrainedConfig,
  276. cache_config: Optional[CacheConfig] = None,
  277. quant_config: Optional[QuantizationConfig] = None,
  278. ) -> None:
  279. super().__init__()
  280. self.config = config
  281. self.quant_config = quant_config
  282. self.model = InternLM2Model(config, cache_config, quant_config)
  283. self.output = ParallelLMHead(config.vocab_size,
  284. config.hidden_size,
  285. quant_config=quant_config)
  286. if self.config.tie_word_embeddings:
  287. self.output.weight = self.model.tok_embeddings.weight
  288. self.logits_processor = LogitsProcessor(config.vocab_size)
  289. self.sampler = Sampler()
  290. self.make_empty_intermediate_tensors = (
  291. self.model.make_empty_intermediate_tensors)
  292. def forward(
  293. self,
  294. input_ids: torch.Tensor,
  295. positions: torch.Tensor,
  296. kv_caches: List[torch.Tensor],
  297. attn_metadata: AttentionMetadata,
  298. intermediate_tensors: IntermediateTensors,
  299. ) -> torch.Tensor:
  300. hidden_states = self.model(input_ids, positions, kv_caches,
  301. attn_metadata, intermediate_tensors)
  302. return hidden_states
  303. def compute_logits(
  304. self,
  305. hidden_states: torch.Tensor,
  306. sampling_metadata: SamplingMetadata,
  307. ) -> Optional[torch.Tensor]:
  308. logits = self.logits_processor(self.output, hidden_states,
  309. sampling_metadata)
  310. return logits
  311. def sample(
  312. self,
  313. logits: torch.Tensor,
  314. sampling_metadata: SamplingMetadata,
  315. ) -> Optional[SamplerOutput]:
  316. next_tokens = self.sampler(logits, sampling_metadata)
  317. return next_tokens
  318. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  319. stacked_params_mapping = [
  320. # (param_name, shard_name, shard_id)
  321. ("gate_up_proj", "w1", 0),
  322. ("gate_up_proj", "w3", 1),
  323. ]
  324. params_dict = dict(self.named_parameters())
  325. for name, loaded_weight in weights:
  326. if "rotary_emb.inv_freq" in name:
  327. continue
  328. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  329. if weight_name not in name:
  330. continue
  331. name = name.replace(weight_name, param_name)
  332. # Skip loading extra bias for GPTQ models.
  333. if name.endswith(".bias") and name not in params_dict:
  334. continue
  335. if is_pp_missing_parameter(name, self):
  336. continue
  337. param = params_dict[name]
  338. weight_loader = param.weight_loader
  339. weight_loader(param, loaded_weight, shard_id)
  340. break
  341. else:
  342. # Skip loading extra bias for GPTQ models.
  343. if name.endswith(".bias") and name not in params_dict:
  344. continue
  345. if is_pp_missing_parameter(name, self):
  346. continue
  347. param = params_dict[name]
  348. weight_loader = getattr(param, "weight_loader",
  349. default_weight_loader)
  350. weight_loader(param, loaded_weight)