Browse Source

Add OLMoE (#772)

* Create olmoe.py

* Update models/__init__.py

* Bump tf version

* (hopefully) fix formatting issues
Fizz~ 4 months ago
parent
commit
8a71788372

+ 7 - 7
aphrodite/common/grammar.py

@@ -1,17 +1,17 @@
 import collections
 import collections
-from copy import deepcopy, copy
-from dataclasses import dataclass, fields
 import functools
 import functools
-import regex
-import torch
-from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
-from typing import Optional, List, Set, Union
 import weakref
 import weakref
+from copy import copy, deepcopy
+from dataclasses import dataclass, fields
+from typing import List, Optional, Set, Union
 
 
+import regex
+import torch
 from lark import Lark
 from lark import Lark
+from lark.lexer import Pattern, PatternRE, PatternStr, Token
 from lark.parsers.lalr_interactive_parser import InteractiveParser
 from lark.parsers.lalr_interactive_parser import InteractiveParser
 from lark.parsers.lalr_parser_state import ParserState
 from lark.parsers.lalr_parser_state import ParserState
-from lark.lexer import Token, Pattern, PatternStr, PatternRE
+from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
 
 
 
 
 class FastParserState(ParserState):
 class FastParserState(ParserState):

+ 2 - 1
aphrodite/common/logits_processor.py

@@ -1,7 +1,8 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-import torch
 from typing import Dict, List
 from typing import Dict, List
 
 
+import torch
+
 
 
 class LogitsProcessor(ABC):
 class LogitsProcessor(ABC):
 
 

+ 2 - 4
aphrodite/common/test_utils.py

@@ -1,10 +1,8 @@
 import ray
 import ray
 
 
 from aphrodite.common.utils import get_open_port
 from aphrodite.common.utils import get_open_port
-from aphrodite.distributed import (
-    ensure_model_parallel_initialized,
-    init_distributed_environment,
-)
+from aphrodite.distributed import (ensure_model_parallel_initialized,
+                                   init_distributed_environment)
 
 
 
 
 def init_test_distributed_environment(
 def init_test_distributed_environment(

+ 8 - 9
aphrodite/endpoints/kobold/api_server.py

@@ -4,25 +4,24 @@ import argparse
 import asyncio
 import asyncio
 import json
 import json
 import os
 import os
-
 from http import HTTPStatus
 from http import HTTPStatus
-from typing import List, Tuple, AsyncGenerator
+from typing import AsyncGenerator, List, Tuple
 
 
-from prometheus_client import make_asgi_app
-import uvicorn
 import fastapi
 import fastapi
+import uvicorn
 from fastapi import APIRouter, Request, Response
 from fastapi import APIRouter, Request, Response
-from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
 from loguru import logger
 from loguru import logger
+from prometheus_client import make_asgi_app
 
 
-from aphrodite.engine.args_tools import AsyncEngineArgs
-from aphrodite.engine.async_aphrodite import AsyncAphrodite
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.outputs import RequestOutput
-from aphrodite.common.sampling_params import SamplingParams, _SAMPLING_EPS
-from aphrodite.transformers_utils.tokenizer import get_tokenizer
+from aphrodite.common.sampling_params import _SAMPLING_EPS, SamplingParams
 from aphrodite.common.utils import random_uuid
 from aphrodite.common.utils import random_uuid
 from aphrodite.endpoints.kobold.protocol import KAIGenerationInputSchema
 from aphrodite.endpoints.kobold.protocol import KAIGenerationInputSchema
+from aphrodite.engine.args_tools import AsyncEngineArgs
+from aphrodite.engine.async_aphrodite import AsyncAphrodite
+from aphrodite.transformers_utils.tokenizer import get_tokenizer
 
 
 TIMEOUT_KEEP_ALIVE = 5  # seconds
 TIMEOUT_KEEP_ALIVE = 5  # seconds
 
 

+ 1 - 0
aphrodite/endpoints/kobold/protocol.py

@@ -1,4 +1,5 @@
 from typing import List, Optional, Union
 from typing import List, Optional, Union
+
 from pydantic import BaseModel, Field, root_validator
 from pydantic import BaseModel, Field, root_validator
 
 
 
 

+ 1 - 1
aphrodite/kv_quant/calibrate.py

@@ -8,9 +8,9 @@ from accelerate import (infer_auto_device_map, init_empty_weights,
                         load_checkpoint_in_model)
                         load_checkpoint_in_model)
 from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
 
 
+from aphrodite.kv_quant.calib_dataloader import get_calib_loaders
 from aphrodite.kv_quant.calibration import CalibrationContext
 from aphrodite.kv_quant.calibration import CalibrationContext
 from aphrodite.kv_quant.utils import collect_target_modules
 from aphrodite.kv_quant.utils import collect_target_modules
-from aphrodite.kv_quant.calib_dataloader import get_calib_loaders
 
 
 LAYER_TYPE_MAP = {
 LAYER_TYPE_MAP = {
     'InternLMForCausalLM': 'InternLMDecoderLayer',
     'InternLMForCausalLM': 'InternLMDecoderLayer',

+ 4 - 3
aphrodite/kv_quant/calibration.py

@@ -3,14 +3,15 @@ from functools import partial
 from typing import Union
 from typing import Union
 
 
 import torch
 import torch
-from torch import nn
 import transformers
 import transformers
-from transformers import PreTrainedTokenizer
 from pkg_resources import parse_version
 from pkg_resources import parse_version
+from torch import nn
+from transformers import PreTrainedTokenizer
+
+from aphrodite.kv_quant.observer import ActivationObserver, KVCacheObserver
 from aphrodite.kv_quant.utils import (bimap_name_mod, collect_target_modules,
 from aphrodite.kv_quant.utils import (bimap_name_mod, collect_target_modules,
                                       concat_decoder_layer_outputs,
                                       concat_decoder_layer_outputs,
                                       split_decoder_layer_inputs)
                                       split_decoder_layer_inputs)
-from aphrodite.kv_quant.observer import ActivationObserver, KVCacheObserver
 
 
 
 
 class CalibrationContext():
 class CalibrationContext():

+ 1 - 1
aphrodite/kv_quant/export_kv_params.py

@@ -2,9 +2,9 @@
 from pathlib import Path
 from pathlib import Path
 from typing import Union
 from typing import Union
 
 
+import fire
 import numpy as np
 import numpy as np
 import torch
 import torch
-import fire
 
 
 
 
 def _export_sym(key_stats: dict,
 def _export_sym(key_stats: dict,

+ 1 - 0
aphrodite/kv_quant/observer.py

@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 # Copyright (c) OpenMMLab. All rights reserved.
 from typing import Dict, Union
 from typing import Dict, Union
+
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 

+ 1 - 0
aphrodite/kv_quant/utils.py

@@ -1,5 +1,6 @@
 # Copyright (c) OpenMMLab. All rights reserved.
 # Copyright (c) OpenMMLab. All rights reserved.
 from typing import Any, Dict, List, Tuple, Union
 from typing import Any, Dict, List, Tuple, Union
+
 import torch
 import torch
 from torch import nn
 from torch import nn
 
 

+ 2 - 2
aphrodite/modeling/guided_decoding/lm_format_enforcer_logits_processors.py

@@ -5,8 +5,8 @@ from typing import List, Optional, Union
 import torch
 import torch
 from lmformatenforcer import (CharacterLevelParser, FormatEnforcerAnalyzer,
 from lmformatenforcer import (CharacterLevelParser, FormatEnforcerAnalyzer,
                               TokenEnforcer, TokenEnforcerTokenizerData)
                               TokenEnforcer, TokenEnforcerTokenizerData)
-from lmformatenforcer.integrations.transformers import \
-    build_token_enforcer_tokenizer_data
+from lmformatenforcer.integrations.transformers import (
+    build_token_enforcer_tokenizer_data)
 from transformers import PreTrainedTokenizerBase
 from transformers import PreTrainedTokenizerBase
 
 
 import aphrodite
 import aphrodite

+ 2 - 2
aphrodite/modeling/layers/ops/rand.py

@@ -1,9 +1,9 @@
+from typing import Optional, Union
+
 import torch
 import torch
 import triton
 import triton
 import triton.language as tl
 import triton.language as tl
 
 
-from typing import Optional, Union
-
 
 
 def seeded_uniform(
 def seeded_uniform(
     *size,
     *size,

+ 1 - 0
aphrodite/modeling/models/__init__.py

@@ -42,6 +42,7 @@ _GENERATION_MODELS = {
     "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
     "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
     "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
     "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
     "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
     "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
+    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
     "OPTForCausalLM": ("opt", "OPTForCausalLM"),
     "OPTForCausalLM": ("opt", "OPTForCausalLM"),
     "OrionForCausalLM": ("orion", "OrionForCausalLM"),
     "OrionForCausalLM": ("orion", "OrionForCausalLM"),
     "PhiForCausalLM": ("phi", "PhiForCausalLM"),
     "PhiForCausalLM": ("phi", "PhiForCausalLM"),

+ 2 - 2
aphrodite/modeling/models/exaone.py

@@ -38,8 +38,8 @@ from aphrodite.distributed import (get_pp_group,
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
-                                               QKVParallelLinear,
-                                               RowParallelLinear)
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.sampler import Sampler

+ 410 - 0
aphrodite/modeling/models/olmoe.py

@@ -0,0 +1,410 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only OLMoE model compatible with HuggingFace weights."""
+from typing import Any, Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import PretrainedConfig
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.common.utils import progress_bar
+from aphrodite.distributed import get_tensor_model_parallel_world_size
+from aphrodite.modeling.layers.fused_moe import FusedMoE
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (QKVParallelLinear,
+                                              ReplicatedLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
+
+
+class OlmoeMoE(nn.Module):
+    """A tensor-parallel MoE implementation for Olmoe that shards each expert
+    across all ranks.
+
+    Each expert's weights are sharded across all ranks and a fused MoE
+    kernel is used for the forward pass, and finally we reduce the outputs
+    across ranks.
+    """
+
+    def __init__(self,
+                 num_experts: int,
+                 top_k: int,
+                 hidden_size: int,
+                 intermediate_size: int,
+                 params_dtype: Optional[torch.dtype] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 tp_size: Optional[int] = None,
+                 prefix: str = ""):
+        super().__init__()
+        self.hidden_size = hidden_size
+
+        # Gate always runs at half / full precision for now.
+        self.gate = ReplicatedLinear(hidden_size,
+                                     num_experts,
+                                     bias=False,
+                                     quant_config=None)
+
+        self.experts = FusedMoE(num_experts=num_experts,
+                                top_k=top_k,
+                                hidden_size=hidden_size,
+                                intermediate_size=intermediate_size,
+                                reduce_results=True,
+                                renormalize=False,
+                                quant_config=quant_config,
+                                tp_size=tp_size)
+
+    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+        # NOTE: hidden_states can have either 1D or 2D shape.
+        orig_shape = hidden_states.shape
+        hidden_dim = hidden_states.shape[-1]
+        hidden_states = hidden_states.view(-1, hidden_dim)
+        # router_logits: (num_tokens, n_experts)
+        router_logits, _ = self.gate(hidden_states)
+        final_hidden_states = self.experts(hidden_states=hidden_states,
+                                           router_logits=router_logits)
+        return final_hidden_states.view(orig_shape)
+
+class OlmoeAttention(nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        rope_theta: float = 10000,
+        rope_scaling: Optional[Dict[str, Any]] = None,
+        max_position_embeddings: int = 4096,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = hidden_size
+        tp_size = get_tensor_model_parallel_world_size()
+        self.total_num_heads = num_heads
+        assert self.total_num_heads % tp_size == 0
+        self.num_heads = self.total_num_heads // tp_size
+        self.total_num_kv_heads = num_kv_heads
+        if self.total_num_kv_heads >= tp_size:
+            # Number of KV heads is greater than TP size, so we partition
+            # the KV heads across multiple tensor parallel GPUs.
+            assert self.total_num_kv_heads % tp_size == 0
+        else:
+            # Number of KV heads is less than TP size, so we replicate
+            # the KV heads across multiple tensor parallel GPUs.
+            assert tp_size % self.total_num_kv_heads == 0
+        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
+        self.head_dim = hidden_size // self.total_num_heads
+        self.q_size = self.num_heads * self.head_dim
+        self.kv_size = self.num_kv_heads * self.head_dim
+        self.scaling = self.head_dim**-0.5
+        self.rope_theta = rope_theta
+        self.max_position_embeddings = max_position_embeddings
+
+        self.qkv_proj = QKVParallelLinear(
+            hidden_size,
+            self.head_dim,
+            self.total_num_heads,
+            self.total_num_kv_heads,
+            bias=False,
+            quant_config=quant_config,
+        )
+        self.q_norm = RMSNorm(hidden_size, eps=1e-5)
+        self.k_norm = RMSNorm(hidden_size, eps=1e-5)
+        self.o_proj = RowParallelLinear(
+            self.total_num_heads * self.head_dim,
+            hidden_size,
+            bias=False,
+            quant_config=quant_config,
+        )
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=max_position_embeddings,
+            base=rope_theta,
+            rope_scaling=rope_scaling,
+            is_neox_style=True,
+        )
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config,
+                              quant_config=quant_config)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        qkv, _ = self.qkv_proj(hidden_states)
+        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
+        q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.o_proj(attn_output)
+        return output
+
+
+class OlmoeDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        layer_idx: int,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        rope_theta = getattr(config, "rope_theta", 10000)
+        rope_scaling = getattr(config, "rope_scaling", None)
+        max_position_embeddings = getattr(config, "max_position_embeddings",
+                                          4096)
+
+        self.self_attn = OlmoeAttention(
+            hidden_size=self.hidden_size,
+            num_heads=config.num_attention_heads,
+            num_kv_heads=config.num_key_value_heads,
+            rope_theta=rope_theta,
+            rope_scaling=rope_scaling,
+            max_position_embeddings=max_position_embeddings,
+            cache_config=cache_config,
+            quant_config=quant_config,
+        )
+
+        self.mlp = OlmoeMoE(
+            num_experts=config.num_experts,
+            top_k=config.num_experts_per_tok,
+            hidden_size=config.hidden_size,
+            intermediate_size=config.intermediate_size,
+            quant_config=quant_config,
+        )
+        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
+        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        # Self Attention
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.input_layernorm(hidden_states)
+        else:
+            hidden_states, residual = self.input_layernorm(
+                hidden_states, residual)
+
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+
+        # Fully Connected
+        hidden_states, residual = self.post_attention_layernorm(
+            hidden_states, residual)
+        hidden_states = self.mlp(hidden_states)
+        return hidden_states, residual
+
+
+class OlmoeModel(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.padding_idx = config.pad_token_id
+        self.vocab_size = config.vocab_size
+
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+        )
+        self.layers = nn.ModuleList([
+            OlmoeDecoderLayer(config,
+                              layer_idx,
+                              cache_config,
+                              quant_config=quant_config)
+            for layer_idx in range(config.num_hidden_layers)
+        ])
+        self.norm = RMSNorm(config.hidden_size, eps=1e-5)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+        residual = None
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states, residual = layer(positions, hidden_states,
+                                            kv_caches[i], attn_metadata,
+                                            residual)
+        hidden_states, _ = self.norm(hidden_states, residual)
+        return hidden_states
+
+
+class OlmoeForCausalLM(nn.Module):
+
+    fall_back_to_pt_during_load = False
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.quant_config = quant_config
+        self.model = OlmoeModel(config, cache_config, quant_config)
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
+        self.logits_processor = LogitsProcessor(config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+    ) -> torch.Tensor:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   attn_metadata)
+        return hidden_states
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+
+        # Params for weights, fp8 weight scales, fp8 activation scales
+        # (param_name, weight_name, expert_id, shard_id)
+        expert_params_mapping = FusedMoE.make_expert_params_mapping(
+            ckpt_gate_proj_name="gate_proj",
+            ckpt_down_proj_name="down_proj",
+            ckpt_up_proj_name="up_proj",
+            num_experts=self.config.num_experts)
+
+        params_dict = dict(self.named_parameters())
+        weights_list = list(weights)
+        for name, loaded_weight in progress_bar(
+            weights_list,
+            desc="Loading modules..."
+        ):
+            if "rotary_emb.inv_freq" in name:
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                # Skip non-stacked layers and experts (experts handled below).
+                if weight_name not in name:
+                    continue
+                # We have mlp.experts[0].gate_proj in the checkpoint.
+                # Since we handle the experts below in expert_params_mapping,
+                # we need to skip here BEFORE we update the name, otherwise
+                # name will be updated to mlp.experts[0].gate_up_proj, which
+                # will then be updated below in expert_params_mapping
+                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
+                if "mlp.experts" in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                if name not in params_dict:
+                    continue
+
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                for mapping in expert_params_mapping:
+                    param_name, weight_name, expert_id, shard_id = mapping
+                    if weight_name not in name:
+                        continue
+                    name = name.replace(weight_name, param_name)
+                    param = params_dict[name]
+                    weight_loader = param.weight_loader
+                    weight_loader(param,
+                                  loaded_weight,
+                                  name,
+                                  shard_id=shard_id,
+                                  expert_id=expert_id)
+                    break
+                else:
+                    # Skip loading extra bias for GPTQ models.
+                    if name.endswith(".bias") and name not in params_dict:
+                        continue
+                    # Remapping the name of FP8 kv-scale.
+                    if name.endswith("kv_scale"):
+                        remapped_kv_scale_name = name.replace(
+                            ".kv_scale", ".attn.kv_scale")
+                        if remapped_kv_scale_name not in params_dict:
+                            print(f"Warning: Found kv scale in the checkpoint "
+                                  f"(e.g. {name}), but not found the expected "
+                                  f"name in the model "
+                                  f"(e.g. {remapped_kv_scale_name}). "
+                                  "kv-scale is not loaded.")
+                            continue
+                        else:
+                            name = remapped_kv_scale_name
+
+                    param = params_dict[name]
+                    weight_loader = getattr(param, "weight_loader",
+                                            default_weight_loader)
+                    weight_loader(param, loaded_weight)

+ 2 - 8
aphrodite/quantization/gguf_utils/gguf_reader.py

@@ -19,14 +19,8 @@ if __name__ == "__main__":
     # Allow running file in package as a script.
     # Allow running file in package as a script.
     sys.path.insert(0, str(Path(__file__).parent.parent))
     sys.path.insert(0, str(Path(__file__).parent.parent))
 
 
-from .constants import (
-    GGML_QUANT_SIZES,
-    GGUF_DEFAULT_ALIGNMENT,
-    GGUF_MAGIC,
-    GGUF_VERSION,
-    GGMLQuantizationType,
-    GGUFValueType,
-)
+from .constants import (GGML_QUANT_SIZES, GGUF_DEFAULT_ALIGNMENT, GGUF_MAGIC,
+                        GGUF_VERSION, GGMLQuantizationType, GGUFValueType)
 
 
 READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
 READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
 
 

+ 1 - 1
aphrodite/quantization/quip_utils.py

@@ -1,10 +1,10 @@
 import math
 import math
+from contextlib import suppress
 from pathlib import Path
 from pathlib import Path
 
 
 import scipy
 import scipy
 import torch
 import torch
 from safetensors.torch import load_file
 from safetensors.torch import load_file
-from contextlib import suppress
 
 
 with suppress(ImportError):
 with suppress(ImportError):
     import aphrodite._hadamard_C as hadamard_C
     import aphrodite._hadamard_C as hadamard_C

+ 0 - 1
aphrodite/transformers_utils/tokenizers/baichuan.py

@@ -11,7 +11,6 @@ import sentencepiece as spm
 from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
 from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
 from transformers.utils import logging
 from transformers.utils import logging
 
 
-
 logger = logging.get_logger(__name__)
 logger = logging.get_logger(__name__)
 
 
 VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
 VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}

+ 4 - 3
examples/marlin/convert.py

@@ -1,10 +1,11 @@
-import torch
 import argparse
 import argparse
 import copy
 import copy
-from transformers import AutoModelForCausalLM, AutoTokenizer
+import gc
+
+import torch
 from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
 from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
 from marlin import Layer as MarlinLayer
 from marlin import Layer as MarlinLayer
-import gc
+from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 parser = argparse.ArgumentParser()
 parser = argparse.ArgumentParser()
 parser.add_argument("--model-id", type=str)
 parser.add_argument("--model-id", type=str)

+ 3 - 2
examples/offline_inference/slora_inference.py

@@ -3,11 +3,12 @@ This example shows how to use the multi-LoRA functionality for offline
 inference. Requires HuggingFace credentials for access to Llama2.
 inference. Requires HuggingFace credentials for access to Llama2.
 """
 """
 
 
-from typing import Optional, List, Tuple
+from typing import List, Optional, Tuple
 
 
 from huggingface_hub import snapshot_download
 from huggingface_hub import snapshot_download
 
 
-from aphrodite import EngineArgs, AphroditeEngine, SamplingParams, RequestOutput
+from aphrodite import (AphroditeEngine, EngineArgs, RequestOutput,
+                       SamplingParams)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 
 
 
 

+ 2 - 2
requirements-common.txt

@@ -4,7 +4,7 @@ numpy < 2.0.0
 requests
 requests
 tqdm
 tqdm
 py-cpuinfo
 py-cpuinfo
-transformers == 4.44.1 # needed for llama
+transformers == 4.45.2 # needed for llama
 tokenizers >= 0.19.1
 tokenizers >= 0.19.1
 fastapi
 fastapi
 aiohttp
 aiohttp
@@ -30,4 +30,4 @@ gguf == 0.9.1
 importlib_metadata
 importlib_metadata
 mistral_common >= 1.3.4
 mistral_common >= 1.3.4
 protobuf
 protobuf
-pandas
+pandas

+ 7 - 8
tests/endpoints/test_openai_server.py

@@ -1,19 +1,18 @@
+# imports for guided decoding tests
+import json
 import os
 import os
+import re
 import subprocess
 import subprocess
+import sys
 import time
 import time
 
 
-import sys
+import jsonschema
+import openai  # use the official client for correctness check
 import pytest
 import pytest
-import requests
 import ray
 import ray
-import openai  # use the official client for correctness check
+import requests
 from huggingface_hub import snapshot_download
 from huggingface_hub import snapshot_download
 
 
-# imports for guided decoding tests
-import json
-import jsonschema
-import re
-
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
 from aphrodite.transformers_utils.tokenizer import get_tokenizer
 
 
 MAX_SERVER_START_WAIT_S = 600  # wait for server to start for 60 seconds
 MAX_SERVER_START_WAIT_S = 600  # wait for server to start for 60 seconds

+ 2 - 2
tests/endpoints/test_outlines.py

@@ -1,11 +1,11 @@
 # This unit test should be moved to a new
 # This unit test should be moved to a new
 # tests/test_guided_decoding directory.
 # tests/test_guided_decoding directory.
 
 
-from transformers import AutoTokenizer
 import torch
 import torch
+from transformers import AutoTokenizer
 
 
 from aphrodite.modeling.outlines_logits_processors import (
 from aphrodite.modeling.outlines_logits_processors import (
-    RegexLogitsProcessor, JSONLogitsProcessor)
+    JSONLogitsProcessor, RegexLogitsProcessor)
 
 
 TEST_SCHEMA = {
 TEST_SCHEMA = {
     "type": "object",
     "type": "object",

+ 0 - 1
tests/engine/test_detokenize.py

@@ -1,5 +1,4 @@
 import pytest
 import pytest
-
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 
 
 from aphrodite.transformers_utils.tokenizer import detokenize_incrementally
 from aphrodite.transformers_utils.tokenizer import detokenize_incrementally

+ 2 - 2
tests/samplers/test_samplers.py

@@ -5,10 +5,10 @@ from unittest.mock import patch
 import pytest
 import pytest
 import torch
 import torch
 
 
-from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.utils import set_random_seed
 from aphrodite.common.sequence import (SamplingParams, SequenceData,
 from aphrodite.common.sequence import (SamplingParams, SequenceData,
                                        SequenceGroupMetadata)
                                        SequenceGroupMetadata)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.utils import set_random_seed
 from aphrodite.task_handler.model_runner import ModelRunner
 from aphrodite.task_handler.model_runner import ModelRunner
 
 
 
 

+ 1 - 1
tests/samplers/test_seeded_generate.py

@@ -7,8 +7,8 @@ from itertools import combinations
 
 
 import pytest
 import pytest
 
 
-from aphrodite.modeling.utils import set_random_seed
 from aphrodite import SamplingParams
 from aphrodite import SamplingParams
+from aphrodite.modeling.utils import set_random_seed
 
 
 MODEL = "EleutherAI/pythia-70m-deduped"
 MODEL = "EleutherAI/pythia-70m-deduped"
 RANDOM_SEEDS = list(range(5))
 RANDOM_SEEDS = list(range(5))

+ 1 - 0
tests/worker/test_model_runner.py

@@ -1,4 +1,5 @@
 import random
 import random
+
 import torch
 import torch
 
 
 from aphrodite.common.sequence import (SamplingParams, SequenceData,
 from aphrodite.common.sequence import (SamplingParams, SequenceData,