mixtral.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Mixtral model."""
  25. from typing import List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import MixtralConfig
  29. from aphrodite.common.config import LoRAConfig
  30. from aphrodite.modeling.metadata import InputMetadata
  31. from aphrodite.modeling.layers.attention import PagedAttention
  32. from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (
  35. LinearMethodBase,
  36. QKVParallelLinear,
  37. ReplicatedLinear,
  38. RowParallelLinear,
  39. ColumnParallelLinear,
  40. )
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. VocabParallelEmbedding,
  45. ParallelLMHead,
  46. DEFAULT_VOCAB_PADDING_SIZE,
  47. )
  48. from aphrodite.modeling.megatron.communication_op import (
  49. tensor_model_parallel_all_reduce, )
  50. from aphrodite.modeling.megatron.parallel_state import (
  51. get_tensor_model_parallel_rank,
  52. get_tensor_model_parallel_world_size,
  53. )
  54. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  55. from aphrodite.modeling.utils import set_weight_attrs
  56. from aphrodite.modeling.hf_downloader import (
  57. default_weight_loader,
  58. hf_model_weights_iterator,
  59. )
  60. from aphrodite.common.sequence import SamplerOutput
  61. KVCache = Tuple[torch.Tensor, torch.Tensor]
  62. class MixtralMoE(nn.Module):
  63. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  64. across all ranks.
  65. Each expert's weights are sharded across all ranks and a fused MoE
  66. kernel is used for the forward pass, and finally we reduce the outputs
  67. across ranks.
  68. """
  69. def __init__(
  70. self,
  71. num_experts: int,
  72. top_k: int,
  73. hidden_size: int,
  74. intermediate_size: int,
  75. params_dtype: Optional[torch.dtype] = None,
  76. tp_size: Optional[int] = None,
  77. ):
  78. super().__init__()
  79. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  80. self.num_total_experts = num_experts
  81. self.top_k = top_k
  82. self.hidden_size = hidden_size
  83. self.intermediate_size = intermediate_size // self.tp_size
  84. if params_dtype is None:
  85. params_dtype = torch.get_default_dtype()
  86. self.params_dtype = params_dtype
  87. self.gate = ReplicatedLinear(
  88. self.hidden_size,
  89. self.num_total_experts,
  90. bias=False,
  91. params_dtype=self.params_dtype,
  92. linear_method=None,
  93. )
  94. self.ws = nn.Parameter(
  95. torch.empty(
  96. self.num_total_experts,
  97. 2 * self.intermediate_size,
  98. self.hidden_size,
  99. device="cuda",
  100. dtype=self.params_dtype,
  101. ))
  102. self.w2s = nn.Parameter(
  103. torch.empty(
  104. self.num_total_experts,
  105. self.hidden_size,
  106. self.intermediate_size,
  107. device="cuda",
  108. dtype=self.params_dtype,
  109. ))
  110. set_weight_attrs(
  111. self.ws,
  112. {
  113. "weight_loader": self.weight_loader,
  114. },
  115. )
  116. set_weight_attrs(
  117. self.w2s,
  118. {
  119. "weight_loader": self.weight_loader,
  120. },
  121. )
  122. def weight_loader(
  123. self,
  124. param: nn.Parameter,
  125. loaded_weight: torch.Tensor,
  126. weight_name: str,
  127. expert_id: int,
  128. ):
  129. tp_rank = get_tensor_model_parallel_rank()
  130. param_data = param.data
  131. shard_size = self.intermediate_size
  132. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  133. if weight_name.endswith("w1.weight"):
  134. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  135. if weight_name.endswith("w3.weight"):
  136. param_data[expert_id, shard_size:2 *
  137. shard_size, :] = (loaded_weight[shard, :])
  138. if weight_name.endswith("w2.weight"):
  139. param_data[expert_id, :, :] = loaded_weight[:, shard]
  140. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  141. batch_size, sequence_length, hidden_size = hidden_states.shape
  142. hidden_states = hidden_states.view(-1, self.hidden_size)
  143. # router_logits: (batch * sequence_length, n_experts)
  144. router_logits, _ = self.gate(hidden_states)
  145. final_hidden_states = fused_moe(
  146. hidden_states,
  147. self.ws,
  148. self.w2s,
  149. router_logits,
  150. self.top_k,
  151. renormalize=True,
  152. inplace=True,
  153. )
  154. if self.tp_size > 1:
  155. final_hidden_states = tensor_model_parallel_all_reduce(
  156. final_hidden_states)
  157. return final_hidden_states.view(batch_size, sequence_length,
  158. hidden_size)
  159. class MixtralAttention(nn.Module):
  160. def __init__(
  161. self,
  162. hidden_size: int,
  163. num_heads: int,
  164. num_kv_heads: int,
  165. max_position: int = 4096 * 32,
  166. rope_theta: float = 10000,
  167. linear_method: Optional[LinearMethodBase] = None,
  168. sliding_window: Optional[int] = None,
  169. ) -> None:
  170. super().__init__()
  171. self.hidden_size = hidden_size
  172. tp_size = get_tensor_model_parallel_world_size()
  173. self.total_num_heads = num_heads
  174. assert self.total_num_heads % tp_size == 0
  175. self.num_heads = self.total_num_heads // tp_size
  176. self.total_num_kv_heads = num_kv_heads
  177. if self.total_num_kv_heads >= tp_size:
  178. # Number of KV heads is greater than TP size, so we partition
  179. # the KV heads across multiple tensor parallel GPUs.
  180. assert self.total_num_kv_heads % tp_size == 0
  181. else:
  182. # Number of KV heads is less than TP size, so we replicate
  183. # the KV heads across multiple tensor parallel GPUs.
  184. assert tp_size % self.total_num_kv_heads == 0
  185. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  186. self.head_dim = hidden_size // self.total_num_heads
  187. self.q_size = self.num_heads * self.head_dim
  188. self.kv_size = self.num_kv_heads * self.head_dim
  189. self.scaling = self.head_dim**-0.5
  190. self.rope_theta = rope_theta
  191. self.sliding_window = sliding_window
  192. if (linear_method is not None
  193. and not linear_method.quant_config.merge_weight()):
  194. self.merge_weight = False
  195. self.q_proj = ColumnParallelLinear(
  196. hidden_size,
  197. self.q_size,
  198. bias=False,
  199. linear_method=linear_method,
  200. )
  201. self.k_proj = ColumnParallelLinear(
  202. hidden_size,
  203. self.kv_size,
  204. bias=False,
  205. linear_method=linear_method,
  206. )
  207. self.v_proj = ColumnParallelLinear(
  208. hidden_size,
  209. self.kv_size,
  210. bias=False,
  211. linear_method=linear_method,
  212. )
  213. else:
  214. self.merge_weight = True
  215. self.qkv_proj = QKVParallelLinear(
  216. hidden_size,
  217. self.head_dim,
  218. self.total_num_heads,
  219. self.total_num_kv_heads,
  220. bias=False,
  221. linear_method=linear_method,
  222. )
  223. self.o_proj = RowParallelLinear(
  224. self.total_num_heads * self.head_dim,
  225. hidden_size,
  226. bias=False,
  227. linear_method=linear_method,
  228. )
  229. is_neox_style = (True if linear_method is None
  230. or linear_method.quant_config.rope_style() is None
  231. else linear_method.quant_config.rope_style())
  232. self.rotary_emb = get_rope(
  233. self.head_dim,
  234. rotary_dim=self.head_dim,
  235. max_position=max_position,
  236. base=int(self.rope_theta),
  237. is_neox_style=is_neox_style,
  238. )
  239. self.attn = PagedAttention(
  240. self.num_heads,
  241. self.head_dim,
  242. self.scaling,
  243. num_kv_heads=self.num_kv_heads,
  244. sliding_window=self.sliding_window,
  245. )
  246. def forward(
  247. self,
  248. positions: torch.Tensor,
  249. hidden_states: torch.Tensor,
  250. kv_cache: KVCache,
  251. input_metadata: InputMetadata,
  252. ) -> torch.Tensor:
  253. if self.merge_weight:
  254. qkv, _ = self.qkv_proj(hidden_states)
  255. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  256. dim=-1)
  257. else:
  258. q, _ = self.q_proj(hidden_states)
  259. k, _ = self.k_proj(hidden_states)
  260. v, _ = self.v_proj(hidden_states)
  261. q, k = self.rotary_emb(positions, q, k)
  262. k_cache, v_cache = kv_cache
  263. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  264. output, _ = self.o_proj(attn_output)
  265. return output
  266. class MixtralDecoderLayer(nn.Module):
  267. def __init__(
  268. self,
  269. config: MixtralConfig,
  270. linear_method: Optional[LinearMethodBase] = None,
  271. ) -> None:
  272. super().__init__()
  273. self.hidden_size = config.hidden_size
  274. # Requires transformers > 4.32.0
  275. rope_theta = getattr(config, "rope_theta", 10000)
  276. self.self_attn = MixtralAttention(
  277. hidden_size=self.hidden_size,
  278. num_heads=config.num_attention_heads,
  279. max_position=config.max_position_embeddings,
  280. num_kv_heads=config.num_key_value_heads,
  281. rope_theta=rope_theta,
  282. sliding_window=config.sliding_window,
  283. linear_method=linear_method,
  284. )
  285. self.block_sparse_moe = MixtralMoE(
  286. num_experts=config.num_local_experts,
  287. top_k=config.num_experts_per_tok,
  288. hidden_size=config.hidden_size,
  289. intermediate_size=config.intermediate_size,
  290. )
  291. self.input_layernorm = RMSNorm(config.hidden_size,
  292. eps=config.rms_norm_eps)
  293. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  294. eps=config.rms_norm_eps)
  295. def forward(
  296. self,
  297. positions: torch.Tensor,
  298. hidden_states: torch.Tensor,
  299. kv_cache: KVCache,
  300. input_metadata: InputMetadata,
  301. residual: Optional[torch.Tensor],
  302. ) -> torch.Tensor:
  303. # Self Attention
  304. if residual is None:
  305. residual = hidden_states
  306. hidden_states = self.input_layernorm(hidden_states)
  307. else:
  308. hidden_states, residual = self.input_layernorm(
  309. hidden_states, residual)
  310. hidden_states = self.self_attn(
  311. positions=positions,
  312. hidden_states=hidden_states,
  313. kv_cache=kv_cache,
  314. input_metadata=input_metadata,
  315. )
  316. # Fully Connected
  317. hidden_states, residual = self.post_attention_layernorm(
  318. hidden_states, residual)
  319. hidden_states = self.block_sparse_moe(hidden_states)
  320. return hidden_states, residual
  321. class MixtralModel(nn.Module):
  322. def __init__(
  323. self,
  324. config: MixtralConfig,
  325. linear_method: Optional[LinearMethodBase] = None,
  326. lora_config: Optional[LoRAConfig] = None,
  327. ) -> None:
  328. super().__init__()
  329. self.padding_idx = config.pad_token_id
  330. lora_vocab = ((lora_config.lora_extra_vocab_size *
  331. (lora_config.max_loras or 1)) if lora_config else 0)
  332. self.vocab_size = config.vocab_size + lora_vocab
  333. self.org_vocab_size = config.vocab_size
  334. self.embed_tokens = VocabParallelEmbedding(
  335. self.vocab_size,
  336. config.hidden_size,
  337. linear_method=linear_method,
  338. org_num_embeddings=config.vocab_size,
  339. )
  340. self.layers = nn.ModuleList([
  341. MixtralDecoderLayer(config, linear_method=linear_method)
  342. for _ in range(config.num_hidden_layers)
  343. ])
  344. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  345. def forward(
  346. self,
  347. input_ids: torch.Tensor,
  348. positions: torch.Tensor,
  349. kv_caches: List[KVCache],
  350. input_metadata: InputMetadata,
  351. ) -> torch.Tensor:
  352. hidden_states = self.embed_tokens(input_ids)
  353. residual = None
  354. for i in range(len(self.layers)):
  355. layer = self.layers[i]
  356. hidden_states, residual = layer(positions, hidden_states,
  357. kv_caches[i], input_metadata,
  358. residual)
  359. hidden_states, _ = self.norm(hidden_states, residual)
  360. return hidden_states
  361. class MixtralForCausalLM(nn.Module):
  362. packed_modules_mapping = {
  363. "qkv_proj": [
  364. "q_proj",
  365. "k_proj",
  366. "v_proj",
  367. ],
  368. }
  369. # LoRA specific attributes
  370. supported_lora_modules = [
  371. "qkv_proj",
  372. "o_proj",
  373. "embed_tokens",
  374. "lm_head",
  375. ]
  376. embedding_modules = {
  377. "embed_tokens": "input_embeddings",
  378. "lm_head": "output_embeddings",
  379. }
  380. embedding_padding_modules = ["lm_head"]
  381. def __init__(
  382. self,
  383. config: MixtralConfig,
  384. linear_method: Optional[LinearMethodBase] = None,
  385. lora_config: Optional[LoRAConfig] = None,
  386. ) -> None:
  387. super().__init__()
  388. self.config = config
  389. self.linear_method = linear_method
  390. self.model = MixtralModel(config,
  391. linear_method,
  392. lora_config=lora_config)
  393. self.unpadded_vocab_size = config.vocab_size
  394. if lora_config:
  395. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  396. self.lm_head = ParallelLMHead(
  397. self.unpadded_vocab_size,
  398. config.hidden_size,
  399. linear_method=linear_method,
  400. org_num_embeddings=config.vocab_size,
  401. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  402. # We need bigger padding if using lora for kernel
  403. # compatibility
  404. if not lora_config else lora_config.lora_vocab_padding_size,
  405. )
  406. self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
  407. self.quant_sampler = QuantSampler(self.unpadded_vocab_size,
  408. config.vocab_size)
  409. def forward(
  410. self,
  411. input_ids: torch.Tensor,
  412. positions: torch.Tensor,
  413. kv_caches: List[KVCache],
  414. input_metadata: InputMetadata,
  415. ) -> torch.Tensor:
  416. hidden_states = self.model(input_ids, positions, kv_caches,
  417. input_metadata)
  418. return hidden_states
  419. def sample(
  420. self,
  421. hidden_states: Optional[torch.Tensor],
  422. sampling_metadata: SamplingMetadata,
  423. ) -> Optional[SamplerOutput]:
  424. if (self.linear_method is not None
  425. and not self.linear_method.quant_config.merge_weight()):
  426. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  427. sampling_metadata)
  428. else:
  429. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  430. sampling_metadata)
  431. return next_tokens
  432. def load_weights(
  433. self,
  434. model_name_or_path: str,
  435. cache_dir: Optional[str] = None,
  436. load_format: str = "auto",
  437. revision: Optional[str] = None,
  438. ):
  439. stacked_params_mapping = [
  440. # (param_name, shard_name, shard_id)
  441. ("qkv_proj", "q_proj", "q"),
  442. ("qkv_proj", "k_proj", "k"),
  443. ("qkv_proj", "v_proj", "v"),
  444. ]
  445. if (self.linear_method is not None
  446. and not self.linear_method.quant_config.merge_weight()):
  447. stacked_params_mapping = []
  448. expert_params_mapping = [
  449. # (param_name, weight_name, expert_id)
  450. (
  451. "ws" if weight_name in ["w1", "w3"] else "w2s",
  452. f"experts.{expert_id}.{weight_name}.weight",
  453. expert_id,
  454. ) for expert_id in range(self.config.num_local_experts)
  455. for weight_name in ["w1", "w2", "w3"]
  456. ]
  457. params_dict = dict(self.named_parameters())
  458. for name, loaded_weight in hf_model_weights_iterator(
  459. model_name_or_path,
  460. cache_dir,
  461. load_format,
  462. revision,
  463. self.config,
  464. fall_back_to_pt=False,
  465. ):
  466. if "rotary_emb.inv_freq" in name:
  467. continue
  468. for param_name, weight_name, shard_id in stacked_params_mapping:
  469. if weight_name not in name:
  470. continue
  471. name = name.replace(weight_name, param_name)
  472. # Skip loading extra bias for GPTQ models.
  473. if name.endswith(".bias") and name not in params_dict:
  474. continue
  475. param = params_dict[name]
  476. weight_loader = param.weight_loader
  477. weight_loader(param, loaded_weight, shard_id)
  478. break
  479. else:
  480. for param_name, weight_name, expert_id in expert_params_mapping:
  481. if weight_name not in name:
  482. continue
  483. name = name.replace(weight_name, param_name)
  484. param = params_dict[name]
  485. weight_loader = param.weight_loader
  486. weight_loader(param,
  487. loaded_weight,
  488. weight_name,
  489. expert_id=expert_id)
  490. break
  491. else:
  492. # Skip loading extra bias for GPTQ models.
  493. if name.endswith(".bias") and name not in params_dict:
  494. continue
  495. param = params_dict[name]
  496. weight_loader = getattr(param, "weight_loader",
  497. default_weight_loader)
  498. weight_loader(param, loaded_weight)