olmoe.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # Licensed under the Apache License, Version 2.0 (the "License");
  2. # you may not use this file except in compliance with the License.
  3. # You may obtain a copy of the License at
  4. #
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. #
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS,
  9. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. # See the License for the specific language governing permissions and
  11. # limitations under the License.
  12. """Inference-only OLMoE model compatible with HuggingFace weights."""
  13. from typing import Any, Dict, Iterable, List, Optional, Tuple
  14. import torch
  15. from torch import nn
  16. from transformers import PretrainedConfig
  17. from aphrodite.attention import Attention, AttentionMetadata
  18. from aphrodite.common.config import CacheConfig
  19. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  20. from aphrodite.common.utils import progress_bar
  21. from aphrodite.distributed import get_tensor_model_parallel_world_size
  22. from aphrodite.modeling.layers.fused_moe import FusedMoE
  23. from aphrodite.modeling.layers.layernorm import RMSNorm
  24. from aphrodite.modeling.layers.linear import (QKVParallelLinear,
  25. ReplicatedLinear,
  26. RowParallelLinear)
  27. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  28. from aphrodite.modeling.layers.rotary_embedding import get_rope
  29. from aphrodite.modeling.layers.sampler import Sampler
  30. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  31. ParallelLMHead, VocabParallelEmbedding)
  32. from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
  33. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  34. from aphrodite.quantization.base_config import QuantizationConfig
  35. class OlmoeMoE(nn.Module):
  36. """A tensor-parallel MoE implementation for Olmoe that shards each expert
  37. across all ranks.
  38. Each expert's weights are sharded across all ranks and a fused MoE
  39. kernel is used for the forward pass, and finally we reduce the outputs
  40. across ranks.
  41. """
  42. def __init__(self,
  43. num_experts: int,
  44. top_k: int,
  45. hidden_size: int,
  46. intermediate_size: int,
  47. params_dtype: Optional[torch.dtype] = None,
  48. quant_config: Optional[QuantizationConfig] = None,
  49. tp_size: Optional[int] = None,
  50. prefix: str = ""):
  51. super().__init__()
  52. self.hidden_size = hidden_size
  53. # Gate always runs at half / full precision for now.
  54. self.gate = ReplicatedLinear(hidden_size,
  55. num_experts,
  56. bias=False,
  57. quant_config=None)
  58. self.experts = FusedMoE(num_experts=num_experts,
  59. top_k=top_k,
  60. hidden_size=hidden_size,
  61. intermediate_size=intermediate_size,
  62. reduce_results=True,
  63. renormalize=False,
  64. quant_config=quant_config,
  65. tp_size=tp_size)
  66. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  67. # NOTE: hidden_states can have either 1D or 2D shape.
  68. orig_shape = hidden_states.shape
  69. hidden_dim = hidden_states.shape[-1]
  70. hidden_states = hidden_states.view(-1, hidden_dim)
  71. # router_logits: (num_tokens, n_experts)
  72. router_logits, _ = self.gate(hidden_states)
  73. final_hidden_states = self.experts(hidden_states=hidden_states,
  74. router_logits=router_logits)
  75. return final_hidden_states.view(orig_shape)
  76. class OlmoeAttention(nn.Module):
  77. def __init__(
  78. self,
  79. hidden_size: int,
  80. num_heads: int,
  81. num_kv_heads: int,
  82. rope_theta: float = 10000,
  83. rope_scaling: Optional[Dict[str, Any]] = None,
  84. max_position_embeddings: int = 4096,
  85. cache_config: Optional[CacheConfig] = None,
  86. quant_config: Optional[QuantizationConfig] = None,
  87. ) -> None:
  88. super().__init__()
  89. self.hidden_size = hidden_size
  90. tp_size = get_tensor_model_parallel_world_size()
  91. self.total_num_heads = num_heads
  92. assert self.total_num_heads % tp_size == 0
  93. self.num_heads = self.total_num_heads // tp_size
  94. self.total_num_kv_heads = num_kv_heads
  95. if self.total_num_kv_heads >= tp_size:
  96. # Number of KV heads is greater than TP size, so we partition
  97. # the KV heads across multiple tensor parallel GPUs.
  98. assert self.total_num_kv_heads % tp_size == 0
  99. else:
  100. # Number of KV heads is less than TP size, so we replicate
  101. # the KV heads across multiple tensor parallel GPUs.
  102. assert tp_size % self.total_num_kv_heads == 0
  103. self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
  104. self.head_dim = hidden_size // self.total_num_heads
  105. self.q_size = self.num_heads * self.head_dim
  106. self.kv_size = self.num_kv_heads * self.head_dim
  107. self.scaling = self.head_dim**-0.5
  108. self.rope_theta = rope_theta
  109. self.max_position_embeddings = max_position_embeddings
  110. self.qkv_proj = QKVParallelLinear(
  111. hidden_size,
  112. self.head_dim,
  113. self.total_num_heads,
  114. self.total_num_kv_heads,
  115. bias=False,
  116. quant_config=quant_config,
  117. )
  118. self.q_norm = RMSNorm(hidden_size, eps=1e-5)
  119. self.k_norm = RMSNorm(hidden_size, eps=1e-5)
  120. self.o_proj = RowParallelLinear(
  121. self.total_num_heads * self.head_dim,
  122. hidden_size,
  123. bias=False,
  124. quant_config=quant_config,
  125. )
  126. self.rotary_emb = get_rope(
  127. self.head_dim,
  128. rotary_dim=self.head_dim,
  129. max_position=max_position_embeddings,
  130. base=rope_theta,
  131. rope_scaling=rope_scaling,
  132. is_neox_style=True,
  133. )
  134. self.attn = Attention(self.num_heads,
  135. self.head_dim,
  136. self.scaling,
  137. num_kv_heads=self.num_kv_heads,
  138. cache_config=cache_config,
  139. quant_config=quant_config)
  140. def forward(
  141. self,
  142. positions: torch.Tensor,
  143. hidden_states: torch.Tensor,
  144. kv_cache: torch.Tensor,
  145. attn_metadata: AttentionMetadata,
  146. ) -> torch.Tensor:
  147. qkv, _ = self.qkv_proj(hidden_states)
  148. q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
  149. q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
  150. q, k = self.rotary_emb(positions, q, k)
  151. attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  152. output, _ = self.o_proj(attn_output)
  153. return output
  154. class OlmoeDecoderLayer(nn.Module):
  155. def __init__(
  156. self,
  157. config: PretrainedConfig,
  158. layer_idx: int,
  159. cache_config: Optional[CacheConfig] = None,
  160. quant_config: Optional[QuantizationConfig] = None,
  161. ) -> None:
  162. super().__init__()
  163. self.hidden_size = config.hidden_size
  164. rope_theta = getattr(config, "rope_theta", 10000)
  165. rope_scaling = getattr(config, "rope_scaling", None)
  166. max_position_embeddings = getattr(config, "max_position_embeddings",
  167. 4096)
  168. self.self_attn = OlmoeAttention(
  169. hidden_size=self.hidden_size,
  170. num_heads=config.num_attention_heads,
  171. num_kv_heads=config.num_key_value_heads,
  172. rope_theta=rope_theta,
  173. rope_scaling=rope_scaling,
  174. max_position_embeddings=max_position_embeddings,
  175. cache_config=cache_config,
  176. quant_config=quant_config,
  177. )
  178. self.mlp = OlmoeMoE(
  179. num_experts=config.num_experts,
  180. top_k=config.num_experts_per_tok,
  181. hidden_size=config.hidden_size,
  182. intermediate_size=config.intermediate_size,
  183. quant_config=quant_config,
  184. )
  185. self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
  186. self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
  187. def forward(
  188. self,
  189. positions: torch.Tensor,
  190. hidden_states: torch.Tensor,
  191. kv_cache: torch.Tensor,
  192. attn_metadata: AttentionMetadata,
  193. residual: Optional[torch.Tensor],
  194. ) -> torch.Tensor:
  195. # Self Attention
  196. if residual is None:
  197. residual = hidden_states
  198. hidden_states = self.input_layernorm(hidden_states)
  199. else:
  200. hidden_states, residual = self.input_layernorm(
  201. hidden_states, residual)
  202. hidden_states = self.self_attn(
  203. positions=positions,
  204. hidden_states=hidden_states,
  205. kv_cache=kv_cache,
  206. attn_metadata=attn_metadata,
  207. )
  208. # Fully Connected
  209. hidden_states, residual = self.post_attention_layernorm(
  210. hidden_states, residual)
  211. hidden_states = self.mlp(hidden_states)
  212. return hidden_states, residual
  213. class OlmoeModel(nn.Module):
  214. def __init__(
  215. self,
  216. config: PretrainedConfig,
  217. cache_config: Optional[CacheConfig] = None,
  218. quant_config: Optional[QuantizationConfig] = None,
  219. ) -> None:
  220. super().__init__()
  221. self.padding_idx = config.pad_token_id
  222. self.vocab_size = config.vocab_size
  223. self.embed_tokens = VocabParallelEmbedding(
  224. config.vocab_size,
  225. config.hidden_size,
  226. )
  227. self.layers = nn.ModuleList([
  228. OlmoeDecoderLayer(config,
  229. layer_idx,
  230. cache_config,
  231. quant_config=quant_config)
  232. for layer_idx in range(config.num_hidden_layers)
  233. ])
  234. self.norm = RMSNorm(config.hidden_size, eps=1e-5)
  235. def forward(
  236. self,
  237. input_ids: torch.Tensor,
  238. positions: torch.Tensor,
  239. kv_caches: List[torch.Tensor],
  240. attn_metadata: AttentionMetadata,
  241. ) -> torch.Tensor:
  242. hidden_states = self.embed_tokens(input_ids)
  243. residual = None
  244. for i in range(len(self.layers)):
  245. layer = self.layers[i]
  246. hidden_states, residual = layer(positions, hidden_states,
  247. kv_caches[i], attn_metadata,
  248. residual)
  249. hidden_states, _ = self.norm(hidden_states, residual)
  250. return hidden_states
  251. class OlmoeForCausalLM(nn.Module):
  252. fall_back_to_pt_during_load = False
  253. def __init__(
  254. self,
  255. config: PretrainedConfig,
  256. cache_config: Optional[CacheConfig] = None,
  257. quant_config: Optional[QuantizationConfig] = None,
  258. ) -> None:
  259. super().__init__()
  260. self.config = config
  261. self.quant_config = quant_config
  262. self.model = OlmoeModel(config, cache_config, quant_config)
  263. self.lm_head = ParallelLMHead(config.vocab_size,
  264. config.hidden_size,
  265. quant_config=quant_config)
  266. self.logits_processor = LogitsProcessor(config.vocab_size)
  267. self.sampler = Sampler()
  268. def forward(
  269. self,
  270. input_ids: torch.Tensor,
  271. positions: torch.Tensor,
  272. kv_caches: List[torch.Tensor],
  273. attn_metadata: AttentionMetadata,
  274. intermediate_tensors: Optional[IntermediateTensors] = None,
  275. ) -> torch.Tensor:
  276. hidden_states = self.model(input_ids, positions, kv_caches,
  277. attn_metadata)
  278. return hidden_states
  279. def compute_logits(self, hidden_states: torch.Tensor,
  280. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  281. logits = self.logits_processor(self.lm_head, hidden_states,
  282. sampling_metadata)
  283. return logits
  284. def sample(
  285. self,
  286. logits: Optional[torch.Tensor],
  287. sampling_metadata: SamplingMetadata,
  288. ) -> Optional[SamplerOutput]:
  289. next_tokens = self.sampler(logits, sampling_metadata)
  290. return next_tokens
  291. def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
  292. stacked_params_mapping = [
  293. # (param_name, shard_name, shard_id)
  294. ("qkv_proj", "q_proj", "q"),
  295. ("qkv_proj", "k_proj", "k"),
  296. ("qkv_proj", "v_proj", "v"),
  297. ("gate_up_proj", "gate_proj", 0),
  298. ("gate_up_proj", "up_proj", 1),
  299. ]
  300. # Params for weights, fp8 weight scales, fp8 activation scales
  301. # (param_name, weight_name, expert_id, shard_id)
  302. expert_params_mapping = FusedMoE.make_expert_params_mapping(
  303. ckpt_gate_proj_name="gate_proj",
  304. ckpt_down_proj_name="down_proj",
  305. ckpt_up_proj_name="up_proj",
  306. num_experts=self.config.num_experts)
  307. params_dict = dict(self.named_parameters())
  308. weights_list = list(weights)
  309. for name, loaded_weight in progress_bar(
  310. weights_list,
  311. desc="Loading modules..."
  312. ):
  313. if "rotary_emb.inv_freq" in name:
  314. continue
  315. for (param_name, weight_name, shard_id) in stacked_params_mapping:
  316. # Skip non-stacked layers and experts (experts handled below).
  317. if weight_name not in name:
  318. continue
  319. # We have mlp.experts[0].gate_proj in the checkpoint.
  320. # Since we handle the experts below in expert_params_mapping,
  321. # we need to skip here BEFORE we update the name, otherwise
  322. # name will be updated to mlp.experts[0].gate_up_proj, which
  323. # will then be updated below in expert_params_mapping
  324. # for mlp.experts[0].gate_gate_up_proj, which breaks load.
  325. if "mlp.experts" in name:
  326. continue
  327. name = name.replace(weight_name, param_name)
  328. # Skip loading extra bias for GPTQ models.
  329. if name.endswith(".bias") and name not in params_dict:
  330. continue
  331. if name not in params_dict:
  332. continue
  333. param = params_dict[name]
  334. weight_loader = param.weight_loader
  335. weight_loader(param, loaded_weight, shard_id)
  336. break
  337. else:
  338. for mapping in expert_params_mapping:
  339. param_name, weight_name, expert_id, shard_id = mapping
  340. if weight_name not in name:
  341. continue
  342. name = name.replace(weight_name, param_name)
  343. param = params_dict[name]
  344. weight_loader = param.weight_loader
  345. weight_loader(param,
  346. loaded_weight,
  347. name,
  348. shard_id=shard_id,
  349. expert_id=expert_id)
  350. break
  351. else:
  352. # Skip loading extra bias for GPTQ models.
  353. if name.endswith(".bias") and name not in params_dict:
  354. continue
  355. # Remapping the name of FP8 kv-scale.
  356. if name.endswith("kv_scale"):
  357. remapped_kv_scale_name = name.replace(
  358. ".kv_scale", ".attn.kv_scale")
  359. if remapped_kv_scale_name not in params_dict:
  360. print(f"Warning: Found kv scale in the checkpoint "
  361. f"(e.g. {name}), but not found the expected "
  362. f"name in the model "
  363. f"(e.g. {remapped_kv_scale_name}). "
  364. "kv-scale is not loaded.")
  365. continue
  366. else:
  367. name = remapped_kv_scale_name
  368. param = params_dict[name]
  369. weight_loader = getattr(param, "weight_loader",
  370. default_weight_loader)
  371. weight_loader(param, loaded_weight)