mistral.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. from typing import List, Optional, Tuple
  2. import torch
  3. from torch import nn
  4. from transformers import MistralConfig
  5. from aphrodite.modeling.metadata import InputMetadata
  6. from aphrodite.modeling.layers.activation import SiluAndMul
  7. from aphrodite.modeling.layers.layernorm import RMSNorm
  8. from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
  9. from aphrodite.modeling.layers.sampler import Sampler
  10. from aphrodite.modeling.layers.quantized_linear import ParallelLinear
  11. from aphrodite.modeling.megatron.parallel_state import (
  12. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  13. from aphrodite.modeling.megatron.layers import VocabParallelEmbedding
  14. from aphrodite.modeling.quantization_utils import QuantizationConfig
  15. from aphrodite.modeling.hf_downloader import (
  16. convert_pyslice_to_tensor, hf_model_weights_iterator,
  17. load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
  18. get_parallel_weight)
  19. from aphrodite.common.sequence import SamplerOutput
  20. KVCache = Tuple[torch.Tensor, torch.Tensor]
  21. class MistralMLP(nn.Module):
  22. def __init__(
  23. self,
  24. hidden_size: int,
  25. intermediate_size: int,
  26. hidden_act: str,
  27. quant_config: Optional[QuantizationConfig] = None,
  28. ) -> None:
  29. super().__init__()
  30. self.gate_up_proj = ParallelLinear.column(hidden_size,
  31. 2 * intermediate_size,
  32. bias=False,
  33. gather_output=False,
  34. quant_config=quant_config)
  35. self.down_proj = ParallelLinear.row(intermediate_size,
  36. hidden_size,
  37. bias=False,
  38. input_is_parallel=True,
  39. quant_config=quant_config)
  40. if hidden_act != "silu":
  41. raise ValueError(f"Unsupported activation: {hidden_act}. "
  42. "Only silu is supported for now.")
  43. self.act_fn = SiluAndMul()
  44. def forward(self, x):
  45. gate_up, _ = self.gate_up_proj(x)
  46. x = self.act_fn(gate_up)
  47. x, _ = self.down_proj(x)
  48. return x
  49. class MistralAttention(nn.Module):
  50. def __init__(self,
  51. hidden_size: int,
  52. num_heads: int,
  53. num_kv_heads: int,
  54. max_position: int = 4096 * 32,
  55. rope_theta: float = 10000,
  56. quant_config: Optional[QuantizationConfig] = None,
  57. sliding_window: Optional[int] = None) -> None:
  58. super().__init__()
  59. self.hidden_size = hidden_size
  60. tp_size = get_tensor_model_parallel_world_size()
  61. self.total_num_heads = num_heads
  62. assert self.total_num_heads % tp_size == 0
  63. self.num_heads = self.total_num_heads // tp_size
  64. self.total_num_kv_heads = num_kv_heads
  65. assert self.total_num_kv_heads % tp_size == 0
  66. self.num_kv_heads = self.total_num_kv_heads // tp_size
  67. self.head_dim = hidden_size // self.total_num_heads
  68. self.q_size = self.num_heads * self.head_dim
  69. self.kv_size = self.num_kv_heads * self.head_dim
  70. self.scaling = self.head_dim**-0.5
  71. self.rope_theta = rope_theta
  72. self.sliding_window = sliding_window
  73. self.qkv_proj = ParallelLinear.column(
  74. hidden_size,
  75. (self.total_num_heads + 2 * self.total_num_kv_heads) *
  76. self.head_dim,
  77. bias=False,
  78. gather_output=False,
  79. quant_config=quant_config,
  80. )
  81. self.o_proj = ParallelLinear.row(
  82. self.total_num_heads * self.head_dim,
  83. hidden_size,
  84. bias=False,
  85. input_is_parallel=True,
  86. quant_config=quant_config,
  87. )
  88. self.attn = PagedAttentionWithRoPE(self.num_heads,
  89. self.head_dim,
  90. self.scaling,
  91. base=self.rope_theta,
  92. max_position=max_position,
  93. rotary_dim=self.head_dim,
  94. num_kv_heads=self.num_kv_heads,
  95. sliding_window=self.sliding_window)
  96. def forward(
  97. self,
  98. positions: torch.Tensor,
  99. hidden_states: torch.Tensor,
  100. kv_cache: KVCache,
  101. input_metadata: InputMetadata,
  102. cache_event: Optional[torch.cuda.Event],
  103. ) -> torch.Tensor:
  104. qkv, _ = self.qkv_proj(hidden_states)
  105. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  106. k_cache, v_cache = kv_cache
  107. attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
  108. input_metadata, cache_event)
  109. output, _ = self.o_proj(attn_output)
  110. return output
  111. class MistralDecoderLayer(nn.Module):
  112. def __init__(
  113. self,
  114. config: MistralConfig,
  115. quant_config: Optional[QuantizationConfig] = None,
  116. ) -> None:
  117. super().__init__()
  118. self.hidden_size = config.hidden_size
  119. # Requires transformers > 4.32.0
  120. rope_theta = getattr(config, "rope_theta", 10000)
  121. self.self_attn = MistralAttention(
  122. hidden_size=self.hidden_size,
  123. num_heads=config.num_attention_heads,
  124. max_position=config.max_position_embeddings,
  125. num_kv_heads=config.num_key_value_heads,
  126. rope_theta=rope_theta,
  127. quant_config=quant_config,
  128. sliding_window=config.sliding_window)
  129. self.mlp = MistralMLP(
  130. hidden_size=self.hidden_size,
  131. intermediate_size=config.intermediate_size,
  132. hidden_act=config.hidden_act,
  133. quant_config=quant_config,
  134. )
  135. self.input_layernorm = RMSNorm(config.hidden_size,
  136. eps=config.rms_norm_eps)
  137. self.post_attention_layernorm = RMSNorm(config.hidden_size,
  138. eps=config.rms_norm_eps)
  139. def forward(
  140. self,
  141. positions: torch.Tensor,
  142. hidden_states: torch.Tensor,
  143. kv_cache: KVCache,
  144. input_metadata: InputMetadata,
  145. cache_event: Optional[torch.cuda.Event],
  146. ) -> torch.Tensor:
  147. # Self Attention
  148. residual = hidden_states
  149. hidden_states = self.input_layernorm(hidden_states)
  150. hidden_states = self.self_attn(
  151. positions=positions,
  152. hidden_states=hidden_states,
  153. kv_cache=kv_cache,
  154. input_metadata=input_metadata,
  155. cache_event=cache_event,
  156. )
  157. hidden_states = residual + hidden_states
  158. # Fully Connected
  159. residual = hidden_states
  160. hidden_states = self.post_attention_layernorm(hidden_states)
  161. hidden_states = self.mlp(hidden_states)
  162. hidden_states = residual + hidden_states
  163. return hidden_states
  164. class MistralModel(nn.Module):
  165. def __init__(
  166. self,
  167. config: MistralConfig,
  168. quant_config: Optional[QuantizationConfig] = None,
  169. ) -> None:
  170. super().__init__()
  171. self.config = config
  172. self.padding_idx = config.pad_token_id
  173. self.vocab_size = config.vocab_size
  174. vocab_size = ((config.vocab_size + 63) // 64) * 64
  175. self.embed_tokens = VocabParallelEmbedding(
  176. vocab_size,
  177. config.hidden_size,
  178. )
  179. self.layers = nn.ModuleList([
  180. MistralDecoderLayer(config, quant_config)
  181. for _ in range(config.num_hidden_layers)
  182. ])
  183. self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  184. def forward(
  185. self,
  186. input_ids: torch.Tensor,
  187. positions: torch.Tensor,
  188. kv_caches: List[KVCache],
  189. input_metadata: InputMetadata,
  190. cache_events: Optional[List[torch.cuda.Event]],
  191. ) -> torch.Tensor:
  192. hidden_states = self.embed_tokens(input_ids)
  193. for i in range(len(self.layers)):
  194. if cache_events is None:
  195. cache_event = None
  196. else:
  197. cache_event = cache_events[i]
  198. layer = self.layers[i]
  199. hidden_states = layer(
  200. positions,
  201. hidden_states,
  202. kv_caches[i],
  203. input_metadata,
  204. cache_event,
  205. )
  206. hidden_states = self.norm(hidden_states)
  207. return hidden_states
  208. class MistralForCausalLM(nn.Module):
  209. def __init__(
  210. self,
  211. config: MistralConfig,
  212. quant_config: Optional[QuantizationConfig] = None,
  213. ) -> None:
  214. super().__init__()
  215. self.config = config
  216. self.quant_config = quant_config
  217. self.model = MistralModel(config, quant_config)
  218. vocab_size = ((config.vocab_size + 63) // 64) * 64
  219. # NOTE: The LM head is not quantized.
  220. self.lm_head = ParallelLinear.column(config.hidden_size,
  221. vocab_size,
  222. bias=False,
  223. gather_output=False,
  224. quant_config=None)
  225. self.sampler = Sampler(config.vocab_size)
  226. def forward(
  227. self,
  228. input_ids: torch.Tensor,
  229. positions: torch.Tensor,
  230. kv_caches: List[KVCache],
  231. input_metadata: InputMetadata,
  232. cache_events: Optional[List[torch.cuda.Event]],
  233. ) -> SamplerOutput:
  234. hidden_states = self.model(input_ids, positions, kv_caches,
  235. input_metadata, cache_events)
  236. next_tokens = self.sampler(self.lm_head.weight, hidden_states,
  237. input_metadata)
  238. return next_tokens
  239. column_parallel_layers = []
  240. row_parallel_layers = ["o_proj", "down_proj"]
  241. def load_weights(self,
  242. model_name_or_path: str,
  243. cache_dir: Optional[str] = None,
  244. load_format: str = "auto",
  245. revision: Optional[str] = None):
  246. column_parallel_weights, row_parallel_weights = get_parallel_weight(
  247. self)
  248. column_weight_suffixes = (
  249. self.quant_config.get_col_parallel_tensor_names()
  250. ) if self.quant_config is not None else ["weight", "bias"]
  251. tp_size = get_tensor_model_parallel_world_size()
  252. tensor_model_parallel_rank = get_tensor_model_parallel_rank()
  253. q_proj_shard_size = (self.config.hidden_size // tp_size)
  254. kv_proj_shard_size = (self.config.hidden_size //
  255. self.config.num_attention_heads *
  256. self.config.num_key_value_heads // tp_size)
  257. attention_weight_specs = [
  258. # (weight_name, shard_size, offset)
  259. ("q_proj", q_proj_shard_size, 0),
  260. ("k_proj", kv_proj_shard_size, q_proj_shard_size),
  261. ("v_proj", kv_proj_shard_size,
  262. q_proj_shard_size + kv_proj_shard_size),
  263. ]
  264. state_dict = self.state_dict()
  265. for name, loaded_weight in hf_model_weights_iterator(
  266. model_name_or_path, cache_dir, load_format, revision):
  267. if "rotary_emb.inv_freq" in name:
  268. continue
  269. packed_dim = None
  270. is_transposed = False
  271. if self.quant_config is not None:
  272. packed_dim = self.quant_config.get_packed_dim(name)
  273. is_transposed = self.quant_config.is_transposed(name)
  274. if is_transposed:
  275. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  276. loaded_weight = loaded_weight.T
  277. is_attention_weight = False
  278. for weight_name, shard_size, offset in attention_weight_specs:
  279. if weight_name not in name:
  280. continue
  281. name = name.replace(weight_name, "qkv_proj")
  282. if name not in state_dict: # pylint: disable=unsupported-membership-test
  283. break
  284. param = state_dict[name] # pylint: disable=unsubscriptable-object
  285. if is_transposed:
  286. param = param.T
  287. if packed_dim is not None:
  288. shard_dim = 0 if not is_transposed else 1
  289. if packed_dim == shard_dim:
  290. shard_size //= self.quant_config.pack_factor
  291. offset //= self.quant_config.pack_factor
  292. if any(
  293. name.endswith(suffix)
  294. for suffix in column_weight_suffixes):
  295. loaded_weight = loaded_weight[
  296. shard_size * tensor_model_parallel_rank:shard_size *
  297. (tensor_model_parallel_rank + 1)]
  298. param_slice = param.data[offset:offset + shard_size]
  299. else:
  300. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  301. param_slice = param.data
  302. assert param_slice.shape == loaded_weight.shape
  303. param_slice.copy_(loaded_weight)
  304. is_attention_weight = True
  305. break
  306. if is_attention_weight:
  307. continue
  308. is_gate_up_weight = False
  309. for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
  310. if weight_name not in name:
  311. continue
  312. name = name.replace(weight_name, "gate_up_proj")
  313. if name not in state_dict: # pylint: disable=unsupported-membership-test
  314. break
  315. param = state_dict[name] # pylint: disable=unsubscriptable-object
  316. if is_transposed:
  317. param = param.T
  318. shard_size = param.shape[0] // 2
  319. if any(
  320. name.endswith(suffix)
  321. for suffix in column_weight_suffixes):
  322. loaded_weight = loaded_weight[
  323. shard_size * tensor_model_parallel_rank:shard_size *
  324. (tensor_model_parallel_rank + 1)]
  325. param_slice = param.data[shard_size *
  326. stride_id:shard_size *
  327. (stride_id + 1)]
  328. else:
  329. loaded_weight = convert_pyslice_to_tensor(loaded_weight)
  330. param_slice = param.data
  331. assert param_slice.shape == loaded_weight.shape
  332. param_slice.copy_(loaded_weight)
  333. is_gate_up_weight = True
  334. break
  335. if is_gate_up_weight:
  336. continue
  337. if name not in state_dict: # pylint: disable=unsupported-membership-test
  338. continue
  339. param = state_dict[name] # pylint: disable=unsubscriptable-object
  340. if is_transposed:
  341. param = param.T
  342. if "embed_tokens" in name or "lm_head" in name:
  343. load_padded_tensor_parallel_vocab(param, loaded_weight,
  344. tensor_model_parallel_rank)
  345. continue
  346. load_tensor_parallel_weights(param, loaded_weight, name,
  347. column_parallel_weights,
  348. row_parallel_weights,
  349. tensor_model_parallel_rank)