mixtral_quant.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Mixtral model."""
  25. from typing import List, Optional, Tuple
  26. import numpy as np
  27. import torch
  28. import torch.nn.functional as F
  29. from torch import nn
  30. from transformers import MixtralConfig
  31. from aphrodite.modeling.metadata import InputMetadata
  32. from aphrodite.modeling.layers.attention import PagedAttention
  33. from aphrodite.modeling.layers.layernorm import RMSNorm
  34. from aphrodite.modeling.layers.linear import (
  35. LinearMethodBase,
  36. ReplicatedLinear,
  37. QKVParallelLinear,
  38. RowParallelLinear,
  39. ColumnParallelLinear,
  40. )
  41. from aphrodite.modeling.layers.rotary_embedding import get_rope
  42. from aphrodite.modeling.layers.sampler import Sampler, QuantSampler
  43. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  44. VocabParallelEmbedding,
  45. ParallelLMHead,
  46. )
  47. from aphrodite.modeling.megatron.communication_op import (
  48. tensor_model_parallel_all_reduce, )
  49. from aphrodite.modeling.megatron.parallel_state import (
  50. get_tensor_model_parallel_rank,
  51. get_tensor_model_parallel_world_size,
  52. )
  53. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  54. from aphrodite.modeling.hf_downloader import (
  55. default_weight_loader,
  56. hf_model_weights_iterator,
  57. )
  58. from aphrodite.common.sequence import SamplerOutput
  59. KVCache = Tuple[torch.Tensor, torch.Tensor]
  60. class MixtralMLP(nn.Module):
  61. def __init__(
  62. self,
  63. num_experts: int,
  64. hidden_size: int,
  65. intermediate_size: int,
  66. linear_method: Optional[LinearMethodBase] = None,
  67. ) -> None:
  68. super().__init__()
  69. self.num_experts = num_experts
  70. self.ffn_dim = intermediate_size
  71. self.hidden_dim = hidden_size
  72. self.w1 = ReplicatedLinear(
  73. self.hidden_dim,
  74. self.ffn_dim,
  75. bias=False,
  76. linear_method=linear_method,
  77. )
  78. self.w2 = ReplicatedLinear(
  79. self.ffn_dim,
  80. self.hidden_dim,
  81. bias=False,
  82. linear_method=linear_method,
  83. )
  84. self.w3 = ReplicatedLinear(
  85. self.hidden_dim,
  86. self.ffn_dim,
  87. bias=False,
  88. linear_method=linear_method,
  89. )
  90. # TODO: Use Aphrodite's SiluAndMul
  91. self.act_fn = nn.SiLU()
  92. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  93. w1_out, _ = self.w1(hidden_states)
  94. w1_out = self.act_fn(w1_out)
  95. w3_out, _ = self.w3(hidden_states)
  96. current_hidden_states = w1_out * w3_out
  97. current_hidden_states, _ = self.w2(current_hidden_states)
  98. return current_hidden_states
  99. class MixtralMoE(nn.Module):
  100. def __init__(
  101. self,
  102. config: MixtralConfig,
  103. linear_method: Optional[LinearMethodBase] = None,
  104. ):
  105. super().__init__()
  106. self.config = config
  107. self.rank = get_tensor_model_parallel_rank()
  108. self.tp_size = get_tensor_model_parallel_world_size()
  109. self.num_total_experts = config.num_local_experts
  110. self.top_k = config.num_experts_per_tok
  111. if self.tp_size > self.num_total_experts:
  112. raise ValueError(
  113. f"Tensor parallel size {self.tp_size} is greater than "
  114. f"the number of experts {self.num_total_experts}.")
  115. # Split experts equally between ranks
  116. self.expert_indicies = np.array_split(range(
  117. self.num_total_experts), self.tp_size)[self.rank].tolist()
  118. if not self.expert_indicies:
  119. raise ValueError(
  120. f"Rank {self.rank} has no experts assigned to it.")
  121. self.experts = nn.ModuleList([
  122. MixtralMLP(
  123. self.num_total_experts,
  124. config.hidden_size,
  125. config.intermediate_size,
  126. linear_method=linear_method,
  127. ) if idx in self.expert_indicies else None
  128. for idx in range(self.num_total_experts)
  129. ])
  130. self.gate = ReplicatedLinear(
  131. config.hidden_size,
  132. self.num_total_experts,
  133. bias=False,
  134. linear_method=None,
  135. )
  136. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  137. batch_size, sequence_length, hidden_dim = hidden_states.shape
  138. hidden_states = hidden_states.view(-1, hidden_dim)
  139. # router_logits: (batch * sequence_length, n_experts)
  140. router_logits, _ = self.gate(hidden_states)
  141. routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
  142. routing_weights, selected_experts = torch.topk(routing_weights,
  143. self.top_k,
  144. dim=-1)
  145. routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
  146. final_hidden_states = None
  147. for expert_idx in self.expert_indicies:
  148. expert_layer = self.experts[expert_idx]
  149. expert_mask = selected_experts == expert_idx
  150. expert_weights = (routing_weights * expert_mask).sum(dim=-1,
  151. keepdim=True)
  152. current_hidden_states = expert_layer(hidden_states).mul_(
  153. expert_weights)
  154. if final_hidden_states is None:
  155. final_hidden_states = current_hidden_states
  156. else:
  157. final_hidden_states.add_(current_hidden_states)
  158. return tensor_model_parallel_all_reduce(final_hidden_states).view(
  159. batch_size, sequence_length, hidden_dim)
  160. class MixtralAttention(nn.Module):
  161. def __init__(
  162. self,
  163. hidden_size: int,
  164. num_heads: int,
  165. num_kv_heads: int,
  166. max_position: int = 4096 * 32,
  167. rope_theta: float = 10000,
  168. linear_method: Optional[LinearMethodBase] = None,
  169. sliding_window: Optional[int] = None,
  170. ) -> None:
  171. super().__init__()
  172. self.hidden_size = hidden_size
  173. tp_size = get_tensor_model_parallel_world_size()
  174. self.total_num_heads = num_heads
  175. assert self.total_num_heads % tp_size == 0
  176. self.num_heads = self.total_num_heads // tp_size
  177. self.total_num_kv_heads = num_kv_heads
  178. if self.total_num_kv_heads >= tp_size:
  179. # Number of KV heads is greater than TP size, so we partition
  180. # the KV heads across multiple tensor parallel GPUs.
  181. assert self.total_num_kv_heads % tp_size == 0
  182. else:
  183. # Number of KV heads is less than TP size, so we replicate
  184. # the KV heads across multiple tensor parallel GPUs.
  185. assert tp_size % self.total_num_kv_heads == 0
  186. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  187. self.head_dim = hidden_size // self.total_num_heads
  188. self.q_size = self.num_heads * self.head_dim
  189. self.kv_size = self.num_kv_heads * self.head_dim
  190. self.scaling = self.head_dim**-0.5
  191. self.rope_theta = rope_theta
  192. self.sliding_window = sliding_window
  193. if (linear_method is not None
  194. and not linear_method.quant_config.merge_weight()):
  195. self.merge_weight = False
  196. self.q_proj = ColumnParallelLinear(
  197. hidden_size,
  198. self.q_size,
  199. bias=False,
  200. linear_method=linear_method,
  201. )
  202. self.k_proj = ColumnParallelLinear(
  203. hidden_size,
  204. self.kv_size,
  205. bias=False,
  206. linear_method=linear_method,
  207. )
  208. self.v_proj = ColumnParallelLinear(
  209. hidden_size,
  210. self.kv_size,
  211. bias=False,
  212. linear_method=linear_method,
  213. )
  214. else:
  215. self.merge_weight = True
  216. self.qkv_proj = QKVParallelLinear(
  217. hidden_size,
  218. self.head_dim,
  219. self.total_num_heads,
  220. self.total_num_kv_heads,
  221. bias=False,
  222. linear_method=linear_method,
  223. )
  224. self.o_proj = RowParallelLinear(
  225. self.total_num_heads * self.head_dim,
  226. hidden_size,
  227. bias=False,
  228. linear_method=linear_method,
  229. )
  230. self.rotary_emb = get_rope(
  231. self.head_dim,
  232. rotary_dim=self.head_dim,
  233. max_position=max_position,
  234. base=int(self.rope_theta),
  235. is_neox_style=True,
  236. )
  237. self.attn = PagedAttention(
  238. self.num_heads,
  239. self.head_dim,
  240. self.scaling,
  241. num_kv_heads=self.num_kv_heads,
  242. sliding_window=self.sliding_window,
  243. )
  244. def forward(
  245. self,
  246. positions: torch.Tensor,
  247. hidden_states: torch.Tensor,
  248. kv_cache: KVCache,
  249. input_metadata: InputMetadata,
  250. ) -> torch.Tensor:
  251. if self.merge_weight:
  252. qkv, _ = self.qkv_proj(hidden_states)
  253. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
  254. dim=-1)
  255. else:
  256. q, _ = self.q_proj(hidden_states)
  257. k, _ = self.k_proj(hidden_states)
  258. v, _ = self.v_proj(hidden_states)
  259. q, k = self.rotary_emb(positions, q, k)
  260. k_cache, v_cache = kv_cache
  261. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
  262. output, _ = self.o_proj(attn_output)
  263. return output
  264. class MixtralDecoderLayer(nn.Module):
  265. def __init__(
  266. self,
  267. config: MixtralConfig,
  268. linear_method: Optional[LinearMethodBase] = None,
  269. ) -> None:
  270. super().__init__()
  271. self.hidden_size = config.hidden_size
  272. # Requires transformers > 4.32.0
  273. rope_theta = getattr(config, "rope_theta", 10000)
  274. self.self_attn = MixtralAttention(
  275. hidden_size=self.hidden_size,
  276. num_heads=config.num_attention_heads,
  277. max_position=config.max_position_embeddings,
  278. num_kv_heads=config.num_key_value_heads,
  279. rope_theta=rope_theta,
  280. sliding_window=config.sliding_window,
  281. linear_method=linear_method,
  282. )
  283. self.block_sparse_moe = MixtralMoE(config=config,
  284. linear_method=linear_method)
  285. self.input_layernorm = RMSNorm(config.hidden_size,
  286. eps=config.rms_norm_eps)
  287. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  288. eps=config.rms_norm_eps)
  289. def forward(
  290. self,
  291. positions: torch.Tensor,
  292. hidden_states: torch.Tensor,
  293. kv_cache: KVCache,
  294. input_metadata: InputMetadata,
  295. residual: Optional[torch.Tensor],
  296. ) -> torch.Tensor:
  297. # Self Attention
  298. if residual is None:
  299. residual = hidden_states
  300. hidden_states = self.input_layernorm(hidden_states)
  301. else:
  302. hidden_states, residual = self.input_layernorm(
  303. hidden_states, residual)
  304. hidden_states = self.self_attn(
  305. positions=positions,
  306. hidden_states=hidden_states,
  307. kv_cache=kv_cache,
  308. input_metadata=input_metadata,
  309. )
  310. # Fully Connected
  311. hidden_states, residual = self.post_attention_layernorm(
  312. hidden_states, residual)
  313. hidden_states = self.block_sparse_moe(hidden_states)
  314. return hidden_states, residual
  315. class MixtralModel(nn.Module):
  316. def __init__(
  317. self,
  318. config: MixtralConfig,
  319. linear_method: Optional[LinearMethodBase] = None,
  320. ) -> None:
  321. super().__init__()
  322. self.padding_idx = config.pad_token_id
  323. self.vocab_size = config.vocab_size
  324. self.embed_tokens = VocabParallelEmbedding(
  325. config.vocab_size,
  326. config.hidden_size,
  327. linear_method=linear_method,
  328. )
  329. self.layers = nn.ModuleList([
  330. MixtralDecoderLayer(config, linear_method=linear_method)
  331. for _ in range(config.num_hidden_layers)
  332. ])
  333. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  334. def forward(
  335. self,
  336. input_ids: torch.Tensor,
  337. positions: torch.Tensor,
  338. kv_caches: List[KVCache],
  339. input_metadata: InputMetadata,
  340. ) -> torch.Tensor:
  341. hidden_states = self.embed_tokens(input_ids)
  342. residual = None
  343. for i in range(len(self.layers)):
  344. layer = self.layers[i]
  345. hidden_states, residual = layer(positions, hidden_states,
  346. kv_caches[i], input_metadata,
  347. residual)
  348. hidden_states, _ = self.norm(hidden_states, residual)
  349. return hidden_states
  350. class MixtralForCausalLM(nn.Module):
  351. def __init__(
  352. self,
  353. config: MixtralConfig,
  354. linear_method: Optional[LinearMethodBase] = None,
  355. ) -> None:
  356. super().__init__()
  357. self.config = config
  358. self.linear_method = linear_method
  359. self.model = MixtralModel(config, linear_method)
  360. self.lm_head = ParallelLMHead(
  361. config.vocab_size,
  362. config.hidden_size,
  363. linear_method=linear_method,
  364. )
  365. self.sampler = Sampler(config.vocab_size)
  366. self.quant_sampler = QuantSampler(config.vocab_size)
  367. def forward(
  368. self,
  369. input_ids: torch.Tensor,
  370. positions: torch.Tensor,
  371. kv_caches: List[KVCache],
  372. input_metadata: InputMetadata,
  373. ) -> torch.Tensor:
  374. hidden_states = self.model(input_ids, positions, kv_caches,
  375. input_metadata)
  376. return hidden_states
  377. def sample(
  378. self,
  379. hidden_states: Optional[torch.Tensor],
  380. sampling_metadata: SamplingMetadata,
  381. ) -> Optional[SamplerOutput]:
  382. if (self.linear_method is not None
  383. and not self.linear_method.quant_config.merge_weight()):
  384. next_tokens = self.quant_sampler(self.lm_head(hidden_states),
  385. sampling_metadata)
  386. else:
  387. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  388. sampling_metadata)
  389. return next_tokens
  390. def load_weights(
  391. self,
  392. model_name_or_path: str,
  393. cache_dir: Optional[str] = None,
  394. load_format: str = "auto",
  395. revision: Optional[str] = None,
  396. ):
  397. stacked_params_mapping = [
  398. # (param_name, shard_name, shard_id)
  399. ("qkv_proj", "q_proj", "q"),
  400. ("qkv_proj", "k_proj", "k"),
  401. ("qkv_proj", "v_proj", "v"),
  402. ]
  403. if (self.linear_method is not None
  404. and not self.linear_method.quant_config.merge_weight()):
  405. stacked_params_mapping = []
  406. params_dict = dict(self.named_parameters())
  407. for name, loaded_weight in hf_model_weights_iterator(
  408. model_name_or_path,
  409. cache_dir,
  410. load_format,
  411. revision,
  412. self.config,
  413. fall_back_to_pt=False,
  414. ):
  415. if "rotary_emb.inv_freq" in name:
  416. continue
  417. for param_name, weight_name, shard_id in stacked_params_mapping:
  418. if weight_name not in name:
  419. continue
  420. name = name.replace(weight_name, param_name)
  421. # Skip loading extra bias for GPTQ models.
  422. if name.endswith(".bias") and name not in params_dict:
  423. continue
  424. param = params_dict[name]
  425. weight_loader = param.weight_loader
  426. weight_loader(param, loaded_weight, shard_id)
  427. break
  428. else:
  429. # Skip loading extra bias for GPTQ models.
  430. if name.endswith(".bias") and name not in params_dict:
  431. continue
  432. # Skip experts that are not assigned to this worker.
  433. if ("block_sparse_moe.experts." in name
  434. and name not in params_dict):
  435. continue
  436. param = params_dict[name]
  437. weight_loader = getattr(param, "weight_loader",
  438. default_weight_loader)
  439. weight_loader(param, loaded_weight)