mixtral.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Mixtral model."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import MixtralConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.distributed import (get_pp_group,
  32. get_tensor_model_parallel_world_size)
  33. from aphrodite.modeling.layers.fused_moe import FusedMoE
  34. from aphrodite.modeling.layers.layernorm import RMSNorm
  35. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  36. ReplicatedLinear,
  37. RowParallelLinear)
  38. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  43. from aphrodite.modeling.model_loader.weight_utils import (
  44. default_weight_loader, maybe_remap_kv_scale_name)
  45. from aphrodite.modeling.models.utils import (is_pp_missing_parameter,
  46. make_layers)
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. from aphrodite.quantization.base_config import QuantizationConfig
  49. from .interfaces import SupportsLoRA
  50. class MixtralMoE(nn.Module):
  51. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  52. across all ranks.
  53. Each expert's weights are sharded across all ranks and a fused MoE
  54. kernel is used for the forward pass, and finally we reduce the outputs
  55. across ranks.
  56. """
  57. def __init__(self,
  58. num_experts: int,
  59. top_k: int,
  60. hidden_size: int,
  61. intermediate_size: int,
  62. params_dtype: Optional[torch.dtype] = None,
  63. quant_config: Optional[QuantizationConfig] = None,
  64. tp_size: Optional[int] = None,
  65. prefix: str = ""):
  66. super().__init__()
  67. self.hidden_size = hidden_size
  68. # Gate always runs at half / full precision for now.
  69. self.gate = ReplicatedLinear(hidden_size,
  70. num_experts,
  71. bias=False,
  72. params_dtype=params_dtype,
  73. quant_config=None,
  74. prefix=f"{prefix}.gate")
  75. self.experts = FusedMoE(num_experts=num_experts,
  76. top_k=top_k,
  77. hidden_size=hidden_size,
  78. intermediate_size=intermediate_size,
  79. params_dtype=params_dtype,
  80. reduce_results=True,
  81. renormalize=True,
  82. quant_config=quant_config,
  83. tp_size=tp_size,
  84. prefix=f"{prefix}.experts")
  85. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  86. # NOTE: hidden_states can have either 1D or 2D shape.
  87. orig_shape = hidden_states.shape
  88. hidden_states = hidden_states.view(-1, self.hidden_size)
  89. # router_logits: (num_tokens, n_experts)
  90. router_logits, _ = self.gate(hidden_states)
  91. final_hidden_states = self.experts(hidden_states, router_logits)
  92. return final_hidden_states.view(orig_shape)
  93. class MixtralAttention(nn.Module):
  94. def __init__(
  95. self,
  96. hidden_size: int,
  97. num_heads: int,
  98. num_kv_heads: int,
  99. max_position: int = 4096 * 32,
  100. rope_theta: float = 10000,
  101. cache_config: Optional[CacheConfig] = None,
  102. quant_config: Optional[QuantizationConfig] = None,
  103. prefix: str = "",
  104. ) -> None:
  105. super().__init__()
  106. self.hidden_size = hidden_size
  107. tp_size = get_tensor_model_parallel_world_size()
  108. self.total_num_heads = num_heads
  109. assert self.total_num_heads % tp_size == 0
  110. self.num_heads = self.total_num_heads // tp_size
  111. self.total_num_kv_heads = num_kv_heads
  112. if self.total_num_kv_heads >= tp_size:
  113. # Number of KV heads is greater than TP size, so we partition
  114. # the KV heads across multiple tensor parallel GPUs.
  115. assert self.total_num_kv_heads % tp_size == 0
  116. else:
  117. # Number of KV heads is less than TP size, so we replicate
  118. # the KV heads across multiple tensor parallel GPUs.
  119. assert tp_size % self.total_num_kv_heads == 0
  120. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  121. self.head_dim = hidden_size // self.total_num_heads
  122. self.q_size = self.num_heads * self.head_dim
  123. self.kv_size = self.num_kv_heads * self.head_dim
  124. self.scaling = self.head_dim**-0.5
  125. self.rope_theta = rope_theta
  126. self.qkv_proj = QKVParallelLinear(
  127. hidden_size,
  128. self.head_dim,
  129. self.total_num_heads,
  130. self.total_num_kv_heads,
  131. bias=False,
  132. quant_config=quant_config,
  133. prefix=f"{prefix}.qkv_proj",
  134. )
  135. self.o_proj = RowParallelLinear(
  136. self.total_num_heads * self.head_dim,
  137. hidden_size,
  138. bias=False,
  139. quant_config=quant_config,
  140. prefix=f"{prefix}.o_proj",
  141. )
  142. self.rotary_emb = get_rope(
  143. self.head_dim,
  144. rotary_dim=self.head_dim,
  145. max_position=max_position,
  146. base=int(self.rope_theta),
  147. is_neox_style=True,
  148. )
  149. self.attn = Attention(self.num_heads,
  150. self.head_dim,
  151. self.scaling,
  152. num_kv_heads=self.num_kv_heads,
  153. cache_config=cache_config,
  154. quant_config=quant_config)
  155. def forward(
  156. self,
  157. positions: torch.Tensor,
  158. hidden_states: torch.Tensor,
  159. kv_cache: torch.Tensor,
  160. attn_metadata: AttentionMetadata,
  161. ) -> torch.Tensor:
  162. qkv, _ = self.qkv_proj(hidden_states)
  163. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  164. q, k = self.rotary_emb(positions, q, k)
  165. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  166. output, _ = self.o_proj(attn_output)
  167. return output
  168. class MixtralDecoderLayer(nn.Module):
  169. def __init__(
  170. self,
  171. config: MixtralConfig,
  172. cache_config: Optional[CacheConfig] = None,
  173. quant_config: Optional[QuantizationConfig] = None,
  174. prefix: str = "",
  175. ) -> None:
  176. super().__init__()
  177. self.hidden_size = config.hidden_size
  178. # Requires transformers > 4.32.0
  179. rope_theta = getattr(config, "rope_theta", 10000)
  180. self.self_attn = MixtralAttention(
  181. hidden_size=self.hidden_size,
  182. num_heads=config.num_attention_heads,
  183. max_position=config.max_position_embeddings,
  184. num_kv_heads=config.num_key_value_heads,
  185. rope_theta=rope_theta,
  186. cache_config=cache_config,
  187. quant_config=quant_config,
  188. prefix=f"{prefix}.self_attn")
  189. self.block_sparse_moe = MixtralMoE(
  190. num_experts=config.num_local_experts,
  191. top_k=config.num_experts_per_tok,
  192. hidden_size=config.hidden_size,
  193. intermediate_size=config.intermediate_size,
  194. quant_config=quant_config,
  195. prefix=f"{prefix}.block_sparse_moe")
  196. self.input_layernorm = RMSNorm(config.hidden_size,
  197. eps=config.rms_norm_eps)
  198. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  199. eps=config.rms_norm_eps)
  200. def forward(
  201. self,
  202. positions: torch.Tensor,
  203. hidden_states: torch.Tensor,
  204. kv_cache: torch.Tensor,
  205. attn_metadata: AttentionMetadata,
  206. residual: Optional[torch.Tensor],
  207. ) -> torch.Tensor:
  208. # Self Attention
  209. if residual is None:
  210. residual = hidden_states
  211. hidden_states = self.input_layernorm(hidden_states)
  212. else:
  213. hidden_states, residual = self.input_layernorm(
  214. hidden_states, residual)
  215. hidden_states = self.self_attn(
  216. positions=positions,
  217. hidden_states=hidden_states,
  218. kv_cache=kv_cache,
  219. attn_metadata=attn_metadata,
  220. )
  221. # Fully Connected
  222. hidden_states, residual = self.post_attention_layernorm(
  223. hidden_states, residual)
  224. hidden_states = self.block_sparse_moe(hidden_states)
  225. return hidden_states, residual
  226. class MixtralModel(nn.Module):
  227. def __init__(
  228. self,
  229. config: MixtralConfig,
  230. cache_config: Optional[CacheConfig] = None,
  231. quant_config: Optional[QuantizationConfig] = None,
  232. lora_config: Optional[LoRAConfig] = None,
  233. prefix: str = "",
  234. ) -> None:
  235. super().__init__()
  236. self.padding_idx = config.pad_token_id
  237. lora_vocab = (lora_config.lora_extra_vocab_size *
  238. (lora_config.max_loras or 1)) if lora_config else 0
  239. self.vocab_size = config.vocab_size + lora_vocab
  240. self.org_vocab_size = config.vocab_size
  241. self.embed_tokens = VocabParallelEmbedding(
  242. self.vocab_size,
  243. config.hidden_size,
  244. org_num_embeddings=config.vocab_size,
  245. )
  246. self.start_layer, self.end_layer, self.layers = make_layers(
  247. config.num_hidden_layers,
  248. lambda prefix: MixtralDecoderLayer(
  249. config, cache_config, quant_config=quant_config, prefix=prefix
  250. ),
  251. prefix=f"{prefix}.layers")
  252. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  253. def forward(
  254. self,
  255. input_ids: torch.Tensor,
  256. positions: torch.Tensor,
  257. kv_caches: List[torch.Tensor],
  258. attn_metadata: AttentionMetadata,
  259. intermediate_tensors: Optional[IntermediateTensors],
  260. ) -> torch.Tensor:
  261. if get_pp_group().is_first_rank:
  262. hidden_states = self.embed_tokens(input_ids)
  263. residual = None
  264. else:
  265. assert intermediate_tensors is not None
  266. hidden_states = intermediate_tensors["hidden_states"]
  267. residual = intermediate_tensors["residual"]
  268. for i in range(self.start_layer, self.end_layer):
  269. layer = self.layers[i]
  270. hidden_states, residual = layer(positions, hidden_states,
  271. kv_caches[i - self.start_layer],
  272. attn_metadata, residual)
  273. if not get_pp_group().is_last_rank:
  274. return IntermediateTensors({
  275. "hidden_states": hidden_states,
  276. "residual": residual
  277. })
  278. hidden_states, _ = self.norm(hidden_states, residual)
  279. return hidden_states
  280. class MixtralForCausalLM(nn.Module, SupportsLoRA):
  281. fall_back_to_pt_during_load = False
  282. packed_modules_mapping = {
  283. "qkv_proj": [
  284. "q_proj",
  285. "k_proj",
  286. "v_proj",
  287. ],
  288. }
  289. # LoRA specific attributes
  290. supported_lora_modules = [
  291. "qkv_proj",
  292. "o_proj",
  293. "embed_tokens",
  294. "lm_head",
  295. ]
  296. embedding_modules = {
  297. "embed_tokens": "input_embeddings",
  298. "lm_head": "output_embeddings",
  299. }
  300. embedding_padding_modules = ["lm_head"]
  301. def __init__(
  302. self,
  303. config: MixtralConfig,
  304. cache_config: Optional[CacheConfig] = None,
  305. quant_config: Optional[QuantizationConfig] = None,
  306. lora_config: Optional[LoRAConfig] = None,
  307. ) -> None:
  308. super().__init__()
  309. self.config = config
  310. self.lora_config = lora_config
  311. self.model = MixtralModel(config,
  312. cache_config,
  313. quant_config,
  314. lora_config=lora_config,
  315. prefix="model")
  316. self.unpadded_vocab_size = config.vocab_size
  317. if lora_config:
  318. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  319. self.lm_head = ParallelLMHead(
  320. self.unpadded_vocab_size,
  321. config.hidden_size,
  322. org_num_embeddings=config.vocab_size,
  323. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  324. # We need bigger padding if using lora for kernel
  325. # compatibility
  326. if not lora_config else lora_config.lora_vocab_padding_size,
  327. quant_config=quant_config,
  328. )
  329. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  330. config.vocab_size)
  331. self.sampler = Sampler()
  332. def forward(
  333. self,
  334. input_ids: torch.Tensor,
  335. positions: torch.Tensor,
  336. kv_caches: List[torch.Tensor],
  337. attn_metadata: AttentionMetadata,
  338. intermediate_tensors: Optional[IntermediateTensors] = None,
  339. ) -> torch.Tensor:
  340. hidden_states = self.model(input_ids, positions, kv_caches,
  341. attn_metadata, intermediate_tensors)
  342. return hidden_states
  343. def compute_logits(
  344. self,
  345. hidden_states: torch.Tensor,
  346. sampling_metadata: SamplingMetadata,
  347. ) -> Optional[torch.Tensor]:
  348. logits = self.logits_processor(self.lm_head, hidden_states,
  349. sampling_metadata)
  350. return logits
  351. def make_empty_intermediate_tensors(
  352. self, batch_size: int, dtype: torch.dtype,
  353. device: torch.device) -> IntermediateTensors:
  354. return IntermediateTensors({
  355. "hidden_states":
  356. torch.zeros((batch_size, self.config.hidden_size),
  357. dtype=dtype,
  358. device=device),
  359. "residual":
  360. torch.zeros((batch_size, self.config.hidden_size),
  361. dtype=dtype,
  362. device=device),
  363. })
  364. def sample(
  365. self,
  366. logits: Optional[torch.Tensor],
  367. sampling_metadata: SamplingMetadata,
  368. ) -> Optional[SamplerOutput]:
  369. next_tokens = self.sampler(logits, sampling_metadata)
  370. return next_tokens
  371. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  372. stacked_params_mapping = [
  373. # (param_name, shard_name, shard_id)
  374. ("qkv_proj", "q_proj", "q"),
  375. ("qkv_proj", "k_proj", "k"),
  376. ("qkv_proj", "v_proj", "v"),
  377. ]
  378. # Params for weights, fp8 weight scales, fp8 activation scales
  379. # (param_name, weight_name, expert_id, shard_id)
  380. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  381. ckpt_gate_proj_name="w1",
  382. ckpt_down_proj_name="w2",
  383. ckpt_up_proj_name="w3",
  384. num_experts=self.config.num_local_experts)
  385. params_dict = dict(self.named_parameters())
  386. for name, loaded_weight in weights:
  387. if "rotary_emb.inv_freq" in name:
  388. continue
  389. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  390. if weight_name not in name:
  391. continue
  392. name = name.replace(weight_name, param_name)
  393. # Skip loading extra bias for GPTQ models.
  394. if name.endswith(".bias") and name not in params_dict:
  395. continue
  396. # Skip layers on other devices.
  397. if is_pp_missing_parameter(name, self):
  398. continue
  399. param = params_dict[name]
  400. weight_loader = param.weight_loader
  401. weight_loader(param, loaded_weight, shard_id)
  402. break
  403. else:
  404. for mapping in expert_params_mapping:
  405. param_name, weight_name, expert_id, shard_id = mapping
  406. if weight_name not in name:
  407. continue
  408. name = name.replace(weight_name, param_name)
  409. # Skip layers on other devices.
  410. if is_pp_missing_parameter(name, self):
  411. continue
  412. param = params_dict[name]
  413. weight_loader = param.weight_loader
  414. weight_loader(param,
  415. loaded_weight,
  416. name,
  417. shard_id=shard_id,
  418. expert_id=expert_id)
  419. break
  420. else:
  421. # Skip loading extra bias for GPTQ models.
  422. if name.endswith(".bias") and name not in params_dict:
  423. continue
  424. # Skip layers on other devices.
  425. if is_pp_missing_parameter(name, self):
  426. continue
  427. # Remapping the name of FP8 kv-scale.
  428. name = maybe_remap_kv_scale_name(name, params_dict)
  429. if name is None:
  430. continue
  431. param = params_dict[name]
  432. weight_loader = getattr(param, "weight_loader",
  433. default_weight_loader)
  434. weight_loader(param, loaded_weight)