olmoe.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # you may not use this file except in compliance with the License.
  3. # You may obtain a copy of the License at
  4. #
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. #
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. """Inference-only OLMoE model compatible with HuggingFace weights."""
  13. from typing import Any, Dict, Iterable, List, Optional, Tuple
  14. import torch
  15. from torch import nn
  16. from transformers import PretrainedConfig
  17. from aphrodite.attention import Attention, AttentionMetadata
  18. from aphrodite.common.config import CacheConfig
  19. from aphrodite.common.sequence import IntermediateTensors
  20. from aphrodite.distributed import get_tensor_model_parallel_world_size
  21. from aphrodite.modeling.layers.fused_moe import FusedMoE
  22. from aphrodite.modeling.layers.layernorm import RMSNorm
  23. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  24. ReplicatedLinear,
  25. RowParallelLinear)
  26. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  27. from aphrodite.modeling.layers.rotary_embedding import get_rope
  28. from aphrodite.modeling.layers.sampler import Sampler, SamplerOutput
  29. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  30. ParallelLMHead, VocabParallelEmbedding)
  31. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  32. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  33. from aphrodite.quantization.base_config import QuantizationConfig
  34. class OlmoeMoE(nn.Module):
  35. """A tensor-parallel MoE implementation for Olmoe that shards each expert
  36. across all ranks.
  37. Each expert's weights are sharded across all ranks and a fused MoE
  38. kernel is used for the forward pass, and finally we reduce the outputs
  39. across ranks.
  40. """
  41. def __init__(self,
  42. num_experts: int,
  43. top_k: int,
  44. hidden_size: int,
  45. intermediate_size: int,
  46. params_dtype: Optional[torch.dtype] = None,
  47. quant_config: Optional[QuantizationConfig] = None,
  48. tp_size: Optional[int] = None,
  49. prefix: str = ""):
  50. super().__init__()
  51. self.hidden_size = hidden_size
  52. # Gate always runs at half / full precision for now.
  53. self.gate = ReplicatedLinear(hidden_size,
  54. num_experts,
  55. bias=False,
  56. quant_config=None)
  57. self.experts = FusedMoE(num_experts=num_experts,
  58. top_k=top_k,
  59. hidden_size=hidden_size,
  60. intermediate_size=intermediate_size,
  61. reduce_results=True,
  62. renormalize=False,
  63. quant_config=quant_config,
  64. tp_size=tp_size)
  65. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  66. # NOTE: hidden_states can have either 1D or 2D shape.
  67. orig_shape = hidden_states.shape
  68. hidden_dim = hidden_states.shape[-1]
  69. hidden_states = hidden_states.view(-1, hidden_dim)
  70. # router_logits: (num_tokens, n_experts)
  71. router_logits, _ = self.gate(hidden_states)
  72. final_hidden_states = self.experts(hidden_states=hidden_states,
  73. router_logits=router_logits)
  74. return final_hidden_states.view(orig_shape)
  75. class OlmoeAttention(nn.Module):
  76. def __init__(
  77. self,
  78. hidden_size: int,
  79. num_heads: int,
  80. num_kv_heads: int,
  81. rope_theta: float = 10000,
  82. rope_scaling: Optional[Dict[str, Any]] = None,
  83. max_position_embeddings: int = 4096,
  84. cache_config: Optional[CacheConfig] = None,
  85. quant_config: Optional[QuantizationConfig] = None,
  86. ) -> None:
  87. super().__init__()
  88. self.hidden_size = hidden_size
  89. tp_size = get_tensor_model_parallel_world_size()
  90. self.total_num_heads = num_heads
  91. assert self.total_num_heads % tp_size == 0
  92. self.num_heads = self.total_num_heads // tp_size
  93. self.total_num_kv_heads = num_kv_heads
  94. if self.total_num_kv_heads >= tp_size:
  95. # Number of KV heads is greater than TP size, so we partition
  96. # the KV heads across multiple tensor parallel GPUs.
  97. assert self.total_num_kv_heads % tp_size == 0
  98. else:
  99. # Number of KV heads is less than TP size, so we replicate
  100. # the KV heads across multiple tensor parallel GPUs.
  101. assert tp_size % self.total_num_kv_heads == 0
  102. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  103. self.head_dim = hidden_size // self.total_num_heads
  104. self.q_size = self.num_heads * self.head_dim
  105. self.kv_size = self.num_kv_heads * self.head_dim
  106. self.scaling = self.head_dim**-0.5
  107. self.rope_theta = rope_theta
  108. self.max_position_embeddings = max_position_embeddings
  109. self.qkv_proj = QKVParallelLinear(
  110. hidden_size,
  111. self.head_dim,
  112. self.total_num_heads,
  113. self.total_num_kv_heads,
  114. bias=False,
  115. quant_config=quant_config,
  116. )
  117. self.q_norm = RMSNorm(hidden_size, eps=1e-5)
  118. self.k_norm = RMSNorm(hidden_size, eps=1e-5)
  119. self.o_proj = RowParallelLinear(
  120. self.total_num_heads * self.head_dim,
  121. hidden_size,
  122. bias=False,
  123. quant_config=quant_config,
  124. )
  125. self.rotary_emb = get_rope(
  126. self.head_dim,
  127. rotary_dim=self.head_dim,
  128. max_position=max_position_embeddings,
  129. base=rope_theta,
  130. rope_scaling=rope_scaling,
  131. is_neox_style=True,
  132. )
  133. self.attn = Attention(self.num_heads,
  134. self.head_dim,
  135. self.scaling,
  136. num_kv_heads=self.num_kv_heads,
  137. cache_config=cache_config,
  138. quant_config=quant_config)
  139. def forward(
  140. self,
  141. positions: torch.Tensor,
  142. hidden_states: torch.Tensor,
  143. kv_cache: torch.Tensor,
  144. attn_metadata: AttentionMetadata,
  145. ) -> torch.Tensor:
  146. qkv, _ = self.qkv_proj(hidden_states)
  147. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  148. q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
  149. q, k = self.rotary_emb(positions, q, k)
  150. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  151. output, _ = self.o_proj(attn_output)
  152. return output
  153. class OlmoeDecoderLayer(nn.Module):
  154. def __init__(
  155. self,
  156. config: PretrainedConfig,
  157. layer_idx: int,
  158. cache_config: Optional[CacheConfig] = None,
  159. quant_config: Optional[QuantizationConfig] = None,
  160. ) -> None:
  161. super().__init__()
  162. self.hidden_size = config.hidden_size
  163. rope_theta = getattr(config, "rope_theta", 10000)
  164. rope_scaling = getattr(config, "rope_scaling", None)
  165. max_position_embeddings = getattr(config, "max_position_embeddings",
  166. 4096)
  167. self.self_attn = OlmoeAttention(
  168. hidden_size=self.hidden_size,
  169. num_heads=config.num_attention_heads,
  170. num_kv_heads=config.num_key_value_heads,
  171. rope_theta=rope_theta,
  172. rope_scaling=rope_scaling,
  173. max_position_embeddings=max_position_embeddings,
  174. cache_config=cache_config,
  175. quant_config=quant_config,
  176. )
  177. self.mlp = OlmoeMoE(
  178. num_experts=config.num_experts,
  179. top_k=config.num_experts_per_tok,
  180. hidden_size=config.hidden_size,
  181. intermediate_size=config.intermediate_size,
  182. quant_config=quant_config,
  183. )
  184. self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
  185. self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
  186. def forward(
  187. self,
  188. positions: torch.Tensor,
  189. hidden_states: torch.Tensor,
  190. kv_cache: torch.Tensor,
  191. attn_metadata: AttentionMetadata,
  192. residual: Optional[torch.Tensor],
  193. ) -> torch.Tensor:
  194. # Self Attention
  195. if residual is None:
  196. residual = hidden_states
  197. hidden_states = self.input_layernorm(hidden_states)
  198. else:
  199. hidden_states, residual = self.input_layernorm(
  200. hidden_states, residual)
  201. hidden_states = self.self_attn(
  202. positions=positions,
  203. hidden_states=hidden_states,
  204. kv_cache=kv_cache,
  205. attn_metadata=attn_metadata,
  206. )
  207. # Fully Connected
  208. hidden_states, residual = self.post_attention_layernorm(
  209. hidden_states, residual)
  210. hidden_states = self.mlp(hidden_states)
  211. return hidden_states, residual
  212. class OlmoeModel(nn.Module):
  213. def __init__(
  214. self,
  215. config: PretrainedConfig,
  216. cache_config: Optional[CacheConfig] = None,
  217. quant_config: Optional[QuantizationConfig] = None,
  218. ) -> None:
  219. super().__init__()
  220. self.padding_idx = config.pad_token_id
  221. self.vocab_size = config.vocab_size
  222. self.embed_tokens = VocabParallelEmbedding(
  223. config.vocab_size,
  224. config.hidden_size,
  225. )
  226. self.layers = nn.ModuleList([
  227. OlmoeDecoderLayer(config,
  228. layer_idx,
  229. cache_config,
  230. quant_config=quant_config)
  231. for layer_idx in range(config.num_hidden_layers)
  232. ])
  233. self.norm = RMSNorm(config.hidden_size, eps=1e-5)
  234. def forward(
  235. self,
  236. input_ids: torch.Tensor,
  237. positions: torch.Tensor,
  238. kv_caches: List[torch.Tensor],
  239. attn_metadata: AttentionMetadata,
  240. ) -> torch.Tensor:
  241. hidden_states = self.embed_tokens(input_ids)
  242. residual = None
  243. for i in range(len(self.layers)):
  244. layer = self.layers[i]
  245. hidden_states, residual = layer(positions, hidden_states,
  246. kv_caches[i], attn_metadata,
  247. residual)
  248. hidden_states, _ = self.norm(hidden_states, residual)
  249. return hidden_states
  250. class OlmoeForCausalLM(nn.Module):
  251. fall_back_to_pt_during_load = False
  252. def __init__(
  253. self,
  254. config: PretrainedConfig,
  255. cache_config: Optional[CacheConfig] = None,
  256. quant_config: Optional[QuantizationConfig] = None,
  257. ) -> None:
  258. super().__init__()
  259. self.config = config
  260. self.quant_config = quant_config
  261. self.model = OlmoeModel(config, cache_config, quant_config)
  262. self.lm_head = ParallelLMHead(config.vocab_size,
  263. config.hidden_size,
  264. quant_config=quant_config)
  265. self.logits_processor = LogitsProcessor(config.vocab_size)
  266. self.sampler = Sampler()
  267. def forward(
  268. self,
  269. input_ids: torch.Tensor,
  270. positions: torch.Tensor,
  271. kv_caches: List[torch.Tensor],
  272. attn_metadata: AttentionMetadata,
  273. intermediate_tensors: Optional[IntermediateTensors] = None,
  274. ) -> torch.Tensor:
  275. hidden_states = self.model(input_ids, positions, kv_caches,
  276. attn_metadata)
  277. return hidden_states
  278. def compute_logits(self, hidden_states: torch.Tensor,
  279. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  280. logits = self.logits_processor(self.lm_head, hidden_states,
  281. sampling_metadata)
  282. return logits
  283. def sample(
  284. self,
  285. logits: Optional[torch.Tensor],
  286. sampling_metadata: SamplingMetadata,
  287. ) -> Optional[SamplerOutput]:
  288. next_tokens = self.sampler(logits, sampling_metadata)
  289. return next_tokens
  290. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  291. stacked_params_mapping = [
  292. # (param_name, shard_name, shard_id)
  293. ("qkv_proj", "q_proj", "q"),
  294. ("qkv_proj", "k_proj", "k"),
  295. ("qkv_proj", "v_proj", "v"),
  296. ("gate_up_proj", "gate_proj", 0),
  297. ("gate_up_proj", "up_proj", 1),
  298. ]
  299. # Params for weights, fp8 weight scales, fp8 activation scales
  300. # (param_name, weight_name, expert_id, shard_id)
  301. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  302. ckpt_gate_proj_name="gate_proj",
  303. ckpt_down_proj_name="down_proj",
  304. ckpt_up_proj_name="up_proj",
  305. num_experts=self.config.num_experts)
  306. params_dict = dict(self.named_parameters())
  307. for name, loaded_weight in weights:
  308. if "rotary_emb.inv_freq" in name:
  309. continue
  310. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  311. # Skip non-stacked layers and experts (experts handled below).
  312. if weight_name not in name:
  313. continue
  314. # We have mlp.experts[0].gate_proj in the checkpoint.
  315. # Since we handle the experts below in expert_params_mapping,
  316. # we need to skip here BEFORE we update the name, otherwise
  317. # name will be updated to mlp.experts[0].gate_up_proj, which
  318. # will then be updated below in expert_params_mapping
  319. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
  320. if "mlp.experts" in name:
  321. continue
  322. name = name.replace(weight_name, param_name)
  323. # Skip loading extra bias for GPTQ models.
  324. if name.endswith(".bias") and name not in params_dict:
  325. continue
  326. if name not in params_dict:
  327. continue
  328. param = params_dict[name]
  329. weight_loader = param.weight_loader
  330. weight_loader(param, loaded_weight, shard_id)
  331. break
  332. else:
  333. for mapping in expert_params_mapping:
  334. param_name, weight_name, expert_id, shard_id = mapping
  335. if weight_name not in name:
  336. continue
  337. name = name.replace(weight_name, param_name)
  338. param = params_dict[name]
  339. weight_loader = param.weight_loader
  340. weight_loader(param,
  341. loaded_weight,
  342. name,
  343. shard_id=shard_id,
  344. expert_id=expert_id)
  345. break
  346. else:
  347. # Skip loading extra bias for GPTQ models.
  348. if name.endswith(".bias") and name not in params_dict:
  349. continue
  350. # Remapping the name of FP8 kv-scale.
  351. if name.endswith("kv_scale"):
  352. remapped_kv_scale_name = name.replace(
  353. ".kv_scale", ".attn.kv_scale")
  354. if remapped_kv_scale_name not in params_dict:
  355. print(f"Warning: Found kv scale in the checkpoint "
  356. f"(e.g. {name}), but not found the expected "
  357. f"name in the model "
  358. f"(e.g. {remapped_kv_scale_name}). "
  359. "kv-scale is not loaded.")
  360. continue
  361. else:
  362. name = remapped_kv_scale_name
  363. param = params_dict[name]
  364. weight_loader = getattr(param, "weight_loader",
  365. default_weight_loader)
  366. weight_loader(param, loaded_weight)