1
0

mixtral.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Mixtral model."""
  25. from typing import List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import MixtralConfig
  29. from aphrodite.common.config import LoRAConfig
  30. from aphrodite.modeling.metadata import InputMetadata
  31. from aphrodite.modeling.layers.attention import PagedAttention
  32. from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  35. QKVParallelLinear,
  36. ReplicatedLinear,
  37. RowParallelLinear,
  38. ColumnParallelLinear)
  39. from aphrodite.modeling.layers.rotary_embedding import get_rope
  40. from aphrodite.modeling.layers.sampler import Sampler
  41. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  42. VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
  43. from aphrodite.modeling.megatron.communication_op import (
  44. tensor_model_parallel_all_reduce)
  45. from aphrodite.modeling.megatron.parallel_state import (
  46. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  47. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  48. from aphrodite.modeling.utils import set_weight_attrs
  49. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  50. hf_model_weights_iterator)
  51. from aphrodite.common.sequence import SamplerOutput
  52. KVCache = Tuple[torch.Tensor, torch.Tensor]
  53. class MixtralMoE(nn.Module):
  54. """A tensor-parallel MoE implementation for Mixtral that shards each expert
  55. across all ranks.
  56. Each expert's weights are sharded across all ranks and a fused MoE
  57. kernel is used for the forward pass, and finally we reduce the outputs
  58. across ranks.
  59. """
  60. def __init__(
  61. self,
  62. num_experts: int,
  63. top_k: int,
  64. hidden_size: int,
  65. intermediate_size: int,
  66. params_dtype: Optional[torch.dtype] = None,
  67. tp_size: Optional[int] = None,
  68. ):
  69. super().__init__()
  70. self.tp_size = tp_size or get_tensor_model_parallel_world_size()
  71. self.num_total_experts = num_experts
  72. self.top_k = top_k
  73. self.hidden_size = hidden_size
  74. self.intermediate_size = intermediate_size // self.tp_size
  75. if params_dtype is None:
  76. params_dtype = torch.get_default_dtype()
  77. self.params_dtype = params_dtype
  78. self.gate = ReplicatedLinear(self.hidden_size,
  79. self.num_total_experts,
  80. bias=False,
  81. params_dtype=self.params_dtype,
  82. linear_method=None)
  83. self.ws = nn.Parameter(
  84. torch.empty(self.num_total_experts,
  85. 2 * self.intermediate_size,
  86. self.hidden_size,
  87. device="cuda",
  88. dtype=self.params_dtype))
  89. self.w2s = nn.Parameter(
  90. torch.empty(self.num_total_experts,
  91. self.hidden_size,
  92. self.intermediate_size,
  93. device="cuda",
  94. dtype=self.params_dtype))
  95. set_weight_attrs(self.ws, {
  96. "weight_loader": self.weight_loader,
  97. })
  98. set_weight_attrs(self.w2s, {
  99. "weight_loader": self.weight_loader,
  100. })
  101. def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
  102. weight_name: str, expert_id: int):
  103. tp_rank = get_tensor_model_parallel_rank()
  104. param_data = param.data
  105. shard_size = self.intermediate_size
  106. shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
  107. if weight_name.endswith("w1.weight"):
  108. param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
  109. if weight_name.endswith("w3.weight"):
  110. param_data[expert_id,
  111. shard_size:2 * shard_size, :] = loaded_weight[shard, :]
  112. if weight_name.endswith("w2.weight"):
  113. param_data[expert_id, :, :] = loaded_weight[:, shard]
  114. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  115. batch_size, sequence_length, hidden_size = hidden_states.shape
  116. hidden_states = hidden_states.view(-1, self.hidden_size)
  117. # router_logits: (batch * sequence_length, n_experts)
  118. router_logits, _ = self.gate(hidden_states)
  119. final_hidden_states = fused_moe(hidden_states,
  120. self.ws,
  121. self.w2s,
  122. router_logits,
  123. self.top_k,
  124. renormalize=True,
  125. inplace=True)
  126. if self.tp_size > 1:
  127. final_hidden_states = tensor_model_parallel_all_reduce(
  128. final_hidden_states)
  129. return final_hidden_states.view(batch_size, sequence_length,
  130. hidden_size)
  131. class MixtralAttention(nn.Module):
  132. def __init__(self,
  133. hidden_size: int,
  134. num_heads: int,
  135. num_kv_heads: int,
  136. max_position: int = 4096 * 32,
  137. rope_theta: float = 10000,
  138. linear_method: Optional[LinearMethodBase] = None,
  139. sliding_window: Optional[int] = None) -> None:
  140. super().__init__()
  141. self.hidden_size = hidden_size
  142. tp_size = get_tensor_model_parallel_world_size()
  143. self.total_num_heads = num_heads
  144. assert self.total_num_heads % tp_size == 0
  145. self.num_heads = self.total_num_heads // tp_size
  146. self.total_num_kv_heads = num_kv_heads
  147. if self.total_num_kv_heads >= tp_size:
  148. # Number of KV heads is greater than TP size, so we partition
  149. # the KV heads across multiple tensor parallel GPUs.
  150. assert self.total_num_kv_heads % tp_size == 0
  151. else:
  152. # Number of KV heads is less than TP size, so we replicate
  153. # the KV heads across multiple tensor parallel GPUs.
  154. assert tp_size % self.total_num_kv_heads == 0
  155. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  156. self.head_dim = hidden_size // self.total_num_heads
  157. self.q_size = self.num_heads * self.head_dim
  158. self.kv_size = self.num_kv_heads * self.head_dim
  159. self.scaling = self.head_dim**-0.5
  160. self.rope_theta = rope_theta
  161. self.sliding_window = sliding_window
  162. if linear_method is not None and not linear_method.quant_config.merge_weight(
  163. ):
  164. self.merge_weight = False
  165. self.q_proj = ColumnParallelLinear(hidden_size,
  166. self.q_size,
  167. bias=False,
  168. linear_method=linear_method)
  169. self.k_proj = ColumnParallelLinear(hidden_size,
  170. self.kv_size,
  171. bias=False,
  172. linear_method=linear_method)
  173. self.v_proj = ColumnParallelLinear(hidden_size,
  174. self.kv_size,
  175. bias=False,
  176. linear_method=linear_method)
  177. else:
  178. self.merge_weight = True
  179. self.qkv_proj = QKVParallelLinear(
  180. hidden_size,
  181. self.head_dim,
  182. self.total_num_heads,
  183. self.total_num_kv_heads,
  184. bias=False,
  185. linear_method=linear_method,
  186. )
  187. self.o_proj = RowParallelLinear(
  188. self.total_num_heads * self.head_dim,
  189. hidden_size,
  190. bias=False,
  191. linear_method=linear_method,
  192. )
  193. is_neox_style = True if linear_method is None or linear_method.quant_config.rope_style(
  194. ) is None else linear_method.quant_config.rope_style()
  195. self.rotary_emb = get_rope(
  196. self.head_dim,
  197. rotary_dim=self.head_dim,
  198. max_position=max_position,
  199. base=int(self.rope_theta),
  200. is_neox_style=is_neox_style,
  201. )
  202. self.attn = PagedAttention(
  203. self.num_heads,
  204. self.head_dim,
  205. self.scaling,
  206. num_kv_heads=self.num_kv_heads,
  207. sliding_window=self.sliding_window,
  208. )
  209. def forward(
  210. self,
  211. positions: torch.Tensor,
  212. hidden_states: torch.Tensor,
  213. kv_cache: KVCache,
  214. input_metadata: InputMetadata,
  215. ) -> torch.Tensor:
  216. if self.merge_weight:
  217. qkv, _ = self.qkv_proj(hidden_states)
  218. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  219. dim=-1)
  220. else:
  221. q, _ = self.q_proj(hidden_states)
  222. k, _ = self.k_proj(hidden_states)
  223. v, _ = self.v_proj(hidden_states)
  224. q, k = self.rotary_emb(positions, q, k)
  225. k_cache, v_cache = kv_cache
  226. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  227. output, _ = self.o_proj(attn_output)
  228. return output
  229. class MixtralDecoderLayer(nn.Module):
  230. def __init__(
  231. self,
  232. config: MixtralConfig,
  233. linear_method: Optional[LinearMethodBase] = None,
  234. ) -> None:
  235. super().__init__()
  236. self.hidden_size = config.hidden_size
  237. # Requires transformers > 4.32.0
  238. rope_theta = getattr(config, "rope_theta", 10000)
  239. self.self_attn = MixtralAttention(
  240. hidden_size=self.hidden_size,
  241. num_heads=config.num_attention_heads,
  242. max_position=config.max_position_embeddings,
  243. num_kv_heads=config.num_key_value_heads,
  244. rope_theta=rope_theta,
  245. sliding_window=config.sliding_window,
  246. linear_method=linear_method)
  247. self.block_sparse_moe = MixtralMoE(
  248. num_experts=config.num_local_experts,
  249. top_k=config.num_experts_per_tok,
  250. hidden_size=config.hidden_size,
  251. intermediate_size=config.intermediate_size)
  252. self.input_layernorm = RMSNorm(config.hidden_size,
  253. eps=config.rms_norm_eps)
  254. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  255. eps=config.rms_norm_eps)
  256. def forward(
  257. self,
  258. positions: torch.Tensor,
  259. hidden_states: torch.Tensor,
  260. kv_cache: KVCache,
  261. input_metadata: InputMetadata,
  262. residual: Optional[torch.Tensor],
  263. ) -> torch.Tensor:
  264. # Self Attention
  265. if residual is None:
  266. residual = hidden_states
  267. hidden_states = self.input_layernorm(hidden_states)
  268. else:
  269. hidden_states, residual = self.input_layernorm(
  270. hidden_states, residual)
  271. hidden_states = self.self_attn(
  272. positions=positions,
  273. hidden_states=hidden_states,
  274. kv_cache=kv_cache,
  275. input_metadata=input_metadata,
  276. )
  277. # Fully Connected
  278. hidden_states, residual = self.post_attention_layernorm(
  279. hidden_states, residual)
  280. hidden_states = self.block_sparse_moe(hidden_states)
  281. return hidden_states, residual
  282. class MixtralModel(nn.Module):
  283. def __init__(
  284. self,
  285. config: MixtralConfig,
  286. linear_method: Optional[LinearMethodBase] = None,
  287. lora_config: Optional[LoRAConfig] = None,
  288. ) -> None:
  289. super().__init__()
  290. self.padding_idx = config.pad_token_id
  291. lora_vocab = (lora_config.lora_extra_vocab_size *
  292. (lora_config.max_loras or 1)) if lora_config else 0
  293. self.vocab_size = config.vocab_size + lora_vocab
  294. self.org_vocab_size = config.vocab_size
  295. self.embed_tokens = VocabParallelEmbedding(
  296. self.vocab_size,
  297. config.hidden_size,
  298. linear_method=linear_method,
  299. org_num_embeddings=config.vocab_size,
  300. )
  301. self.layers = nn.ModuleList([
  302. MixtralDecoderLayer(config, linear_method=linear_method)
  303. for _ in range(config.num_hidden_layers)
  304. ])
  305. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  306. def forward(
  307. self,
  308. input_ids: torch.Tensor,
  309. positions: torch.Tensor,
  310. kv_caches: List[KVCache],
  311. input_metadata: InputMetadata,
  312. ) -> torch.Tensor:
  313. hidden_states = self.embed_tokens(input_ids)
  314. residual = None
  315. for i in range(len(self.layers)):
  316. layer = self.layers[i]
  317. hidden_states, residual = layer(positions, hidden_states,
  318. kv_caches[i], input_metadata,
  319. residual)
  320. hidden_states, _ = self.norm(hidden_states, residual)
  321. return hidden_states
  322. class MixtralForCausalLM(nn.Module):
  323. packed_modules_mapping = {
  324. "qkv_proj": [
  325. "q_proj",
  326. "k_proj",
  327. "v_proj",
  328. ],
  329. }
  330. # LoRA specific attributes
  331. supported_lora_modules = [
  332. "qkv_proj",
  333. "o_proj",
  334. "embed_tokens",
  335. "lm_head",
  336. ]
  337. embedding_modules = {
  338. "embed_tokens": "input_embeddings",
  339. "lm_head": "output_embeddings",
  340. }
  341. embedding_padding_modules = ["lm_head"]
  342. def __init__(
  343. self,
  344. config: MixtralConfig,
  345. linear_method: Optional[LinearMethodBase] = None,
  346. lora_config: Optional[LoRAConfig] = None,
  347. ) -> None:
  348. super().__init__()
  349. self.config = config
  350. self.linear_method = linear_method
  351. self.model = MixtralModel(config,
  352. linear_method,
  353. lora_config=lora_config)
  354. self.unpadded_vocab_size = config.vocab_size
  355. if lora_config:
  356. self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
  357. self.lm_head = ParallelLMHead(
  358. self.unpadded_vocab_size,
  359. config.hidden_size,
  360. linear_method=linear_method,
  361. org_num_embeddings=config.vocab_size,
  362. padding_size=DEFAULT_VOCAB_PADDING_SIZE
  363. # We need bigger padding if using lora for kernel
  364. # compatibility
  365. if not lora_config else lora_config.lora_vocab_padding_size,
  366. )
  367. self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
  368. def forward(
  369. self,
  370. input_ids: torch.Tensor,
  371. positions: torch.Tensor,
  372. kv_caches: List[KVCache],
  373. input_metadata: InputMetadata,
  374. ) -> torch.Tensor:
  375. hidden_states = self.model(input_ids, positions, kv_caches,
  376. input_metadata)
  377. return hidden_states
  378. def sample(
  379. self,
  380. hidden_states: Optional[torch.Tensor],
  381. sampling_metadata: SamplingMetadata,
  382. ) -> Optional[SamplerOutput]:
  383. next_tokens = self.sampler(self.lm_head(hidden_states),
  384. sampling_metadata)
  385. return next_tokens
  386. def load_weights(self,
  387. model_name_or_path: str,
  388. cache_dir: Optional[str] = None,
  389. load_format: str = "auto",
  390. revision: Optional[str] = None):
  391. stacked_params_mapping = [
  392. # (param_name, shard_name, shard_id)
  393. ("qkv_proj", "q_proj", "q"),
  394. ("qkv_proj", "k_proj", "k"),
  395. ("qkv_proj", "v_proj", "v"),
  396. ]
  397. if self.linear_method is not None and not self.linear_method.quant_config.merge_weight(
  398. ):
  399. stacked_params_mapping = []
  400. expert_params_mapping = [
  401. # (param_name, weight_name, expert_id)
  402. ("ws" if weight_name in ["w1", "w3"] else "w2s",
  403. f"experts.{expert_id}.{weight_name}.weight", expert_id)
  404. for expert_id in range(self.config.num_local_experts)
  405. for weight_name in ["w1", "w2", "w3"]
  406. ]
  407. params_dict = dict(self.named_parameters())
  408. for name, loaded_weight in hf_model_weights_iterator(
  409. model_name_or_path,
  410. cache_dir,
  411. load_format,
  412. revision,
  413. self.config,
  414. fall_back_to_pt=False):
  415. if "rotary_emb.inv_freq" in name:
  416. continue
  417. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  418. if weight_name not in name:
  419. continue
  420. name = name.replace(weight_name, param_name)
  421. # Skip loading extra bias for GPTQ models.
  422. if name.endswith(".bias") and name not in params_dict:
  423. continue
  424. param = params_dict[name]
  425. weight_loader = param.weight_loader
  426. weight_loader(param, loaded_weight, shard_id)
  427. break
  428. else:
  429. for param_name, weight_name, expert_id in expert_params_mapping:
  430. if weight_name not in name:
  431. continue
  432. name = name.replace(weight_name, param_name)
  433. param = params_dict[name]
  434. weight_loader = param.weight_loader
  435. weight_loader(param,
  436. loaded_weight,
  437. weight_name,
  438. expert_id=expert_id)
  439. break
  440. else:
  441. # Skip loading extra bias for GPTQ models.
  442. if name.endswith(".bias") and name not in params_dict:
  443. continue
  444. param = params_dict[name]
  445. weight_loader = getattr(param, "weight_loader",
  446. default_weight_loader)
  447. weight_loader(param, loaded_weight)