mistral.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. # coding=utf-8
  2. # Adapted from
  3. # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
  4. # Copyright 2023 The PygmalionAI team.
  5. # Copyright 2023 The vLLM team.
  6. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  7. #
  8. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  9. # and OPT implementations in this library. It has been modified from its
  10. # original forms to accommodate minor architectural differences compared
  11. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  12. #
  13. # Licensed under the Apache License, Version 2.0 (the "License");
  14. # you may not use this file except in compliance with the License.
  15. # You may obtain a copy of the License at
  16. #
  17. # http://www.apache.org/licenses/LICENSE-2.0
  18. #
  19. # Unless required by applicable law or agreed to in writing, software
  20. # distributed under the License is distributed on an "AS IS" BASIS,
  21. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  22. # See the License for the specific language governing permissions and
  23. # limitations under the License.
  24. """Inference-only Mistral model compatible with HuggingFace weights."""
  25. from typing import List, Optional, Tuple
  26. import torch
  27. from torch import nn
  28. from transformers import MistralConfig
  29. from aphrodite.modeling.metadata import InputMetadata
  30. from aphrodite.modeling.layers.activation import SiluAndMul
  31. from aphrodite.modeling.layers.attention import PagedAttention
  32. from aphrodite.modeling.layers.layernorm import RMSNorm
  33. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  34. MergedColumnParallelLinear,
  35. QKVParallelLinear,
  36. RowParallelLinear)
  37. from aphrodite.modeling.layers.rotary_embedding import get_rope
  38. from aphrodite.modeling.layers.sampler import Sampler
  39. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  40. VocabParallelEmbedding, ParallelLMHead)
  41. from aphrodite.modeling.megatron.parallel_state import (
  42. get_tensor_model_parallel_world_size)
  43. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  44. from aphrodite.modeling.hf_downloader import (default_weight_loader,
  45. hf_model_weights_iterator)
  46. from aphrodite.common.sequence import SamplerOutput
  47. KVCache = Tuple[torch.Tensor, torch.Tensor]
  48. class MistralMLP(nn.Module):
  49. def __init__(
  50. self,
  51. hidden_size: int,
  52. intermediate_size: int,
  53. hidden_act: str,
  54. linear_method: Optional[LinearMethodBase] = None,
  55. ) -> None:
  56. super().__init__()
  57. self.gate_up_proj = MergedColumnParallelLinear(
  58. hidden_size, [intermediate_size] * 2,
  59. bias=False,
  60. linear_method=linear_method)
  61. self.down_proj = RowParallelLinear(intermediate_size,
  62. hidden_size,
  63. bias=False,
  64. linear_method=linear_method)
  65. if hidden_act != "silu":
  66. raise ValueError(f"Unsupported activation: {hidden_act}. "
  67. "Only silu is supported for now.")
  68. self.act_fn = SiluAndMul()
  69. def forward(self, x):
  70. gate_up, _ = self.gate_up_proj(x)
  71. x = self.act_fn(gate_up)
  72. x, _ = self.down_proj(x)
  73. return x
  74. class MistralAttention(nn.Module):
  75. def __init__(self,
  76. hidden_size: int,
  77. num_heads: int,
  78. num_kv_heads: int,
  79. max_position: int = 4096 * 32,
  80. rope_theta: float = 10000,
  81. linear_method: Optional[LinearMethodBase] = None,
  82. sliding_window: Optional[int] = None) -> None:
  83. super().__init__()
  84. self.hidden_size = hidden_size
  85. tp_size = get_tensor_model_parallel_world_size()
  86. self.total_num_heads = num_heads
  87. assert self.total_num_heads % tp_size == 0
  88. self.num_heads = self.total_num_heads // tp_size
  89. self.total_num_kv_heads = num_kv_heads
  90. if self.total_num_kv_heads >= tp_size:
  91. # Number of KV heads is greater than TP size, so we partition
  92. # the KV heads across multiple tensor parallel GPUs.
  93. assert self.total_num_kv_heads % tp_size == 0
  94. else:
  95. # Number of KV heads is less than TP size, so we replicate
  96. # the KV heads across multiple tensor parallel GPUs.
  97. assert tp_size % self.total_num_kv_heads == 0
  98. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  99. self.head_dim = hidden_size // self.total_num_heads
  100. self.q_size = self.num_heads * self.head_dim
  101. self.kv_size = self.num_kv_heads * self.head_dim
  102. self.scaling = self.head_dim**-0.5
  103. self.rope_theta = rope_theta
  104. self.sliding_window = sliding_window
  105. self.qkv_proj = QKVParallelLinear(
  106. hidden_size,
  107. self.head_dim,
  108. self.total_num_heads,
  109. self.total_num_kv_heads,
  110. bias=False,
  111. linear_method=linear_method,
  112. )
  113. self.o_proj = RowParallelLinear(
  114. self.total_num_heads * self.head_dim,
  115. hidden_size,
  116. bias=False,
  117. linear_method=linear_method,
  118. )
  119. self.rotary_emb = get_rope(
  120. self.head_dim,
  121. rotary_dim=self.head_dim,
  122. max_position=max_position,
  123. base=self.rope_theta,
  124. )
  125. self.attn = PagedAttention(self.num_heads,
  126. self.head_dim,
  127. self.scaling,
  128. num_kv_heads=self.num_kv_heads,
  129. sliding_window=self.sliding_window)
  130. def forward(
  131. self,
  132. positions: torch.Tensor,
  133. hidden_states: torch.Tensor,
  134. kv_cache: KVCache,
  135. input_metadata: InputMetadata,
  136. cache_event: Optional[torch.cuda.Event],
  137. ) -> torch.Tensor:
  138. qkv, _ = self.qkv_proj(hidden_states)
  139. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  140. q, k = self.rotary_emb(positions, q, k)
  141. k_cache, v_cache = kv_cache
  142. attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
  143. cache_event)
  144. output, _ = self.o_proj(attn_output)
  145. return output
  146. class MistralDecoderLayer(nn.Module):
  147. def __init__(
  148. self,
  149. config: MistralConfig,
  150. linear_method: Optional[LinearMethodBase] = None,
  151. ) -> None:
  152. super().__init__()
  153. self.hidden_size = config.hidden_size
  154. # Requires transformers > 4.32.0
  155. rope_theta = getattr(config, "rope_theta", 10000)
  156. self.self_attn = MistralAttention(
  157. hidden_size=self.hidden_size,
  158. num_heads=config.num_attention_heads,
  159. max_position=config.max_position_embeddings,
  160. num_kv_heads=config.num_key_value_heads,
  161. rope_theta=rope_theta,
  162. linear_method=linear_method,
  163. sliding_window=config.sliding_window)
  164. self.mlp = MistralMLP(
  165. hidden_size=self.hidden_size,
  166. intermediate_size=config.intermediate_size,
  167. hidden_act=config.hidden_act,
  168. linear_method=linear_method,
  169. )
  170. self.input_layernorm = RMSNorm(config.hidden_size,
  171. eps=config.rms_norm_eps)
  172. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  173. eps=config.rms_norm_eps)
  174. def forward(
  175. self,
  176. positions: torch.Tensor,
  177. hidden_states: torch.Tensor,
  178. kv_cache: KVCache,
  179. input_metadata: InputMetadata,
  180. cache_event: Optional[torch.cuda.Event],
  181. residual: Optional[torch.Tensor],
  182. ) -> Tuple[torch.Tensor, torch.Tensor]:
  183. # Self Attention
  184. if residual is None:
  185. residual = hidden_states
  186. hidden_states = self.input_layernorm(hidden_states)
  187. else:
  188. hidden_states, residual = self.input_layernorm(
  189. hidden_states, residual)
  190. hidden_states = self.self_attn(
  191. positions=positions,
  192. hidden_states=hidden_states,
  193. kv_cache=kv_cache,
  194. input_metadata=input_metadata,
  195. cache_event=cache_event,
  196. )
  197. # Fully Connected
  198. hidden_states, residual = self.post_attention_layernorm(
  199. hidden_states, residual)
  200. hidden_states = self.mlp(hidden_states)
  201. return hidden_states, residual
  202. class MistralModel(nn.Module):
  203. def __init__(
  204. self,
  205. config: MistralConfig,
  206. linear_method: Optional[LinearMethodBase] = None,
  207. ) -> None:
  208. super().__init__()
  209. self.config = config
  210. self.padding_idx = config.pad_token_id
  211. self.vocab_size = config.vocab_size
  212. self.embed_tokens = VocabParallelEmbedding(
  213. config.vocab_size,
  214. config.hidden_size,
  215. )
  216. self.layers = nn.ModuleList([
  217. MistralDecoderLayer(config, linear_method)
  218. for _ in range(config.num_hidden_layers)
  219. ])
  220. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  221. def forward(
  222. self,
  223. input_ids: torch.Tensor,
  224. positions: torch.Tensor,
  225. kv_caches: List[KVCache],
  226. input_metadata: InputMetadata,
  227. cache_events: Optional[List[torch.cuda.Event]],
  228. ) -> torch.Tensor:
  229. hidden_states = self.embed_tokens(input_ids)
  230. residual = None
  231. for i in range(len(self.layers)):
  232. cache_event = None if cache_events is None else cache_events[i]
  233. layer = self.layers[i]
  234. hidden_states, residual = layer(
  235. positions,
  236. hidden_states,
  237. kv_caches[i],
  238. input_metadata,
  239. cache_event,
  240. residual,
  241. )
  242. hidden_states, _ = self.norm(hidden_states, residual)
  243. return hidden_states
  244. class MistralForCausalLM(nn.Module):
  245. def __init__(
  246. self,
  247. config: MistralConfig,
  248. linear_method: Optional[LinearMethodBase] = None,
  249. ) -> None:
  250. super().__init__()
  251. self.config = config
  252. self.linear_method = linear_method
  253. self.model = MistralModel(config, linear_method)
  254. self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
  255. self.sampler = Sampler(config.vocab_size)
  256. def forward(
  257. self,
  258. input_ids: torch.Tensor,
  259. positions: torch.Tensor,
  260. kv_caches: List[KVCache],
  261. input_metadata: InputMetadata,
  262. cache_events: Optional[List[torch.cuda.Event]],
  263. ) -> torch.Tensor:
  264. hidden_states = self.model(input_ids, positions, kv_caches,
  265. input_metadata, cache_events)
  266. return hidden_states
  267. def sample(
  268. self,
  269. hidden_states: torch.Tensor,
  270. sampling_metadata: SamplingMetadata,
  271. ) -> SamplerOutput:
  272. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  273. sampling_metadata)
  274. return next_tokens
  275. def load_weights(self,
  276. model_name_or_path: str,
  277. cache_dir: Optional[str] = None,
  278. load_format: str = "auto",
  279. revision: Optional[str] = None):
  280. stacked_params_mapping = [
  281. # (param_name, shard_name, shard_id)
  282. ("qkv_proj", "q_proj", "q"),
  283. ("qkv_proj", "k_proj", "k"),
  284. ("qkv_proj", "v_proj", "v"),
  285. ("gate_up_proj", "gate_proj", 0),
  286. ("gate_up_proj", "up_proj", 1),
  287. ]
  288. params_dict = dict(self.named_parameters())
  289. for name, loaded_weight in hf_model_weights_iterator(
  290. model_name_or_path, cache_dir, load_format, revision):
  291. if "rotary_emb.inv_freq" in name:
  292. continue
  293. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  294. if weight_name not in name:
  295. continue
  296. name = name.replace(weight_name, param_name)
  297. # Skip loading extra bias for GPTQ models.
  298. if name.endswith(".bias") and name not in params_dict:
  299. continue
  300. param = params_dict[name]
  301. weight_loader = param.weight_loader
  302. weight_loader(param, loaded_weight, shard_id)
  303. break
  304. else:
  305. # Skip loading extra bias for GPTQ models.
  306. if name.endswith(".bias") and name not in params_dict:
  307. continue
  308. param = params_dict[name]
  309. weight_loader = getattr(param, "weight_loader",
  310. default_weight_loader)
  311. weight_loader(param, loaded_weight)