mixtral_quant.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Mixtral model."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import numpy as np
  26. import torch
  27. import torch.nn.functional as F
  28. from torch import nn
  29. from transformers import MixtralConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import CacheConfig
  32. from aphrodite.common.sequence import IntermediateTensors
  33. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size,
  35. tensor_model_parallel_all_reduce)
  36. from aphrodite.modeling.layers.layernorm import RMSNorm
  37. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  38. ReplicatedLinear,
  39. RowParallelLinear)
  40. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. ParallelLMHead, VocabParallelEmbedding)
  45. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.quantization.base_config import QuantizationConfig
  48. class MixtralMLP(nn.Module):
  49. def __init__(
  50. self,
  51. num_experts: int,
  52. hidden_size: int,
  53. intermediate_size: int,
  54. quant_config: Optional[QuantizationConfig] = None,
  55. ) -> None:
  56. super().__init__()
  57. self.num_experts = num_experts
  58. self.ffn_dim = intermediate_size
  59. self.hidden_dim = hidden_size
  60. self.w1 = ReplicatedLinear(self.hidden_dim,
  61. self.ffn_dim,
  62. bias=False,
  63. quant_config=quant_config)
  64. self.w2 = ReplicatedLinear(self.ffn_dim,
  65. self.hidden_dim,
  66. bias=False,
  67. quant_config=quant_config)
  68. self.w3 = ReplicatedLinear(self.hidden_dim,
  69. self.ffn_dim,
  70. bias=False,
  71. quant_config=quant_config)
  72. # TODO: Use Aphrodite's SiluAndMul
  73. self.act_fn = nn.SiLU()
  74. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  75. w1_out, _ = self.w1(hidden_states)
  76. w1_out = self.act_fn(w1_out)
  77. w3_out, _ = self.w3(hidden_states)
  78. current_hidden_states = w1_out * w3_out
  79. current_hidden_states, _ = self.w2(current_hidden_states)
  80. return current_hidden_states
  81. class MixtralMoE(nn.Module):
  82. def __init__(
  83. self,
  84. config: MixtralConfig,
  85. quant_config: Optional[QuantizationConfig] = None,
  86. ):
  87. super().__init__()
  88. self.config = config
  89. self.rank = get_tensor_model_parallel_rank()
  90. self.tp_size = get_tensor_model_parallel_world_size()
  91. self.num_total_experts = config.num_local_experts
  92. self.top_k = config.num_experts_per_tok
  93. if self.tp_size > self.num_total_experts:
  94. raise ValueError(
  95. f"Tensor parallel size {self.tp_size} is greater than "
  96. f"the number of experts {self.num_total_experts}.")
  97. # Split experts equally between ranks
  98. self.expert_indicies = np.array_split(range(
  99. self.num_total_experts), self.tp_size)[self.rank].tolist()
  100. if not self.expert_indicies:
  101. raise ValueError(
  102. f"Rank {self.rank} has no experts assigned to it.")
  103. self.experts = nn.ModuleList([
  104. MixtralMLP(self.num_total_experts,
  105. config.hidden_size,
  106. config.intermediate_size,
  107. quant_config=quant_config)
  108. if idx in self.expert_indicies else None
  109. for idx in range(self.num_total_experts)
  110. ])
  111. self.gate = ReplicatedLinear(config.hidden_size,
  112. self.num_total_experts,
  113. bias=False,
  114. quant_config=None)
  115. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  116. num_tokens, hidden_dim = hidden_states.shape
  117. hidden_states = hidden_states.view(-1, hidden_dim)
  118. # router_logits: (num_tokens, n_experts)
  119. router_logits, _ = self.gate(hidden_states)
  120. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  121. routing_weights, selected_experts = torch.topk(routing_weights,
  122. self.top_k,
  123. dim=-1)
  124. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  125. final_hidden_states = None
  126. for expert_idx in self.expert_indicies:
  127. expert_layer = self.experts[expert_idx]
  128. expert_mask = (selected_experts == expert_idx)
  129. expert_weights = (routing_weights * expert_mask).sum(dim=-1,
  130. keepdim=True)
  131. current_hidden_states = expert_layer(hidden_states).mul_(
  132. expert_weights)
  133. if final_hidden_states is None:
  134. final_hidden_states = current_hidden_states
  135. else:
  136. final_hidden_states.add_(current_hidden_states)
  137. return tensor_model_parallel_all_reduce(final_hidden_states).view(
  138. num_tokens, hidden_dim)
  139. class MixtralAttention(nn.Module):
  140. def __init__(
  141. self,
  142. hidden_size: int,
  143. num_heads: int,
  144. num_kv_heads: int,
  145. max_position: int = 4096 * 32,
  146. rope_theta: float = 10000,
  147. quant_config: Optional[QuantizationConfig] = None,
  148. cache_config: Optional[CacheConfig] = None,
  149. ) -> None:
  150. super().__init__()
  151. self.hidden_size = hidden_size
  152. tp_size = get_tensor_model_parallel_world_size()
  153. self.total_num_heads = num_heads
  154. assert self.total_num_heads % tp_size == 0
  155. self.num_heads = self.total_num_heads // tp_size
  156. self.total_num_kv_heads = num_kv_heads
  157. if self.total_num_kv_heads >= tp_size:
  158. # Number of KV heads is greater than TP size, so we partition
  159. # the KV heads across multiple tensor parallel GPUs.
  160. assert self.total_num_kv_heads % tp_size == 0
  161. else:
  162. # Number of KV heads is less than TP size, so we replicate
  163. # the KV heads across multiple tensor parallel GPUs.
  164. assert tp_size % self.total_num_kv_heads == 0
  165. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  166. self.head_dim = hidden_size // self.total_num_heads
  167. self.q_size = self.num_heads * self.head_dim
  168. self.kv_size = self.num_kv_heads * self.head_dim
  169. self.scaling = self.head_dim**-0.5
  170. self.rope_theta = rope_theta
  171. self.qkv_proj = QKVParallelLinear(
  172. hidden_size,
  173. self.head_dim,
  174. self.total_num_heads,
  175. self.total_num_kv_heads,
  176. bias=False,
  177. quant_config=quant_config,
  178. )
  179. self.o_proj = RowParallelLinear(
  180. self.total_num_heads * self.head_dim,
  181. hidden_size,
  182. bias=False,
  183. quant_config=quant_config,
  184. )
  185. self.rotary_emb = get_rope(
  186. self.head_dim,
  187. rotary_dim=self.head_dim,
  188. max_position=max_position,
  189. base=int(self.rope_theta),
  190. is_neox_style=True,
  191. )
  192. self.attn = Attention(self.num_heads,
  193. self.head_dim,
  194. self.scaling,
  195. num_kv_heads=self.num_kv_heads,
  196. cache_config=cache_config,
  197. quant_config=quant_config)
  198. def forward(
  199. self,
  200. positions: torch.Tensor,
  201. hidden_states: torch.Tensor,
  202. kv_cache: torch.Tensor,
  203. attn_metadata: AttentionMetadata,
  204. ) -> torch.Tensor:
  205. qkv, _ = self.qkv_proj(hidden_states)
  206. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  207. q, k = self.rotary_emb(positions, q, k)
  208. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  209. output, _ = self.o_proj(attn_output)
  210. return output
  211. class MixtralDecoderLayer(nn.Module):
  212. def __init__(
  213. self,
  214. config: MixtralConfig,
  215. cache_config: Optional[CacheConfig] = None,
  216. quant_config: Optional[QuantizationConfig] = None,
  217. ) -> None:
  218. super().__init__()
  219. self.hidden_size = config.hidden_size
  220. # Requires transformers > 4.32.0
  221. rope_theta = getattr(config, "rope_theta", 10000)
  222. self.self_attn = MixtralAttention(
  223. hidden_size=self.hidden_size,
  224. num_heads=config.num_attention_heads,
  225. max_position=config.max_position_embeddings,
  226. num_kv_heads=config.num_key_value_heads,
  227. rope_theta=rope_theta,
  228. cache_config=cache_config,
  229. quant_config=quant_config)
  230. self.block_sparse_moe = MixtralMoE(config=config,
  231. quant_config=quant_config)
  232. self.input_layernorm = RMSNorm(config.hidden_size,
  233. eps=config.rms_norm_eps)
  234. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  235. eps=config.rms_norm_eps)
  236. def forward(
  237. self,
  238. positions: torch.Tensor,
  239. hidden_states: torch.Tensor,
  240. kv_cache: torch.Tensor,
  241. attn_metadata: AttentionMetadata,
  242. residual: Optional[torch.Tensor],
  243. ) -> torch.Tensor:
  244. # Self Attention
  245. if residual is None:
  246. residual = hidden_states
  247. hidden_states = self.input_layernorm(hidden_states)
  248. else:
  249. hidden_states, residual = self.input_layernorm(
  250. hidden_states, residual)
  251. hidden_states = self.self_attn(
  252. positions=positions,
  253. hidden_states=hidden_states,
  254. kv_cache=kv_cache,
  255. attn_metadata=attn_metadata,
  256. )
  257. # Fully Connected
  258. hidden_states, residual = self.post_attention_layernorm(
  259. hidden_states, residual)
  260. hidden_states = self.block_sparse_moe(hidden_states)
  261. return hidden_states, residual
  262. class MixtralModel(nn.Module):
  263. def __init__(
  264. self,
  265. config: MixtralConfig,
  266. cache_config: Optional[CacheConfig] = None,
  267. quant_config: Optional[QuantizationConfig] = None,
  268. ) -> None:
  269. super().__init__()
  270. self.padding_idx = config.pad_token_id
  271. self.vocab_size = config.vocab_size
  272. self.embed_tokens = VocabParallelEmbedding(
  273. config.vocab_size,
  274. config.hidden_size,
  275. )
  276. self.layers = nn.ModuleList([
  277. MixtralDecoderLayer(config,
  278. cache_config,
  279. quant_config=quant_config)
  280. for _ in range(config.num_hidden_layers)
  281. ])
  282. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  283. def forward(
  284. self,
  285. input_ids: torch.Tensor,
  286. positions: torch.Tensor,
  287. kv_caches: List[torch.Tensor],
  288. attn_metadata: AttentionMetadata,
  289. ) -> torch.Tensor:
  290. hidden_states = self.embed_tokens(input_ids)
  291. residual = None
  292. for i in range(len(self.layers)):
  293. layer = self.layers[i]
  294. hidden_states, residual = layer(positions, hidden_states,
  295. kv_caches[i], attn_metadata,
  296. residual)
  297. hidden_states, _ = self.norm(hidden_states, residual)
  298. return hidden_states
  299. class MixtralForCausalLM(nn.Module):
  300. fall_back_to_pt_during_load = False
  301. def __init__(
  302. self,
  303. config: MixtralConfig,
  304. cache_config: Optional[CacheConfig] = None,
  305. quant_config: Optional[QuantizationConfig] = None,
  306. ) -> None:
  307. super().__init__()
  308. self.config = config
  309. self.quant_config = quant_config
  310. self.model = MixtralModel(config, cache_config, quant_config)
  311. self.lm_head = ParallelLMHead(config.vocab_size,
  312. config.hidden_size,
  313. quant_config=quant_config)
  314. self.logits_processor = LogitsProcessor(config.vocab_size)
  315. self.sampler = Sampler()
  316. def forward(
  317. self,
  318. input_ids: torch.Tensor,
  319. positions: torch.Tensor,
  320. kv_caches: List[torch.Tensor],
  321. attn_metadata: AttentionMetadata,
  322. intermediate_tensors: Optional[IntermediateTensors] = None,
  323. ) -> torch.Tensor:
  324. hidden_states = self.model(input_ids, positions, kv_caches,
  325. attn_metadata)
  326. return hidden_states
  327. def compute_logits(
  328. self,
  329. hidden_states: torch.Tensor,
  330. sampling_metadata: SamplingMetadata,
  331. ) -> Optional[torch.Tensor]:
  332. logits = self.logits_processor(self.lm_head, hidden_states,
  333. sampling_metadata)
  334. return logits
  335. def sample(
  336. self,
  337. logits: Optional[torch.Tensor],
  338. sampling_metadata: SamplingMetadata,
  339. ) -> Optional[SamplerOutput]:
  340. next_tokens = self.sampler(logits, sampling_metadata)
  341. return next_tokens
  342. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  343. stacked_params_mapping = [
  344. # (param_name, shard_name, shard_id)
  345. ("qkv_proj", "q_proj", "q"),
  346. ("qkv_proj", "k_proj", "k"),
  347. ("qkv_proj", "v_proj", "v"),
  348. ]
  349. params_dict = dict(self.named_parameters())
  350. for name, loaded_weight in weights:
  351. if "rotary_emb.inv_freq" in name:
  352. continue
  353. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  354. if weight_name not in name:
  355. continue
  356. name = name.replace(weight_name, param_name)
  357. # Skip loading extra bias for GPTQ models.
  358. if name.endswith(".bias") and name not in params_dict:
  359. continue
  360. param = params_dict[name]
  361. weight_loader = param.weight_loader
  362. weight_loader(param, loaded_weight, shard_id)
  363. break
  364. else:
  365. # Skip loading extra bias for GPTQ models.
  366. if name.endswith(".bias") and name not in params_dict:
  367. continue
  368. # Skip experts that are not assigned to this worker.
  369. if ("block_sparse_moe.experts." in name
  370. and name not in params_dict):
  371. continue
  372. param = params_dict[name]
  373. weight_loader = getattr(param, "weight_loader",
  374. default_weight_loader)
  375. weight_loader(param, loaded_weight)