mixtral.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Mixtral model."""
  25. from typing import Iterable, List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import MixtralConfig
  29. from aphrodite.attention import Attention, AttentionMetadata
  30. from aphrodite.common.config import LoRAConfig
  31. from aphrodite.common.sequence import SamplerOutput
  32. from aphrodite.common.utils import print_warning_once
  33. from aphrodite.distributed import (get_tensor_model_parallel_rank,
  34. get_tensor_model_parallel_world_size,
  35. tensor_model_parallel_all_reduce)
  36. from aphrodite.modeling.layers.fused_moe import fused_moe
  37. from aphrodite.modeling.layers.layernorm import RMSNorm
  38. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  39. QKVParallelLinear,
  40. ReplicatedLinear,
  41. RowParallelLinear)
  42. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  43. from aphrodite.quantization.fp8 import (Fp8LinearMethod,
  44. per_tensor_quantize)
  45. from aphrodite.modeling.layers.rotary_embedding import get_rope
  46. from aphrodite.modeling.layers.sampler import Sampler
  47. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  48. DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
  49. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  50. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  51. from aphrodite.modeling.utils import set_weight_attrs
  52. class MixtralMoE(nn.Module):
  53. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  54. across all ranks.
  55. Each expert's weights are sharded across all ranks and a fused MoE
  56. kernel is used for the forward pass, and finally we reduce the outputs
  57. across ranks.
  58. """
  59. def __init__(
  60. self,
  61. num_experts: int,
  62. top_k: int,
  63. hidden_size: int,
  64. intermediate_size: int,
  65. params_dtype: Optional[torch.dtype] = None,
  66. tp_size: Optional[int] = None,
  67. linear_method: Optional[LinearMethodBase] = None,
  68. ):
  69. super().__init__()
  70. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  71. self.num_total_experts = num_experts
  72. self.top_k = top_k
  73. self.hidden_size = hidden_size
  74. self.intermediate_size = intermediate_size // self.tp_size
  75. # FIXME(pcmoritz): Make this more general to support different
  76. # quantization schemes
  77. self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
  78. if params_dtype is None:
  79. params_dtype = torch.get_default_dtype()
  80. self.params_dtype = params_dtype
  81. self.gate = ReplicatedLinear(self.hidden_size,
  82. self.num_total_experts,
  83. bias=False,
  84. params_dtype=self.params_dtype,
  85. linear_method=None)
  86. self.ws = nn.Parameter(
  87. torch.empty(self.num_total_experts,
  88. 2 * self.intermediate_size,
  89. self.hidden_size,
  90. device="cuda",
  91. dtype=self.params_dtype))
  92. self.w2s = nn.Parameter(
  93. torch.empty(self.num_total_experts,
  94. self.hidden_size,
  95. self.intermediate_size,
  96. device="cuda",
  97. dtype=self.params_dtype))
  98. # Scaling factors for FP8 weights
  99. self.ws_scale = nn.Parameter(
  100. torch.ones(
  101. self.num_total_experts, device="cuda", dtype=torch.float32),
  102. requires_grad=False) if self.use_fp8 else None
  103. self.w2s_scale = nn.Parameter(
  104. torch.ones(
  105. self.num_total_experts, device="cuda", dtype=torch.float32),
  106. requires_grad=False) if self.use_fp8 else None
  107. set_weight_attrs(self.ws, {
  108. "weight_loader": self.weight_loader,
  109. })
  110. set_weight_attrs(self.w2s, {
  111. "weight_loader": self.weight_loader,
  112. })
  113. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  114. weight_name: str, expert_id: int):
  115. tp_rank = get_tensor_model_parallel_rank()
  116. param_data = param.data
  117. shard_size = self.intermediate_size
  118. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  119. if weight_name.endswith("w1.weight"):
  120. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  121. if weight_name.endswith("w3.weight"):
  122. param_data[expert_id,
  123. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  124. if weight_name.endswith("w2.weight"):
  125. param_data[expert_id, :, :] = loaded_weight[:, shard]
  126. def process_weights_after_loading(self):
  127. if self.use_fp8:
  128. ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
  129. w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
  130. for expert in range(self.num_total_experts):
  131. ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
  132. self.ws.data[expert, :, :])
  133. w2s[expert, :, :], self.w2s_scale[
  134. expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
  135. self.ws = nn.Parameter(ws, requires_grad=False)
  136. self.w2s = nn.Parameter(w2s, requires_grad=False)
  137. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  138. num_tokens, hidden_size = hidden_states.shape
  139. hidden_states = hidden_states.view(-1, self.hidden_size)
  140. # router_logits: (num_tokens, n_experts)
  141. router_logits, _ = self.gate(hidden_states)
  142. final_hidden_states = fused_moe(hidden_states,
  143. self.ws,
  144. self.w2s,
  145. router_logits,
  146. self.top_k,
  147. renormalize=True,
  148. inplace=True,
  149. use_fp8=self.use_fp8,
  150. w1_scale=self.ws_scale,
  151. w2_scale=self.w2s_scale)
  152. if self.tp_size > 1:
  153. final_hidden_states = tensor_model_parallel_all_reduce(
  154. final_hidden_states)
  155. return final_hidden_states.view(num_tokens, hidden_size)
  156. class MixtralAttention(nn.Module):
  157. def __init__(self,
  158. hidden_size: int,
  159. num_heads: int,
  160. num_kv_heads: int,
  161. max_position: int = 4096 * 32,
  162. rope_theta: float = 10000,
  163. linear_method: Optional[LinearMethodBase] = None,
  164. sliding_window: Optional[int] = None) -> None:
  165. super().__init__()
  166. self.hidden_size = hidden_size
  167. tp_size = get_tensor_model_parallel_world_size()
  168. self.total_num_heads = num_heads
  169. assert self.total_num_heads % tp_size == 0
  170. self.num_heads = self.total_num_heads // tp_size
  171. self.total_num_kv_heads = num_kv_heads
  172. if self.total_num_kv_heads >= tp_size:
  173. # Number of KV heads is greater than TP size, so we partition
  174. # the KV heads across multiple tensor parallel GPUs.
  175. assert self.total_num_kv_heads % tp_size == 0
  176. else:
  177. # Number of KV heads is less than TP size, so we replicate
  178. # the KV heads across multiple tensor parallel GPUs.
  179. assert tp_size % self.total_num_kv_heads == 0
  180. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  181. self.head_dim = hidden_size // self.total_num_heads
  182. self.q_size = self.num_heads * self.head_dim
  183. self.kv_size = self.num_kv_heads * self.head_dim
  184. self.scaling = self.head_dim**-0.5
  185. self.rope_theta = rope_theta
  186. self.sliding_window = sliding_window
  187. if isinstance(linear_method, Fp8LinearMethod):
  188. print_warning_once(
  189. "For Mixtral FP8 quantization, we currently do not quantize "
  190. "the attention layers until their FP8 performance is improved."
  191. )
  192. linear_method = None
  193. self.qkv_proj = QKVParallelLinear(
  194. hidden_size,
  195. self.head_dim,
  196. self.total_num_heads,
  197. self.total_num_kv_heads,
  198. bias=False,
  199. linear_method=linear_method,
  200. )
  201. self.o_proj = RowParallelLinear(
  202. self.total_num_heads * self.head_dim,
  203. hidden_size,
  204. bias=False,
  205. linear_method=linear_method,
  206. )
  207. self.rotary_emb = get_rope(
  208. self.head_dim,
  209. rotary_dim=self.head_dim,
  210. max_position=max_position,
  211. base=int(self.rope_theta),
  212. is_neox_style=True,
  213. )
  214. self.attn = Attention(
  215. self.num_heads,
  216. self.head_dim,
  217. self.scaling,
  218. num_kv_heads=self.num_kv_heads,
  219. sliding_window=self.sliding_window,
  220. )
  221. def forward(
  222. self,
  223. positions: torch.Tensor,
  224. hidden_states: torch.Tensor,
  225. kv_cache: torch.Tensor,
  226. attn_metadata: AttentionMetadata,
  227. ) -> torch.Tensor:
  228. qkv, _ = self.qkv_proj(hidden_states)
  229. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  230. q, k = self.rotary_emb(positions, q, k)
  231. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  232. output, _ = self.o_proj(attn_output)
  233. return output
  234. class MixtralDecoderLayer(nn.Module):
  235. def __init__(
  236. self,
  237. config: MixtralConfig,
  238. linear_method: Optional[LinearMethodBase] = None,
  239. ) -> None:
  240. super().__init__()
  241. self.hidden_size = config.hidden_size
  242. # Requires transformers > 4.32.0
  243. rope_theta = getattr(config, "rope_theta", 10000)
  244. self.self_attn = MixtralAttention(
  245. hidden_size=self.hidden_size,
  246. num_heads=config.num_attention_heads,
  247. max_position=config.max_position_embeddings,
  248. num_kv_heads=config.num_key_value_heads,
  249. rope_theta=rope_theta,
  250. sliding_window=config.sliding_window,
  251. linear_method=linear_method)
  252. self.block_sparse_moe = MixtralMoE(
  253. num_experts=config.num_local_experts,
  254. top_k=config.num_experts_per_tok,
  255. hidden_size=config.hidden_size,
  256. intermediate_size=config.intermediate_size,
  257. linear_method=linear_method)
  258. self.input_layernorm = RMSNorm(config.hidden_size,
  259. eps=config.rms_norm_eps)
  260. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  261. eps=config.rms_norm_eps)
  262. def forward(
  263. self,
  264. positions: torch.Tensor,
  265. hidden_states: torch.Tensor,
  266. kv_cache: torch.Tensor,
  267. attn_metadata: AttentionMetadata,
  268. residual: Optional[torch.Tensor],
  269. ) -> torch.Tensor:
  270. # Self Attention
  271. if residual is None:
  272. residual = hidden_states
  273. hidden_states = self.input_layernorm(hidden_states)
  274. else:
  275. hidden_states, residual = self.input_layernorm(
  276. hidden_states, residual)
  277. hidden_states = self.self_attn(
  278. positions=positions,
  279. hidden_states=hidden_states,
  280. kv_cache=kv_cache,
  281. attn_metadata=attn_metadata,
  282. )
  283. # Fully Connected
  284. hidden_states, residual = self.post_attention_layernorm(
  285. hidden_states, residual)
  286. hidden_states = self.block_sparse_moe(hidden_states)
  287. return hidden_states, residual
  288. class MixtralModel(nn.Module):
  289. def __init__(
  290. self,
  291. config: MixtralConfig,
  292. linear_method: Optional[LinearMethodBase] = None,
  293. lora_config: Optional[LoRAConfig] = None,
  294. ) -> None:
  295. super().__init__()
  296. self.padding_idx = config.pad_token_id
  297. lora_vocab = (lora_config.lora_extra_vocab_size *
  298. (lora_config.max_loras or 1)) if lora_config else 0
  299. self.vocab_size = config.vocab_size + lora_vocab
  300. self.org_vocab_size = config.vocab_size
  301. self.embed_tokens = VocabParallelEmbedding(
  302. self.vocab_size,
  303. config.hidden_size,
  304. org_num_embeddings=config.vocab_size,
  305. )
  306. self.layers = nn.ModuleList([
  307. MixtralDecoderLayer(config, linear_method=linear_method)
  308. for _ in range(config.num_hidden_layers)
  309. ])
  310. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  311. def forward(
  312. self,
  313. input_ids: torch.Tensor,
  314. positions: torch.Tensor,
  315. kv_caches: List[torch.Tensor],
  316. attn_metadata: AttentionMetadata,
  317. ) -> torch.Tensor:
  318. hidden_states = self.embed_tokens(input_ids)
  319. residual = None
  320. for i in range(len(self.layers)):
  321. layer = self.layers[i]
  322. hidden_states, residual = layer(positions, hidden_states,
  323. kv_caches[i], attn_metadata,
  324. residual)
  325. hidden_states, _ = self.norm(hidden_states, residual)
  326. return hidden_states
  327. class MixtralForCausalLM(nn.Module):
  328. fall_back_to_pt_during_load = False
  329. packed_modules_mapping = {
  330. "qkv_proj": [
  331. "q_proj",
  332. "k_proj",
  333. "v_proj",
  334. ],
  335. }
  336. # LoRA specific attributes
  337. supported_lora_modules = [
  338. "qkv_proj",
  339. "o_proj",
  340. "embed_tokens",
  341. "lm_head",
  342. ]
  343. embedding_modules = {
  344. "embed_tokens": "input_embeddings",
  345. "lm_head": "output_embeddings",
  346. }
  347. embedding_padding_modules = ["lm_head"]
  348. def __init__(
  349. self,
  350. config: MixtralConfig,
  351. linear_method: Optional[LinearMethodBase] = None,
  352. lora_config: Optional[LoRAConfig] = None,
  353. ) -> None:
  354. super().__init__()
  355. self.config = config
  356. self.linear_method = linear_method
  357. self.model = MixtralModel(config,
  358. linear_method,
  359. lora_config=lora_config)
  360. self.unpadded_vocab_size = config.vocab_size
  361. if lora_config:
  362. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  363. self.lm_head = ParallelLMHead(
  364. self.unpadded_vocab_size,
  365. config.hidden_size,
  366. org_num_embeddings=config.vocab_size,
  367. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  368. # We need bigger padding if using lora for kernel
  369. # compatibility
  370. if not lora_config else lora_config.lora_vocab_padding_size,
  371. )
  372. self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
  373. config.vocab_size)
  374. self.sampler = Sampler()
  375. def forward(
  376. self,
  377. input_ids: torch.Tensor,
  378. positions: torch.Tensor,
  379. kv_caches: List[torch.Tensor],
  380. attn_metadata: AttentionMetadata,
  381. ) -> torch.Tensor:
  382. hidden_states = self.model(input_ids, positions, kv_caches,
  383. attn_metadata)
  384. return hidden_states
  385. def compute_logits(self, hidden_states: torch.Tensor,
  386. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  387. logits = self.logits_processor(self.lm_head.weight, hidden_states,
  388. sampling_metadata)
  389. return logits
  390. def sample(
  391. self,
  392. logits: Optional[torch.Tensor],
  393. sampling_metadata: SamplingMetadata,
  394. ) -> Optional[SamplerOutput]:
  395. next_tokens = self.sampler(logits, sampling_metadata)
  396. return next_tokens
  397. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  398. stacked_params_mapping = [
  399. # (param_name, shard_name, shard_id)
  400. ("qkv_proj", "q_proj", "q"),
  401. ("qkv_proj", "k_proj", "k"),
  402. ("qkv_proj", "v_proj", "v"),
  403. ]
  404. expert_params_mapping = [
  405. # (param_name, weight_name, expert_id)
  406. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  407. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  408. for expert_id in range(self.config.num_local_experts)
  409. for weight_name in ["w1", "w2", "w3"]
  410. ]
  411. params_dict = dict(self.named_parameters())
  412. for name, loaded_weight in weights:
  413. if "rotary_emb.inv_freq" in name:
  414. continue
  415. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  416. if weight_name not in name:
  417. continue
  418. name = name.replace(weight_name, param_name)
  419. # Skip loading extra bias for GPTQ models.
  420. if name.endswith(".bias") and name not in params_dict:
  421. continue
  422. param = params_dict[name]
  423. weight_loader = param.weight_loader
  424. weight_loader(param, loaded_weight, shard_id)
  425. break
  426. else:
  427. for param_name, weight_name, expert_id in expert_params_mapping:
  428. if weight_name not in name:
  429. continue
  430. name = name.replace(weight_name, param_name)
  431. param = params_dict[name]
  432. weight_loader = param.weight_loader
  433. weight_loader(param,
  434. loaded_weight,
  435. weight_name,
  436. expert_id=expert_id)
  437. break
  438. else:
  439. # Skip loading extra bias for GPTQ models.
  440. if name.endswith(".bias") and name not in params_dict:
  441. continue
  442. param = params_dict[name]
  443. weight_loader = getattr(param, "weight_loader",
  444. default_weight_loader)
  445. weight_loader(param, loaded_weight)