1
0

mixtral.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The vLLM team.
  5. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  6. #
  7. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  8. # and OPT implementations in this library. It has been modified from its
  9. # original forms to accommodate minor architectural differences compared
  10. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  11. #
  12. # Licensed under the Apache License, Version 2.0 (the "License");
  13. # you may not use this file except in compliance with the License.
  14. # You may obtain a copy of the License at
  15. #
  16. # http://www.apache.org/licenses/LICENSE-2.0
  17. #
  18. # Unless required by applicable law or agreed to in writing, software
  19. # distributed under the License is distributed on an "AS IS" BASIS,
  20. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  21. # See the License for the specific language governing permissions and
  22. # limitations under the License.
  23. """Inference-only Mixtral model."""
  24. from typing import Iterable, List, Optional, Tuple
  25. import torch
  26. from torch import nn
  27. from transformers import MixtralConfig
  28. from aphrodite.attention import Attention, AttentionMetadata
  29. from aphrodite.common.config import CacheConfig, LoRAConfig
  30. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  31. from aphrodite.common.utils import progress_bar
  32. from aphrodite.distributed import (get_pp_group,
  33. get_tensor_model_parallel_world_size)
  34. from aphrodite.modeling.layers.fused_moe import FusedMoE
  35. from aphrodite.modeling.layers.layernorm import RMSNorm
  36. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  37. ReplicatedLinear,
  38. RowParallelLinear)
  39. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  40. from aphrodite.modeling.layers.rotary_embedding import get_rope
  41. from aphrodite.modeling.layers.sampler import Sampler
  42. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  43. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  44. from aphrodite.modeling.model_loader.weight_utils import (
  45. default_weight_loader, maybe_remap_kv_scale_name)
  46. from aphrodite.modeling.models.utils import (is_pp_missing_parameter,
  47. make_layers)
  48. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  49. from aphrodite.quantization.base_config import QuantizationConfig
  50. from .interfaces import SupportsLoRA
  51. class MixtralMoE(nn.Module):
  52. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  53. across all ranks.
  54. Each expert's weights are sharded across all ranks and a fused MoE
  55. kernel is used for the forward pass, and finally we reduce the outputs
  56. across ranks.
  57. """
  58. def __init__(self,
  59. num_experts: int,
  60. top_k: int,
  61. hidden_size: int,
  62. intermediate_size: int,
  63. params_dtype: Optional[torch.dtype] = None,
  64. quant_config: Optional[QuantizationConfig] = None,
  65. tp_size: Optional[int] = None,
  66. prefix: str = ""):
  67. super().__init__()
  68. self.hidden_size = hidden_size
  69. # Gate always runs at half / full precision for now.
  70. self.gate = ReplicatedLinear(hidden_size,
  71. num_experts,
  72. bias=False,
  73. params_dtype=params_dtype,
  74. quant_config=None,
  75. prefix=f"{prefix}.gate")
  76. self.experts = FusedMoE(num_experts=num_experts,
  77. top_k=top_k,
  78. hidden_size=hidden_size,
  79. intermediate_size=intermediate_size,
  80. params_dtype=params_dtype,
  81. reduce_results=True,
  82. renormalize=True,
  83. quant_config=quant_config,
  84. tp_size=tp_size,
  85. prefix=f"{prefix}.experts")
  86. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  87. # NOTE: hidden_states can have either 1D or 2D shape.
  88. orig_shape = hidden_states.shape
  89. hidden_states = hidden_states.view(-1, self.hidden_size)
  90. # router_logits: (num_tokens, n_experts)
  91. router_logits, _ = self.gate(hidden_states)
  92. final_hidden_states = self.experts(hidden_states, router_logits)
  93. return final_hidden_states.view(orig_shape)
  94. class MixtralAttention(nn.Module):
  95. def __init__(
  96. self,
  97. hidden_size: int,
  98. num_heads: int,
  99. num_kv_heads: int,
  100. max_position: int = 4096 * 32,
  101. rope_theta: float = 10000,
  102. cache_config: Optional[CacheConfig] = None,
  103. quant_config: Optional[QuantizationConfig] = None,
  104. prefix: str = "",
  105. ) -> None:
  106. super().__init__()
  107. self.hidden_size = hidden_size
  108. tp_size = get_tensor_model_parallel_world_size()
  109. self.total_num_heads = num_heads
  110. assert self.total_num_heads % tp_size == 0
  111. self.num_heads = self.total_num_heads // tp_size
  112. self.total_num_kv_heads = num_kv_heads
  113. if self.total_num_kv_heads >= tp_size:
  114. # Number of KV heads is greater than TP size, so we partition
  115. # the KV heads across multiple tensor parallel GPUs.
  116. assert self.total_num_kv_heads % tp_size == 0
  117. else:
  118. # Number of KV heads is less than TP size, so we replicate
  119. # the KV heads across multiple tensor parallel GPUs.
  120. assert tp_size % self.total_num_kv_heads == 0
  121. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  122. self.head_dim = hidden_size // self.total_num_heads
  123. self.q_size = self.num_heads * self.head_dim
  124. self.kv_size = self.num_kv_heads * self.head_dim
  125. self.scaling = self.head_dim**-0.5
  126. self.rope_theta = rope_theta
  127. self.qkv_proj = QKVParallelLinear(
  128. hidden_size,
  129. self.head_dim,
  130. self.total_num_heads,
  131. self.total_num_kv_heads,
  132. bias=False,
  133. quant_config=quant_config,
  134. prefix=f"{prefix}.qkv_proj",
  135. )
  136. self.o_proj = RowParallelLinear(
  137. self.total_num_heads * self.head_dim,
  138. hidden_size,
  139. bias=False,
  140. quant_config=quant_config,
  141. prefix=f"{prefix}.o_proj",
  142. )
  143. self.rotary_emb = get_rope(
  144. self.head_dim,
  145. rotary_dim=self.head_dim,
  146. max_position=max_position,
  147. base=int(self.rope_theta),
  148. is_neox_style=True,
  149. )
  150. self.attn = Attention(self.num_heads,
  151. self.head_dim,
  152. self.scaling,
  153. num_kv_heads=self.num_kv_heads,
  154. cache_config=cache_config,
  155. quant_config=quant_config)
  156. def forward(
  157. self,
  158. positions: torch.Tensor,
  159. hidden_states: torch.Tensor,
  160. kv_cache: torch.Tensor,
  161. attn_metadata: AttentionMetadata,
  162. ) -> torch.Tensor:
  163. qkv, _ = self.qkv_proj(hidden_states)
  164. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  165. q, k = self.rotary_emb(positions, q, k)
  166. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  167. output, _ = self.o_proj(attn_output)
  168. return output
  169. class MixtralDecoderLayer(nn.Module):
  170. def __init__(
  171. self,
  172. config: MixtralConfig,
  173. cache_config: Optional[CacheConfig] = None,
  174. quant_config: Optional[QuantizationConfig] = None,
  175. prefix: str = "",
  176. ) -> None:
  177. super().__init__()
  178. self.hidden_size = config.hidden_size
  179. # Requires transformers > 4.32.0
  180. rope_theta = getattr(config, "rope_theta", 10000)
  181. self.self_attn = MixtralAttention(
  182. hidden_size=self.hidden_size,
  183. num_heads=config.num_attention_heads,
  184. max_position=config.max_position_embeddings,
  185. num_kv_heads=config.num_key_value_heads,
  186. rope_theta=rope_theta,
  187. cache_config=cache_config,
  188. quant_config=quant_config,
  189. prefix=f"{prefix}.self_attn")
  190. self.block_sparse_moe = MixtralMoE(
  191. num_experts=config.num_local_experts,
  192. top_k=config.num_experts_per_tok,
  193. hidden_size=config.hidden_size,
  194. intermediate_size=config.intermediate_size,
  195. quant_config=quant_config,
  196. prefix=f"{prefix}.block_sparse_moe")
  197. self.input_layernorm = RMSNorm(config.hidden_size,
  198. eps=config.rms_norm_eps)
  199. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  200. eps=config.rms_norm_eps)
  201. def forward(
  202. self,
  203. positions: torch.Tensor,
  204. hidden_states: torch.Tensor,
  205. kv_cache: torch.Tensor,
  206. attn_metadata: AttentionMetadata,
  207. residual: Optional[torch.Tensor],
  208. ) -> torch.Tensor:
  209. # Self Attention
  210. if residual is None:
  211. residual = hidden_states
  212. hidden_states = self.input_layernorm(hidden_states)
  213. else:
  214. hidden_states, residual = self.input_layernorm(
  215. hidden_states, residual)
  216. hidden_states = self.self_attn(
  217. positions=positions,
  218. hidden_states=hidden_states,
  219. kv_cache=kv_cache,
  220. attn_metadata=attn_metadata,
  221. )
  222. # Fully Connected
  223. hidden_states, residual = self.post_attention_layernorm(
  224. hidden_states, residual)
  225. hidden_states = self.block_sparse_moe(hidden_states)
  226. return hidden_states, residual
  227. class MixtralModel(nn.Module):
  228. def __init__(
  229. self,
  230. config: MixtralConfig,
  231. cache_config: Optional[CacheConfig] = None,
  232. quant_config: Optional[QuantizationConfig] = None,
  233. lora_config: Optional[LoRAConfig] = None,
  234. prefix: str = "",
  235. ) -> None:
  236. super().__init__()
  237. self.padding_idx = config.pad_token_id
  238. lora_vocab = (lora_config.lora_extra_vocab_size *
  239. (lora_config.max_loras or 1)) if lora_config else 0
  240. self.vocab_size = config.vocab_size + lora_vocab
  241. self.org_vocab_size = config.vocab_size
  242. self.embed_tokens = VocabParallelEmbedding(
  243. self.vocab_size,
  244. config.hidden_size,
  245. org_num_embeddings=config.vocab_size,
  246. )
  247. self.start_layer, self.end_layer, self.layers = make_layers(
  248. config.num_hidden_layers,
  249. lambda prefix: MixtralDecoderLayer(
  250. config, cache_config, quant_config=quant_config, prefix=prefix
  251. ),
  252. prefix=f"{prefix}.layers")
  253. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  254. def forward(
  255. self,
  256. input_ids: torch.Tensor,
  257. positions: torch.Tensor,
  258. kv_caches: List[torch.Tensor],
  259. attn_metadata: AttentionMetadata,
  260. intermediate_tensors: Optional[IntermediateTensors],
  261. ) -> torch.Tensor:
  262. if get_pp_group().is_first_rank:
  263. hidden_states = self.embed_tokens(input_ids)
  264. residual = None
  265. else:
  266. assert intermediate_tensors is not None
  267. hidden_states = intermediate_tensors["hidden_states"]
  268. residual = intermediate_tensors["residual"]
  269. for i in range(self.start_layer, self.end_layer):
  270. layer = self.layers[i]
  271. hidden_states, residual = layer(positions, hidden_states,
  272. kv_caches[i - self.start_layer],
  273. attn_metadata, residual)
  274. if not get_pp_group().is_last_rank:
  275. return IntermediateTensors({
  276. "hidden_states": hidden_states,
  277. "residual": residual
  278. })
  279. hidden_states, _ = self.norm(hidden_states, residual)
  280. return hidden_states
  281. class MixtralForCausalLM(nn.Module, SupportsLoRA):
  282. fall_back_to_pt_during_load = False
  283. packed_modules_mapping = {
  284. "qkv_proj": [
  285. "q_proj",
  286. "k_proj",
  287. "v_proj",
  288. ],
  289. }
  290. # LoRA specific attributes
  291. supported_lora_modules = [
  292. "qkv_proj",
  293. "o_proj",
  294. "embed_tokens",
  295. "lm_head",
  296. ]
  297. embedding_modules = {
  298. "embed_tokens": "input_embeddings",
  299. "lm_head": "output_embeddings",
  300. }
  301. embedding_padding_modules = ["lm_head"]
  302. def __init__(
  303. self,
  304. config: MixtralConfig,
  305. cache_config: Optional[CacheConfig] = None,
  306. quant_config: Optional[QuantizationConfig] = None,
  307. lora_config: Optional[LoRAConfig] = None,
  308. ) -> None:
  309. super().__init__()
  310. self.config = config
  311. self.lora_config = lora_config
  312. self.model = MixtralModel(config,
  313. cache_config,
  314. quant_config,
  315. lora_config=lora_config,
  316. prefix="model")
  317. self.unpadded_vocab_size = config.vocab_size
  318. if lora_config:
  319. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  320. self.lm_head = ParallelLMHead(
  321. self.unpadded_vocab_size,
  322. config.hidden_size,
  323. org_num_embeddings=config.vocab_size,
  324. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  325. # We need bigger padding if using lora for kernel
  326. # compatibility
  327. if not lora_config else lora_config.lora_vocab_padding_size,
  328. quant_config=quant_config,
  329. )
  330. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  331. config.vocab_size)
  332. self.sampler = Sampler()
  333. def forward(
  334. self,
  335. input_ids: torch.Tensor,
  336. positions: torch.Tensor,
  337. kv_caches: List[torch.Tensor],
  338. attn_metadata: AttentionMetadata,
  339. intermediate_tensors: Optional[IntermediateTensors] = None,
  340. ) -> torch.Tensor:
  341. hidden_states = self.model(input_ids, positions, kv_caches,
  342. attn_metadata, intermediate_tensors)
  343. return hidden_states
  344. def compute_logits(
  345. self,
  346. hidden_states: torch.Tensor,
  347. sampling_metadata: SamplingMetadata,
  348. ) -> Optional[torch.Tensor]:
  349. logits = self.logits_processor(self.lm_head, hidden_states,
  350. sampling_metadata)
  351. return logits
  352. def make_empty_intermediate_tensors(
  353. self, batch_size: int, dtype: torch.dtype,
  354. device: torch.device) -> IntermediateTensors:
  355. return IntermediateTensors({
  356. "hidden_states":
  357. torch.zeros((batch_size, self.config.hidden_size),
  358. dtype=dtype,
  359. device=device),
  360. "residual":
  361. torch.zeros((batch_size, self.config.hidden_size),
  362. dtype=dtype,
  363. device=device),
  364. })
  365. def sample(
  366. self,
  367. logits: Optional[torch.Tensor],
  368. sampling_metadata: SamplingMetadata,
  369. ) -> Optional[SamplerOutput]:
  370. next_tokens = self.sampler(logits, sampling_metadata)
  371. return next_tokens
  372. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  373. stacked_params_mapping = [
  374. # (param_name, shard_name, shard_id)
  375. ("qkv_proj", "q_proj", "q"),
  376. ("qkv_proj", "k_proj", "k"),
  377. ("qkv_proj", "v_proj", "v"),
  378. ]
  379. # Params for weights, fp8 weight scales, fp8 activation scales
  380. # (param_name, weight_name, expert_id, shard_id)
  381. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  382. ckpt_gate_proj_name="w1",
  383. ckpt_down_proj_name="w2",
  384. ckpt_up_proj_name="w3",
  385. num_experts=self.config.num_local_experts)
  386. params_dict = dict(self.named_parameters())
  387. weights_list = list(weights)
  388. for name, loaded_weight in progress_bar(weights_list,
  389. desc="Loading modules..."):
  390. if "rotary_emb.inv_freq" in name:
  391. continue
  392. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  393. if weight_name not in name:
  394. continue
  395. name = name.replace(weight_name, param_name)
  396. # Skip loading extra bias for GPTQ models.
  397. if name.endswith(".bias") and name not in params_dict:
  398. continue
  399. # Skip layers on other devices.
  400. if is_pp_missing_parameter(name, self):
  401. continue
  402. param = params_dict[name]
  403. weight_loader = param.weight_loader
  404. weight_loader(param, loaded_weight, shard_id)
  405. break
  406. else:
  407. for mapping in expert_params_mapping:
  408. param_name, weight_name, expert_id, shard_id = mapping
  409. if weight_name not in name:
  410. continue
  411. name = name.replace(weight_name, param_name)
  412. # Skip layers on other devices.
  413. if is_pp_missing_parameter(name, self):
  414. continue
  415. param = params_dict[name]
  416. weight_loader = param.weight_loader
  417. weight_loader(param,
  418. loaded_weight,
  419. name,
  420. shard_id=shard_id,
  421. expert_id=expert_id)
  422. break
  423. else:
  424. # Skip loading extra bias for GPTQ models.
  425. if name.endswith(".bias") and name not in params_dict:
  426. continue
  427. # Skip layers on other devices.
  428. if is_pp_missing_parameter(name, self):
  429. continue
  430. # Remapping the name of FP8 kv-scale.
  431. name = maybe_remap_kv_scale_name(name, params_dict)
  432. if name is None:
  433. continue
  434. param = params_dict[name]
  435. weight_loader = getattr(param, "weight_loader",
  436. default_weight_loader)
  437. weight_loader(param, loaded_weight)