mixtral.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Mixtral model."""
  25. from typing import List, Optional
  26. import numpy as np
  27. import torch
  28. from torch import nn
  29. from transformers import MixtralConfig
  30. from aphrodite.attention import Attention, AttentionMetadata
  31. from aphrodite.common.config import LoRAConfig
  32. from aphrodite.modeling.layers.fused_moe import fused_topk
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (
  35. ColumnParallelLinear, LinearMethodBase, MergedColumnParallelLinear,
  36. QKVParallelLinear, ReplicatedLinear, RowParallelLinear,
  37. UnquantizedLinearMethod)
  38. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  43. from aphrodite.distributed import (tensor_model_parallel_all_reduce)
  44. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  45. get_tensor_model_parallel_world_size)
  46. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  47. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  48. hf_model_weights_iterator)
  49. from aphrodite.common.sequence import SamplerOutput
  50. class MixtralMLP(nn.Module):
  51. def __init__(
  52. self,
  53. num_experts: int,
  54. hidden_size: int,
  55. intermediate_size: int,
  56. linear_method: Optional[LinearMethodBase] = None,
  57. ) -> None:
  58. super().__init__()
  59. self.num_experts = num_experts
  60. self.ffn_dim = intermediate_size
  61. self.hidden_dim = hidden_size
  62. self.w1 = ReplicatedLinear(self.hidden_dim,
  63. self.ffn_dim,
  64. bias=False,
  65. linear_method=linear_method)
  66. self.w2 = ReplicatedLinear(self.ffn_dim,
  67. self.hidden_dim,
  68. bias=False,
  69. linear_method=linear_method)
  70. self.w3 = ReplicatedLinear(self.hidden_dim,
  71. self.ffn_dim,
  72. bias=False,
  73. linear_method=linear_method)
  74. # TODO: Use aphrodite's SiluAndMul
  75. self.act_fn = nn.SiLU()
  76. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  77. w1_out, _ = self.w1(hidden_states)
  78. w1_out = self.act_fn(w1_out)
  79. w3_out, _ = self.w3(hidden_states)
  80. current_hidden_states = w1_out * w3_out
  81. current_hidden_states, _ = self.w2(current_hidden_states)
  82. return current_hidden_states
  83. class MixtralMoE(nn.Module):
  84. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  85. across all ranks.
  86. Each expert's weights are sharded across all ranks and a fused MoE
  87. kernel is used for the forward pass, and finally we reduce the outputs
  88. across ranks.
  89. """
  90. def __init__(
  91. self,
  92. num_experts: int,
  93. top_k: int,
  94. hidden_size: int,
  95. intermediate_size: int,
  96. tp_size: Optional[int] = None,
  97. linear_method: Optional[LinearMethodBase] = None,
  98. ):
  99. super().__init__()
  100. self.rank = get_tensor_model_parallel_rank()
  101. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  102. self.num_total_experts = num_experts
  103. self.top_k = top_k
  104. self.hidden_size = hidden_size
  105. self.intermediate_size = intermediate_size // self.tp_size
  106. self.linear_method = linear_method
  107. if self.linear_method is None:
  108. self.linear_method = UnquantizedLinearMethod()
  109. self.gate = ReplicatedLinear(self.hidden_size,
  110. self.num_total_experts,
  111. bias=False,
  112. linear_method=None)
  113. if not isinstance(
  114. self.linear_method, UnquantizedLinearMethod
  115. ) and not self.linear_method.quant_config.support_fused_moe():
  116. if self.tp_size > self.num_total_experts:
  117. raise ValueError(
  118. f"Tensor parallel size {self.tp_size} is greater than "
  119. f"the number of experts {self.num_total_experts}.")
  120. # Split experts equally between ranks
  121. self.expert_indicies = np.array_split(
  122. range(self.num_total_experts),
  123. self.tp_size)[self.rank].tolist()
  124. if not self.expert_indicies:
  125. raise ValueError(
  126. f"Rank {self.rank} has no experts assigned to it.")
  127. self.experts = nn.ModuleList([
  128. MixtralMLP(self.num_total_experts,
  129. hidden_size,
  130. intermediate_size,
  131. linear_method=linear_method)
  132. if idx in self.expert_indicies else None
  133. for idx in range(self.num_total_experts)
  134. ])
  135. else:
  136. self.ws = MergedColumnParallelLinear(hidden_size,
  137. [intermediate_size] * 2,
  138. bias=False,
  139. linear_method=linear_method,
  140. num_experts=num_experts)
  141. self.w2s = RowParallelLinear(intermediate_size,
  142. hidden_size,
  143. bias=False,
  144. linear_method=linear_method,
  145. num_experts=num_experts)
  146. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  147. num_tokens, hidden_size = hidden_states.shape
  148. hidden_states = hidden_states.view(-1, self.hidden_size)
  149. # router_logits: (num_tokens, n_experts)
  150. router_logits, _ = self.gate(hidden_states)
  151. if not isinstance(
  152. self.linear_method, UnquantizedLinearMethod
  153. ) and not self.linear_method.quant_config.support_fused_moe():
  154. routing_weights, selected_experts = fused_topk(router_logits,
  155. self.top_k,
  156. renormalize=True)
  157. final_hidden_states = None
  158. for expert_idx in self.expert_indicies:
  159. expert_layer = self.experts[expert_idx]
  160. expert_mask = (selected_experts == expert_idx)
  161. expert_weights = (routing_weights * expert_mask).sum(
  162. dim=-1, keepdim=True)
  163. current_hidden_states = expert_layer(hidden_states).mul_(
  164. expert_weights)
  165. if final_hidden_states is None:
  166. final_hidden_states = current_hidden_states
  167. else:
  168. final_hidden_states.add_(current_hidden_states)
  169. else:
  170. final_hidden_states = self.linear_method.apply_moe_weights(
  171. self.ws.linear_weights,
  172. self.w2s.linear_weights,
  173. hidden_states,
  174. router_logits,
  175. self.top_k,
  176. renormalize=True,
  177. )
  178. if self.tp_size > 1:
  179. final_hidden_states = tensor_model_parallel_all_reduce(
  180. final_hidden_states)
  181. return final_hidden_states.view(num_tokens, hidden_size)
  182. class MixtralAttention(nn.Module):
  183. def __init__(self,
  184. hidden_size: int,
  185. num_heads: int,
  186. num_kv_heads: int,
  187. max_position: int = 4096 * 32,
  188. rope_theta: float = 10000,
  189. linear_method: Optional[LinearMethodBase] = None,
  190. sliding_window: Optional[int] = None) -> None:
  191. super().__init__()
  192. self.hidden_size = hidden_size
  193. tp_size = get_tensor_model_parallel_world_size()
  194. self.total_num_heads = num_heads
  195. assert self.total_num_heads % tp_size == 0
  196. self.num_heads = self.total_num_heads // tp_size
  197. self.total_num_kv_heads = num_kv_heads
  198. if self.total_num_kv_heads >= tp_size:
  199. # Number of KV heads is greater than TP size, so we partition
  200. # the KV heads across multiple tensor parallel GPUs.
  201. assert self.total_num_kv_heads % tp_size == 0
  202. else:
  203. # Number of KV heads is less than TP size, so we replicate
  204. # the KV heads across multiple tensor parallel GPUs.
  205. assert tp_size % self.total_num_kv_heads == 0
  206. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  207. self.head_dim = hidden_size // self.total_num_heads
  208. self.q_size = self.num_heads * self.head_dim
  209. self.kv_size = self.num_kv_heads * self.head_dim
  210. self.scaling = self.head_dim**-0.5
  211. self.rope_theta = rope_theta
  212. self.sliding_window = sliding_window
  213. if (linear_method is not None
  214. and not linear_method.quant_config.merge_weight()):
  215. self.merge_weight = False
  216. self.q_proj = ColumnParallelLinear(hidden_size,
  217. self.total_num_heads *
  218. self.head_dim,
  219. bias=False,
  220. linear_method=linear_method)
  221. self.k_proj = ColumnParallelLinear(hidden_size,
  222. self.total_num_kv_heads *
  223. self.head_dim,
  224. bias=False,
  225. linear_method=linear_method)
  226. self.v_proj = ColumnParallelLinear(hidden_size,
  227. self.total_num_kv_heads *
  228. self.head_dim,
  229. bias=False,
  230. linear_method=linear_method)
  231. else:
  232. self.merge_weight = True
  233. self.qkv_proj = QKVParallelLinear(
  234. hidden_size,
  235. self.head_dim,
  236. self.total_num_heads,
  237. self.total_num_kv_heads,
  238. bias=False,
  239. linear_method=linear_method,
  240. )
  241. self.o_proj = RowParallelLinear(
  242. self.total_num_heads * self.head_dim,
  243. hidden_size,
  244. bias=False,
  245. linear_method=linear_method,
  246. )
  247. self.rotary_emb = get_rope(
  248. self.head_dim,
  249. rotary_dim=self.head_dim,
  250. max_position=max_position,
  251. base=int(self.rope_theta),
  252. is_neox_style=True,
  253. )
  254. self.attn = Attention(
  255. self.num_heads,
  256. self.head_dim,
  257. self.scaling,
  258. num_kv_heads=self.num_kv_heads,
  259. sliding_window=self.sliding_window,
  260. )
  261. def forward(
  262. self,
  263. positions: torch.Tensor,
  264. hidden_states: torch.Tensor,
  265. kv_cache: torch.Tensor,
  266. attn_metadata: AttentionMetadata,
  267. ) -> torch.Tensor:
  268. if self.merge_weight:
  269. qkv, _ = self.qkv_proj(hidden_states)
  270. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  271. dim=-1)
  272. else:
  273. q, _ = self.q_proj(hidden_states)
  274. k, _ = self.k_proj(hidden_states)
  275. v, _ = self.v_proj(hidden_states)
  276. q, k = self.rotary_emb(positions, q, k)
  277. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  278. output, _ = self.o_proj(attn_output)
  279. return output
  280. class MixtralDecoderLayer(nn.Module):
  281. def __init__(
  282. self,
  283. config: MixtralConfig,
  284. linear_method: Optional[LinearMethodBase] = None,
  285. ) -> None:
  286. super().__init__()
  287. self.hidden_size = config.hidden_size
  288. # Requires transformers > 4.32.0
  289. rope_theta = getattr(config, "rope_theta", 10000)
  290. self.self_attn = MixtralAttention(
  291. hidden_size=self.hidden_size,
  292. num_heads=config.num_attention_heads,
  293. max_position=config.max_position_embeddings,
  294. num_kv_heads=config.num_key_value_heads,
  295. rope_theta=rope_theta,
  296. sliding_window=config.sliding_window,
  297. linear_method=linear_method)
  298. self.block_sparse_moe = MixtralMoE(
  299. num_experts=config.num_local_experts,
  300. top_k=config.num_experts_per_tok,
  301. hidden_size=config.hidden_size,
  302. intermediate_size=config.intermediate_size,
  303. linear_method=linear_method)
  304. self.input_layernorm = RMSNorm(config.hidden_size,
  305. eps=config.rms_norm_eps)
  306. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  307. eps=config.rms_norm_eps)
  308. def forward(
  309. self,
  310. positions: torch.Tensor,
  311. hidden_states: torch.Tensor,
  312. kv_cache: torch.Tensor,
  313. attn_metadata: AttentionMetadata,
  314. residual: Optional[torch.Tensor],
  315. ) -> torch.Tensor:
  316. # Self Attention
  317. if residual is None:
  318. residual = hidden_states
  319. hidden_states = self.input_layernorm(hidden_states)
  320. else:
  321. hidden_states, residual = self.input_layernorm(
  322. hidden_states, residual)
  323. hidden_states = self.self_attn(
  324. positions=positions,
  325. hidden_states=hidden_states,
  326. kv_cache=kv_cache,
  327. attn_metadata=attn_metadata,
  328. )
  329. # Fully Connected
  330. hidden_states, residual = self.post_attention_layernorm(
  331. hidden_states, residual)
  332. hidden_states = self.block_sparse_moe(hidden_states)
  333. return hidden_states, residual
  334. class MixtralModel(nn.Module):
  335. def __init__(
  336. self,
  337. config: MixtralConfig,
  338. linear_method: Optional[LinearMethodBase] = None,
  339. lora_config: Optional[LoRAConfig] = None,
  340. ) -> None:
  341. super().__init__()
  342. self.padding_idx = config.pad_token_id
  343. lora_vocab = (lora_config.lora_extra_vocab_size *
  344. (lora_config.max_loras or 1)) if lora_config else 0
  345. self.vocab_size = config.vocab_size + lora_vocab
  346. self.org_vocab_size = config.vocab_size
  347. self.embed_tokens = VocabParallelEmbedding(
  348. self.vocab_size,
  349. config.hidden_size,
  350. linear_method=linear_method,
  351. org_num_embeddings=config.vocab_size,
  352. )
  353. self.layers = nn.ModuleList([
  354. MixtralDecoderLayer(config, linear_method=linear_method)
  355. for _ in range(config.num_hidden_layers)
  356. ])
  357. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  358. def forward(
  359. self,
  360. input_ids: torch.Tensor,
  361. positions: torch.Tensor,
  362. kv_caches: List[torch.Tensor],
  363. attn_metadata: AttentionMetadata,
  364. ) -> torch.Tensor:
  365. hidden_states = self.embed_tokens(input_ids)
  366. residual = None
  367. for i in range(len(self.layers)):
  368. layer = self.layers[i]
  369. hidden_states, residual = layer(positions, hidden_states,
  370. kv_caches[i], attn_metadata,
  371. residual)
  372. hidden_states, _ = self.norm(hidden_states, residual)
  373. return hidden_states
  374. class MixtralForCausalLM(nn.Module):
  375. packed_modules_mapping = {
  376. "qkv_proj": [
  377. "q_proj",
  378. "k_proj",
  379. "v_proj",
  380. ],
  381. }
  382. # LoRA specific attributes
  383. supported_lora_modules = [
  384. "qkv_proj",
  385. "o_proj",
  386. "embed_tokens",
  387. "lm_head",
  388. ]
  389. embedding_modules = {
  390. "embed_tokens": "input_embeddings",
  391. "lm_head": "output_embeddings",
  392. }
  393. embedding_padding_modules = ["lm_head"]
  394. def __init__(
  395. self,
  396. config: MixtralConfig,
  397. linear_method: Optional[LinearMethodBase] = None,
  398. lora_config: Optional[LoRAConfig] = None,
  399. ) -> None:
  400. super().__init__()
  401. self.config = config
  402. self.linear_method = linear_method
  403. self.model = MixtralModel(config,
  404. linear_method,
  405. lora_config=lora_config)
  406. self.unpadded_vocab_size = config.vocab_size
  407. if lora_config:
  408. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  409. self.lm_head = ParallelLMHead(
  410. self.unpadded_vocab_size,
  411. config.hidden_size,
  412. linear_method=linear_method,
  413. org_num_embeddings=config.vocab_size,
  414. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  415. # We need bigger padding if using lora for kernel
  416. # compatibility
  417. if not lora_config else lora_config.lora_vocab_padding_size,
  418. )
  419. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  420. config.vocab_size)
  421. self.sampler = Sampler()
  422. def forward(
  423. self,
  424. input_ids: torch.Tensor,
  425. positions: torch.Tensor,
  426. kv_caches: List[torch.Tensor],
  427. attn_metadata: AttentionMetadata,
  428. ) -> torch.Tensor:
  429. hidden_states = self.model(input_ids, positions, kv_caches,
  430. attn_metadata)
  431. return hidden_states
  432. def compute_logits(self, hidden_states: torch.Tensor,
  433. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  434. logits = self.logits_processor(self.lm_head, hidden_states,
  435. sampling_metadata)
  436. return logits
  437. def sample(
  438. self,
  439. logits: Optional[torch.Tensor],
  440. sampling_metadata: SamplingMetadata,
  441. ) -> Optional[SamplerOutput]:
  442. next_tokens = self.sampler(logits, sampling_metadata)
  443. return next_tokens
  444. def load_weights(self,
  445. model_name_or_path: str,
  446. cache_dir: Optional[str] = None,
  447. load_format: str = "auto",
  448. revision: Optional[str] = None):
  449. stacked_params_mapping = [
  450. # (param_name, shard_name, shard_id)
  451. ("qkv_proj", "q_proj", "q"),
  452. ("qkv_proj", "k_proj", "k"),
  453. ("qkv_proj", "v_proj", "v"),
  454. ]
  455. if (self.linear_method is not None
  456. and not self.linear_method.quant_config.merge_weight()):
  457. stacked_params_mapping = []
  458. expert_params_mapping = [
  459. # (param_name, weight_name, shard_id, expert_id)
  460. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  461. f"experts.{expert_id}.{weight_name}", shard_id, expert_id)
  462. for expert_id in range(self.config.num_local_experts)
  463. for weight_name, shard_id in [("w1", 0), ("w3", 1), ("w2", None)]
  464. ] if self.linear_method is None or (
  465. self.linear_method.quant_config.support_fused_moe()) else []
  466. params_dict = dict(self.named_parameters())
  467. for name, loaded_weight in hf_model_weights_iterator(
  468. model_name_or_path,
  469. cache_dir,
  470. load_format,
  471. revision,
  472. self.config,
  473. fall_back_to_pt=False):
  474. if "rotary_emb.inv_freq" in name:
  475. continue
  476. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  477. if weight_name not in name:
  478. continue
  479. name = name.replace(weight_name, param_name)
  480. # Skip loading extra bias for GPTQ models.
  481. if name.endswith(".bias") and name not in params_dict:
  482. continue
  483. param = params_dict[name]
  484. weight_loader = param.weight_loader
  485. weight_loader(param, loaded_weight, shard_id)
  486. break
  487. else:
  488. for (param_name, weight_name, shard_id,
  489. expert_id) in expert_params_mapping:
  490. if weight_name not in name:
  491. continue
  492. name = name.replace(weight_name, param_name)
  493. if name.endswith(".bias") and name not in params_dict:
  494. continue
  495. param = params_dict[name]
  496. weight_loader = param.weight_loader
  497. if shard_id is None:
  498. weight_loader(param,
  499. loaded_weight,
  500. expert_id=expert_id)
  501. else:
  502. weight_loader(param,
  503. loaded_weight,
  504. shard_id,
  505. expert_id=expert_id)
  506. break
  507. else:
  508. # Skip loading extra bias for GPTQ models.
  509. if name.endswith(".bias") and name not in params_dict:
  510. continue
  511. # Skip experts that are not assigned to this worker.
  512. if ("block_sparse_moe.experts." in name
  513. and name not in params_dict):
  514. continue
  515. param = params_dict[name]
  516. weight_loader = getattr(param, "weight_loader",
  517. default_weight_loader)
  518. weight_loader(param, loaded_weight)