mixtral_quant.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Mixtral model."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import numpy as np
  26. import torch
  27. import torch.nn.functional as F
  28. from torch import nn
  29. from transformers import MixtralConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.sequence import SamplerOutput
  32. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  33. get_tensor_model_parallel_world_size,
  34. tensor_model_parallel_all_reduce)
  35. from aphrodite.modeling.layers.layernorm import RMSNorm
  36. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  37. ReplicatedLinear,
  38. RowParallelLinear)
  39. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  40. from aphrodite.modeling.layers.rotary_embedding import get_rope
  41. from aphrodite.modeling.layers.sampler import Sampler
  42. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  43. ParallelLMHead, VocabParallelEmbedding)
  44. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  45. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  46. from aphrodite.quantization.base_config import QuantizationConfig
  47. class MixtralMLP(nn.Module):
  48. def __init__(
  49. self,
  50. num_experts: int,
  51. hidden_size: int,
  52. intermediate_size: int,
  53. quant_config: Optional[QuantizationConfig] = None,
  54. ) -> None:
  55. super().__init__()
  56. self.num_experts = num_experts
  57. self.ffn_dim = intermediate_size
  58. self.hidden_dim = hidden_size
  59. self.w1 = ReplicatedLinear(self.hidden_dim,
  60. self.ffn_dim,
  61. bias=False,
  62. quant_config=quant_config)
  63. self.w2 = ReplicatedLinear(self.ffn_dim,
  64. self.hidden_dim,
  65. bias=False,
  66. quant_config=quant_config)
  67. self.w3 = ReplicatedLinear(self.hidden_dim,
  68. self.ffn_dim,
  69. bias=False,
  70. quant_config=quant_config)
  71. # TODO: Use vllm's SiluAndMul
  72. self.act_fn = nn.SiLU()
  73. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  74. w1_out, _ = self.w1(hidden_states)
  75. w1_out = self.act_fn(w1_out)
  76. w3_out, _ = self.w3(hidden_states)
  77. current_hidden_states = w1_out * w3_out
  78. current_hidden_states, _ = self.w2(current_hidden_states)
  79. return current_hidden_states
  80. class MixtralMoE(nn.Module):
  81. def __init__(
  82. self,
  83. config: MixtralConfig,
  84. quant_config: Optional[QuantizationConfig] = None,
  85. ):
  86. super().__init__()
  87. self.config = config
  88. self.rank = get_tensor_model_parallel_rank()
  89. self.tp_size = get_tensor_model_parallel_world_size()
  90. self.num_total_experts = config.num_local_experts
  91. self.top_k = config.num_experts_per_tok
  92. if self.tp_size > self.num_total_experts:
  93. raise ValueError(
  94. f"Tensor parallel size {self.tp_size} is greater than "
  95. f"the number of experts {self.num_total_experts}.")
  96. # Split experts equally between ranks
  97. self.expert_indicies = np.array_split(range(
  98. self.num_total_experts), self.tp_size)[self.rank].tolist()
  99. if not self.expert_indicies:
  100. raise ValueError(
  101. f"Rank {self.rank} has no experts assigned to it.")
  102. self.experts = nn.ModuleList([
  103. MixtralMLP(self.num_total_experts,
  104. config.hidden_size,
  105. config.intermediate_size,
  106. quant_config=quant_config)
  107. if idx in self.expert_indicies else None
  108. for idx in range(self.num_total_experts)
  109. ])
  110. self.gate = ReplicatedLinear(config.hidden_size,
  111. self.num_total_experts,
  112. bias=False,
  113. quant_config=None)
  114. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  115. num_tokens, hidden_dim = hidden_states.shape
  116. hidden_states = hidden_states.view(-1, hidden_dim)
  117. # router_logits: (num_tokens, n_experts)
  118. router_logits, _ = self.gate(hidden_states)
  119. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  120. routing_weights, selected_experts = torch.topk(routing_weights,
  121. self.top_k,
  122. dim=-1)
  123. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  124. final_hidden_states = None
  125. for expert_idx in self.expert_indicies:
  126. expert_layer = self.experts[expert_idx]
  127. expert_mask = (selected_experts == expert_idx)
  128. expert_weights = (routing_weights * expert_mask).sum(dim=-1,
  129. keepdim=True)
  130. current_hidden_states = expert_layer(hidden_states).mul_(
  131. expert_weights)
  132. if final_hidden_states is None:
  133. final_hidden_states = current_hidden_states
  134. else:
  135. final_hidden_states.add_(current_hidden_states)
  136. return tensor_model_parallel_all_reduce(final_hidden_states).view(
  137. num_tokens, hidden_dim)
  138. class MixtralAttention(nn.Module):
  139. def __init__(self,
  140. hidden_size: int,
  141. num_heads: int,
  142. num_kv_heads: int,
  143. max_position: int = 4096 * 32,
  144. rope_theta: float = 10000,
  145. quant_config: Optional[QuantizationConfig] = None,
  146. sliding_window: Optional[int] = None) -> None:
  147. super().__init__()
  148. self.hidden_size = hidden_size
  149. tp_size = get_tensor_model_parallel_world_size()
  150. self.total_num_heads = num_heads
  151. assert self.total_num_heads % tp_size == 0
  152. self.num_heads = self.total_num_heads // tp_size
  153. self.total_num_kv_heads = num_kv_heads
  154. if self.total_num_kv_heads >= tp_size:
  155. # Number of KV heads is greater than TP size, so we partition
  156. # the KV heads across multiple tensor parallel GPUs.
  157. assert self.total_num_kv_heads % tp_size == 0
  158. else:
  159. # Number of KV heads is less than TP size, so we replicate
  160. # the KV heads across multiple tensor parallel GPUs.
  161. assert tp_size % self.total_num_kv_heads == 0
  162. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  163. self.head_dim = hidden_size // self.total_num_heads
  164. self.q_size = self.num_heads * self.head_dim
  165. self.kv_size = self.num_kv_heads * self.head_dim
  166. self.scaling = self.head_dim**-0.5
  167. self.rope_theta = rope_theta
  168. self.sliding_window = sliding_window
  169. self.qkv_proj = QKVParallelLinear(
  170. hidden_size,
  171. self.head_dim,
  172. self.total_num_heads,
  173. self.total_num_kv_heads,
  174. bias=False,
  175. quant_config=quant_config,
  176. )
  177. self.o_proj = RowParallelLinear(
  178. self.total_num_heads * self.head_dim,
  179. hidden_size,
  180. bias=False,
  181. quant_config=quant_config,
  182. )
  183. self.rotary_emb = get_rope(
  184. self.head_dim,
  185. rotary_dim=self.head_dim,
  186. max_position=max_position,
  187. base=int(self.rope_theta),
  188. is_neox_style=True,
  189. )
  190. self.attn = Attention(
  191. self.num_heads,
  192. self.head_dim,
  193. self.scaling,
  194. num_kv_heads=self.num_kv_heads,
  195. sliding_window=self.sliding_window,
  196. )
  197. def forward(
  198. self,
  199. positions: torch.Tensor,
  200. hidden_states: torch.Tensor,
  201. kv_cache: torch.Tensor,
  202. attn_metadata: AttentionMetadata,
  203. ) -> torch.Tensor:
  204. qkv, _ = self.qkv_proj(hidden_states)
  205. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  206. q, k = self.rotary_emb(positions, q, k)
  207. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  208. output, _ = self.o_proj(attn_output)
  209. return output
  210. class MixtralDecoderLayer(nn.Module):
  211. def __init__(
  212. self,
  213. config: MixtralConfig,
  214. quant_config: Optional[QuantizationConfig] = None,
  215. ) -> None:
  216. super().__init__()
  217. self.hidden_size = config.hidden_size
  218. # Requires transformers > 4.32.0
  219. rope_theta = getattr(config, "rope_theta", 10000)
  220. self.self_attn = MixtralAttention(
  221. hidden_size=self.hidden_size,
  222. num_heads=config.num_attention_heads,
  223. max_position=config.max_position_embeddings,
  224. num_kv_heads=config.num_key_value_heads,
  225. rope_theta=rope_theta,
  226. sliding_window=config.sliding_window,
  227. quant_config=quant_config)
  228. self.block_sparse_moe = MixtralMoE(config=config,
  229. quant_config=quant_config)
  230. self.input_layernorm = RMSNorm(config.hidden_size,
  231. eps=config.rms_norm_eps)
  232. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  233. eps=config.rms_norm_eps)
  234. def forward(
  235. self,
  236. positions: torch.Tensor,
  237. hidden_states: torch.Tensor,
  238. kv_cache: torch.Tensor,
  239. attn_metadata: AttentionMetadata,
  240. residual: Optional[torch.Tensor],
  241. ) -> torch.Tensor:
  242. # Self Attention
  243. if residual is None:
  244. residual = hidden_states
  245. hidden_states = self.input_layernorm(hidden_states)
  246. else:
  247. hidden_states, residual = self.input_layernorm(
  248. hidden_states, residual)
  249. hidden_states = self.self_attn(
  250. positions=positions,
  251. hidden_states=hidden_states,
  252. kv_cache=kv_cache,
  253. attn_metadata=attn_metadata,
  254. )
  255. # Fully Connected
  256. hidden_states, residual = self.post_attention_layernorm(
  257. hidden_states, residual)
  258. hidden_states = self.block_sparse_moe(hidden_states)
  259. return hidden_states, residual
  260. class MixtralModel(nn.Module):
  261. def __init__(
  262. self,
  263. config: MixtralConfig,
  264. quant_config: Optional[QuantizationConfig] = None,
  265. ) -> None:
  266. super().__init__()
  267. self.padding_idx = config.pad_token_id
  268. self.vocab_size = config.vocab_size
  269. self.embed_tokens = VocabParallelEmbedding(
  270. config.vocab_size,
  271. config.hidden_size,
  272. )
  273. self.layers = nn.ModuleList([
  274. MixtralDecoderLayer(config, quant_config=quant_config)
  275. for _ in range(config.num_hidden_layers)
  276. ])
  277. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  278. def forward(
  279. self,
  280. input_ids: torch.Tensor,
  281. positions: torch.Tensor,
  282. kv_caches: List[torch.Tensor],
  283. attn_metadata: AttentionMetadata,
  284. ) -> torch.Tensor:
  285. hidden_states = self.embed_tokens(input_ids)
  286. residual = None
  287. for i in range(len(self.layers)):
  288. layer = self.layers[i]
  289. hidden_states, residual = layer(positions, hidden_states,
  290. kv_caches[i], attn_metadata,
  291. residual)
  292. hidden_states, _ = self.norm(hidden_states, residual)
  293. return hidden_states
  294. class MixtralForCausalLM(nn.Module):
  295. fall_back_to_pt_during_load = False
  296. def __init__(
  297. self,
  298. config: MixtralConfig,
  299. quant_config: Optional[QuantizationConfig] = None,
  300. ) -> None:
  301. super().__init__()
  302. self.config = config
  303. self.quant_config = quant_config
  304. self.model = MixtralModel(config, quant_config)
  305. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  306. self.logits_processor = LogitsProcessor(config.vocab_size)
  307. self.sampler = Sampler()
  308. def forward(
  309. self,
  310. input_ids: torch.Tensor,
  311. positions: torch.Tensor,
  312. kv_caches: List[torch.Tensor],
  313. attn_metadata: AttentionMetadata,
  314. ) -> torch.Tensor:
  315. hidden_states = self.model(input_ids, positions, kv_caches,
  316. attn_metadata)
  317. return hidden_states
  318. def compute_logits(self, hidden_states: torch.Tensor,
  319. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  320. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  321. sampling_metadata)
  322. return logits
  323. def sample(
  324. self,
  325. logits: Optional[torch.Tensor],
  326. sampling_metadata: SamplingMetadata,
  327. ) -> Optional[SamplerOutput]:
  328. next_tokens = self.sampler(logits, sampling_metadata)
  329. return next_tokens
  330. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  331. stacked_params_mapping = [
  332. # (param_name, shard_name, shard_id)
  333. ("qkv_proj", "q_proj", "q"),
  334. ("qkv_proj", "k_proj", "k"),
  335. ("qkv_proj", "v_proj", "v"),
  336. ]
  337. params_dict = dict(self.named_parameters())
  338. for name, loaded_weight in weights:
  339. if "rotary_emb.inv_freq" in name:
  340. continue
  341. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  342. if weight_name not in name:
  343. continue
  344. name = name.replace(weight_name, param_name)
  345. # Skip loading extra bias for GPTQ models.
  346. if name.endswith(".bias") and name not in params_dict:
  347. continue
  348. param = params_dict[name]
  349. weight_loader = param.weight_loader
  350. weight_loader(param, loaded_weight, shard_id)
  351. break
  352. else:
  353. # Skip loading extra bias for GPTQ models.
  354. if name.endswith(".bias") and name not in params_dict:
  355. continue
  356. # Skip experts that are not assigned to this worker.
  357. if ("block_sparse_moe.experts." in name
  358. and name not in params_dict):
  359. continue
  360. param = params_dict[name]
  361. weight_loader = getattr(param, "weight_loader",
  362. default_weight_loader)
  363. weight_loader(param, loaded_weight)