mixtral_quant.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Mixtral model."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import numpy as np
  26. import torch
  27. import torch.nn.functional as F
  28. from torch import nn
  29. from transformers import MixtralConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import CacheConfig
  32. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  33. from aphrodite.common.utils import progress_bar
  34. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  35. get_tensor_model_parallel_world_size,
  36. tensor_model_parallel_all_reduce)
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  39. ReplicatedLinear,
  40. RowParallelLinear)
  41. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  42. from aphrodite.modeling.layers.rotary_embedding import get_rope
  43. from aphrodite.modeling.layers.sampler import Sampler
  44. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  45. ParallelLMHead, VocabParallelEmbedding)
  46. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. from aphrodite.quantization.base_config import QuantizationConfig
  49. class MixtralMLP(nn.Module):
  50. def __init__(
  51. self,
  52. num_experts: int,
  53. hidden_size: int,
  54. intermediate_size: int,
  55. quant_config: Optional[QuantizationConfig] = None,
  56. ) -> None:
  57. super().__init__()
  58. self.num_experts = num_experts
  59. self.ffn_dim = intermediate_size
  60. self.hidden_dim = hidden_size
  61. self.w1 = ReplicatedLinear(self.hidden_dim,
  62. self.ffn_dim,
  63. bias=False,
  64. quant_config=quant_config)
  65. self.w2 = ReplicatedLinear(self.ffn_dim,
  66. self.hidden_dim,
  67. bias=False,
  68. quant_config=quant_config)
  69. self.w3 = ReplicatedLinear(self.hidden_dim,
  70. self.ffn_dim,
  71. bias=False,
  72. quant_config=quant_config)
  73. # TODO: Use Aphrodite's SiluAndMul
  74. self.act_fn = nn.SiLU()
  75. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  76. w1_out, _ = self.w1(hidden_states)
  77. w1_out = self.act_fn(w1_out)
  78. w3_out, _ = self.w3(hidden_states)
  79. current_hidden_states = w1_out * w3_out
  80. current_hidden_states, _ = self.w2(current_hidden_states)
  81. return current_hidden_states
  82. class MixtralMoE(nn.Module):
  83. def __init__(
  84. self,
  85. config: MixtralConfig,
  86. quant_config: Optional[QuantizationConfig] = None,
  87. ):
  88. super().__init__()
  89. self.config = config
  90. self.rank = get_tensor_model_parallel_rank()
  91. self.tp_size = get_tensor_model_parallel_world_size()
  92. self.num_total_experts = config.num_local_experts
  93. self.top_k = config.num_experts_per_tok
  94. if self.tp_size > self.num_total_experts:
  95. raise ValueError(
  96. f"Tensor parallel size {self.tp_size} is greater than "
  97. f"the number of experts {self.num_total_experts}.")
  98. # Split experts equally between ranks
  99. self.expert_indicies = np.array_split(range(
  100. self.num_total_experts), self.tp_size)[self.rank].tolist()
  101. if not self.expert_indicies:
  102. raise ValueError(
  103. f"Rank {self.rank} has no experts assigned to it.")
  104. self.experts = nn.ModuleList([
  105. MixtralMLP(self.num_total_experts,
  106. config.hidden_size,
  107. config.intermediate_size,
  108. quant_config=quant_config)
  109. if idx in self.expert_indicies else None
  110. for idx in range(self.num_total_experts)
  111. ])
  112. self.gate = ReplicatedLinear(config.hidden_size,
  113. self.num_total_experts,
  114. bias=False,
  115. quant_config=None)
  116. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  117. num_tokens, hidden_dim = hidden_states.shape
  118. hidden_states = hidden_states.view(-1, hidden_dim)
  119. # router_logits: (num_tokens, n_experts)
  120. router_logits, _ = self.gate(hidden_states)
  121. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  122. routing_weights, selected_experts = torch.topk(routing_weights,
  123. self.top_k,
  124. dim=-1)
  125. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  126. final_hidden_states = None
  127. for expert_idx in self.expert_indicies:
  128. expert_layer = self.experts[expert_idx]
  129. expert_mask = (selected_experts == expert_idx)
  130. expert_weights = (routing_weights * expert_mask).sum(dim=-1,
  131. keepdim=True)
  132. current_hidden_states = expert_layer(hidden_states).mul_(
  133. expert_weights)
  134. if final_hidden_states is None:
  135. final_hidden_states = current_hidden_states
  136. else:
  137. final_hidden_states.add_(current_hidden_states)
  138. return tensor_model_parallel_all_reduce(final_hidden_states).view(
  139. num_tokens, hidden_dim)
  140. class MixtralAttention(nn.Module):
  141. def __init__(
  142. self,
  143. hidden_size: int,
  144. num_heads: int,
  145. num_kv_heads: int,
  146. max_position: int = 4096 * 32,
  147. rope_theta: float = 10000,
  148. quant_config: Optional[QuantizationConfig] = None,
  149. cache_config: Optional[CacheConfig] = None,
  150. ) -> None:
  151. super().__init__()
  152. self.hidden_size = hidden_size
  153. tp_size = get_tensor_model_parallel_world_size()
  154. self.total_num_heads = num_heads
  155. assert self.total_num_heads % tp_size == 0
  156. self.num_heads = self.total_num_heads // tp_size
  157. self.total_num_kv_heads = num_kv_heads
  158. if self.total_num_kv_heads >= tp_size:
  159. # Number of KV heads is greater than TP size, so we partition
  160. # the KV heads across multiple tensor parallel GPUs.
  161. assert self.total_num_kv_heads % tp_size == 0
  162. else:
  163. # Number of KV heads is less than TP size, so we replicate
  164. # the KV heads across multiple tensor parallel GPUs.
  165. assert tp_size % self.total_num_kv_heads == 0
  166. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  167. self.head_dim = hidden_size // self.total_num_heads
  168. self.q_size = self.num_heads * self.head_dim
  169. self.kv_size = self.num_kv_heads * self.head_dim
  170. self.scaling = self.head_dim**-0.5
  171. self.rope_theta = rope_theta
  172. self.qkv_proj = QKVParallelLinear(
  173. hidden_size,
  174. self.head_dim,
  175. self.total_num_heads,
  176. self.total_num_kv_heads,
  177. bias=False,
  178. quant_config=quant_config,
  179. )
  180. self.o_proj = RowParallelLinear(
  181. self.total_num_heads * self.head_dim,
  182. hidden_size,
  183. bias=False,
  184. quant_config=quant_config,
  185. )
  186. self.rotary_emb = get_rope(
  187. self.head_dim,
  188. rotary_dim=self.head_dim,
  189. max_position=max_position,
  190. base=int(self.rope_theta),
  191. is_neox_style=True,
  192. )
  193. self.attn = Attention(self.num_heads,
  194. self.head_dim,
  195. self.scaling,
  196. num_kv_heads=self.num_kv_heads,
  197. cache_config=cache_config,
  198. quant_config=quant_config)
  199. def forward(
  200. self,
  201. positions: torch.Tensor,
  202. hidden_states: torch.Tensor,
  203. kv_cache: torch.Tensor,
  204. attn_metadata: AttentionMetadata,
  205. ) -> torch.Tensor:
  206. qkv, _ = self.qkv_proj(hidden_states)
  207. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  208. q, k = self.rotary_emb(positions, q, k)
  209. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  210. output, _ = self.o_proj(attn_output)
  211. return output
  212. class MixtralDecoderLayer(nn.Module):
  213. def __init__(
  214. self,
  215. config: MixtralConfig,
  216. cache_config: Optional[CacheConfig] = None,
  217. quant_config: Optional[QuantizationConfig] = None,
  218. ) -> None:
  219. super().__init__()
  220. self.hidden_size = config.hidden_size
  221. # Requires transformers > 4.32.0
  222. rope_theta = getattr(config, "rope_theta", 10000)
  223. self.self_attn = MixtralAttention(
  224. hidden_size=self.hidden_size,
  225. num_heads=config.num_attention_heads,
  226. max_position=config.max_position_embeddings,
  227. num_kv_heads=config.num_key_value_heads,
  228. rope_theta=rope_theta,
  229. cache_config=cache_config,
  230. quant_config=quant_config)
  231. self.block_sparse_moe = MixtralMoE(config=config,
  232. quant_config=quant_config)
  233. self.input_layernorm = RMSNorm(config.hidden_size,
  234. eps=config.rms_norm_eps)
  235. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  236. eps=config.rms_norm_eps)
  237. def forward(
  238. self,
  239. positions: torch.Tensor,
  240. hidden_states: torch.Tensor,
  241. kv_cache: torch.Tensor,
  242. attn_metadata: AttentionMetadata,
  243. residual: Optional[torch.Tensor],
  244. ) -> torch.Tensor:
  245. # Self Attention
  246. if residual is None:
  247. residual = hidden_states
  248. hidden_states = self.input_layernorm(hidden_states)
  249. else:
  250. hidden_states, residual = self.input_layernorm(
  251. hidden_states, residual)
  252. hidden_states = self.self_attn(
  253. positions=positions,
  254. hidden_states=hidden_states,
  255. kv_cache=kv_cache,
  256. attn_metadata=attn_metadata,
  257. )
  258. # Fully Connected
  259. hidden_states, residual = self.post_attention_layernorm(
  260. hidden_states, residual)
  261. hidden_states = self.block_sparse_moe(hidden_states)
  262. return hidden_states, residual
  263. class MixtralModel(nn.Module):
  264. def __init__(
  265. self,
  266. config: MixtralConfig,
  267. cache_config: Optional[CacheConfig] = None,
  268. quant_config: Optional[QuantizationConfig] = None,
  269. ) -> None:
  270. super().__init__()
  271. self.padding_idx = config.pad_token_id
  272. self.vocab_size = config.vocab_size
  273. self.embed_tokens = VocabParallelEmbedding(
  274. config.vocab_size,
  275. config.hidden_size,
  276. )
  277. self.layers = nn.ModuleList([
  278. MixtralDecoderLayer(config,
  279. cache_config,
  280. quant_config=quant_config)
  281. for _ in range(config.num_hidden_layers)
  282. ])
  283. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  284. def forward(
  285. self,
  286. input_ids: torch.Tensor,
  287. positions: torch.Tensor,
  288. kv_caches: List[torch.Tensor],
  289. attn_metadata: AttentionMetadata,
  290. ) -> torch.Tensor:
  291. hidden_states = self.embed_tokens(input_ids)
  292. residual = None
  293. for i in range(len(self.layers)):
  294. layer = self.layers[i]
  295. hidden_states, residual = layer(positions, hidden_states,
  296. kv_caches[i], attn_metadata,
  297. residual)
  298. hidden_states, _ = self.norm(hidden_states, residual)
  299. return hidden_states
  300. class MixtralForCausalLM(nn.Module):
  301. fall_back_to_pt_during_load = False
  302. def __init__(
  303. self,
  304. config: MixtralConfig,
  305. cache_config: Optional[CacheConfig] = None,
  306. quant_config: Optional[QuantizationConfig] = None,
  307. ) -> None:
  308. super().__init__()
  309. self.config = config
  310. self.quant_config = quant_config
  311. self.model = MixtralModel(config, cache_config, quant_config)
  312. self.lm_head = ParallelLMHead(config.vocab_size,
  313. config.hidden_size,
  314. quant_config=quant_config)
  315. self.logits_processor = LogitsProcessor(config.vocab_size)
  316. self.sampler = Sampler()
  317. def forward(
  318. self,
  319. input_ids: torch.Tensor,
  320. positions: torch.Tensor,
  321. kv_caches: List[torch.Tensor],
  322. attn_metadata: AttentionMetadata,
  323. intermediate_tensors: Optional[IntermediateTensors] = None,
  324. ) -> torch.Tensor:
  325. hidden_states = self.model(input_ids, positions, kv_caches,
  326. attn_metadata)
  327. return hidden_states
  328. def compute_logits(
  329. self,
  330. hidden_states: torch.Tensor,
  331. sampling_metadata: SamplingMetadata,
  332. ) -> Optional[torch.Tensor]:
  333. logits = self.logits_processor(self.lm_head, hidden_states,
  334. sampling_metadata)
  335. return logits
  336. def sample(
  337. self,
  338. logits: Optional[torch.Tensor],
  339. sampling_metadata: SamplingMetadata,
  340. ) -> Optional[SamplerOutput]:
  341. next_tokens = self.sampler(logits, sampling_metadata)
  342. return next_tokens
  343. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  344. stacked_params_mapping = [
  345. # (param_name, shard_name, shard_id)
  346. ("qkv_proj", "q_proj", "q"),
  347. ("qkv_proj", "k_proj", "k"),
  348. ("qkv_proj", "v_proj", "v"),
  349. ]
  350. params_dict = dict(self.named_parameters())
  351. weights_list = list(weights)
  352. for name, loaded_weight in progress_bar(weights_list,
  353. desc="Loading modules..."):
  354. if "rotary_emb.inv_freq" in name:
  355. continue
  356. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  357. if weight_name not in name:
  358. continue
  359. name = name.replace(weight_name, param_name)
  360. # Skip loading extra bias for GPTQ models.
  361. if name.endswith(".bias") and name not in params_dict:
  362. continue
  363. param = params_dict[name]
  364. weight_loader = param.weight_loader
  365. weight_loader(param, loaded_weight, shard_id)
  366. break
  367. else:
  368. # Skip loading extra bias for GPTQ models.
  369. if name.endswith(".bias") and name not in params_dict:
  370. continue
  371. # Skip experts that are not assigned to this worker.
  372. if ("block_sparse_moe.experts." in name
  373. and name not in params_dict):
  374. continue
  375. param = params_dict[name]
  376. weight_loader = getattr(param, "weight_loader",
  377. default_weight_loader)
  378. weight_loader(param, loaded_weight)