mixtral.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Mixtral model."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import MixtralConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import print_warning_once
  32. from aphrodite.distributed import get_tensor_model_parallel_world_size
  33. from aphrodite.modeling.layers.fused_moe import FusedMoE
  34. from aphrodite.modeling.layers.layernorm import RMSNorm
  35. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  36. ReplicatedLinear,
  37. RowParallelLinear)
  38. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  43. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  44. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  45. from aphrodite.quantization.base_config import QuantizationConfig
  46. from .interfaces import SupportsLoRA
  47. class MixtralMoE(nn.Module):
  48. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  49. across all ranks.
  50. Each expert's weights are sharded across all ranks and a fused MoE
  51. kernel is used for the forward pass, and finally we reduce the outputs
  52. across ranks.
  53. """
  54. def __init__(self,
  55. num_experts: int,
  56. top_k: int,
  57. hidden_size: int,
  58. intermediate_size: int,
  59. params_dtype: Optional[torch.dtype] = None,
  60. quant_config: Optional[QuantizationConfig] = None,
  61. tp_size: Optional[int] = None):
  62. super().__init__()
  63. self.hidden_size = hidden_size
  64. # Gate always runs at half / full precision for now.
  65. self.gate = ReplicatedLinear(hidden_size,
  66. num_experts,
  67. bias=False,
  68. params_dtype=params_dtype,
  69. quant_config=None)
  70. self.experts = FusedMoE(num_experts=num_experts,
  71. top_k=top_k,
  72. hidden_size=hidden_size,
  73. intermediate_size=intermediate_size,
  74. params_dtype=params_dtype,
  75. reduce_results=True,
  76. renormalize=True,
  77. quant_config=quant_config,
  78. tp_size=tp_size)
  79. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  80. num_tokens, hidden_size = hidden_states.shape
  81. hidden_states = hidden_states.view(-1, self.hidden_size)
  82. # router_logits: (num_tokens, n_experts)
  83. router_logits, _ = self.gate(hidden_states)
  84. final_hidden_states = self.experts(hidden_states, router_logits)
  85. return final_hidden_states.view(num_tokens, hidden_size)
  86. class MixtralAttention(nn.Module):
  87. def __init__(
  88. self,
  89. hidden_size: int,
  90. num_heads: int,
  91. num_kv_heads: int,
  92. max_position: int = 4096 * 32,
  93. rope_theta: float = 10000,
  94. cache_config: Optional[CacheConfig] = None,
  95. quant_config: Optional[QuantizationConfig] = None,
  96. ) -> None:
  97. super().__init__()
  98. self.hidden_size = hidden_size
  99. tp_size = get_tensor_model_parallel_world_size()
  100. self.total_num_heads = num_heads
  101. assert self.total_num_heads % tp_size == 0
  102. self.num_heads = self.total_num_heads // tp_size
  103. self.total_num_kv_heads = num_kv_heads
  104. if self.total_num_kv_heads >= tp_size:
  105. # Number of KV heads is greater than TP size, so we partition
  106. # the KV heads across multiple tensor parallel GPUs.
  107. assert self.total_num_kv_heads % tp_size == 0
  108. else:
  109. # Number of KV heads is less than TP size, so we replicate
  110. # the KV heads across multiple tensor parallel GPUs.
  111. assert tp_size % self.total_num_kv_heads == 0
  112. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  113. self.head_dim = hidden_size // self.total_num_heads
  114. self.q_size = self.num_heads * self.head_dim
  115. self.kv_size = self.num_kv_heads * self.head_dim
  116. self.scaling = self.head_dim**-0.5
  117. self.rope_theta = rope_theta
  118. self.qkv_proj = QKVParallelLinear(
  119. hidden_size,
  120. self.head_dim,
  121. self.total_num_heads,
  122. self.total_num_kv_heads,
  123. bias=False,
  124. quant_config=quant_config,
  125. )
  126. self.o_proj = RowParallelLinear(
  127. self.total_num_heads * self.head_dim,
  128. hidden_size,
  129. bias=False,
  130. quant_config=quant_config,
  131. )
  132. self.rotary_emb = get_rope(
  133. self.head_dim,
  134. rotary_dim=self.head_dim,
  135. max_position=max_position,
  136. base=int(self.rope_theta),
  137. is_neox_style=True,
  138. )
  139. self.attn = Attention(self.num_heads,
  140. self.head_dim,
  141. self.scaling,
  142. num_kv_heads=self.num_kv_heads,
  143. cache_config=cache_config,
  144. quant_config=quant_config)
  145. def forward(
  146. self,
  147. positions: torch.Tensor,
  148. hidden_states: torch.Tensor,
  149. kv_cache: torch.Tensor,
  150. attn_metadata: AttentionMetadata,
  151. ) -> torch.Tensor:
  152. qkv, _ = self.qkv_proj(hidden_states)
  153. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  154. q, k = self.rotary_emb(positions, q, k)
  155. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  156. output, _ = self.o_proj(attn_output)
  157. return output
  158. class MixtralDecoderLayer(nn.Module):
  159. def __init__(
  160. self,
  161. config: MixtralConfig,
  162. cache_config: Optional[CacheConfig] = None,
  163. quant_config: Optional[QuantizationConfig] = None,
  164. ) -> None:
  165. super().__init__()
  166. self.hidden_size = config.hidden_size
  167. # Requires transformers > 4.32.0
  168. rope_theta = getattr(config, "rope_theta", 10000)
  169. self.self_attn = MixtralAttention(
  170. hidden_size=self.hidden_size,
  171. num_heads=config.num_attention_heads,
  172. max_position=config.max_position_embeddings,
  173. num_kv_heads=config.num_key_value_heads,
  174. rope_theta=rope_theta,
  175. cache_config=cache_config,
  176. quant_config=quant_config)
  177. self.block_sparse_moe = MixtralMoE(
  178. num_experts=config.num_local_experts,
  179. top_k=config.num_experts_per_tok,
  180. hidden_size=config.hidden_size,
  181. intermediate_size=config.intermediate_size,
  182. quant_config=quant_config)
  183. self.input_layernorm = RMSNorm(config.hidden_size,
  184. eps=config.rms_norm_eps)
  185. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  186. eps=config.rms_norm_eps)
  187. def forward(
  188. self,
  189. positions: torch.Tensor,
  190. hidden_states: torch.Tensor,
  191. kv_cache: torch.Tensor,
  192. attn_metadata: AttentionMetadata,
  193. residual: Optional[torch.Tensor],
  194. ) -> torch.Tensor:
  195. # Self Attention
  196. if residual is None:
  197. residual = hidden_states
  198. hidden_states = self.input_layernorm(hidden_states)
  199. else:
  200. hidden_states, residual = self.input_layernorm(
  201. hidden_states, residual)
  202. hidden_states = self.self_attn(
  203. positions=positions,
  204. hidden_states=hidden_states,
  205. kv_cache=kv_cache,
  206. attn_metadata=attn_metadata,
  207. )
  208. # Fully Connected
  209. hidden_states, residual = self.post_attention_layernorm(
  210. hidden_states, residual)
  211. hidden_states = self.block_sparse_moe(hidden_states)
  212. return hidden_states, residual
  213. class MixtralModel(nn.Module):
  214. def __init__(
  215. self,
  216. config: MixtralConfig,
  217. cache_config: Optional[CacheConfig] = None,
  218. quant_config: Optional[QuantizationConfig] = None,
  219. lora_config: Optional[LoRAConfig] = None,
  220. ) -> None:
  221. super().__init__()
  222. self.padding_idx = config.pad_token_id
  223. lora_vocab = (lora_config.lora_extra_vocab_size *
  224. (lora_config.max_loras or 1)) if lora_config else 0
  225. self.vocab_size = config.vocab_size + lora_vocab
  226. self.org_vocab_size = config.vocab_size
  227. self.embed_tokens = VocabParallelEmbedding(
  228. self.vocab_size,
  229. config.hidden_size,
  230. org_num_embeddings=config.vocab_size,
  231. )
  232. self.layers = nn.ModuleList([
  233. MixtralDecoderLayer(config,
  234. cache_config,
  235. quant_config=quant_config)
  236. for _ in range(config.num_hidden_layers)
  237. ])
  238. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  239. def forward(
  240. self,
  241. input_ids: torch.Tensor,
  242. positions: torch.Tensor,
  243. kv_caches: List[torch.Tensor],
  244. attn_metadata: AttentionMetadata,
  245. ) -> torch.Tensor:
  246. hidden_states = self.embed_tokens(input_ids)
  247. residual = None
  248. for i in range(len(self.layers)):
  249. layer = self.layers[i]
  250. hidden_states, residual = layer(positions, hidden_states,
  251. kv_caches[i], attn_metadata,
  252. residual)
  253. hidden_states, _ = self.norm(hidden_states, residual)
  254. return hidden_states
  255. class MixtralForCausalLM(nn.Module, SupportsLoRA):
  256. fall_back_to_pt_during_load = False
  257. packed_modules_mapping = {
  258. "qkv_proj": [
  259. "q_proj",
  260. "k_proj",
  261. "v_proj",
  262. ],
  263. }
  264. # LoRA specific attributes
  265. supported_lora_modules = [
  266. "qkv_proj",
  267. "o_proj",
  268. "embed_tokens",
  269. "lm_head",
  270. ]
  271. embedding_modules = {
  272. "embed_tokens": "input_embeddings",
  273. "lm_head": "output_embeddings",
  274. }
  275. embedding_padding_modules = ["lm_head"]
  276. def __init__(
  277. self,
  278. config: MixtralConfig,
  279. cache_config: Optional[CacheConfig] = None,
  280. quant_config: Optional[QuantizationConfig] = None,
  281. lora_config: Optional[LoRAConfig] = None,
  282. ) -> None:
  283. super().__init__()
  284. self.config = config
  285. self.lora_config = lora_config
  286. self.model = MixtralModel(config,
  287. cache_config,
  288. quant_config,
  289. lora_config=lora_config)
  290. self.unpadded_vocab_size = config.vocab_size
  291. if lora_config:
  292. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  293. self.lm_head = ParallelLMHead(
  294. self.unpadded_vocab_size,
  295. config.hidden_size,
  296. org_num_embeddings=config.vocab_size,
  297. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  298. # We need bigger padding if using lora for kernel
  299. # compatibility
  300. if not lora_config else lora_config.lora_vocab_padding_size,
  301. quant_config=quant_config,
  302. )
  303. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  304. config.vocab_size)
  305. self.sampler = Sampler()
  306. def forward(
  307. self,
  308. input_ids: torch.Tensor,
  309. positions: torch.Tensor,
  310. kv_caches: List[torch.Tensor],
  311. attn_metadata: AttentionMetadata,
  312. intermediate_tensors: Optional[IntermediateTensors] = None,
  313. ) -> torch.Tensor:
  314. hidden_states = self.model(input_ids, positions, kv_caches,
  315. attn_metadata)
  316. return hidden_states
  317. def compute_logits(self, hidden_states: torch.Tensor,
  318. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  319. logits = self.logits_processor(self.lm_head, hidden_states,
  320. sampling_metadata)
  321. return logits
  322. def sample(
  323. self,
  324. logits: Optional[torch.Tensor],
  325. sampling_metadata: SamplingMetadata,
  326. ) -> Optional[SamplerOutput]:
  327. next_tokens = self.sampler(logits, sampling_metadata)
  328. return next_tokens
  329. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  330. stacked_params_mapping = [
  331. # (param_name, shard_name, shard_id)
  332. ("qkv_proj", "q_proj", "q"),
  333. ("qkv_proj", "k_proj", "k"),
  334. ("qkv_proj", "v_proj", "v"),
  335. ]
  336. expert_params_mapping = [
  337. # These are the weight scales for the experts
  338. # (param_name, weight_name, expert_id, shard_id)
  339. ("experts.w13_scale"
  340. if weight_name in ["w1", "w3"] else "experts.w2_scale",
  341. f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
  342. shard_id) for expert_id in range(self.config.num_local_experts)
  343. for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
  344. ] + [
  345. # These are the weights for the experts
  346. # (param_name, weight_name, expert_id)
  347. ("experts.w13_weight"
  348. if weight_name in ["w1", "w3"] else "experts.w2_weight",
  349. f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
  350. for expert_id in range(self.config.num_local_experts)
  351. for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
  352. ] + [
  353. # These are the activation scales for the experts
  354. # (param_name, weight_name, expert_id)
  355. ("experts.a13_scale"
  356. if weight_name in ["w1", "w3"] else "experts.a2_scale",
  357. f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
  358. shard_id) for expert_id in range(self.config.num_local_experts)
  359. for shard_id, weight_name in enumerate(["w1", "w2", "w3"])
  360. ]
  361. params_dict = dict(self.named_parameters())
  362. for name, loaded_weight in weights:
  363. if "rotary_emb.inv_freq" in name:
  364. continue
  365. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  366. if weight_name not in name:
  367. continue
  368. name = name.replace(weight_name, param_name)
  369. # Skip loading extra bias for GPTQ models.
  370. if name.endswith(".bias") and name not in params_dict:
  371. continue
  372. param = params_dict[name]
  373. weight_loader = param.weight_loader
  374. weight_loader(param, loaded_weight, shard_id)
  375. break
  376. else:
  377. for mapping in expert_params_mapping:
  378. param_name, weight_name, expert_id, shard_id = mapping
  379. if weight_name not in name:
  380. continue
  381. name = name.replace(weight_name, param_name)
  382. param = params_dict[name]
  383. weight_loader = param.weight_loader
  384. weight_loader(param,
  385. loaded_weight,
  386. weight_name,
  387. shard_id=shard_id,
  388. expert_id=expert_id)
  389. break
  390. else:
  391. # Skip loading extra bias for GPTQ models.
  392. if name.endswith(".bias") and name not in params_dict:
  393. continue
  394. # Remapping the name of FP8 kv-scale.
  395. if name.endswith("kv_scale"):
  396. remapped_kv_scale_name = name.replace(
  397. ".kv_scale", ".attn.kv_scale")
  398. if remapped_kv_scale_name not in params_dict:
  399. print_warning_once(
  400. "Found kv scale in the checkpoint "
  401. f"(e.g. {name}), but not found the expected "
  402. f"name in the model "
  403. f"(e.g. {remapped_kv_scale_name}). "
  404. "kv-scale is not loaded.")
  405. continue
  406. else:
  407. name = remapped_kv_scale_name
  408. param = params_dict[name]
  409. weight_loader = getattr(param, "weight_loader",
  410. default_weight_loader)
  411. weight_loader(param, loaded_weight)