1
0

mixtral.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Mixtral model."""
  25. from typing import List, Optional
  26. import numpy as np
  27. import torch
  28. from torch import nn
  29. from transformers import MixtralConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import LoRAConfig
  32. from aphrodite.modeling.layers.fused_moe import fused_topk
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (
  35. ColumnParallelLinear, LinearMethodBase, MergedColumnParallelLinear,
  36. QKVParallelLinear, ReplicatedLinear, RowParallelLinear,
  37. UnquantizedLinearMethod)
  38. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  43. from aphrodite.distributed import (tensor_model_parallel_all_reduce)
  44. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  45. get_tensor_model_parallel_world_size)
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  48. hf_model_weights_iterator)
  49. from aphrodite.common.sequence import SamplerOutput
  50. from aphrodite.quantization.gguf import GGUFLinearMethod
  51. class MixtralMLP(nn.Module):
  52. def __init__(
  53. self,
  54. num_experts: int,
  55. hidden_size: int,
  56. intermediate_size: int,
  57. linear_method: Optional[LinearMethodBase] = None,
  58. ) -> None:
  59. super().__init__()
  60. self.num_experts = num_experts
  61. self.ffn_dim = intermediate_size
  62. self.hidden_dim = hidden_size
  63. self.w1 = ReplicatedLinear(self.hidden_dim,
  64. self.ffn_dim,
  65. bias=False,
  66. linear_method=linear_method)
  67. self.w2 = ReplicatedLinear(self.ffn_dim,
  68. self.hidden_dim,
  69. bias=False,
  70. linear_method=linear_method)
  71. self.w3 = ReplicatedLinear(self.hidden_dim,
  72. self.ffn_dim,
  73. bias=False,
  74. linear_method=linear_method)
  75. # TODO: Use aphrodite's SiluAndMul
  76. self.act_fn = nn.SiLU()
  77. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  78. w1_out, _ = self.w1(hidden_states)
  79. w1_out = self.act_fn(w1_out)
  80. w3_out, _ = self.w3(hidden_states)
  81. current_hidden_states = w1_out * w3_out
  82. current_hidden_states, _ = self.w2(current_hidden_states)
  83. return current_hidden_states
  84. class MixtralMoE(nn.Module):
  85. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  86. across all ranks.
  87. Each expert's weights are sharded across all ranks and a fused MoE
  88. kernel is used for the forward pass, and finally we reduce the outputs
  89. across ranks.
  90. """
  91. def __init__(
  92. self,
  93. num_experts: int,
  94. top_k: int,
  95. hidden_size: int,
  96. intermediate_size: int,
  97. tp_size: Optional[int] = None,
  98. linear_method: Optional[LinearMethodBase] = None,
  99. ):
  100. super().__init__()
  101. self.rank = get_tensor_model_parallel_rank()
  102. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  103. self.num_total_experts = num_experts
  104. self.top_k = top_k
  105. self.hidden_size = hidden_size
  106. self.intermediate_size = intermediate_size // self.tp_size
  107. self.linear_method = linear_method
  108. if self.linear_method is None:
  109. self.linear_method = UnquantizedLinearMethod()
  110. self.gate = ReplicatedLinear(
  111. self.hidden_size,
  112. self.num_total_experts,
  113. bias=False,
  114. linear_method=linear_method
  115. if isinstance(linear_method, GGUFLinearMethod) else None)
  116. if not isinstance(
  117. self.linear_method, UnquantizedLinearMethod
  118. ) and not self.linear_method.quant_config.support_fused_moe():
  119. if self.tp_size > self.num_total_experts:
  120. raise ValueError(
  121. f"Tensor parallel size {self.tp_size} is greater than "
  122. f"the number of experts {self.num_total_experts}.")
  123. # Split experts equally between ranks
  124. self.expert_indicies = np.array_split(
  125. range(self.num_total_experts),
  126. self.tp_size)[self.rank].tolist()
  127. if not self.expert_indicies:
  128. raise ValueError(
  129. f"Rank {self.rank} has no experts assigned to it.")
  130. self.experts = nn.ModuleList([
  131. MixtralMLP(self.num_total_experts,
  132. hidden_size,
  133. intermediate_size,
  134. linear_method=linear_method)
  135. if idx in self.expert_indicies else None
  136. for idx in range(self.num_total_experts)
  137. ])
  138. else:
  139. self.ws = MergedColumnParallelLinear(hidden_size,
  140. [intermediate_size] * 2,
  141. bias=False,
  142. linear_method=linear_method,
  143. num_experts=num_experts)
  144. self.w2s = RowParallelLinear(intermediate_size,
  145. hidden_size,
  146. bias=False,
  147. linear_method=linear_method,
  148. num_experts=num_experts)
  149. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  150. num_tokens, hidden_size = hidden_states.shape
  151. hidden_states = hidden_states.view(-1, self.hidden_size)
  152. # router_logits: (num_tokens, n_experts)
  153. router_logits, _ = self.gate(hidden_states)
  154. if not isinstance(
  155. self.linear_method, UnquantizedLinearMethod
  156. ) and not self.linear_method.quant_config.support_fused_moe():
  157. routing_weights, selected_experts = fused_topk(router_logits,
  158. self.top_k,
  159. renormalize=True)
  160. final_hidden_states = None
  161. for expert_idx in self.expert_indicies:
  162. expert_layer = self.experts[expert_idx]
  163. expert_mask = (selected_experts == expert_idx)
  164. expert_weights = (routing_weights * expert_mask).sum(
  165. dim=-1, keepdim=True)
  166. current_hidden_states = expert_layer(hidden_states).mul_(
  167. expert_weights)
  168. if final_hidden_states is None:
  169. final_hidden_states = current_hidden_states
  170. else:
  171. final_hidden_states.add_(current_hidden_states)
  172. else:
  173. final_hidden_states = self.linear_method.apply_moe_weights(
  174. self.ws.linear_weights,
  175. self.w2s.linear_weights,
  176. hidden_states,
  177. router_logits,
  178. self.top_k,
  179. renormalize=True,
  180. )
  181. if self.tp_size > 1:
  182. final_hidden_states = tensor_model_parallel_all_reduce(
  183. final_hidden_states)
  184. return final_hidden_states.view(num_tokens, hidden_size)
  185. class MixtralAttention(nn.Module):
  186. def __init__(self,
  187. hidden_size: int,
  188. num_heads: int,
  189. num_kv_heads: int,
  190. max_position: int = 4096 * 32,
  191. rope_theta: float = 10000,
  192. linear_method: Optional[LinearMethodBase] = None,
  193. sliding_window: Optional[int] = None) -> None:
  194. super().__init__()
  195. self.hidden_size = hidden_size
  196. tp_size = get_tensor_model_parallel_world_size()
  197. self.total_num_heads = num_heads
  198. assert self.total_num_heads % tp_size == 0
  199. self.num_heads = self.total_num_heads // tp_size
  200. self.total_num_kv_heads = num_kv_heads
  201. if self.total_num_kv_heads >= tp_size:
  202. # Number of KV heads is greater than TP size, so we partition
  203. # the KV heads across multiple tensor parallel GPUs.
  204. assert self.total_num_kv_heads % tp_size == 0
  205. else:
  206. # Number of KV heads is less than TP size, so we replicate
  207. # the KV heads across multiple tensor parallel GPUs.
  208. assert tp_size % self.total_num_kv_heads == 0
  209. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  210. self.head_dim = hidden_size // self.total_num_heads
  211. self.q_size = self.num_heads * self.head_dim
  212. self.kv_size = self.num_kv_heads * self.head_dim
  213. self.scaling = self.head_dim**-0.5
  214. self.rope_theta = rope_theta
  215. self.sliding_window = sliding_window
  216. if (linear_method is not None
  217. and not linear_method.quant_config.merge_weight()):
  218. self.merge_weight = False
  219. self.q_proj = ColumnParallelLinear(hidden_size,
  220. self.total_num_heads *
  221. self.head_dim,
  222. bias=False,
  223. linear_method=linear_method)
  224. self.k_proj = ColumnParallelLinear(hidden_size,
  225. self.total_num_kv_heads *
  226. self.head_dim,
  227. bias=False,
  228. linear_method=linear_method)
  229. self.v_proj = ColumnParallelLinear(hidden_size,
  230. self.total_num_kv_heads *
  231. self.head_dim,
  232. bias=False,
  233. linear_method=linear_method)
  234. else:
  235. self.merge_weight = True
  236. self.qkv_proj = QKVParallelLinear(
  237. hidden_size,
  238. self.head_dim,
  239. self.total_num_heads,
  240. self.total_num_kv_heads,
  241. bias=False,
  242. linear_method=linear_method,
  243. )
  244. self.o_proj = RowParallelLinear(
  245. self.total_num_heads * self.head_dim,
  246. hidden_size,
  247. bias=False,
  248. linear_method=linear_method,
  249. )
  250. is_neox_style = (True if linear_method is None
  251. or linear_method.quant_config.rope_style() is None
  252. else linear_method.quant_config.rope_style())
  253. self.rotary_emb = get_rope(
  254. self.head_dim,
  255. rotary_dim=self.head_dim,
  256. max_position=max_position,
  257. base=int(self.rope_theta),
  258. is_neox_style=is_neox_style,
  259. )
  260. self.attn = Attention(
  261. self.num_heads,
  262. self.head_dim,
  263. self.scaling,
  264. num_kv_heads=self.num_kv_heads,
  265. sliding_window=self.sliding_window,
  266. )
  267. def forward(
  268. self,
  269. positions: torch.Tensor,
  270. hidden_states: torch.Tensor,
  271. kv_cache: torch.Tensor,
  272. attn_metadata: AttentionMetadata,
  273. ) -> torch.Tensor:
  274. if self.merge_weight:
  275. qkv, _ = self.qkv_proj(hidden_states)
  276. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  277. dim=-1)
  278. else:
  279. q, _ = self.q_proj(hidden_states)
  280. k, _ = self.k_proj(hidden_states)
  281. v, _ = self.v_proj(hidden_states)
  282. q, k = self.rotary_emb(positions, q, k)
  283. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  284. output, _ = self.o_proj(attn_output)
  285. return output
  286. class MixtralDecoderLayer(nn.Module):
  287. def __init__(
  288. self,
  289. config: MixtralConfig,
  290. linear_method: Optional[LinearMethodBase] = None,
  291. ) -> None:
  292. super().__init__()
  293. self.hidden_size = config.hidden_size
  294. # Requires transformers > 4.32.0
  295. rope_theta = getattr(config, "rope_theta", 10000)
  296. self.self_attn = MixtralAttention(
  297. hidden_size=self.hidden_size,
  298. num_heads=config.num_attention_heads,
  299. max_position=config.max_position_embeddings,
  300. num_kv_heads=config.num_key_value_heads,
  301. rope_theta=rope_theta,
  302. sliding_window=config.sliding_window,
  303. linear_method=linear_method)
  304. self.block_sparse_moe = MixtralMoE(
  305. num_experts=config.num_local_experts,
  306. top_k=config.num_experts_per_tok,
  307. hidden_size=config.hidden_size,
  308. intermediate_size=config.intermediate_size,
  309. linear_method=linear_method)
  310. self.input_layernorm = RMSNorm(config.hidden_size,
  311. eps=config.rms_norm_eps)
  312. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  313. eps=config.rms_norm_eps)
  314. def forward(
  315. self,
  316. positions: torch.Tensor,
  317. hidden_states: torch.Tensor,
  318. kv_cache: torch.Tensor,
  319. attn_metadata: AttentionMetadata,
  320. residual: Optional[torch.Tensor],
  321. ) -> torch.Tensor:
  322. # Self Attention
  323. if residual is None:
  324. residual = hidden_states
  325. hidden_states = self.input_layernorm(hidden_states)
  326. else:
  327. hidden_states, residual = self.input_layernorm(
  328. hidden_states, residual)
  329. hidden_states = self.self_attn(
  330. positions=positions,
  331. hidden_states=hidden_states,
  332. kv_cache=kv_cache,
  333. attn_metadata=attn_metadata,
  334. )
  335. # Fully Connected
  336. hidden_states, residual = self.post_attention_layernorm(
  337. hidden_states, residual)
  338. hidden_states = self.block_sparse_moe(hidden_states)
  339. return hidden_states, residual
  340. class MixtralModel(nn.Module):
  341. def __init__(
  342. self,
  343. config: MixtralConfig,
  344. linear_method: Optional[LinearMethodBase] = None,
  345. lora_config: Optional[LoRAConfig] = None,
  346. ) -> None:
  347. super().__init__()
  348. self.padding_idx = config.pad_token_id
  349. lora_vocab = (lora_config.lora_extra_vocab_size *
  350. (lora_config.max_loras or 1)) if lora_config else 0
  351. self.vocab_size = config.vocab_size + lora_vocab
  352. self.org_vocab_size = config.vocab_size
  353. self.embed_tokens = VocabParallelEmbedding(
  354. self.vocab_size,
  355. config.hidden_size,
  356. linear_method=linear_method,
  357. org_num_embeddings=config.vocab_size,
  358. )
  359. self.layers = nn.ModuleList([
  360. MixtralDecoderLayer(config, linear_method=linear_method)
  361. for _ in range(config.num_hidden_layers)
  362. ])
  363. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  364. def forward(
  365. self,
  366. input_ids: torch.Tensor,
  367. positions: torch.Tensor,
  368. kv_caches: List[torch.Tensor],
  369. attn_metadata: AttentionMetadata,
  370. ) -> torch.Tensor:
  371. hidden_states = self.embed_tokens(input_ids)
  372. residual = None
  373. for i in range(len(self.layers)):
  374. layer = self.layers[i]
  375. hidden_states, residual = layer(positions, hidden_states,
  376. kv_caches[i], attn_metadata,
  377. residual)
  378. hidden_states, _ = self.norm(hidden_states, residual)
  379. return hidden_states
  380. class MixtralForCausalLM(nn.Module):
  381. packed_modules_mapping = {
  382. "qkv_proj": [
  383. "q_proj",
  384. "k_proj",
  385. "v_proj",
  386. ],
  387. }
  388. # LoRA specific attributes
  389. supported_lora_modules = [
  390. "qkv_proj",
  391. "o_proj",
  392. "embed_tokens",
  393. "lm_head",
  394. ]
  395. embedding_modules = {
  396. "embed_tokens": "input_embeddings",
  397. "lm_head": "output_embeddings",
  398. }
  399. embedding_padding_modules = ["lm_head"]
  400. def __init__(
  401. self,
  402. config: MixtralConfig,
  403. linear_method: Optional[LinearMethodBase] = None,
  404. lora_config: Optional[LoRAConfig] = None,
  405. ) -> None:
  406. super().__init__()
  407. self.config = config
  408. self.linear_method = linear_method
  409. self.model = MixtralModel(config,
  410. linear_method,
  411. lora_config=lora_config)
  412. self.unpadded_vocab_size = config.vocab_size
  413. if lora_config:
  414. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  415. self.lm_head = ParallelLMHead(
  416. self.unpadded_vocab_size,
  417. config.hidden_size,
  418. linear_method=linear_method,
  419. org_num_embeddings=config.vocab_size,
  420. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  421. # We need bigger padding if using lora for kernel
  422. # compatibility
  423. if not lora_config else lora_config.lora_vocab_padding_size,
  424. )
  425. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  426. config.tokenizer_vocab_size)
  427. self.sampler = Sampler()
  428. def forward(
  429. self,
  430. input_ids: torch.Tensor,
  431. positions: torch.Tensor,
  432. kv_caches: List[torch.Tensor],
  433. attn_metadata: AttentionMetadata,
  434. ) -> torch.Tensor:
  435. hidden_states = self.model(input_ids, positions, kv_caches,
  436. attn_metadata)
  437. return hidden_states
  438. def compute_logits(self, hidden_states: torch.Tensor,
  439. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  440. logits = self.logits_processor(self.lm_head, hidden_states,
  441. sampling_metadata)
  442. return logits
  443. def sample(
  444. self,
  445. logits: Optional[torch.Tensor],
  446. sampling_metadata: SamplingMetadata,
  447. ) -> Optional[SamplerOutput]:
  448. next_tokens = self.sampler(logits, sampling_metadata)
  449. return next_tokens
  450. def load_weights(self,
  451. model_name_or_path: str,
  452. cache_dir: Optional[str] = None,
  453. load_format: str = "auto",
  454. revision: Optional[str] = None):
  455. stacked_params_mapping = [
  456. # (param_name, shard_name, shard_id)
  457. ("qkv_proj", "q_proj", "q"),
  458. ("qkv_proj", "k_proj", "k"),
  459. ("qkv_proj", "v_proj", "v"),
  460. ]
  461. if (self.linear_method is not None
  462. and not self.linear_method.quant_config.merge_weight()):
  463. stacked_params_mapping = []
  464. expert_params_mapping = [
  465. # (param_name, weight_name, shard_id, expert_id)
  466. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  467. f"experts.{expert_id}.{weight_name}", shard_id, expert_id)
  468. for expert_id in range(self.config.num_local_experts)
  469. for weight_name, shard_id in [("w1", 0), ("w3", 1), ("w2", None)]
  470. ] if self.linear_method is None or (
  471. self.linear_method.quant_config.support_fused_moe()) else []
  472. params_dict = dict(self.named_parameters())
  473. for name, loaded_weight in hf_model_weights_iterator(
  474. model_name_or_path,
  475. cache_dir,
  476. load_format,
  477. revision,
  478. self.config,
  479. fall_back_to_pt=False):
  480. if "rotary_emb.inv_freq" in name:
  481. continue
  482. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  483. if weight_name not in name:
  484. continue
  485. name = name.replace(weight_name, param_name)
  486. # Skip loading extra bias for GPTQ models.
  487. if name.endswith(".bias") and name not in params_dict:
  488. continue
  489. param = params_dict[name]
  490. weight_loader = param.weight_loader
  491. weight_loader(param, loaded_weight, shard_id)
  492. break
  493. else:
  494. for (param_name, weight_name, shard_id,
  495. expert_id) in expert_params_mapping:
  496. if weight_name not in name:
  497. continue
  498. name = name.replace(weight_name, param_name)
  499. if name.endswith(".bias") and name not in params_dict:
  500. continue
  501. param = params_dict[name]
  502. weight_loader = param.weight_loader
  503. if shard_id is None:
  504. weight_loader(param,
  505. loaded_weight,
  506. expert_id=expert_id)
  507. else:
  508. weight_loader(param,
  509. loaded_weight,
  510. shard_id,
  511. expert_id=expert_id)
  512. break
  513. else:
  514. # Skip loading extra bias for GPTQ models.
  515. if name.endswith(".bias") and name not in params_dict:
  516. continue
  517. # Skip experts that are not assigned to this worker.
  518. if ("block_sparse_moe.experts." in name
  519. and name not in params_dict):
  520. continue
  521. param = params_dict[name]
  522. weight_loader = getattr(param, "weight_loader",
  523. default_weight_loader)
  524. weight_loader(param, loaded_weight)