mixtral.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Mixtral model."""
  25. from typing import List, Optional, Tuple
  26. import numpy as np
  27. import torch
  28. import torch.nn.functional as F
  29. from torch import nn
  30. from transformers import MixtralConfig
  31. from aphrodite.modeling.metadata import InputMetadata
  32. from aphrodite.modeling.layers.attention import PagedAttention
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  35. ReplicatedLinear,
  36. QKVParallelLinear,
  37. RowParallelLinear,
  38. ColumnParallelLinear)
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. VocabParallelEmbedding, ParallelLMHead)
  43. from aphrodite.modeling.megatron.communication_op import (
  44. tensor_model_parallel_all_reduce)
  45. from aphrodite.modeling.megatron.parallel_state import (
  46. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  49. hf_model_weights_iterator)
  50. from aphrodite.common.sequence import SamplerOutput
  51. KVCache = Tuple[torch.Tensor, torch.Tensor]
  52. class MixtralMLP(nn.Module):
  53. def __init__(
  54. self,
  55. num_experts: int,
  56. hidden_size: int,
  57. intermediate_size: int,
  58. linear_method: Optional[LinearMethodBase] = None,
  59. ) -> None:
  60. super().__init__()
  61. self.num_experts = num_experts
  62. self.ffn_dim = intermediate_size
  63. self.hidden_dim = hidden_size
  64. self.w1 = ReplicatedLinear(self.hidden_dim,
  65. self.ffn_dim,
  66. bias=False,
  67. linear_method=linear_method)
  68. self.w2 = ReplicatedLinear(self.ffn_dim,
  69. self.hidden_dim,
  70. bias=False,
  71. linear_method=linear_method)
  72. self.w3 = ReplicatedLinear(self.hidden_dim,
  73. self.ffn_dim,
  74. bias=False,
  75. linear_method=linear_method)
  76. # TODO: Use Aphrodite's SiluAndMul
  77. self.act_fn = nn.SiLU()
  78. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  79. w1_out, _ = self.w1(hidden_states)
  80. w1_out = self.act_fn(w1_out)
  81. w3_out, _ = self.w3(hidden_states)
  82. current_hidden_states = w1_out * w3_out
  83. current_hidden_states, _ = self.w2(current_hidden_states)
  84. return current_hidden_states
  85. class MixtralMoE(nn.Module):
  86. def __init__(
  87. self,
  88. config: MixtralConfig,
  89. linear_method: Optional[LinearMethodBase] = None,
  90. ):
  91. super().__init__()
  92. self.config = config
  93. self.rank = get_tensor_model_parallel_rank()
  94. self.tp_size = get_tensor_model_parallel_world_size()
  95. self.num_total_experts = config.num_local_experts
  96. self.top_k = config.num_experts_per_tok
  97. if self.tp_size > self.num_total_experts:
  98. raise ValueError(
  99. f"Tensor parallel size {self.tp_size} is greater than "
  100. f"the number of experts {self.num_total_experts}.")
  101. # Split experts equally between ranks
  102. self.expert_indicies = np.array_split(range(
  103. self.num_total_experts), self.tp_size)[self.rank].tolist()
  104. if not self.expert_indicies:
  105. raise ValueError(
  106. f"Rank {self.rank} has no experts assigned to it.")
  107. self.experts = nn.ModuleList([
  108. MixtralMLP(self.num_total_experts,
  109. config.hidden_size,
  110. config.intermediate_size,
  111. linear_method=linear_method)
  112. if idx in self.expert_indicies else None
  113. for idx in range(self.num_total_experts)
  114. ])
  115. self.gate = ReplicatedLinear(config.hidden_size,
  116. self.num_total_experts,
  117. bias=False,
  118. linear_method=None)
  119. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  120. batch_size, sequence_length, hidden_dim = hidden_states.shape
  121. hidden_states = hidden_states.view(-1, hidden_dim)
  122. # router_logits: (batch * sequence_length, n_experts)
  123. router_logits, _ = self.gate(hidden_states)
  124. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  125. routing_weights, selected_experts = torch.topk(routing_weights,
  126. self.top_k,
  127. dim=-1)
  128. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  129. final_hidden_states = None
  130. for expert_idx in self.expert_indicies:
  131. expert_layer = self.experts[expert_idx]
  132. expert_mask = (selected_experts == expert_idx)
  133. expert_weights = (routing_weights * expert_mask).sum(dim=-1,
  134. keepdim=True)
  135. current_hidden_states = expert_layer(hidden_states).mul_(
  136. expert_weights)
  137. if final_hidden_states is None:
  138. final_hidden_states = current_hidden_states
  139. else:
  140. final_hidden_states.add_(current_hidden_states)
  141. return tensor_model_parallel_all_reduce(final_hidden_states).view(
  142. batch_size, sequence_length, hidden_dim)
  143. class MixtralAttention(nn.Module):
  144. def __init__(self,
  145. hidden_size: int,
  146. num_heads: int,
  147. num_kv_heads: int,
  148. max_position: int = 4096 * 32,
  149. rope_theta: float = 10000,
  150. linear_method: Optional[LinearMethodBase] = None,
  151. sliding_window: Optional[int] = None) -> None:
  152. super().__init__()
  153. self.hidden_size = hidden_size
  154. tp_size = get_tensor_model_parallel_world_size()
  155. self.total_num_heads = num_heads
  156. assert self.total_num_heads % tp_size == 0
  157. self.num_heads = self.total_num_heads // tp_size
  158. self.total_num_kv_heads = num_kv_heads
  159. if self.total_num_kv_heads >= tp_size:
  160. # Number of KV heads is greater than TP size, so we partition
  161. # the KV heads across multiple tensor parallel GPUs.
  162. assert self.total_num_kv_heads % tp_size == 0
  163. else:
  164. # Number of KV heads is less than TP size, so we replicate
  165. # the KV heads across multiple tensor parallel GPUs.
  166. assert tp_size % self.total_num_kv_heads == 0
  167. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  168. self.head_dim = hidden_size // self.total_num_heads
  169. self.q_size = self.num_heads * self.head_dim
  170. self.kv_size = self.num_kv_heads * self.head_dim
  171. self.scaling = self.head_dim**-0.5
  172. self.rope_theta = rope_theta
  173. self.sliding_window = sliding_window
  174. if linear_method is not None and not linear_method.quant_config.merge_weight(
  175. ):
  176. self.merge_weight = False
  177. self.q_proj = ColumnParallelLinear(hidden_size,
  178. self.q_size,
  179. bias=False,
  180. linear_method=linear_method)
  181. self.k_proj = ColumnParallelLinear(hidden_size,
  182. self.kv_size,
  183. bias=False,
  184. linear_method=linear_method)
  185. self.v_proj = ColumnParallelLinear(hidden_size,
  186. self.kv_size,
  187. bias=False,
  188. linear_method=linear_method)
  189. else:
  190. self.merge_weight = True
  191. self.qkv_proj = QKVParallelLinear(
  192. hidden_size,
  193. self.head_dim,
  194. self.total_num_heads,
  195. self.total_num_kv_heads,
  196. bias=False,
  197. linear_method=linear_method,
  198. )
  199. self.o_proj = RowParallelLinear(
  200. self.total_num_heads * self.head_dim,
  201. hidden_size,
  202. bias=False,
  203. linear_method=linear_method,
  204. )
  205. is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
  206. ) is None else linear_method.quant_config.rope_style()
  207. self.rotary_emb = get_rope(
  208. self.head_dim,
  209. rotary_dim=self.head_dim,
  210. max_position=max_position,
  211. base=int(self.rope_theta),
  212. is_neox_style=is_neox_style,
  213. )
  214. self.attn = PagedAttention(
  215. self.num_heads,
  216. self.head_dim,
  217. self.scaling,
  218. num_kv_heads=self.num_kv_heads,
  219. sliding_window=self.sliding_window,
  220. )
  221. def forward(
  222. self,
  223. positions: torch.Tensor,
  224. hidden_states: torch.Tensor,
  225. kv_cache: KVCache,
  226. input_metadata: InputMetadata,
  227. ) -> torch.Tensor:
  228. if self.merge_weight:
  229. qkv, _ = self.qkv_proj(hidden_states)
  230. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  231. dim=-1)
  232. else:
  233. q, _ = self.q_proj(hidden_states)
  234. k, _ = self.k_proj(hidden_states)
  235. v, _ = self.v_proj(hidden_states)
  236. q, k = self.rotary_emb(positions, q, k)
  237. k_cache, v_cache = kv_cache
  238. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  239. output, _ = self.o_proj(attn_output)
  240. return output
  241. class MixtralDecoderLayer(nn.Module):
  242. def __init__(
  243. self,
  244. config: MixtralConfig,
  245. linear_method: Optional[LinearMethodBase] = None,
  246. ) -> None:
  247. super().__init__()
  248. self.hidden_size = config.hidden_size
  249. # Requires transformers > 4.32.0
  250. rope_theta = getattr(config, "rope_theta", 10000)
  251. self.self_attn = MixtralAttention(
  252. hidden_size=self.hidden_size,
  253. num_heads=config.num_attention_heads,
  254. max_position=config.max_position_embeddings,
  255. num_kv_heads=config.num_key_value_heads,
  256. rope_theta=rope_theta,
  257. sliding_window=config.sliding_window,
  258. linear_method=linear_method)
  259. self.block_sparse_moe = MixtralMoE(config=config,
  260. linear_method=linear_method)
  261. self.input_layernorm = RMSNorm(config.hidden_size,
  262. eps=config.rms_norm_eps)
  263. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  264. eps=config.rms_norm_eps)
  265. def forward(
  266. self,
  267. positions: torch.Tensor,
  268. hidden_states: torch.Tensor,
  269. kv_cache: KVCache,
  270. input_metadata: InputMetadata,
  271. residual: Optional[torch.Tensor],
  272. ) -> torch.Tensor:
  273. # Self Attention
  274. if residual is None:
  275. residual = hidden_states
  276. hidden_states = self.input_layernorm(hidden_states)
  277. else:
  278. hidden_states, residual = self.input_layernorm(
  279. hidden_states, residual)
  280. hidden_states = self.self_attn(
  281. positions=positions,
  282. hidden_states=hidden_states,
  283. kv_cache=kv_cache,
  284. input_metadata=input_metadata,
  285. )
  286. # Fully Connected
  287. hidden_states, residual = self.post_attention_layernorm(
  288. hidden_states, residual)
  289. hidden_states = self.block_sparse_moe(hidden_states)
  290. return hidden_states, residual
  291. class MixtralModel(nn.Module):
  292. def __init__(
  293. self,
  294. config: MixtralConfig,
  295. linear_method: Optional[LinearMethodBase] = None,
  296. ) -> None:
  297. super().__init__()
  298. self.padding_idx = config.pad_token_id
  299. self.vocab_size = config.vocab_size
  300. self.embed_tokens = VocabParallelEmbedding(
  301. config.vocab_size,
  302. config.hidden_size,
  303. linear_method=linear_method,
  304. )
  305. self.layers = nn.ModuleList([
  306. MixtralDecoderLayer(config, linear_method=linear_method)
  307. for _ in range(config.num_hidden_layers)
  308. ])
  309. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  310. def forward(
  311. self,
  312. input_ids: torch.Tensor,
  313. positions: torch.Tensor,
  314. kv_caches: List[KVCache],
  315. input_metadata: InputMetadata,
  316. ) -> torch.Tensor:
  317. hidden_states = self.embed_tokens(input_ids)
  318. residual = None
  319. for i in range(len(self.layers)):
  320. layer = self.layers[i]
  321. hidden_states, residual = layer(positions, hidden_states,
  322. kv_caches[i], input_metadata,
  323. residual)
  324. hidden_states, _ = self.norm(hidden_states, residual)
  325. return hidden_states
  326. class MixtralForCausalLM(nn.Module):
  327. def __init__(
  328. self,
  329. config: MixtralConfig,
  330. linear_method: Optional[LinearMethodBase] = None,
  331. ) -> None:
  332. super().__init__()
  333. self.config = config
  334. self.linear_method = linear_method
  335. self.model = MixtralModel(config, linear_method)
  336. self.lm_head = ParallelLMHead(config.vocab_size,
  337. config.hidden_size,
  338. linear_method=linear_method)
  339. self.sampler = Sampler(config.vocab_size)
  340. def forward(
  341. self,
  342. input_ids: torch.Tensor,
  343. positions: torch.Tensor,
  344. kv_caches: List[KVCache],
  345. input_metadata: InputMetadata,
  346. ) -> torch.Tensor:
  347. hidden_states = self.model(input_ids, positions, kv_caches,
  348. input_metadata)
  349. return hidden_states
  350. def sample(
  351. self,
  352. hidden_states: Optional[torch.Tensor],
  353. sampling_metadata: SamplingMetadata,
  354. ) -> Optional[SamplerOutput]:
  355. next_tokens = self.sampler(self.lm_head(hidden_states),
  356. sampling_metadata)
  357. return next_tokens
  358. def load_weights(self,
  359. model_name_or_path: str,
  360. cache_dir: Optional[str] = None,
  361. load_format: str = "auto",
  362. revision: Optional[str] = None):
  363. stacked_params_mapping = [
  364. # (param_name, shard_name, shard_id)
  365. ("qkv_proj", "q_proj", "q"),
  366. ("qkv_proj", "k_proj", "k"),
  367. ("qkv_proj", "v_proj", "v"),
  368. ]
  369. if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
  370. ):
  371. stacked_params_mapping = []
  372. params_dict = dict(self.named_parameters())
  373. for name, loaded_weight in hf_model_weights_iterator(
  374. model_name_or_path,
  375. cache_dir,
  376. load_format,
  377. revision,
  378. fall_back_to_pt=False):
  379. if "rotary_emb.inv_freq" in name:
  380. continue
  381. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  382. if weight_name not in name:
  383. continue
  384. name = name.replace(weight_name, param_name)
  385. # Skip loading extra bias for GPTQ models.
  386. if name.endswith(".bias") and name not in params_dict:
  387. continue
  388. param = params_dict[name]
  389. weight_loader = param.weight_loader
  390. weight_loader(param, loaded_weight, shard_id)
  391. break
  392. else:
  393. # Skip loading extra bias for GPTQ models.
  394. if name.endswith(".bias") and name not in params_dict:
  395. continue
  396. # Skip experts that are not assigned to this worker.
  397. if ("block_sparse_moe.experts." in name
  398. and name not in params_dict):
  399. continue
  400. param = params_dict[name]
  401. weight_loader = getattr(param, "weight_loader",
  402. default_weight_loader)
  403. weight_loader(param, loaded_weight)