Bladeren bron

chore: generalize linear_method to be quant_method (#540)

Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
AlpinDale 7 maanden geleden
bovenliggende
commit
b178ae4b4a
49 gewijzigde bestanden met toevoegingen van 963 en 955 verwijderingen
  1. 13 17
      aphrodite/lora/layers.py
  2. 86 81
      aphrodite/modeling/layers/linear.py
  3. 20 21
      aphrodite/modeling/model_loader/loader.py
  4. 6 6
      aphrodite/modeling/model_loader/tensorizer.py
  5. 6 6
      aphrodite/modeling/models/__init__.py
  6. 23 22
      aphrodite/modeling/models/baichuan.py
  7. 16 17
      aphrodite/modeling/models/bloom.py
  8. 18 18
      aphrodite/modeling/models/chatglm.py
  9. 18 18
      aphrodite/modeling/models/commandr.py
  10. 17 17
      aphrodite/modeling/models/dbrx.py
  11. 3 4
      aphrodite/modeling/models/decilm.py
  12. 21 24
      aphrodite/modeling/models/deepseek.py
  13. 16 17
      aphrodite/modeling/models/falcon.py
  14. 21 22
      aphrodite/modeling/models/gemma.py
  15. 16 17
      aphrodite/modeling/models/gpt2.py
  16. 16 17
      aphrodite/modeling/models/gpt_bigcode.py
  17. 16 17
      aphrodite/modeling/models/gpt_j.py
  18. 16 17
      aphrodite/modeling/models/gpt_neox.py
  19. 16 16
      aphrodite/modeling/models/internlm2.py
  20. 16 17
      aphrodite/modeling/models/jais.py
  21. 15 17
      aphrodite/modeling/models/llama.py
  22. 4 4
      aphrodite/modeling/models/llava.py
  23. 17 18
      aphrodite/modeling/models/minicpm.py
  24. 21 22
      aphrodite/modeling/models/mixtral.py
  25. 21 22
      aphrodite/modeling/models/mixtral_quant.py
  26. 16 16
      aphrodite/modeling/models/mpt.py
  27. 160 173
      aphrodite/modeling/models/olmo.py
  28. 18 19
      aphrodite/modeling/models/opt.py
  29. 16 16
      aphrodite/modeling/models/orion.py
  30. 17 18
      aphrodite/modeling/models/phi.py
  31. 16 16
      aphrodite/modeling/models/qwen.py
  32. 16 17
      aphrodite/modeling/models/qwen2.py
  33. 21 24
      aphrodite/modeling/models/qwen2_moe.py
  34. 14 14
      aphrodite/modeling/models/stablelm.py
  35. 15 16
      aphrodite/modeling/models/starcoder2.py
  36. 16 16
      aphrodite/modeling/models/xverse.py
  37. 4 4
      aphrodite/quantization/__init__.py
  38. 9 6
      aphrodite/quantization/aqlm.py
  39. 12 9
      aphrodite/quantization/awq.py
  40. 25 3
      aphrodite/quantization/base_config.py
  41. 16 12
      aphrodite/quantization/bitsandbytes.py
  42. 14 11
      aphrodite/quantization/eetq.py
  43. 13 10
      aphrodite/quantization/exl2.py
  44. 27 39
      aphrodite/quantization/fp8.py
  45. 13 10
      aphrodite/quantization/gguf.py
  46. 12 9
      aphrodite/quantization/gptq.py
  47. 9 6
      aphrodite/quantization/marlin.py
  48. 11 7
      aphrodite/quantization/quip.py
  49. 15 10
      aphrodite/quantization/squeezellm.py

+ 13 - 17
aphrodite/lora/layers.py

@@ -388,10 +388,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
         self.indices = base_indices
         self.indices_len = indices_len
 
-    def apply_weights(self, x: torch.Tensor,
-                      bias: Optional[torch.Tensor]) -> torch.Tensor:
-        output = self.base_layer.linear_method.apply_weights(
-            self.base_layer.linear_weights, x, bias)
+    def apply(self, x: torch.Tensor,
+              bias: Optional[torch.Tensor]) -> torch.Tensor:
+        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
         _apply_lora(
             x,
             self.lora_a_stacked,
@@ -415,7 +414,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
                 if not self.base_layer.skip_bias_add else None)
 
         # Matrix multiply.
-        output_parallel = self.apply_weights(input_, bias)
+        output_parallel = self.apply(input_, bias)
         if self.base_layer.gather_output:
             # All-gather across the partitions.
             output = tensor_model_parallel_all_gather(output_parallel)
@@ -526,10 +525,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
                 index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
                     lora_b[1].T, non_blocking=True)
 
-    def apply_weights(self, x: torch.Tensor,
-                      bias: Optional[torch.Tensor]) -> torch.Tensor:
-        output = self.base_layer.linear_method.apply_weights(
-            self.base_layer.linear_weights, x, bias)
+    def apply(self, x: torch.Tensor,
+              bias: Optional[torch.Tensor]) -> torch.Tensor:
+        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
         _apply_lora_packed_nslice(
             x,
             self.lora_a_stacked,
@@ -768,10 +766,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
                 index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
                     lora_a[2].T, non_blocking=True)
 
-    def apply_weights(self, x: torch.Tensor,
-                      bias: Optional[torch.Tensor]) -> torch.Tensor:
-        output = self.base_layer.linear_method.apply_weights(
-            self.base_layer.linear_weights, x, bias)
+    def apply(self, x: torch.Tensor,
+              bias: Optional[torch.Tensor]) -> torch.Tensor:
+        output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
         _apply_lora_packed_nslice(
             x,
             self.lora_a_stacked,
@@ -863,9 +860,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
         self.indices = base_indices
         self.indices_len = indices_len
 
-    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
-        output = self.base_layer.linear_method.apply_weights(
-            self.base_layer.linear_weights, x)
+    def apply(self, x: torch.Tensor) -> torch.Tensor:
+        output = self.base_layer.quant_method.apply(self.base_layer, x)
         _apply_lora(
             x,
             self.lora_a_stacked,
@@ -898,7 +894,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
             input_parallel = splitted_input[tp_rank].contiguous()
 
         # Matrix multiply.
-        output_parallel = self.apply_weights(input_parallel)
+        output_parallel = self.apply(input_parallel)
         if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
             output_ = tensor_model_parallel_all_reduce(output_parallel)
         else:

+ 86 - 81
aphrodite/modeling/layers/linear.py

@@ -1,10 +1,9 @@
-from abc import ABC, abstractmethod
+from abc import abstractmethod
 from typing import List, Optional
 
 import torch
 import torch.nn.functional as F
 from loguru import logger
-from torch import nn
 from torch.nn.parameter import Parameter
 
 from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
@@ -13,6 +12,8 @@ from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
                                    tensor_model_parallel_all_gather,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import (QuantizationConfig,
+                                                QuantizeMethodBase)
 
 
 def adjust_marlin_shard(param, shard_size, shard_offset):
@@ -23,7 +24,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
     return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
 
 
-class LinearMethodBase(ABC):
+class LinearMethodBase(QuantizeMethodBase):
     """Base class for different (maybe quantized) linear methods."""
 
     @abstractmethod
@@ -48,22 +49,15 @@ class LinearMethodBase(ABC):
         raise NotImplementedError
 
     @abstractmethod
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         """Apply the weights in layer to the input tensor.
 
         Expects create_weights to have been called before on the layer."""
         raise NotImplementedError
 
-    def process_weights_after_loading(self, layer: nn.Module) -> None:
-        """Process the weight after loading.
-
-        This can be used for example, to transpose weights for computation.
-        """
-        return
-
 
 class UnquantizedLinearMethod(LinearMethodBase):
     """Linear method without quantization.
@@ -90,10 +84,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
         layer.register_parameter("weight", weight)
         set_weight_attrs(weight, extra_weight_attrs)
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         weight = layer.weight
         if self.separate_bias_add:
             if bias is not None:
@@ -102,8 +96,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
         return F.linear(x, weight, bias)
 
 
-class ReplicatedLinear(torch.nn.Module):
-    """Replicated linear layer.
+class LinearBase(torch.nn.Module):
+    """Base linear layer.
 
     Args:
         input_size: input dimension of the linear layer.
@@ -111,17 +105,16 @@ class ReplicatedLinear(torch.nn.Module):
         bias: If true, add bias.
         skip_bias_add: If true, skip adding bias but instead return it.
         params_dtype: Data type for the parameters.
-        linear_method: (Maybe quantized) linear method.
+        quant_config: Quantization config..
     """
 
     def __init__(
         self,
         input_size: int,
         output_size: int,
-        bias: bool = True,
         skip_bias_add: bool = False,
         params_dtype: Optional[torch.dtype] = None,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
 
@@ -132,12 +125,41 @@ class ReplicatedLinear(torch.nn.Module):
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
         self.params_dtype = params_dtype
-        if linear_method is None:
-            linear_method = UnquantizedLinearMethod()
-        self.linear_method = linear_method
-        self.linear_method.create_weights(self, self.input_size,
-                                          [self.output_size], self.input_size,
-                                          self.output_size, self.params_dtype)
+        if quant_config is None:
+            self.quant_method = UnquantizedLinearMethod()
+        else:
+            self.quant_method = quant_config.get_quant_method(self)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError
+
+
+class ReplicatedLinear(LinearBase):
+    """Replicated linear layer.
+    Args:
+        input_size: input dimension of the linear layer.
+        output_size: output dimension of the linear layer.
+        bias: If true, add bias.
+        skip_bias_add: If true, skip adding bias but instead return it.
+        params_dtype: Data type for the parameters.
+        quant_config: Quantization configure.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int,
+        bias: bool = True,
+        skip_bias_add: bool = False,
+        params_dtype: Optional[torch.dtype] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
+                         quant_config)
+
+        self.quant_method.create_weights(self, self.input_size,
+                                         [self.output_size], self.input_size,
+                                         self.output_size, self.params_dtype)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.output_size, dtype=self.params_dtype))
@@ -147,12 +169,12 @@ class ReplicatedLinear(torch.nn.Module):
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         bias = self.bias if not self.skip_bias_add else None
-        output = self.linear_method.apply_weights(self, x, bias)
+        output = self.quant_method.apply(self, x, bias)
         output_bias = self.bias if self.skip_bias_add else None
         return output, output_bias
 
 
-class ColumnParallelLinear(torch.nn.Module):
+class ColumnParallelLinear(LinearBase):
     """Linear layer with column parallelism.
 
     The linear layer is defined as Y = XA + b. A is parallelized along
@@ -169,7 +191,7 @@ class ColumnParallelLinear(torch.nn.Module):
                        bias can be fused with other element-wise operations. we
                        skip adding bias but instead return it.
         params_dtype: Data type for the parameters.
-        linear_method: (Maybe quantized) linear method.
+        quant_config: Quantization configure.
         output_sizes: list of output sizes packed into one output, like for QKV
                        the list would be size 3.
     """
@@ -182,34 +204,26 @@ class ColumnParallelLinear(torch.nn.Module):
         gather_output: bool = False,
         skip_bias_add: bool = False,
         params_dtype: Optional[torch.dtype] = None,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         output_sizes: Optional[List[int]] = None,
     ):
-        super().__init__()
+        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
+                         quant_config)
 
-        # Keep input parameters
-        self.input_size = input_size
-        self.output_size = output_size
         self.gather_output = gather_output
+
         # Divide the weight matrix along the last dimension.
         tp_size = get_tensor_model_parallel_world_size()
         self.output_size_per_partition = divide(output_size, tp_size)
-        self.skip_bias_add = skip_bias_add
-        if params_dtype is None:
-            params_dtype = torch.get_default_dtype()
-        self.params_dtype = params_dtype
-        if linear_method is None:
-            linear_method = UnquantizedLinearMethod()
         if output_sizes is None:
             output_sizes = [output_size]
-        self.linear_method = linear_method
-        self.linear_method.create_weights(self,
-                                          self.input_size,
-                                          [x // tp_size for x in output_sizes],
-                                          self.input_size,
-                                          self.output_size,
-                                          self.params_dtype,
-                                          weight_loader=self.weight_loader)
+        self.quant_method.create_weights(self,
+                                         self.input_size,
+                                         [x // tp_size for x in output_sizes],
+                                         self.input_size,
+                                         self.output_size,
+                                         self.params_dtype,
+                                         weight_loader=self.weight_loader)
         if bias:
             self.bias = Parameter(
                 torch.empty(self.output_size_per_partition,
@@ -237,7 +251,7 @@ class ColumnParallelLinear(torch.nn.Module):
         bias = self.bias if not self.skip_bias_add else None
 
         # Matrix multiply.
-        output_parallel = self.linear_method.apply_weights(self, input_, bias)
+        output_parallel = self.quant_method.apply(self, input_, bias)
         if self.gather_output:
             # All-gather across the partitions.
             output = tensor_model_parallel_all_gather(output_parallel)
@@ -265,7 +279,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                        bias can be fused with other element-wise operations. we
                        skip adding bias but instead return it.
         params_dtype: Data type for the parameters.
-        linear_method: (Maybe quantized) linear method.
+        quant_config: Quantization configure.
     """
 
     def __init__(
@@ -276,13 +290,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
         gather_output: bool = False,
         skip_bias_add: bool = False,
         params_dtype: Optional[torch.dtype] = None,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         self.output_sizes = output_sizes
         tp_size = get_tensor_model_parallel_world_size()
         assert all(output_size % tp_size == 0 for output_size in output_sizes)
         super().__init__(input_size, sum(output_sizes), bias, gather_output,
-                         skip_bias_add, params_dtype, linear_method,
+                         skip_bias_add, params_dtype, quant_config,
                          self.output_sizes)
 
     def weight_loader(self,
@@ -382,7 +396,7 @@ class QKVParallelLinear(ColumnParallelLinear):
                        bias can be fused with other element-wise operations. we
                        skip adding bias but instead return it.
         params_dtype: Data type for the parameters.
-        linear_method: (Maybe quantized) linear method.
+        quant_config: Quantization configure.
     """
 
     def __init__(
@@ -394,7 +408,7 @@ class QKVParallelLinear(ColumnParallelLinear):
         bias: bool = True,
         skip_bias_add: bool = False,
         params_dtype: Optional[torch.dtype] = None,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         self.hidden_size = hidden_size
         self.head_size = head_size
@@ -422,7 +436,7 @@ class QKVParallelLinear(ColumnParallelLinear):
         ]
 
         super().__init__(input_size, output_size, bias, False, skip_bias_add,
-                         params_dtype, linear_method, output_sizes)
+                         params_dtype, quant_config, output_sizes)
 
     def weight_loader(self,
                       param: Parameter,
@@ -515,7 +529,7 @@ class QKVParallelLinear(ColumnParallelLinear):
         param_data.copy_(loaded_weight)
 
 
-class RowParallelLinear(torch.nn.Module):
+class RowParallelLinear(LinearBase):
     """Linear layer with row parallelism.
 
     The linear layer is defined as Y = XA + b. A is parallelized along
@@ -538,7 +552,7 @@ class RowParallelLinear(torch.nn.Module):
                        bias can be fused with other element-wise operations.
                        We skip adding bias but instead return it.
         params_dtype: Data type for the parameters.
-        linear_method: (Maybe quantized) linear method.
+        quant_config: Quantization configure.
     """
 
     def __init__(
@@ -550,32 +564,24 @@ class RowParallelLinear(torch.nn.Module):
         skip_bias_add: bool = False,
         params_dtype: Optional[torch.dtype] = None,
         reduce_results: bool = True,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
-        super().__init__()
-        # Keep input parameters
-        self.input_size = input_size
-        self.output_size = output_size
+        super().__init__(input_size, output_size, skip_bias_add, params_dtype,
+                         quant_config)
+
         self.input_is_parallel = input_is_parallel
         self.reduce_results = reduce_results
-        if params_dtype is None:
-            params_dtype = torch.get_default_dtype()
-        self.params_dtype = params_dtype
 
         # Divide the weight matrix along the last dimension.
         self.tp_size = get_tensor_model_parallel_world_size()
         self.input_size_per_partition = divide(input_size, self.tp_size)
-        self.skip_bias_add = skip_bias_add
-        if linear_method is None:
-            linear_method = UnquantizedLinearMethod()
-        self.linear_method = linear_method
-        self.linear_method.create_weights(self,
-                                          self.input_size_per_partition,
-                                          [self.output_size],
-                                          self.input_size,
-                                          self.output_size,
-                                          self.params_dtype,
-                                          weight_loader=self.weight_loader)
+        self.quant_method.create_weights(self,
+                                         self.input_size_per_partition,
+                                         [self.output_size],
+                                         self.input_size,
+                                         self.output_size,
+                                         self.params_dtype,
+                                         weight_loader=self.weight_loader)
 
         if not reduce_results and (bias and not skip_bias_add):
             raise ValueError("When not reduce the results, adding bias to the "
@@ -614,8 +620,7 @@ class RowParallelLinear(torch.nn.Module):
             input_parallel = splitted_input[tp_rank].contiguous()
 
         # Matrix multiply.
-        output_parallel = self.linear_method.apply_weights(
-            self, input_parallel)
+        output_parallel = self.quant_method.apply(self, input_parallel)
         if self.reduce_results and self.tp_size > 1:
             output_ = tensor_model_parallel_all_reduce(output_parallel)
         else:

+ 20 - 21
aphrodite/modeling/model_loader/loader.py

@@ -1,12 +1,11 @@
 # ruff: noqa: SIM117
 import copy
+import gc
 import glob
 import os
 from abc import ABC, abstractmethod
-import gc
 from contextlib import nullcontext
-from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
-                    Type)
+from typing import (Any, Dict, Generator, List, Optional, Tuple, Type)
 
 import torch
 from torch import nn
@@ -25,22 +24,19 @@ from aphrodite.modeling.model_loader.weight_utils import (
     get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
     pt_weights_iterator, safetensors_weights_iterator)
 from aphrodite.modeling.models.llava import LlavaForConditionalGeneration
+from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.bitsandbytes import (BNBLinearMethod,
                                                  replace_quant_params)
 
-if TYPE_CHECKING:
-    from aphrodite.modeling.layers.linear import LinearMethodBase
-
 _VISION_MODEL_CLASSES = [
     LlavaForConditionalGeneration,
 ]
 
 
-def _get_linear_method(
+def _get_quantization_config(
         model_config: ModelConfig,
-        load_config: LoadConfig) -> Optional["LinearMethodBase"]:
-    """Get the (maybe quantized) linear method."""
-    linear_method = None
+        load_config: LoadConfig) -> Optional[QuantizationConfig]:
+    """Get the quantization config."""
     if model_config.quantization is not None:
         quant_config = get_quant_config(model_config, load_config)
         capability = torch.cuda.get_device_capability()
@@ -58,8 +54,8 @@ def _get_linear_method(
                 f"method {model_config.quantization}. Supported dtypes: "
                 f"{supported_dtypes}")
 
-        linear_method = quant_config.get_linear_method()
-    return linear_method
+        return quant_config
+    return None
 
 
 def _get_model_initialization_kwargs(
@@ -87,10 +83,10 @@ def _initialize_model(
         vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
     """Initialize a model with the given configurations."""
     model_class = get_model_architecture(model_config)[0]
-    linear_method = _get_linear_method(model_config, load_config)
+    quant_config = _get_quantization_config(model_config, load_config)
 
     return model_class(config=model_config.hf_config,
-                       linear_method=linear_method,
+                       quant_config=quant_config,
                        **_get_model_initialization_kwargs(
                            model_class, lora_config, vision_language_config))
 
@@ -221,7 +217,8 @@ class DefaultModelLoader(BaseModelLoader):
                    parallel_config: ParallelConfig,
                    scheduler_config: SchedulerConfig) -> nn.Module:
         with set_default_torch_dtype(model_config.dtype):
-            linear_method = _get_linear_method(model_config, self.load_config)
+            linear_method = _get_quantization_config(model_config,
+                                                     self.load_config)
 
             context = torch.device(device_config.device) if not (
                 isinstance(linear_method, BNBLinearMethod)
@@ -238,9 +235,11 @@ class DefaultModelLoader(BaseModelLoader):
                                                "fall_back_to_pt_during_load",
                                                True)), )
             for _, module in model.named_modules():
-                linear_method = getattr(module, "linear_method", None)
-                if linear_method is not None:
-                    linear_method.process_weights_after_loading(module)
+                quant_method = getattr(module, "quant_method", None)
+                if quant_method is not None:
+                    quant_method.process_weights_after_loading(module)
+                # FIXME: Remove this after Mixtral is updated
+                # to use quant_method.
                 if hasattr(module, "process_weights_after_loading"):
                     module.process_weights_after_loading()
 
@@ -336,11 +335,11 @@ class TensorizerLoader(BaseModelLoader):
         with set_default_torch_dtype(model_config.dtype):
             with torch.device(device_config.device):
                 model_class = get_model_architecture(model_config)[0]
-                linear_method = _get_linear_method(model_config,
-                                                   self.load_config)
+                quant_config = _get_quantization_config(
+                    model_config, self.load_config)
                 extra_kwargs = _get_model_initialization_kwargs(
                     model_class, lora_config, vision_language_config)
-                extra_kwargs["linear_method"] = linear_method
+                extra_kwargs["quant_config"] = quant_config
 
                 tensorizer_config = copy.copy(self.tensorizer_config)
                 tensorizer_config.model_class = model_class

+ 6 - 6
aphrodite/modeling/model_loader/tensorizer.py

@@ -13,9 +13,9 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.common.config import ModelConfig, ParallelConfig
-from aphrodite.modeling.layers.linear import LinearMethodBase
 from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
+from aphrodite.quantization.base_config import QuantizationConfig
 
 tensorizer_load_fail = None
 
@@ -251,7 +251,7 @@ class TensorizerAgent:
     """
 
     def __init__(self, tensorizer_config: TensorizerConfig,
-                 linear_method: LinearMethodBase, **extra_kwargs):
+                 quant_config: QuantizationConfig, **extra_kwargs):
         if tensorizer_load_fail is not None:
             raise ImportError(
                 "Tensorizer is not installed. Please install tensorizer "
@@ -263,10 +263,10 @@ class TensorizerAgent:
         self.tensorizer_args = (
             self.tensorizer_config._construct_tensorizer_args())
         self.extra_kwargs = extra_kwargs
-        if extra_kwargs.get("linear_method", None) is not None:
-            self.linear_method = extra_kwargs["linear_method"]
+        if extra_kwargs.get("quant_config", None) is not None:
+            self.quant_config = extra_kwargs["quant_config"]
         else:
-            self.linear_method = linear_method
+            self.quant_config = quant_config
         self.model = self._init_model()
 
     def _init_model(self):
@@ -275,7 +275,7 @@ class TensorizerAgent:
         with no_init_or_tensor():
             return self.tensorizer_config.model_class(
                 config=model_args,
-                linear_method=self.linear_method,
+                quant_config=self.quant_config,
                 **self.extra_kwargs)
 
     def _resize_lora_embeddings(self):

+ 6 - 6
aphrodite/modeling/models/__init__.py

@@ -40,7 +40,7 @@ _MODELS = {
     "MptForCausalLM": ("mpt", "MPTForCausalLM"),
     "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
     "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
-    "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
+    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
     "OPTForCausalLM": ("opt", "OPTForCausalLM"),
     "OrionForCausalLM": ("orion", "OrionForCausalLM"),
     "PhiForCausalLM": ("phi", "PhiForCausalLM"),
@@ -89,8 +89,8 @@ class ModelRegistry:
                     "ROCm for now.")
             if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                 logger.warning(
-                    f"Model architecture {model_arch} is partially supported "
-                    "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
+                    "Model architecture %s is partially supported by ROCm: %s",
+                    model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
 
         module_name, model_cls_name = _MODELS[model_arch]
         module = importlib.import_module(
@@ -105,9 +105,9 @@ class ModelRegistry:
     def register_model(model_arch: str, model_cls: Type[nn.Module]):
         if model_arch in _MODELS:
             logger.warning(
-                f"Model architecture {model_arch} is already registered, "
-                "and will be overwritten by the new model "
-                f"class {model_cls.__name__}.")
+                "Model architecture %s is already registered, and will be "
+                "overwritten by the new model class %s.", model_arch,
+                model_cls.__name__)
         global _OOT_MODELS
         _OOT_MODELS[model_arch] = model_cls
 

+ 23 - 22
aphrodite/modeling/models/baichuan.py

@@ -32,8 +32,7 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -43,6 +42,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -77,17 +77,17 @@ class BaiChuanMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
-                                           linear_method=linear_method)
+                                           quant_config=quant_config)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -110,7 +110,7 @@ class BaiChuanAttention(nn.Module):
         position_embedding: str,
         rope_theta: float = 10000,
         max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.hidden_size = hidden_size
@@ -132,13 +132,13 @@ class BaiChuanAttention(nn.Module):
             self.total_num_heads,
             self.total_num_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         # Create the alibi slopes and slice them.
         if self.postion_embedding == "ALIBI":
@@ -184,7 +184,7 @@ class BaiChuanDecoderLayer(nn.Module):
     def __init__(self,
                  config: PretrainedConfig,
                  position_embedding: str,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.hidden_size = config.hidden_size
         rope_theta = getattr(config, "rope_theta", 10000)
@@ -196,13 +196,13 @@ class BaiChuanDecoderLayer(nn.Module):
             position_embedding=position_embedding,
             rope_theta=rope_theta,
             max_position_embeddings=max_position_embeddings,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.mlp = BaiChuanMLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
@@ -243,7 +243,7 @@ class BaiChuanModel(nn.Module):
     def __init__(self,
                  config: PretrainedConfig,
                  position_embedding: str,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
         self.padding_idx = config.pad_token_id
@@ -254,7 +254,7 @@ class BaiChuanModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            BaiChuanDecoderLayer(config, position_embedding, linear_method)
+            BaiChuanDecoderLayer(config, position_embedding, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -303,13 +303,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
         self,
         config,
         position_embedding: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = BaiChuanModel(config, position_embedding, linear_method)
+        self.quant_config = quant_config
+        self.model = BaiChuanModel(config, position_embedding, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
@@ -354,7 +354,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
                 # Refer to:
                 # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
                 # Distinguish between Baichuan and Baichuan2 by checking the
-                # vocab size.
+                # vocab size. This is suggested by
+                # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
                 is_baichuan2 = self.config.vocab_size == 125696
                 if is_baichuan2:
                     loaded_weight = torch.nn.functional.normalize(
@@ -387,13 +388,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ):
         if config.hidden_size == 4096:  # baichuan2 7b
-            super().__init__(config, "ROPE", linear_method, lora_config)
+            super().__init__(config, "ROPE", quant_config, lora_config)
         else:  # baichuan 13b, baichuan2 13b
-            super().__init__(config, "ALIBI", linear_method, lora_config)
+            super().__init__(config, "ALIBI", quant_config, lora_config)
 
 
 class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
@@ -402,7 +403,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ):
-        super().__init__(config, "ROPE", linear_method, lora_config)
+        super().__init__(config, "ROPE", quant_config, lora_config)

+ 16 - 17
aphrodite/modeling/models/bloom.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 HuggingFace Inc. team and BigScience workshop.
 #
@@ -30,7 +29,6 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -39,6 +37,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@@ -71,7 +70,7 @@ class BloomAttention(nn.Module):
     def __init__(
         self,
         config: BloomConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -88,13 +87,13 @@ class BloomAttention(nn.Module):
             self.head_dim,
             self.total_num_heads,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.dense = RowParallelLinear(
             self.hidden_size,
             self.hidden_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         # Create the alibi slopes and slice them.
@@ -130,21 +129,21 @@ class BloomMLP(nn.Module):
     def __init__(
         self,
         config: BloomConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
         self.dense_h_to_4h = ColumnParallelLinear(
             hidden_size,
             4 * hidden_size,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
-        quant_config = getattr(linear_method, "quant_config", None)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
         self.dense_4h_to_h = RowParallelLinear(
             4 * hidden_size,
             hidden_size,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -159,17 +158,17 @@ class BloomBlock(nn.Module):
     def __init__(
         self,
         config: BloomConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
 
         self.input_layernorm = nn.LayerNorm(hidden_size,
                                             eps=config.layer_norm_epsilon)
-        self.self_attention = BloomAttention(config, linear_method)
+        self.self_attention = BloomAttention(config, quant_config)
         self.post_attention_layernorm = nn.LayerNorm(
             hidden_size, eps=config.layer_norm_epsilon)
-        self.mlp = BloomMLP(config, linear_method)
+        self.mlp = BloomMLP(config, quant_config)
         self.apply_residual_connection_post_layernorm = (
             config.apply_residual_connection_post_layernorm)
 
@@ -215,7 +214,7 @@ class BloomModel(nn.Module):
     def __init__(
         self,
         config: BloomConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.embed_dim = config.hidden_size
@@ -230,7 +229,7 @@ class BloomModel(nn.Module):
 
         # Transformer blocks
         self.h = nn.ModuleList([
-            BloomBlock(config, linear_method)
+            BloomBlock(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
 
@@ -263,12 +262,12 @@ class BloomForCausalLM(nn.Module):
     def __init__(
         self,
         config: BloomConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.transformer = BloomModel(config, linear_method)
+        self.quant_config = quant_config
+        self.transformer = BloomModel(config, quant_config)
         self.lm_head_weight = self.transformer.word_embeddings.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 18 - 18
aphrodite/modeling/models/chatglm.py

@@ -14,8 +14,7 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -25,6 +24,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.transformers_utils.configs import ChatGLMConfig
 
 
@@ -33,7 +33,7 @@ class GLMAttention(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -65,13 +65,13 @@ class GLMAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=config.add_bias_linear or config.add_qkv_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.dense = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             config.hidden_size,
             bias=config.add_bias_linear,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
@@ -123,7 +123,7 @@ class GLMMLP(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
 
@@ -134,7 +134,7 @@ class GLMMLP(nn.Module):
             config.hidden_size,
             [config.ffn_hidden_size] * 2,
             bias=config.add_bias_linear,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.activation_func = SiluAndMul()
@@ -144,7 +144,7 @@ class GLMMLP(nn.Module):
             config.ffn_hidden_size,
             config.hidden_size,
             bias=config.add_bias_linear,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
     def forward(self, hidden_states):
@@ -166,7 +166,7 @@ class GLMBlock(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.apply_residual_connection_post_layernorm = (
@@ -180,7 +180,7 @@ class GLMBlock(nn.Module):
                                                eps=config.layernorm_epsilon)
 
         # Self attention.
-        self.self_attention = GLMAttention(config, linear_method)
+        self.self_attention = GLMAttention(config, quant_config)
         self.hidden_dropout = config.hidden_dropout
 
         # Layernorm on the attention output
@@ -188,7 +188,7 @@ class GLMBlock(nn.Module):
             config.hidden_size, eps=config.layernorm_epsilon)
 
         # MLP
-        self.mlp = GLMMLP(config, linear_method)
+        self.mlp = GLMMLP(config, quant_config)
 
     def forward(
         self,
@@ -236,7 +236,7 @@ class GLMTransformer(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.post_layer_norm = config.post_layer_norm
@@ -246,7 +246,7 @@ class GLMTransformer(nn.Module):
 
         # Transformer layers.
         self.layers = nn.ModuleList(
-            [GLMBlock(config, linear_method) for i in range(self.num_layers)])
+            [GLMBlock(config, quant_config) for i in range(self.num_layers)])
 
         if self.post_layer_norm:
             layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
@@ -281,7 +281,7 @@ class ChatGLMModel(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
 
@@ -291,7 +291,7 @@ class ChatGLMModel(nn.Module):
         self.num_layers = config.num_layers
         self.multi_query_group_num = config.multi_query_group_num
         self.kv_channels = config.kv_channels
-        self.encoder = GLMTransformer(config, linear_method)
+        self.encoder = GLMTransformer(config, quant_config)
 
         self.output_layer = ParallelLMHead(config.padded_vocab_size,
                                            config.hidden_size)
@@ -333,13 +333,13 @@ class ChatGLMForCausalLM(nn.Module):
     def __init__(
         self,
         config: ChatGLMConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ):
         super().__init__()
         self.config: ChatGLMConfig = config
-        self.linear_method = linear_method
-        self.transformer = ChatGLMModel(config, linear_method)
+        self.quant_config = quant_config
+        self.transformer = ChatGLMModel(config, quant_config)
         self.lm_head_weight = self.transformer.output_layer.weight
         self.logits_processor = LogitsProcessor(config.padded_vocab_size)
         self.sampler = Sampler()

+ 18 - 18
aphrodite/modeling/models/commandr.py

@@ -33,8 +33,7 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -45,6 +44,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 @torch.compile
@@ -91,7 +91,7 @@ class CohereMLP(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -101,13 +101,13 @@ class CohereMLP(nn.Module):
             self.hidden_size,
             [self.intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.down_proj = RowParallelLinear(
             self.intermediate_size,
             self.hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.act_fn = SiluAndMul()
 
@@ -123,7 +123,7 @@ class CohereAttention(nn.Module):
     def __init__(
         self,
         config: CohereConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         tp_size = get_tensor_model_parallel_world_size()
@@ -158,13 +158,13 @@ class CohereAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             self.hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.rotary_emb = get_rope(
             self.head_dim,
@@ -218,13 +218,13 @@ class CohereDecoderLayer(nn.Module):
 
     def __init__(self,
                  config: CohereConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.hidden_size = config.hidden_size
 
-        self.self_attn = CohereAttention(config, linear_method=linear_method)
+        self.self_attn = CohereAttention(config, quant_config=quant_config)
 
-        self.mlp = CohereMLP(config, linear_method=linear_method)
+        self.mlp = CohereMLP(config, quant_config=quant_config)
         self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
                                          eps=config.layer_norm_eps)
 
@@ -257,7 +257,7 @@ class CohereModel(nn.Module):
     def __init__(
         self,
         config: CohereConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -265,7 +265,7 @@ class CohereModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size)
         self.layers = nn.ModuleList([
-            CohereDecoderLayer(config, linear_method=linear_method)
+            CohereDecoderLayer(config, quant_config=quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = LayerNorm(param_shape=(config.hidden_size),
@@ -298,14 +298,14 @@ class CohereForCausalLM(nn.Module):
     def __init__(
         self,
         config: CohereConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
+        self.quant_config = quant_config
         self.logits_processor = LogitsProcessor(config.vocab_size,
                                                 scale=config.logit_scale)
-        self.model = CohereModel(config, linear_method)
+        self.model = CohereModel(config, quant_config)
         self.sampler = Sampler()
 
     @torch.no_grad()
@@ -358,8 +358,8 @@ class CohereForCausalLM(nn.Module):
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # lm_head is not used in Aphrodite as it is tied with
-                # embed_token. To prevent errors, skip loading lm_head.weight.
+                # lm_head is not used in vllm as it is tied with embed_token.
+                # To prevent errors, skip loading lm_head.weight.
                 if "lm_head.weight" in name:
                     continue
                 # Skip loading extra bias for GPTQ models.

+ 17 - 17
aphrodite/modeling/models/dbrx.py

@@ -10,8 +10,7 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.fused_moe import fused_moe
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              QKVParallelLinear,
+from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -22,6 +21,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.transformers_utils.configs.dbrx import DbrxConfig
 
 
@@ -44,7 +44,7 @@ class DbrxRouter(nn.Module):
             self.num_total_experts,
             bias=False,
             params_dtype=params_dtype,
-            linear_method=None,
+            quant_config=None,
         )
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -63,7 +63,7 @@ class DbrxExperts(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         params_dtype: Optional[torch.dtype] = None,
     ):
         super().__init__()
@@ -165,7 +165,7 @@ class DbrxAttention(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.d_model = config.d_model
@@ -183,13 +183,13 @@ class DbrxAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.out_proj = RowParallelLinear(
             self.d_model,
             self.d_model,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.rotary_emb = get_rope(
             self.head_dim,
@@ -244,11 +244,11 @@ class DbrxFusedNormAttention(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.d_model = config.d_model
-        self.attn = DbrxAttention(config, linear_method)
+        self.attn = DbrxAttention(config, quant_config)
         self.norm_1 = nn.LayerNorm(self.d_model)
         self.norm_2 = nn.LayerNorm(self.d_model)
 
@@ -278,11 +278,11 @@ class DbrxBlock(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
-        self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
-        self.ffn = DbrxExperts(config, linear_method)
+        self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
+        self.ffn = DbrxExperts(config, quant_config)
 
     def forward(
         self,
@@ -307,7 +307,7 @@ class DbrxModel(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.wte = VocabParallelEmbedding(
@@ -315,7 +315,7 @@ class DbrxModel(nn.Module):
             config.d_model,
         )
         self.blocks = nn.ModuleList(
-            [DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
+            [DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
         self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
         for module in self.modules():
             if hasattr(module, "bias") and isinstance(module.bias,
@@ -348,13 +348,13 @@ class DbrxForCausalLM(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
+        self.quant_config = quant_config
         self.unpadded_vocab_size = config.vocab_size
-        self.transformer = DbrxModel(config, linear_method)
+        self.transformer = DbrxModel(config, quant_config)
         self.lm_head = ParallelLMHead(
             config.vocab_size,
             config.d_model,

+ 3 - 4
aphrodite/modeling/models/decilm.py

@@ -2,7 +2,6 @@
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
 # Copyright 2023 DeciAI Research Team. All rights reserved.
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -30,9 +29,9 @@ import torch
 from transformers import PretrainedConfig
 
 from aphrodite.common.config import LoRAConfig
-from aphrodite.modeling.layers.linear import LinearMethodBase
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.llama import LlamaForCausalLM
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class DeciLMForCausalLM(LlamaForCausalLM):
@@ -56,13 +55,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
     def __init__(
         self,
         config: Optional[PretrainedConfig] = None,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
         delattr(config, "num_key_value_heads_per_layer")
         super().__init__(config=config,
-                         linear_method=linear_method,
+                         quant_config=quant_config,
                          lora_config=lora_config)
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

+ 21 - 24
aphrodite/modeling/models/deepseek.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -36,8 +35,7 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
@@ -48,6 +46,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class DeepseekMLP(nn.Module):
@@ -57,18 +56,18 @@ class DeepseekMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         reduce_results: bool = True,
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
-                                           linear_method=linear_method,
+                                           quant_config=quant_config,
                                            reduce_results=reduce_results)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
@@ -87,7 +86,7 @@ class DeepseekMoE(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -104,7 +103,7 @@ class DeepseekMoE(nn.Module):
             DeepseekMLP(hidden_size=config.hidden_size,
                         intermediate_size=config.moe_intermediate_size,
                         hidden_act=config.hidden_act,
-                        linear_method=linear_method,
+                        quant_config=quant_config,
                         reduce_results=False)
             for idx in range(self.n_routed_experts)
         ])
@@ -113,7 +112,7 @@ class DeepseekMoE(nn.Module):
         self.gate = ReplicatedLinear(config.hidden_size,
                                      self.n_routed_experts,
                                      bias=False,
-                                     linear_method=None)
+                                     quant_config=None)
 
         if config.n_shared_experts is not None:
             intermediate_size = (config.moe_intermediate_size *
@@ -122,7 +121,7 @@ class DeepseekMoE(nn.Module):
                 hidden_size=config.hidden_size,
                 intermediate_size=intermediate_size,
                 hidden_act=config.hidden_act,
-                linear_method=linear_method,
+                quant_config=quant_config,
                 reduce_results=False,
             )
 
@@ -178,7 +177,7 @@ class DeepseekAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -209,14 +208,14 @@ class DeepseekAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.rotary_emb = get_rope(
@@ -252,7 +251,7 @@ class DeepseekDecoderLayer(nn.Module):
         self,
         config: PretrainedConfig,
         layer_idx: int,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -267,18 +266,18 @@ class DeepseekDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         if (config.n_routed_experts is not None
                 and layer_idx >= config.first_k_dense_replace
                 and layer_idx % config.moe_layer_freq == 0):
-            self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
+            self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
         else:
             self.mlp = DeepseekMLP(
                 hidden_size=config.hidden_size,
                 intermediate_size=config.intermediate_size,
                 hidden_act=config.hidden_act,
-                linear_method=linear_method,
+                quant_config=quant_config,
             )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
@@ -321,7 +320,7 @@ class DeepseekModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.padding_idx = config.pad_token_id
@@ -332,9 +331,7 @@ class DeepseekModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            DeepseekDecoderLayer(config,
-                                 layer_idx,
-                                 linear_method=linear_method)
+            DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
             for layer_idx in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -362,12 +359,12 @@ class DeepseekForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = DeepseekModel(config, linear_method)
+        self.quant_config = quant_config
+        self.model = DeepseekModel(config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 16 - 17
aphrodite/modeling/models/falcon.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2023 the Falcon authors and HuggingFace Inc. team.  All rights
 # reserved.
@@ -34,7 +33,6 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -44,6 +42,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.transformers_utils.configs import RWConfig
 
 FalconConfig = Union[HF_FalconConfig, RWConfig]
@@ -77,7 +76,7 @@ class FalconAttention(nn.Module):
     def __init__(
         self,
         config: FalconConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
 
@@ -116,7 +115,7 @@ class FalconAttention(nn.Module):
             self.total_num_kv_heads,
             bias=config.bias,
             skip_bias_add=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
@@ -130,7 +129,7 @@ class FalconAttention(nn.Module):
             self.hidden_size,
             bias=config.bias,
             skip_bias_add=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
             reduce_results=self.reduce_row_parallel_results)
 
         self.use_rotary = config.rotary
@@ -193,7 +192,7 @@ class FalconMLP(nn.Module):
     def __init__(
         self,
         config: FalconConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
@@ -202,8 +201,8 @@ class FalconMLP(nn.Module):
                                                   4 * hidden_size,
                                                   bias=config.bias,
                                                   skip_bias_add=True,
-                                                  linear_method=linear_method)
-        quant_config = getattr(linear_method, "quant_config", None)
+                                                  quant_config=quant_config)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
         self.reduce_row_parallel_results = not (config.new_decoder_architecture
                                                 or config.parallel_attn)
@@ -213,7 +212,7 @@ class FalconMLP(nn.Module):
             bias=config.bias,
             skip_bias_add=True,
             reduce_results=self.reduce_row_parallel_results,
-            linear_method=linear_method)
+            quant_config=quant_config)
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
@@ -230,13 +229,13 @@ class FalconDecoderLayer(nn.Module):
     def __init__(
         self,
         config: FalconConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
         self.num_heads = config.num_attention_heads
-        self.self_attention = FalconAttention(config, linear_method)
-        self.mlp = FalconMLP(config, linear_method)
+        self.self_attention = FalconAttention(config, quant_config)
+        self.mlp = FalconMLP(config, quant_config)
         self.config = config
 
         if config.new_decoder_architecture:
@@ -312,7 +311,7 @@ class FalconModel(nn.Module):
     def __init__(
         self,
         config: FalconConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -328,7 +327,7 @@ class FalconModel(nn.Module):
 
         # Transformer blocks
         self.h = nn.ModuleList([
-            FalconDecoderLayer(config, linear_method)
+            FalconDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
 
@@ -360,12 +359,12 @@ class FalconForCausalLM(nn.Module):
     def __init__(
         self,
         config: FalconConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.transformer = FalconModel(config, linear_method)
+        self.quant_config = quant_config
+        self.transformer = FalconModel(config, quant_config)
         self.lm_head_weight = self.transformer.word_embeddings.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 21 - 22
aphrodite/modeling/models/gemma.py

@@ -1,5 +1,4 @@
 # coding=utf-8
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright (c) Google Inc.
 #
@@ -29,8 +28,7 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -40,6 +38,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 @lru_cache(maxsize=None)
@@ -54,10 +53,10 @@ def _get_gemma_act_fn(
                 "in the config JSON file when it was initially released. "
                 "Changing the activation function to approximate GeLU "
                 "(`gelu_pytorch_tanh`). If you want to use the legacy "
-                f"`{hidden_act}`, edit the config JSON to set "
-                f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
+                "`%s`, edit the config JSON to set "
+                "`hidden_activation=%s` instead of `hidden_act`. "
                 "See https://github.com/huggingface/transformers/pull/29402 "
-                "for more details.")
+                "for more details.", hidden_act, hidden_act)
         return GeluAndMul(approximate="tanh")
     elif hidden_activation == "gelu_pytorch_tanh":
         return GeluAndMul(approximate="tanh")
@@ -76,17 +75,17 @@ class GemmaMLP(nn.Module):
         intermediate_size: int,
         hidden_act: Optional[str] = None,
         hidden_activation: Optional[str] = None,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
-                                           linear_method=linear_method)
+                                           quant_config=quant_config)
         self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
 
     def forward(self, x):
@@ -105,7 +104,7 @@ class GemmaAttention(nn.Module):
                  head_dim: int,
                  max_position_embeddings: int = 8192,
                  rope_theta: float = 10000,
-                 linear_method: Optional[LinearMethodBase] = None) -> None:
+                 quant_config: Optional[QuantizationConfig] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -134,13 +133,13 @@ class GemmaAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.rotary_emb = get_rope(
@@ -175,7 +174,7 @@ class GemmaDecoderLayer(nn.Module):
     def __init__(
         self,
         config: GemmaConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -186,14 +185,14 @@ class GemmaDecoderLayer(nn.Module):
             head_dim=config.head_dim,
             max_position_embeddings=config.max_position_embeddings,
             rope_theta=config.rope_theta,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.mlp = GemmaMLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
             hidden_activation=getattr(config, "hidden_activation", None),
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
@@ -234,7 +233,7 @@ class GemmaModel(nn.Module):
     def __init__(
         self,
         config: GemmaConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
@@ -244,7 +243,7 @@ class GemmaModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            GemmaDecoderLayer(config, linear_method)
+            GemmaDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -307,14 +306,14 @@ class GemmaForCausalLM(nn.Module):
     def __init__(
         self,
         config: GemmaConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         del lora_config  # Unused.
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = GemmaModel(config, linear_method)
+        self.quant_config = quant_config
+        self.model = GemmaModel(config, quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -368,8 +367,8 @@ class GemmaForCausalLM(nn.Module):
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # lm_head is not used in Aphrodite as it is tied with
-                # embed_token. To prevent errors, skip loading lm_head.weight.
+                # lm_head is not used in vllm as it is tied with embed_token.
+                # To prevent errors, skip loading lm_head.weight.
                 if "lm_head.weight" in name:
                     continue
                 # Skip loading extra bias for GPTQ models.

+ 16 - 17
aphrodite/modeling/models/gpt2.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
 # Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
@@ -29,7 +28,6 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -38,6 +36,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class GPT2Attention(nn.Module):
@@ -45,7 +44,7 @@ class GPT2Attention(nn.Module):
     def __init__(
         self,
         config: GPT2Config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -62,13 +61,13 @@ class GPT2Attention(nn.Module):
             self.head_dim,
             total_num_heads,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.c_proj = RowParallelLinear(
             self.hidden_size,
             self.hidden_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
 
@@ -91,7 +90,7 @@ class GPT2MLP(nn.Module):
         self,
         intermediate_size: int,
         config: GPT2Config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
@@ -99,15 +98,15 @@ class GPT2MLP(nn.Module):
             hidden_size,
             intermediate_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.c_proj = RowParallelLinear(
             intermediate_size,
             hidden_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
-        quant_config = getattr(linear_method, "quant_config", None)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.act = get_act_fn(config.activation_function, quant_config,
                               intermediate_size)
 
@@ -123,7 +122,7 @@ class GPT2Block(nn.Module):
     def __init__(
         self,
         config: GPT2Config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
@@ -131,9 +130,9 @@ class GPT2Block(nn.Module):
                      hidden_size)
 
         self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
-        self.attn = GPT2Attention(config, linear_method)
+        self.attn = GPT2Attention(config, quant_config)
         self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
-        self.mlp = GPT2MLP(inner_dim, config, linear_method)
+        self.mlp = GPT2MLP(inner_dim, config, quant_config)
 
     def forward(
         self,
@@ -164,7 +163,7 @@ class GPT2Model(nn.Module):
     def __init__(
         self,
         config: GPT2Config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -175,7 +174,7 @@ class GPT2Model(nn.Module):
         self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.h = nn.ModuleList([
-            GPT2Block(config, linear_method)
+            GPT2Block(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -204,12 +203,12 @@ class GPT2LMHeadModel(nn.Module):
     def __init__(
         self,
         config: GPT2Config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.transformer = GPT2Model(config, linear_method)
+        self.quant_config = quant_config
+        self.transformer = GPT2Model(config, quant_config)
         self.lm_head_weight = self.transformer.wte.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 16 - 17
aphrodite/modeling/models/gpt_bigcode.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2023 CTranslate2, and Michael Feil
 # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
@@ -30,7 +29,6 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -39,6 +37,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class GPTBigCodeAttention(nn.Module):
@@ -46,7 +45,7 @@ class GPTBigCodeAttention(nn.Module):
     def __init__(
         self,
         config: GPTBigCodeConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -73,14 +72,14 @@ class GPTBigCodeAttention(nn.Module):
             total_num_heads,
             total_num_kv_heads,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.c_proj = RowParallelLinear(
             self.hidden_size,
             self.hidden_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.attn = Attention(self.num_heads,
                               self.head_dim,
@@ -112,7 +111,7 @@ class GPTBigMLP(nn.Module):
         self,
         intermediate_size: int,
         config: GPTBigCodeConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
@@ -120,15 +119,15 @@ class GPTBigMLP(nn.Module):
             hidden_size,
             intermediate_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.c_proj = RowParallelLinear(
             intermediate_size,
             hidden_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
-        quant_config = getattr(linear_method, "quant_config", None)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.act = get_act_fn(config.activation_function, quant_config,
                               intermediate_size)
 
@@ -144,7 +143,7 @@ class GPTBigCodeBlock(nn.Module):
     def __init__(
         self,
         config: GPTBigCodeConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
@@ -152,9 +151,9 @@ class GPTBigCodeBlock(nn.Module):
                      hidden_size)
 
         self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
-        self.attn = GPTBigCodeAttention(config, linear_method)
+        self.attn = GPTBigCodeAttention(config, quant_config)
         self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
-        self.mlp = GPTBigMLP(inner_dim, config, linear_method)
+        self.mlp = GPTBigMLP(inner_dim, config, quant_config)
 
     def forward(
         self,
@@ -185,7 +184,7 @@ class GPTBigCodeModel(nn.Module):
     def __init__(
         self,
         config: GPTBigCodeConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -196,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
         self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.h = nn.ModuleList([
-            GPTBigCodeBlock(config, linear_method)
+            GPTBigCodeBlock(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -225,12 +224,12 @@ class GPTBigCodeForCausalLM(nn.Module):
     def __init__(
         self,
         config: GPTBigCodeConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.transformer = GPTBigCodeModel(config, linear_method)
+        self.quant_config = quant_config
+        self.transformer = GPTBigCodeModel(config, quant_config)
         self.lm_head_weight = self.transformer.wte.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 16 - 17
aphrodite/modeling/models/gpt_j.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
 #
@@ -28,7 +27,6 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -38,6 +36,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class GPTJAttention(nn.Module):
@@ -45,7 +44,7 @@ class GPTJAttention(nn.Module):
     def __init__(
         self,
         config: GPTJConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.total_num_heads = config.num_attention_heads
@@ -57,13 +56,13 @@ class GPTJAttention(nn.Module):
             self.head_size,
             self.total_num_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.out_proj = RowParallelLinear(
             config.hidden_size,
             config.hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         tp_world_size = get_tensor_model_parallel_world_size()
@@ -106,21 +105,21 @@ class GPTJMLP(nn.Module):
         self,
         intermediate_size: int,
         config: GPTJConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.n_embd
         self.fc_in = ColumnParallelLinear(
             hidden_size,
             intermediate_size,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.fc_out = RowParallelLinear(
             intermediate_size,
             hidden_size,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
-        quant_config = getattr(linear_method, "quant_config", None)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.act = get_act_fn(config.activation_function, quant_config,
                               intermediate_size)
 
@@ -136,14 +135,14 @@ class GPTJBlock(nn.Module):
     def __init__(
         self,
         config: GPTJConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         inner_dim = (4 * config.n_embd
                      if config.n_inner is None else config.n_inner)
         self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
-        self.attn = GPTJAttention(config, linear_method)
-        self.mlp = GPTJMLP(inner_dim, config, linear_method)
+        self.attn = GPTJAttention(config, quant_config)
+        self.mlp = GPTJMLP(inner_dim, config, quant_config)
 
     def forward(
         self,
@@ -170,7 +169,7 @@ class GPTJModel(nn.Module):
     def __init__(
         self,
         config: GPTJConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -180,7 +179,7 @@ class GPTJModel(nn.Module):
             self.embed_dim,
         )
         self.h = nn.ModuleList(
-            [GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
+            [GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
     def forward(
@@ -208,13 +207,13 @@ class GPTJForCausalLM(nn.Module):
     def __init__(
         self,
         config: GPTJConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
+        self.quant_config = quant_config
         assert not config.tie_word_embeddings
-        self.transformer = GPTJModel(config, linear_method)
+        self.transformer = GPTJModel(config, quant_config)
         self.lm_head = ParallelLMHead(
             config.vocab_size,
             config.n_embd,

+ 16 - 17
aphrodite/modeling/models/gpt_neox.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
 #
@@ -28,7 +27,6 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -38,6 +36,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class GPTNeoXAttention(nn.Module):
@@ -45,7 +44,7 @@ class GPTNeoXAttention(nn.Module):
     def __init__(
         self,
         config: GPTNeoXConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.total_num_heads = config.num_attention_heads
@@ -64,13 +63,13 @@ class GPTNeoXAttention(nn.Module):
             self.head_size,
             self.total_num_heads,
             bias=self.bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.dense = RowParallelLinear(
             config.hidden_size,
             config.hidden_size,
             bias=self.bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         scaling = self.head_size**-0.5
         rotary_dim = int(self.head_size * config.rotary_pct)
@@ -106,20 +105,20 @@ class GPTNeoXMLP(nn.Module):
     def __init__(
         self,
         config: GPTNeoXConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.dense_h_to_4h = ColumnParallelLinear(
             config.hidden_size,
             config.intermediate_size,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.dense_4h_to_h = RowParallelLinear(
             config.intermediate_size,
             config.hidden_size,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
-        quant_config = getattr(linear_method, "quant_config", None)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.act = get_act_fn(config.hidden_act, quant_config,
                               config.intermediate_size)
 
@@ -135,7 +134,7 @@ class GPTNeoXLayer(nn.Module):
     def __init__(
         self,
         config: GPTNeoXConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.use_parallel_residual = config.use_parallel_residual
@@ -143,8 +142,8 @@ class GPTNeoXLayer(nn.Module):
                                             eps=config.layer_norm_eps)
         self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                      eps=config.layer_norm_eps)
-        self.attention = GPTNeoXAttention(config, linear_method)
-        self.mlp = GPTNeoXMLP(config, linear_method)
+        self.attention = GPTNeoXAttention(config, quant_config)
+        self.mlp = GPTNeoXMLP(config, quant_config)
 
     def forward(
         self,
@@ -183,7 +182,7 @@ class GPTNeoXModel(nn.Module):
     def __init__(
         self,
         config: GPTNeoXConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -193,7 +192,7 @@ class GPTNeoXModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            GPTNeoXLayer(config, linear_method)
+            GPTNeoXLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.final_layer_norm = nn.LayerNorm(config.hidden_size,
@@ -224,12 +223,12 @@ class GPTNeoXForCausalLM(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.gpt_neox = GPTNeoXModel(config, linear_method)
+        self.quant_config = quant_config
+        self.gpt_neox = GPTNeoXModel(config, quant_config)
         self.embed_out = ParallelLMHead(
             config.vocab_size,
             config.hidden_size,

+ 16 - 16
aphrodite/modeling/models/internlm2.py

@@ -10,8 +10,7 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -21,6 +20,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class InternLM2MLP(nn.Module):
@@ -30,17 +30,17 @@ class InternLM2MLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.w2 = RowParallelLinear(intermediate_size,
                                     hidden_size,
                                     bias=False,
-                                    linear_method=linear_method)
+                                    quant_config=quant_config)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -63,7 +63,7 @@ class InternLM2Attention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -94,13 +94,13 @@ class InternLM2Attention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.wo = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.rotary_emb = get_rope(
@@ -135,7 +135,7 @@ class InternLMDecoderLayer(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -150,13 +150,13 @@ class InternLMDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.feed_forward = InternLM2MLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.attention_norm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
@@ -195,7 +195,7 @@ class InternLM2Model(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
@@ -206,7 +206,7 @@ class InternLM2Model(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            InternLMDecoderLayer(config, linear_method)
+            InternLMDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -238,12 +238,12 @@ class InternLM2ForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = InternLM2Model(config, linear_method)
+        self.quant_config = quant_config
+        self.model = InternLM2Model(config, quant_config)
         self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 16 - 17
aphrodite/modeling/models/jais.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2023 the Jais authors and HuggingFace Inc. team.  All rights
 # reserved.
@@ -31,7 +30,6 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -40,6 +38,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.transformers_utils.configs import JAISConfig
 
 
@@ -69,7 +68,7 @@ class JAISAttention(nn.Module):
     def __init__(
         self,
         config: JAISConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -89,13 +88,13 @@ class JAISAttention(nn.Module):
             self.head_dim,
             total_num_heads,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.c_proj = RowParallelLinear(
             self.hidden_size,
             self.hidden_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         tp_rank = get_tensor_model_parallel_rank()
@@ -129,7 +128,7 @@ class JAISMLP(nn.Module):
         self,
         intermediate_size: int,
         config: JAISConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
@@ -138,19 +137,19 @@ class JAISMLP(nn.Module):
             hidden_size,
             intermediate_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.c_fc2 = (ColumnParallelLinear(
             hidden_size,
             intermediate_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         ) if self.swiglu else None)
         self.c_proj = RowParallelLinear(
             intermediate_size,
             hidden_size,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.act = SwiGLUActivation()
@@ -170,7 +169,7 @@ class JAISBlock(nn.Module):
     def __init__(
         self,
         config: JAISConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
@@ -178,9 +177,9 @@ class JAISBlock(nn.Module):
                      hidden_size)
 
         self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
-        self.attn = JAISAttention(config, linear_method)
+        self.attn = JAISAttention(config, quant_config)
         self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
-        self.mlp = JAISMLP(inner_dim, config, linear_method)
+        self.mlp = JAISMLP(inner_dim, config, quant_config)
 
     def forward(
         self,
@@ -211,7 +210,7 @@ class JAISModel(nn.Module):
     def __init__(
         self,
         config: JAISConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -228,7 +227,7 @@ class JAISModel(nn.Module):
         else:
             self.embeddings_scale = config.mup_embeddings_scale
         self.h = nn.ModuleList([
-            JAISBlock(config, linear_method)
+            JAISBlock(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -262,12 +261,12 @@ class JAISLMHeadModel(nn.Module):
     def __init__(
         self,
         config: JAISConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.transformer = JAISModel(config, linear_method)
+        self.quant_config = quant_config
+        self.transformer = JAISModel(config, quant_config)
         self.lm_head_weight = self.transformer.wte.weight
         if hasattr(config, "width_scale"):
             self.output_logits_scale = config.width_scale

+ 15 - 17
aphrodite/modeling/models/llama.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -36,8 +35,7 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -48,6 +46,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.model_loader.weight_utils import (
     default_weight_loader, kv_cache_scales_loader)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class LlamaMLP(nn.Module):
@@ -57,17 +56,17 @@ class LlamaMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QKVParallelLinear] = None,
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
-                                           linear_method=linear_method)
+                                           quant_config=quant_config)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -90,7 +89,7 @@ class LlamaAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         bias: bool = False,
         sliding_window: Optional[int] = None,
     ) -> None:
@@ -132,13 +131,13 @@ class LlamaAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.rotary_emb = get_rope(
@@ -175,7 +174,7 @@ class LlamaDecoderLayer(nn.Module):
     def __init__(
         self,
         config: LlamaConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -200,7 +199,7 @@ class LlamaDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
-            linear_method=linear_method,
+            quant_config=quant_config,
             bias=attention_bias,
             sliding_window=sliding_window,
         )
@@ -208,7 +207,7 @@ class LlamaDecoderLayer(nn.Module):
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
@@ -249,7 +248,7 @@ class LlamaModel(nn.Module):
     def __init__(
         self,
         config: LlamaConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
@@ -265,7 +264,7 @@ class LlamaModel(nn.Module):
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
-            LlamaDecoderLayer(config, linear_method)
+            LlamaDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -330,13 +329,12 @@ class LlamaForCausalLM(nn.Module):
     def __init__(
         self,
         config: LlamaConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = LlamaModel(config, linear_method, lora_config=lora_config)
+        self.model = LlamaModel(config, quant_config, lora_config=lora_config)
         self.unpadded_vocab_size = config.vocab_size
         if lora_config:
             self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

+ 4 - 4
aphrodite/modeling/models/llava.py

@@ -10,13 +10,13 @@ from aphrodite.attention import AttentionMetadata
 from aphrodite.common.config import VisionLanguageConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.linear import LinearMethodBase
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.llama import LlamaModel
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 _KEYS_TO_MODIFY_MAPPING = {
     "language_model.lm_head": "lm_head",
@@ -61,7 +61,7 @@ class LlavaForConditionalGeneration(nn.Module):
     def __init__(self,
                  config: "LlavaConfig",
                  vision_language_config: VisionLanguageConfig,
-                 linear_method: Optional["LinearMethodBase"] = None) -> None:
+                 quant_config: Optional["QuantizationConfig"] = None) -> None:
         super().__init__()
         self.config = config
 
@@ -83,8 +83,8 @@ class LlavaForConditionalGeneration(nn.Module):
             text_hidden_size=config.text_config.hidden_size,
             projector_hidden_act=config.projector_hidden_act)
 
-        self.linear_method = linear_method
-        self.language_model = LlamaModel(config.text_config, linear_method)
+        self.quant_config = quant_config
+        self.language_model = LlamaModel(config.text_config, quant_config)
         self.unpadded_vocab_size = config.text_config.vocab_size
         self.lm_head = ParallelLMHead(
             self.unpadded_vocab_size,

+ 17 - 18
aphrodite/modeling/models/minicpm.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -37,8 +36,7 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
@@ -50,6 +48,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class MiniCPMMoE(nn.Module):
@@ -85,7 +84,7 @@ class MiniCPMMoE(nn.Module):
                                      self.num_total_experts,
                                      bias=False,
                                      params_dtype=self.params_dtype,
-                                     linear_method=None)
+                                     quant_config=None)
 
         self.ws = nn.Parameter(
             torch.empty(self.num_total_experts,
@@ -148,17 +147,17 @@ class MiniCPMMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
-                                           linear_method=linear_method)
+                                           quant_config=quant_config)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -181,7 +180,7 @@ class MiniCPMAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -212,13 +211,13 @@ class MiniCPMAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.rotary_emb = get_rope(
@@ -259,7 +258,7 @@ class MiniCPMDecoderLayer(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
@@ -275,7 +274,7 @@ class MiniCPMDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.num_experts = getattr(self.config, "num_experts", 0)
         if self.num_experts == 0:
@@ -283,7 +282,7 @@ class MiniCPMDecoderLayer(nn.Module):
                 hidden_size=self.hidden_size,
                 intermediate_size=config.intermediate_size,
                 hidden_act=config.hidden_act,
-                linear_method=linear_method,
+                quant_config=quant_config,
             )
         else:
             self.mlp = MiniCPMMoE(num_experts=config.num_experts,
@@ -330,7 +329,7 @@ class MiniCPMModel(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
@@ -346,7 +345,7 @@ class MiniCPMModel(nn.Module):
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
-            MiniCPMDecoderLayer(config, linear_method)
+            MiniCPMDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -413,15 +412,15 @@ class MiniCPMForCausalLM(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.num_experts = getattr(self.config, "num_experts", 0)
-        self.linear_method = linear_method
+        self.quant_config = quant_config
         self.model = MiniCPMModel(config,
-                                  linear_method,
+                                  quant_config,
                                   lora_config=lora_config)
         unpadded_vocab_size = config.vocab_size
         if lora_config:

+ 21 - 22
aphrodite/modeling/models/mixtral.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -28,6 +27,7 @@ import torch
 from torch import nn
 from transformers import MixtralConfig
 
+from aphrodite._C import ops
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.config import LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
@@ -37,12 +37,10 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              QKVParallelLinear,
+from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
-from aphrodite.quantization.fp8 import (Fp8LinearMethod, per_tensor_quantize)
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
@@ -50,6 +48,8 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.quantization.fp8 import Fp8Config
 
 
 class MixtralMoE(nn.Module):
@@ -69,7 +69,7 @@ class MixtralMoE(nn.Module):
         intermediate_size: int,
         params_dtype: Optional[torch.dtype] = None,
         tp_size: Optional[int] = None,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.tp_size = tp_size or get_tensor_model_parallel_world_size()
@@ -79,7 +79,7 @@ class MixtralMoE(nn.Module):
         self.intermediate_size = intermediate_size // self.tp_size
         # FIXME(pcmoritz): Make this more general to support different
         # quantization schemes
-        self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
+        self.use_fp8 = isinstance(quant_config, Fp8Config)
 
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
@@ -89,7 +89,7 @@ class MixtralMoE(nn.Module):
                                      self.num_total_experts,
                                      bias=False,
                                      params_dtype=self.params_dtype,
-                                     linear_method=None)
+                                     quant_config=None)
 
         self.ws = nn.Parameter(
             torch.empty(self.num_total_experts,
@@ -140,10 +140,10 @@ class MixtralMoE(nn.Module):
             ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
             w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
             for expert in range(self.num_total_experts):
-                ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
+                ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant(
                     self.ws.data[expert, :, :])
                 w2s[expert, :, :], self.w2s_scale[
-                    expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
+                    expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :])
             self.ws = nn.Parameter(ws, requires_grad=False)
             self.w2s = nn.Parameter(w2s, requires_grad=False)
 
@@ -178,7 +178,7 @@ class MixtralAttention(nn.Module):
                  num_kv_heads: int,
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
-                 linear_method: Optional[LinearMethodBase] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
                  sliding_window: Optional[int] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -203,12 +203,12 @@ class MixtralAttention(nn.Module):
         self.rope_theta = rope_theta
         self.sliding_window = sliding_window
 
-        if isinstance(linear_method, Fp8LinearMethod):
+        if isinstance(quant_config, Fp8Config):
             print_warning_once(
                 "For Mixtral FP8 quantization, we currently do not quantize "
                 "the attention layers until their FP8 performance is improved."
             )
-            linear_method = None
+            quant_config = None
 
         self.qkv_proj = QKVParallelLinear(
             hidden_size,
@@ -216,13 +216,13 @@ class MixtralAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.rotary_emb = get_rope(
             self.head_dim,
@@ -259,7 +259,7 @@ class MixtralDecoderLayer(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -272,13 +272,13 @@ class MixtralDecoderLayer(nn.Module):
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             sliding_window=config.sliding_window,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.block_sparse_moe = MixtralMoE(
             num_experts=config.num_local_experts,
             top_k=config.num_experts_per_tok,
             hidden_size=config.hidden_size,
             intermediate_size=config.intermediate_size,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
         self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -318,7 +318,7 @@ class MixtralModel(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
@@ -334,7 +334,7 @@ class MixtralModel(nn.Module):
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
-            MixtralDecoderLayer(config, linear_method=linear_method)
+            MixtralDecoderLayer(config, quant_config=quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -384,14 +384,13 @@ class MixtralForCausalLM(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
         self.model = MixtralModel(config,
-                                  linear_method,
+                                  quant_config,
                                   lora_config=lora_config)
         self.unpadded_vocab_size = config.vocab_size
         if lora_config:

+ 21 - 22
aphrodite/modeling/models/mixtral_quant.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -36,8 +35,7 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
                                    tensor_model_parallel_all_reduce)
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              QKVParallelLinear,
+from aphrodite.modeling.layers.linear import (QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -47,6 +45,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class MixtralMLP(nn.Module):
@@ -56,7 +55,7 @@ class MixtralMLP(nn.Module):
         num_experts: int,
         hidden_size: int,
         intermediate_size: int,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.num_experts = num_experts
@@ -66,17 +65,17 @@ class MixtralMLP(nn.Module):
         self.w1 = ReplicatedLinear(self.hidden_dim,
                                    self.ffn_dim,
                                    bias=False,
-                                   linear_method=linear_method)
+                                   quant_config=quant_config)
         self.w2 = ReplicatedLinear(self.ffn_dim,
                                    self.hidden_dim,
                                    bias=False,
-                                   linear_method=linear_method)
+                                   quant_config=quant_config)
         self.w3 = ReplicatedLinear(self.hidden_dim,
                                    self.ffn_dim,
                                    bias=False,
-                                   linear_method=linear_method)
+                                   quant_config=quant_config)
 
-        # TODO: Use Aphrodite's SiluAndMul
+        # TODO: Use vllm's SiluAndMul
         self.act_fn = nn.SiLU()
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -93,7 +92,7 @@ class MixtralMoE(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -116,14 +115,14 @@ class MixtralMoE(nn.Module):
             MixtralMLP(self.num_total_experts,
                        config.hidden_size,
                        config.intermediate_size,
-                       linear_method=linear_method)
+                       quant_config=quant_config)
             if idx in self.expert_indicies else None
             for idx in range(self.num_total_experts)
         ])
         self.gate = ReplicatedLinear(config.hidden_size,
                                      self.num_total_experts,
                                      bias=False,
-                                     linear_method=None)
+                                     quant_config=None)
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         num_tokens, hidden_dim = hidden_states.shape
@@ -163,7 +162,7 @@ class MixtralAttention(nn.Module):
                  num_kv_heads: int,
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
-                 linear_method: Optional[LinearMethodBase] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
                  sliding_window: Optional[int] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -194,13 +193,13 @@ class MixtralAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.rotary_emb = get_rope(
             self.head_dim,
@@ -237,7 +236,7 @@ class MixtralDecoderLayer(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -250,9 +249,9 @@ class MixtralDecoderLayer(nn.Module):
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             sliding_window=config.sliding_window,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.block_sparse_moe = MixtralMoE(config=config,
-                                           linear_method=linear_method)
+                                           quant_config=quant_config)
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
         self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -292,7 +291,7 @@ class MixtralModel(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.padding_idx = config.pad_token_id
@@ -303,7 +302,7 @@ class MixtralModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            MixtralDecoderLayer(config, linear_method=linear_method)
+            MixtralDecoderLayer(config, quant_config=quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -332,12 +331,12 @@ class MixtralForCausalLM(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = MixtralModel(config, linear_method)
+        self.quant_config = quant_config
+        self.model = MixtralModel(config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 16 - 16
aphrodite/modeling/models/mpt.py

@@ -12,7 +12,6 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -21,6 +20,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.transformers_utils.configs.mpt import MPTConfig
 
 
@@ -42,7 +42,7 @@ class MPTAttention(nn.Module):
     def __init__(
         self,
         config: MPTConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.d_model = config.d_model
@@ -65,7 +65,7 @@ class MPTAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=not config.no_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         if self.qk_ln:
             self.q_ln = nn.LayerNorm(self.d_model)
@@ -74,7 +74,7 @@ class MPTAttention(nn.Module):
             self.d_model,
             self.d_model,
             bias=not config.no_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         tp_world_size = get_tensor_model_parallel_world_size()
@@ -133,7 +133,7 @@ class MPTMLP(nn.Module):
     def __init__(
         self,
         config: MPTConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.d_model
@@ -143,15 +143,15 @@ class MPTMLP(nn.Module):
             hidden_size,
             intermediate_size,
             bias=not config.no_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
-        quant_config = getattr(linear_method, "quant_config", None)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.act = get_act_fn("gelu", quant_config, intermediate_size)
         self.down_proj = RowParallelLinear(
             intermediate_size,
             hidden_size,
             bias=not config.no_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -166,14 +166,14 @@ class MPTBlock(nn.Module):
     def __init__(
         self,
         config: MPTConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.d_model
         self.norm_1 = nn.LayerNorm(hidden_size)
-        self.attn = MPTAttention(config, linear_method)
+        self.attn = MPTAttention(config, quant_config)
         self.norm_2 = nn.LayerNorm(hidden_size)
-        self.ffn = MPTMLP(config, linear_method)
+        self.ffn = MPTMLP(config, quant_config)
 
     def forward(
         self,
@@ -201,7 +201,7 @@ class MPTModel(nn.Module):
     def __init__(
         self,
         config: MPTConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         assert config.embedding_fraction == 1.0
@@ -212,7 +212,7 @@ class MPTModel(nn.Module):
             config.d_model,
         )
         self.blocks = nn.ModuleList(
-            [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
+            [MPTBlock(config, quant_config) for _ in range(config.n_layers)])
         self.norm_f = nn.LayerNorm(config.d_model)
         if config.no_bias:
             for module in self.modules():
@@ -246,14 +246,14 @@ class MPTForCausalLM(nn.Module):
     def __init__(
         self,
         config: MPTConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         assert config.tie_word_embeddings
-        self.linear_method = linear_method
+        self.quant_config = quant_config
 
-        self.transformer = MPTModel(config, linear_method)
+        self.transformer = MPTModel(config, quant_config)
         self.lm_head_weight = self.transformer.wte.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 160 - 173
aphrodite/modeling/models/olmo.py

@@ -1,65 +1,47 @@
 # coding=utf-8
 # Adapted from
-# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and
-# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py
-# Copyright 2023 The PygmalionAI team.
-# Copyright 2023 The vLLM team.
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT license.
+# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
+# Copyright 2024 The vLLM team.
+# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
-# BSD 3-Clause License
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
 #
-# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
-# All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
 #
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are met:
+#     http://www.apache.org/licenses/LICENSE-2.0
 #
-# * Redistributions of source code must retain the above copyright notice, this
-#   list of conditions and the following disclaimer.
-#
-# * Redistributions in binary form must reproduce the above copyright notice,
-#   this list of conditions and the following disclaimer in the documentation
-#   and/or other materials provided with the distribution.
-#
-# * Neither the name of the copyright holder nor the names of its
-#   contributors may be used to endorse or promote products derived from
-#   this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
-# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
-# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
-# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
-# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
-# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
-# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
-# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
 """Inference-only OLMo model compatible with HuggingFace weights."""
 from typing import Iterable, List, Optional, Tuple
 
 import torch
-# this model must need this dependency
-from hf_olmo import OLMoConfig
 from torch import nn
+from transformers import OlmoConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import \
-    VocabParallelEmbedding
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class OlmoAttention(nn.Module):
@@ -71,56 +53,53 @@ class OlmoAttention(nn.Module):
 
     def __init__(
         self,
-        config: OLMoConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        config: OlmoConfig,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.hidden_size = config.d_model
-        assert config.d_model % config.n_heads == 0
+        self.hidden_size = config.hidden_size
         tensor_model_parallel_world_size = (
             get_tensor_model_parallel_world_size())
-        self.total_num_heads = self.config.n_heads
+        self.total_num_heads = config.num_attention_heads
+
+        assert self.hidden_size % self.total_num_heads == 0
         assert self.total_num_heads % tensor_model_parallel_world_size == 0
+
         self.num_heads = (self.total_num_heads //
                           tensor_model_parallel_world_size)
         self.head_dim = self.hidden_size // self.total_num_heads
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_theta = config.rope_theta
+        self.clip_qkv = config.clip_qkv
 
-        # Layer norms.
-        self.attn_norm = nn.LayerNorm(config.d_model,
-                                      elementwise_affine=False,
-                                      bias=False)
         # Attention input projection. Projects x -> (q, k, v)
-        self.att_proj = QKVParallelLinear(
-            config.d_model,
+        self.qkv_proj = QKVParallelLinear(
+            self.hidden_size,
             self.head_dim,
             self.total_num_heads,
-            bias=config.include_bias,
-            linear_method=linear_method,
+            bias=config.attention_bias,
+            quant_config=quant_config,
         )
 
         # Rotary embeddings.
-        if self.config.rope:
-            rope_theta = getattr(config, "rope_theta", 10000)
-            max_position_embeddings = getattr(config,
-                                              "max_position_embeddings", 8192)
-            self.rotary_emb = get_rope(
-                self.head_dim,
-                rotary_dim=self.head_dim,
-                max_position=max_position_embeddings,
-                base=rope_theta,
-            )
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=self.max_position_embeddings,
+            base=self.rope_theta,
+        )
         self.scaling = self.head_dim**-0.5
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               scale=self.scaling)
 
         # Attention output projection.
-        self.attn_out = RowParallelLinear(
-            config.d_model,
-            config.d_model,
-            bias=config.include_bias,
-            linear_method=linear_method,
+        self.o_proj = RowParallelLinear(
+            self.hidden_size,
+            self.hidden_size,
+            bias=config.attention_bias,
+            quant_config=quant_config,
         )
 
     def forward(
@@ -130,13 +109,13 @@ class OlmoAttention(nn.Module):
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        hidden_states = self.attn_norm(hidden_states)
-        qkv, _ = self.att_proj(hidden_states)
+        qkv, _ = self.qkv_proj(hidden_states)
+        if self.clip_qkv is not None:
+            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
         q, k, v = qkv.chunk(chunks=3, dim=-1)
-        if self.config.rope:
-            q, k = self.rotary_emb(positions, q, k)
+        q, k = self.rotary_emb(positions, q, k)
         attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
-        output, _ = self.attn_out(attn_output)
+        output, _ = self.o_proj(attn_output)
         return output
 
 
@@ -149,57 +128,44 @@ class OlmoMLP(nn.Module):
 
     def __init__(
         self,
-        config: OLMoConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        config: OlmoConfig,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size
-                            is not None else config.mlp_ratio * config.d_model)
-
-        # Layer norms.
-        self.ff_norm = nn.LayerNorm(config.d_model,
-                                    elementwise_affine=False,
-                                    bias=False)
+        self.hidden_size = config.hidden_size
+        self.intermediate_size = config.intermediate_size
 
         # Feed-forward input projection.
-        self.ff_proj = MergedColumnParallelLinear(
-            config.d_model,
-            [self.hidden_size // 2] * 2,
-            bias=config.include_bias,
-            linear_method=linear_method,
+        self.gate_up_proj = MergedColumnParallelLinear(
+            self.hidden_size,
+            [self.intermediate_size] * 2,
+            bias=False,
+            quant_config=quant_config,
         )
 
         # Activation function.
-        self.act = SiluAndMul()
-        self.act.output_multiplier = 0.5
-        assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
+        self.act_fn = SiluAndMul()
 
         # Feed-forward output projection.
-        self.ff_out = RowParallelLinear(
-            int(self.act.output_multiplier * self.hidden_size),
-            config.d_model,
-            bias=config.include_bias,
-            linear_method=linear_method,
+        self.down_proj = RowParallelLinear(
+            self.intermediate_size,
+            self.hidden_size,
+            bias=False,
+            quant_config=quant_config,
         )
 
     def forward(
         self,
         x: torch.Tensor,
     ) -> torch.Tensor:
-        # Add feed-forward projection.
-        # shape: (batch_size, seq_len, d_model)
-        og_x = x
-        x = self.ff_norm(x)
-        x, _ = self.ff_proj(x)
-        x = self.act(x)
-        x, _ = self.ff_out(x)
-        x = og_x + x
-
+        gate_up, _ = self.gate_up_proj(x)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
         return x
 
 
-class OlmoBlock(nn.Module):
+class OlmoDecoderLayer(nn.Module):
     """
     This is a typical transformer block where the output is
     computed as ``MLP(LN(x + Attention(LN(x))))``
@@ -207,14 +173,22 @@ class OlmoBlock(nn.Module):
     """
 
     def __init__(self,
-                 config: OLMoConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 config: OlmoConfig,
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         # Attention block.
-        self.attn = OlmoAttention(config, linear_method)
+        self.self_attn = OlmoAttention(config, quant_config)
 
         # MLP block.
-        self.mlp = OlmoMLP(config, linear_method)
+        self.mlp = OlmoMLP(config, quant_config)
+
+        # LayerNorm
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            elementwise_affine=False,
+                                            bias=False)
+        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
+                                                     elementwise_affine=False,
+                                                     bias=False)
 
     def forward(
         self,
@@ -224,52 +198,37 @@ class OlmoBlock(nn.Module):
         attn_metadata: AttentionMetadata,
     ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
         # Attention block.
-        og_x = hidden_states
-        x = self.attn(positions, hidden_states, kv_cache, attn_metadata)
-        x = x + og_x
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+        hidden_states = self.self_attn(positions, hidden_states, kv_cache,
+                                       attn_metadata)
+        hidden_states = hidden_states + residual
 
         # MLP block.
-        hidden_states = self.mlp(x)
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
         return hidden_states
 
 
 class OlmoModel(nn.Module):
 
     def __init__(self,
-                 config: OLMoConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 config: OlmoConfig,
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
 
-        self.transformer = nn.ModuleDict(
-            dict(
-                wte=VocabParallelEmbedding(
-                    config.embedding_size or config.vocab_size,
-                    config.d_model,
-                ),
-                ln_f=nn.LayerNorm(config.d_model,
-                                  elementwise_affine=False,
-                                  bias=False),
-            ))
-
-        blocks = [
-            OlmoBlock(config, linear_method) for i in range(config.n_layers)
-        ]
-        if self.config.block_group_size > 1:
-            raise NotImplementedError("Block group size > 1 not supported yet")
-        else:
-            self.transformer.update({"blocks": nn.ModuleList(blocks)})
-
-        if not config.weight_tying:
-            self.transformer.update({
-                "ff_out":
-                ColumnParallelLinear(
-                    config.d_model,
-                    config.embedding_size or config.vocab_size,
-                    bias=config.include_bias,
-                    linear_method=linear_method,
-                )
-            })
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size)
+        self.layers = nn.ModuleList([
+            OlmoDecoderLayer(config, quant_config)
+            for layer_idx in range(config.num_hidden_layers)
+        ])
+        self.norm = nn.LayerNorm(config.hidden_size,
+                                 elementwise_affine=False,
+                                 bias=False)
 
     def forward(
         self,
@@ -283,39 +242,48 @@ class OlmoModel(nn.Module):
         """
         # Get embeddings of input.
         # shape: (batch_size, seq_len, d_model)
-        x = self.transformer.wte(input_ids)  # type: ignore
+        inputs_embeds = self.embed_tokens(input_ids)
+
+        # embed positions
+        hidden_states = inputs_embeds
 
         # Apply blocks one-by-one.
-        for block_idx, block in enumerate(self.transformer.blocks):
+        for layer_idx, decoder_layer in enumerate(self.layers):
             # shape: (batch_size, seq_len, d_model)
-            x = block(
+            hidden_states = decoder_layer(
                 positions,
-                x,
-                kv_caches[block_idx],
+                hidden_states,
+                kv_caches[layer_idx],
                 attn_metadata,
             )
 
         # Apply final layer norm.
         # shape: (batch_size, seq_len or 1, d_model)
-        x = self.transformer.ln_f(x)  # type: ignore
-        return x
+        hidden_states = self.norm(hidden_states)
+        return hidden_states
 
 
-class OLMoForCausalLM(nn.Module):
+class OlmoForCausalLM(nn.Module):
     """
     Extremely barebones HF model wrapper.
     """
 
     def __init__(self,
-                 config: OLMoConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 config: OlmoConfig,
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = OlmoModel(config, linear_method)
-        self.lm_head_weight = (self.model.transformer.wte.weight
-                               if config.weight_tying else
-                               self.model.transformer.ff_out.weight)
+        self.model = OlmoModel(config, quant_config)
+        if config.tie_word_embeddings:
+            self.lm_head_weight = self.model.embed_tokens.weight
+        else:
+            self.unpadded_vocab_size = config.vocab_size
+            self.lm_head = ParallelLMHead(
+                self.unpadded_vocab_size,
+                config.hidden_size,
+                org_num_embeddings=config.vocab_size,
+            )
+            self.lm_head_weight = self.lm_head.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 
@@ -349,20 +317,39 @@ class OLMoForCausalLM(nn.Module):
         return next_tokens
 
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
         for name, loaded_weight in weights:
-            # attention
-            if ".att" in name:
-                name = name.replace(".att", ".attn.att")
-            # mlp
-            if ".ff_proj" in name:
-                name = name.replace(".ff_proj", ".mlp.ff_proj")
-                # Reverse the weight for the MergeColumnParallelLinear
-                loaded_weight = torch.concat(loaded_weight.chunk(2)[::-1])
-            if ".ff_out" in name and "transformer.ff_out" not in name:
-                name = name.replace(".ff_out", ".mlp.ff_out")
-            # there is no bias in olmo
-            param = params_dict[name]
-            weight_loader = getattr(param, "weight_loader",
-                                    default_weight_loader)
-            weight_loader(param, loaded_weight)
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if ("rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
+                continue
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 18 - 19
aphrodite/modeling/models/opt.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights
 # reserved.
@@ -29,7 +28,6 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
@@ -39,6 +37,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import \
     VocabParallelEmbedding
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class OPTLearnedPositionalEmbedding(nn.Embedding):
@@ -61,7 +60,7 @@ class OPTAttention(nn.Module):
         embed_dim: int,
         num_heads: int,
         bias: bool = True,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.embed_dim = embed_dim
@@ -78,13 +77,13 @@ class OPTAttention(nn.Module):
             self.head_dim,
             total_num_heads,
             bias=bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.out_proj = RowParallelLinear(
             embed_dim,
             embed_dim,
             bias=bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.attn = Attention(self.num_heads,
                               self.head_dim,
@@ -108,7 +107,7 @@ class OPTDecoderLayer(nn.Module):
     def __init__(
         self,
         config: OPTConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -117,7 +116,7 @@ class OPTDecoderLayer(nn.Module):
             embed_dim=self.embed_dim,
             num_heads=config.num_attention_heads,
             bias=config.enable_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.do_layer_norm_before = config.do_layer_norm_before
 
@@ -128,16 +127,16 @@ class OPTDecoderLayer(nn.Module):
             self.embed_dim,
             config.ffn_dim,
             bias=config.enable_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
-        quant_config = getattr(linear_method, "quant_config", None)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.activation_fn = get_act_fn(config.activation_function,
                                         quant_config, config.ffn_dim)
         self.fc2 = RowParallelLinear(
             config.ffn_dim,
             self.embed_dim,
             bias=config.enable_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.final_layer_norm = nn.LayerNorm(
             self.embed_dim,
@@ -182,7 +181,7 @@ class OPTDecoder(nn.Module):
     def __init__(
         self,
         config: OPTConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -203,7 +202,7 @@ class OPTDecoder(nn.Module):
             self.project_out = ReplicatedLinear(config.hidden_size,
                                                 config.word_embed_proj_dim,
                                                 bias=False,
-                                                linear_method=linear_method)
+                                                quant_config=quant_config)
         else:
             self.project_out = None
 
@@ -211,7 +210,7 @@ class OPTDecoder(nn.Module):
             self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
                                                config.hidden_size,
                                                bias=False,
-                                               linear_method=linear_method)
+                                               quant_config=quant_config)
         else:
             self.project_in = None
 
@@ -227,7 +226,7 @@ class OPTDecoder(nn.Module):
             self.final_layer_norm = None
 
         self.layers = nn.ModuleList([
-            OPTDecoderLayer(config, linear_method)
+            OPTDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
 
@@ -260,10 +259,10 @@ class OPTModel(nn.Module):
     def __init__(
         self,
         config: OPTConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
-        self.decoder = OPTDecoder(config, linear_method)
+        self.decoder = OPTDecoder(config, quant_config)
 
     def forward(
         self,
@@ -280,12 +279,12 @@ class OPTForCausalLM(nn.Module):
     def __init__(
         self,
         config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = OPTModel(config, linear_method)
+        self.quant_config = quant_config
+        self.model = OPTModel(config, quant_config)
         self.lm_head_weight = self.model.decoder.embed_tokens.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 16 - 16
aphrodite/modeling/models/orion.py

@@ -14,8 +14,7 @@ from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -25,6 +24,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class OrionMLP(nn.Module):
@@ -34,17 +34,17 @@ class OrionMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
-                                           linear_method=linear_method)
+                                           quant_config=quant_config)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -67,7 +67,7 @@ class OrionAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -98,13 +98,13 @@ class OrionAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.rotary_emb = get_rope(
@@ -139,7 +139,7 @@ class OrionDecoderLayer(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -154,13 +154,13 @@ class OrionDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.mlp = OrionMLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.input_layernorm = nn.LayerNorm(config.hidden_size,
@@ -201,7 +201,7 @@ class OrionModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
@@ -212,7 +212,7 @@ class OrionModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            OrionDecoderLayer(config, linear_method)
+            OrionDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -244,12 +244,12 @@ class OrionForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = OrionModel(config, linear_method)
+        self.quant_config = quant_config
+        self.model = OrionModel(config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 17 - 18
aphrodite/modeling/models/phi.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright (c) Microsoft Corporation.
 # Licensed under the MIT license.
@@ -47,7 +46,6 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -57,13 +55,14 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class PhiAttention(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.total_num_heads = config.num_attention_heads
         self.hidden_size = config.hidden_size
@@ -81,12 +80,12 @@ class PhiAttention(nn.Module):
             self.head_size,
             self.total_num_heads,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.dense = RowParallelLinear(
             self.hidden_size,
             self.hidden_size,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         scaling = self.head_size**-0.5
@@ -126,7 +125,7 @@ class PhiMLP(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
 
         n_inner = getattr(config, "n_inner", None)
@@ -135,14 +134,14 @@ class PhiMLP(nn.Module):
         self.fc1 = ColumnParallelLinear(
             config.hidden_size,
             n_inner,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.fc2 = RowParallelLinear(
             n_inner,
             config.hidden_size,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
-        quant_config = getattr(linear_method, "quant_config", None)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
 
     def forward(self, hidden_states):
@@ -156,12 +155,12 @@ class PhiLayer(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
-        self.self_attn = PhiAttention(config, linear_method)
-        self.mlp = PhiMLP(config, linear_method)
+        self.self_attn = PhiAttention(config, quant_config)
+        self.mlp = PhiMLP(config, quant_config)
 
     def forward(
         self,
@@ -187,14 +186,14 @@ class PhiModel(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
+        self.quant_config = quant_config
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size)
         self.layers = nn.ModuleList([
-            PhiLayer(config, linear_method)
+            PhiLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.final_layernorm = nn.LayerNorm(config.hidden_size,
@@ -226,12 +225,12 @@ class PhiForCausalLM(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
+        self.quant_config = quant_config
 
-        self.model = PhiModel(config, linear_method)
+        self.model = PhiModel(config, quant_config)
 
         self.lm_head = ParallelLMHead(config.vocab_size,
                                       config.hidden_size,

+ 16 - 16
aphrodite/modeling/models/qwen.py

@@ -15,8 +15,7 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -26,6 +25,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class QWenMLP(nn.Module):
@@ -35,17 +35,17 @@ class QWenMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str = "silu",
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.c_proj = RowParallelLinear(intermediate_size,
                                         hidden_size,
                                         bias=False,
-                                        linear_method=linear_method)
+                                        quant_config=quant_config)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -67,7 +67,7 @@ class QWenAttention(nn.Module):
         max_position_embeddings: int,
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.hidden_size = hidden_size
@@ -83,13 +83,13 @@ class QWenAttention(nn.Module):
             self.head_dim,
             self.total_num_heads,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.c_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.scaling = self.head_dim**-0.5
 
@@ -122,7 +122,7 @@ class QWenBlock(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -134,13 +134,13 @@ class QWenBlock(nn.Module):
                                   config.max_position_embeddings,
                                   rope_theta=rope_theta,
                                   rope_scaling=rope_scaling,
-                                  linear_method=linear_method)
+                                  quant_config=quant_config)
 
         self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
 
         self.mlp = QWenMLP(config.hidden_size,
                            config.intermediate_size // 2,
-                           linear_method=linear_method)
+                           quant_config=quant_config)
 
     def forward(
         self,
@@ -174,7 +174,7 @@ class QWenModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -185,7 +185,7 @@ class QWenModel(nn.Module):
             config.hidden_size,
         )
         self.h = nn.ModuleList([
-            QWenBlock(config, linear_method)
+            QWenBlock(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -217,12 +217,12 @@ class QWenLMHeadModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.transformer = QWenModel(config, linear_method)
+        self.quant_config = quant_config
+        self.transformer = QWenModel(config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 16 - 17
aphrodite/modeling/models/qwen2.py

@@ -2,7 +2,6 @@
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
 # Copyright 2024 The Qwen team.
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -35,8 +34,7 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -46,6 +44,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class Qwen2MLP(nn.Module):
@@ -55,17 +54,17 @@ class Qwen2MLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
-                                           linear_method=linear_method)
+                                           quant_config=quant_config)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -87,7 +86,7 @@ class Qwen2Attention(nn.Module):
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
                  use_sliding_window: bool = False,
-                 linear_method: Optional[LinearMethodBase] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
                  sliding_window: Optional[int] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -118,13 +117,13 @@ class Qwen2Attention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.rotary_emb = get_rope(
@@ -160,7 +159,7 @@ class Qwen2DecoderLayer(nn.Module):
         self,
         config: Qwen2Config,
         layer_idx: int,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -175,13 +174,13 @@ class Qwen2DecoderLayer(nn.Module):
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             use_sliding_window=use_sliding_window,
-            linear_method=linear_method,
+            quant_config=quant_config,
             sliding_window=config.sliding_window)
         self.mlp = Qwen2MLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
@@ -222,7 +221,7 @@ class Qwen2Model(nn.Module):
     def __init__(
         self,
         config: Qwen2Config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
@@ -234,7 +233,7 @@ class Qwen2Model(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            Qwen2DecoderLayer(config, layer_idx, linear_method)
+            Qwen2DecoderLayer(config, layer_idx, quant_config)
             for layer_idx in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -287,14 +286,14 @@ class Qwen2ForCausalLM(nn.Module):
     def __init__(
         self,
         config: Qwen2Config,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         del lora_config
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = Qwen2Model(config, linear_method)
+        self.quant_config = quant_config
+        self.model = Qwen2Model(config, quant_config)
 
         if config.tie_word_embeddings:
             self.lm_head_weight = self.model.embed_tokens.weight

+ 21 - 24
aphrodite/modeling/models/qwen2_moe.py

@@ -2,7 +2,6 @@
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
 # Copyright 2024 The Qwen team.
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -38,8 +37,7 @@ from aphrodite.distributed import (get_tensor_model_parallel_rank,
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
@@ -50,6 +48,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class Qwen2MoeMLP(nn.Module):
@@ -59,18 +58,18 @@ class Qwen2MoeMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         reduce_results: bool = True,
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
-                                           linear_method=linear_method,
+                                           quant_config=quant_config,
                                            reduce_results=reduce_results)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
@@ -89,7 +88,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
@@ -106,7 +105,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
             Qwen2MoeMLP(hidden_size=config.hidden_size,
                         intermediate_size=config.moe_intermediate_size,
                         hidden_act=config.hidden_act,
-                        linear_method=linear_method,
+                        quant_config=quant_config,
                         reduce_results=False)
             for idx in range(self.n_routed_experts)
         ])
@@ -115,13 +114,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
         self.gate = ReplicatedLinear(config.hidden_size,
                                      self.n_routed_experts,
                                      bias=False,
-                                     linear_method=None)
+                                     quant_config=None)
         if config.shared_expert_intermediate_size > 0:
             self.shared_expert = Qwen2MoeMLP(
                 hidden_size=config.hidden_size,
                 intermediate_size=config.shared_expert_intermediate_size,
                 hidden_act=config.hidden_act,
-                linear_method=linear_method,
+                quant_config=quant_config,
                 reduce_results=False,
             )
         else:
@@ -187,7 +186,7 @@ class Qwen2MoeAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -218,14 +217,14 @@ class Qwen2MoeAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=True,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.rotary_emb = get_rope(
@@ -261,7 +260,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
         self,
         config: PretrainedConfig,
         layer_idx: int,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -276,18 +275,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         if (config.num_experts is not None
                 and (layer_idx + 1) % config.decoder_sparse_step == 0):
             self.mlp = Qwen2MoeSparseMoeBlock(config=config,
-                                              linear_method=linear_method)
+                                              quant_config=quant_config)
         else:
             self.mlp = Qwen2MoeMLP(
                 hidden_size=config.hidden_size,
                 intermediate_size=config.intermediate_size,
                 hidden_act=config.hidden_act,
-                linear_method=linear_method,
+                quant_config=quant_config,
             )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
@@ -328,7 +327,7 @@ class Qwen2MoeModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.padding_idx = config.pad_token_id
@@ -339,9 +338,7 @@ class Qwen2MoeModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            Qwen2MoeDecoderLayer(config,
-                                 layer_idx,
-                                 linear_method=linear_method)
+            Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
             for layer_idx in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -371,12 +368,12 @@ class Qwen2MoeForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = Qwen2MoeModel(config, linear_method)
+        self.quant_config = quant_config
+        self.model = Qwen2MoeModel(config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 14 - 14
aphrodite/modeling/models/stablelm.py

@@ -29,8 +29,7 @@ from aphrodite.attention import Attention, AttentionMetadata
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -40,13 +39,14 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class StablelmMLP(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None) -> None:
+                 quant_config: Optional[QuantizationConfig] = None) -> None:
         super().__init__()
         self.config = config
         self.hidden_size = config.hidden_size
@@ -54,7 +54,7 @@ class StablelmMLP(nn.Module):
         self.gate_up_proj = MergedColumnParallelLinear(
             config.hidden_size, [config.intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(config.intermediate_size,
                                            config.hidden_size,
                                            bias=False)
@@ -71,7 +71,7 @@ class StablelmAttention(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None) -> None:
+                 quant_config: Optional[QuantizationConfig] = None) -> None:
         super().__init__()
         self.config = config
         self.hidden_size = config.hidden_size
@@ -109,11 +109,11 @@ class StablelmAttention(nn.Module):
                                           self.total_num_heads,
                                           self.total_num_key_value_heads,
                                           self.qkv_bias,
-                                          linear_method=linear_method)
+                                          quant_config=quant_config)
         self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                         self.hidden_size,
                                         bias=False,
-                                        linear_method=linear_method)
+                                        quant_config=quant_config)
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.rotary_ndims,
@@ -145,11 +145,11 @@ class StablelmDecoderLayer(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.self_attn = StablelmAttention(config)
-        self.mlp = StablelmMLP(config, linear_method)
+        self.mlp = StablelmMLP(config, quant_config)
         norm_eps = getattr(config, "norm_eps",
                            getattr(config, "layer_norm_eps", 1e-05))
         self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
@@ -187,14 +187,14 @@ class StableLMEpochModel(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None) -> None:
+                 quant_config: Optional[QuantizationConfig] = None) -> None:
         super().__init__()
         self.embed_tokens = VocabParallelEmbedding(
             config.vocab_size,
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            StablelmDecoderLayer(config, linear_method)
+            StablelmDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         norm_eps = getattr(config, "norm_eps",
@@ -226,12 +226,12 @@ class StablelmForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = StableLMEpochModel(config, linear_method)
+        self.quant_config = quant_config
+        self.model = StableLMEpochModel(config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 15 - 16
aphrodite/modeling/models/starcoder2.py

@@ -29,7 +29,6 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -39,13 +38,14 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class Starcoder2Attention(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
 
@@ -79,13 +79,13 @@ class Starcoder2Attention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=self.use_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             self.hidden_size,
             bias=self.use_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.rotary_emb = get_rope(
             self.head_dim,
@@ -121,21 +121,21 @@ class Starcoder2MLP(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.c_fc = ColumnParallelLinear(
             config.hidden_size,
             config.intermediate_size,
             bias=config.use_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.c_proj = RowParallelLinear(
             config.intermediate_size,
             config.hidden_size,
             bias=config.use_bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
-        quant_config = getattr(linear_method, "quant_config", None)
+        quant_config = getattr(quant_config, "quant_config", None)
         self.act = get_act_fn(config.hidden_act, quant_config,
                               config.intermediate_size)
 
@@ -150,12 +150,11 @@ class Starcoder2DecoderLayer(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.hidden_size = config.hidden_size
-        self.self_attn = Starcoder2Attention(config,
-                                             linear_method=linear_method)
-        self.mlp = Starcoder2MLP(config, linear_method=linear_method)
+        self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
+        self.mlp = Starcoder2MLP(config, quant_config=quant_config)
         self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                             eps=config.norm_epsilon)
         self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
@@ -192,7 +191,7 @@ class Starcoder2Model(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
         self.padding_idx = config.pad_token_id
@@ -202,7 +201,7 @@ class Starcoder2Model(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size)
         self.layers = nn.ModuleList([
-            Starcoder2DecoderLayer(config, linear_method=linear_method)
+            Starcoder2DecoderLayer(config, quant_config=quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
@@ -227,10 +226,10 @@ class Starcoder2ForCausalLM(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
-                 linear_method: Optional[LinearMethodBase] = None):
+                 quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
-        self.model = Starcoder2Model(config, linear_method=linear_method)
+        self.model = Starcoder2Model(config, quant_config=quant_config)
         self.vocab_size = config.vocab_size
         self.unpadded_vocab_size = config.vocab_size
         if config.tie_word_embeddings:

+ 16 - 16
aphrodite/modeling/models/xverse.py

@@ -32,8 +32,7 @@ from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              MergedColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -43,6 +42,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class XverseMLP(nn.Module):
@@ -52,17 +52,17 @@ class XverseMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.gate_up_proj = MergedColumnParallelLinear(
             hidden_size, [intermediate_size] * 2,
             bias=False,
-            linear_method=linear_method)
+            quant_config=quant_config)
         self.down_proj = RowParallelLinear(intermediate_size,
                                            hidden_size,
                                            bias=False,
-                                           linear_method=linear_method)
+                                           quant_config=quant_config)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -85,7 +85,7 @@ class XverseAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         bias: bool = False,
         sliding_window: Optional[int] = None,
     ) -> None:
@@ -112,13 +112,13 @@ class XverseAttention(nn.Module):
             self.total_num_heads,
             self.total_num_kv_heads,
             bias=bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=bias,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
 
         self.rotary_emb = get_rope(
@@ -154,7 +154,7 @@ class XverseDecoderLayer(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -171,7 +171,7 @@ class XverseDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
-            linear_method=linear_method,
+            quant_config=quant_config,
             bias=getattr(config, "bias", False),
             sliding_window=sliding_window,
         )
@@ -179,7 +179,7 @@ class XverseDecoderLayer(nn.Module):
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
-            linear_method=linear_method,
+            quant_config=quant_config,
         )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
@@ -220,7 +220,7 @@ class XverseModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
@@ -236,7 +236,7 @@ class XverseModel(nn.Module):
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
-            XverseDecoderLayer(config, linear_method)
+            XverseDecoderLayer(config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -294,13 +294,13 @@ class XverseForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
-        linear_method: Optional[LinearMethodBase] = None,
+        quant_config: Optional[QuantizationConfig] = None,
         lora_config=None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.linear_method = linear_method
-        self.model = XverseModel(config, linear_method)
+        self.quant_config = quant_config
+        self.model = XverseModel(config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 4 - 4
aphrodite/quantization/__init__.py

@@ -4,12 +4,11 @@ from loguru import logger
 
 from aphrodite.quantization.aqlm import AQLMConfig
 from aphrodite.quantization.awq import AWQConfig
-from aphrodite.quantization.base_config import \
-    QuantizationConfig
-from aphrodite.quantization.bitsandbytes import \
-    BitsandBytesConfig
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.quantization.bitsandbytes import BitsandBytesConfig
 from aphrodite.quantization.eetq import EETQConfig
 from aphrodite.quantization.exl2 import Exl2Config
+from aphrodite.quantization.fp8 import Fp8Config
 from aphrodite.quantization.gguf import GGUFConfig
 from aphrodite.quantization.gptq import GPTQConfig
 from aphrodite.quantization.marlin import MarlinConfig
@@ -30,6 +29,7 @@ QUANTIZATION_METHODS = {
     "bnb": BitsandBytesConfig,
     "eetq": EETQConfig,
     "exl2": Exl2Config,
+    "fp8": Fp8Config,
     "gguf": GGUFConfig,
     "gptq": GPTQConfig,
     "quip": QuipConfig,

+ 9 - 6
aphrodite/quantization/aqlm.py

@@ -9,9 +9,9 @@ import torch
 import torch.nn.functional as F
 from torch.nn.parameter import Parameter
 
-from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
-from aphrodite.quantization.base_config import \
-    QuantizationConfig
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 HAS_QUANTS = False
 with suppress(ImportError):
@@ -211,8 +211,11 @@ class AQLMConfig(QuantizationConfig):
         return cls(in_group_size, nbits_per_codebook, num_code_books,
                    out_group_size)
 
-    def get_linear_method(self) -> "AQLMLinearMethod":
-        return AQLMLinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return AQLMLinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return []
@@ -325,7 +328,7 @@ class AQLMLinearMethod(LinearMethodBase):
         layer.register_parameter("scales", scales)
         set_weight_attrs(scales, extra_weight_attrs)
 
-    def apply_weights(
+    def apply(
         self,
         layer: torch.nn.Module,
         x: torch.Tensor,

+ 12 - 9
aphrodite/quantization/awq.py

@@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional
 import torch
 from torch.nn.parameter import Parameter
 
-from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
-from aphrodite.quantization.base_config import \
-    QuantizationConfig
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 HAS_QUANTS = False
 with suppress(ImportError):
@@ -65,8 +65,11 @@ class AWQConfig(QuantizationConfig):
         zero_point = cls.get_from_keys(config, ["zero_point"])
         return cls(weight_bits, group_size, zero_point)
 
-    def get_linear_method(self) -> "AWQLinearMethod":
-        return AWQLinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return AWQLinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@@ -150,10 +153,10 @@ class AWQLinearMethod(LinearMethodBase):
         layer.register_parameter("scales", scales)
         set_weight_attrs(scales, extra_weight_attrs)
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         qweight = layer.qweight
         scales = layer.scales
         qzeros = layer.qzeros

+ 25 - 3
aphrodite/quantization/base_config.py

@@ -2,8 +2,30 @@ from abc import ABC, abstractmethod
 from typing import Any, Dict, List
 
 import torch
+from torch import nn
 
-from aphrodite.modeling.layers.linear import LinearMethodBase
+
+class QuantizeMethodBase(ABC):
+    """Base class for different quantized methods."""
+
+    @abstractmethod
+    def create_weights(self, layer: torch.nn.Module, *weight_args,
+                       **extra_weight_attrs):
+        """Create weights for a layer.
+        The weights will be set as attributes of the layer."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
+        """Apply the weights in layer to the input tensor.
+        Expects create_weights to have been called before on the layer."""
+        raise NotImplementedError
+
+    def process_weights_after_loading(self, layer: nn.Module) -> None:
+        """Process the weight after loading.
+        This can be used for example, to transpose weights for computation.
+        """
+        return
 
 
 class QuantizationConfig(ABC):
@@ -51,8 +73,8 @@ class QuantizationConfig(ABC):
                          "quantization config.")
 
     @abstractmethod
-    def get_linear_method(self) -> LinearMethodBase:
-        """Get the linear method to use for the quantized linear layer."""
+    def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
+        """Get the quantize method to use for the quantized layer."""
         raise NotImplementedError
 
     @abstractmethod

+ 16 - 12
aphrodite/quantization/bitsandbytes.py

@@ -1,14 +1,15 @@
+from contextlib import suppress
+from typing import Any, Dict, List, NamedTuple, Optional, TypeVar
+
 import torch
 from torch.nn.parameter import Parameter
-from typing import List, Dict, Any, Optional, TypeVar, NamedTuple
-from contextlib import suppress
 
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              set_weight_attrs)
-from aphrodite.quantization.base_config import (QuantizationConfig)
-from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear, LinearBase,
+                                              LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 HAS_QUANTS = False
 with suppress(ImportError):
@@ -87,8 +88,11 @@ class BitsandBytesConfig(QuantizationConfig):
             quant_mode = "weight_only"
         return cls(weight_bits, group_size, zero_point, from_float, quant_mode)
 
-    def get_linear_method(self) -> "BNBLinearMethod":
-        return BNBLinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["BNBLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return BNBLinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@@ -180,10 +184,10 @@ class BNBLinearMethod(LinearMethodBase):
             layer.register_parameter("weight", weight)
             set_weight_attrs(weight, extra_weight_attrs)
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         if self.quant_config.quant_mode == "weight_only":
             qweight = layer.qweight
             scales_zeros = layer.scales_zeros

+ 14 - 11
aphrodite/quantization/eetq.py

@@ -1,12 +1,12 @@
-from typing import Any, Dict, List, Optional
 from contextlib import suppress
+from typing import Any, Dict, List, Optional
 
 import torch
 from torch.nn.parameter import Parameter
 
-from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
-from aphrodite.quantization.base_config import \
-    QuantizationConfig
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 HAS_EETQ = False
 with suppress(ImportError):
@@ -58,8 +58,11 @@ class EETQConfig(QuantizationConfig):
         zero_point = cls.get_from_keys(config, ["zero_point"])
         return cls(weight_bits, zero_point)
 
-    def get_linear_method(self) -> "EETQLinearMethod":
-        return EETQLinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["EETQLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return EETQLinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return []
@@ -97,11 +100,11 @@ class EETQLinearMethod(LinearMethodBase):
         layer.register_parameter("weight_scales", weight_scales)
         set_weight_attrs(weight_scales, extra_weight_attrs)
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
-        qweight = layer.qweightdata
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+        qweight = layer.qweight.data
         weight_scales = layer.weight_scales.data
 
         if HAS_EETQ:

+ 13 - 10
aphrodite/quantization/exl2.py

@@ -1,11 +1,11 @@
-from typing import Any, Dict, List, Optional
 from contextlib import suppress
+from typing import Any, Dict, List, Optional
 
 import torch
 
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              set_weight_attrs)
-from aphrodite.quantization.base_config import (QuantizationConfig)
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 HAS_QUANTS = False
 with suppress(ImportError):
@@ -58,8 +58,11 @@ class Exl2Config(QuantizationConfig):
     def from_config(cls, config: Dict[str, Any]) -> "Exl2Config":
         return cls()
 
-    def get_linear_method(self) -> "Exl2LinearMethod":
-        return Exl2LinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["Exl2LinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return Exl2LinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return []
@@ -116,10 +119,10 @@ class Exl2LinearMethod(LinearMethodBase):
             set_weight_attrs(fake_weight, {"ignore_warning": True})
             layer.register_parameter(name, fake_weight)
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         out_shape = x.shape[:-1] + (layer.q_weight.shape[-1], )
         reshaped_x = x.reshape(-1, x.shape[-1])
 

+ 27 - 39
aphrodite/quantization/fp8.py

@@ -1,15 +1,22 @@
-from typing import Any, Dict, List, Optional, Tuple
+from contextlib import suppress
+from typing import Any, Dict, List, Optional
 
 import torch
 from torch.nn import Module
 from torch.nn.parameter import Parameter
 
-from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
-from aphrodite.quantization.base_config import \
-    QuantizationConfig
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import (QuantizationConfig,
+                                                QuantizeMethodBase)
 
+HAS_QUANTS = False
+with suppress(ImportError):
+    from aphrodite._quant_C import quant_ops as ops
+    HAS_QUANTS = True
 
-class FP8Config(QuantizationConfig):
+
+class Fp8Config(QuantizationConfig):
     """Config class for FP8."""
 
     @classmethod
@@ -29,11 +36,14 @@ class FP8Config(QuantizationConfig):
         return []
 
     @classmethod
-    def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
+    def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
         return cls()
 
-    def get_linear_method(self) -> "Fp8LinearMethod":
-        return Fp8LinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
+        if isinstance(layer, LinearBase):
+            return Fp8LinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return []
@@ -53,7 +63,9 @@ class Fp8LinearMethod(LinearMethodBase):
         quant_config: The quantization config.
     """
 
-    def __init__(self, quant_config: FP8Config):
+    def __init__(self, quant_config: Fp8Config):
+        if not HAS_QUANTS:
+            raise ImportError("Could not find the quantization kernels.")
         self.quant_config = quant_config
 
     def create_weights(
@@ -89,17 +101,17 @@ class Fp8LinearMethod(LinearMethodBase):
         if not hasattr(layer, "weight_scaling_factor"):
             return
 
-        qweight, weight_scale = per_tensor_quantize(layer.weight)
+        qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
         # torch._scaled_mm requires column-major in the second
         # input (weight), so we transpose the quantized weight.
         layer.weight = Parameter(qweight.t(), requires_grad=False)
         layer.weight_scaling_factor.data.copy_(weight_scale)
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
-        qinput, x_scale = per_tensor_quantize(x)
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+        qinput, x_scale = ops.scaled_fp8_quant(x)
         output, _ = torch._scaled_mm(
             qinput,
             layer.weight,
@@ -109,27 +121,3 @@ class Fp8LinearMethod(LinearMethodBase):
             bias=bias,
         )
         return output
-
-
-def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
-    """Quantize a tensor using per-tensor static scaling factor.
-
-    Args:
-        tensor: The input tensor.
-    """
-    finfo = torch.finfo(torch.float8_e4m3fn)
-    # Calculate the scale as dtype max divided by absmax.
-    # Since .abs() creates a new tensor, we use aminmax to get
-    # the min and max first and then calculate the absmax.
-    min_val, max_val = tensor.aminmax()
-    amax = min_val.abs().max(max_val.abs())
-    scale = finfo.max / amax.clamp(min=1e-12)
-    # scale and clamp the tensor to bring it to
-    # the representative range of float8 data type
-    # (as default cast is unsaturated)
-    qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
-    # Return both float8 data and the inverse scale (as float),
-    # as both required as inputs to torch._scaled_mm
-    qweight = qweight.to(torch.float8_e4m3fn)
-    scale = scale.float().reciprocal()
-    return qweight, scale

+ 13 - 10
aphrodite/quantization/gguf.py

@@ -1,12 +1,12 @@
-from typing import Any, Dict, List, Optional
 from contextlib import suppress
+from typing import Any, Dict, List, Optional
 
 import torch
 from torch.nn.parameter import Parameter
 
-from aphrodite.modeling.layers.linear import (LinearMethodBase,
-                                              set_weight_attrs)
-from aphrodite.quantization.base_config import (QuantizationConfig)
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 HAS_QUANTS = False
 with suppress(ImportError):
@@ -62,8 +62,11 @@ class GGUFConfig(QuantizationConfig):
     def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
         return cls()
 
-    def get_linear_method(self) -> "GGUFLinearMethod":
-        return GGUFLinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["GGUFLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return GGUFLinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return []
@@ -114,10 +117,10 @@ class GGUFLinearMethod(LinearMethodBase):
         set_weight_attrs(weight_type, {"ignore_warning": True})
         layer.register_parameter("weight_type", weight_type)
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         if isinstance(layer.weight_type, torch.Tensor):
             layer.weight_type = int(layer.weight_type)
             # Check tensor parallel shape here on first pass

+ 12 - 9
aphrodite/quantization/gptq.py

@@ -7,9 +7,9 @@ import torch
 from torch.nn.parameter import Parameter
 
 from aphrodite._quant_C import quant_ops as ops
-from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
-from aphrodite.quantization.base_config import \
-    QuantizationConfig
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class GPTQConfig(QuantizationConfig):
@@ -62,8 +62,11 @@ class GPTQConfig(QuantizationConfig):
         desc_act = cls.get_from_keys(config, ["desc_act"])
         return cls(weight_bits, group_size, desc_act)
 
-    def get_linear_method(self) -> "GPTQLinearMethod":
-        return GPTQLinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return GPTQLinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return []
@@ -193,10 +196,10 @@ class GPTQLinearMethod(LinearMethodBase):
 
         layer.exllama_state = exllama_state
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         qweight = layer.qweight
         out_shape = x.shape[:-1] + (qweight.shape[-1], )
         reshaped_x = x.reshape(-1, x.shape[-1])

+ 9 - 6
aphrodite/quantization/marlin.py

@@ -4,9 +4,9 @@ import torch
 from torch.nn.parameter import Parameter
 
 from aphrodite._quant_C import quant_ops as ops
-from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
-from aphrodite.quantization.base_config import \
-    QuantizationConfig
+from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class MarlinConfig(QuantizationConfig):
@@ -71,8 +71,11 @@ class MarlinConfig(QuantizationConfig):
         group_size = cls.get_from_keys(config, ["group_size"])
         return cls(group_size)
 
-    def get_linear_method(self) -> "MarlinLinearMethod":
-        return MarlinLinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return MarlinLinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return []
@@ -196,7 +199,7 @@ class MarlinLinearMethod(LinearMethodBase):
         layer.register_parameter("workspace", workspace)
         set_weight_attrs(workspace, extra_weight_attrs)
 
-    def apply_weights(
+    def apply(
         self,
         layer: torch.nn.Module,
         x: torch.Tensor,

+ 11 - 7
aphrodite/quantization/quip.py

@@ -4,11 +4,12 @@ from typing import Any, Dict, List, Optional
 import torch
 from torch.nn.parameter import Parameter
 
-from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
+from aphrodite.modeling.layers.linear import LinearMethodBase, LinearBase
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.quip_utils import (get_hadK, get_packed_abs_grid,
                                                matmul_hadU_cuda,
                                                matmul_hadUt_cuda)
+from aphrodite.modeling.utils import set_weight_attrs
 
 HAS_QUANTS = False
 with suppress(ImportError):
@@ -58,8 +59,11 @@ class QuipConfig(QuantizationConfig):
         use_rand = cls.get_from_keys(config, ["use_rand"])
         return cls(codebook, use_rand)
 
-    def get_linear_method(self) -> "QuipLinearMethod":
-        return QuipLinearMethod(self)
+    def get_quant_method(
+            self, layer: torch.nn.Module) -> Optional["QuipLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return QuipLinearMethod(self)
+        return None
 
     def get_scaled_act_names(self) -> List[str]:
         return []
@@ -155,10 +159,10 @@ class QuipLinearMethod(LinearMethodBase):
             ))
         set_weight_attrs(layer.SV, extra_weight_attrs)
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         # First run
         if isinstance(layer.Wscale, torch.Tensor):
             layer.Wscale = layer.Wscale.item()

+ 15 - 10
aphrodite/quantization/squeezellm.py

@@ -5,9 +5,10 @@ from torch.nn.parameter import Parameter
 
 from aphrodite._quant_C import quant_ops as ops
 from aphrodite.common.utils import is_hip
-from aphrodite.modeling.layers.linear import LinearMethodBase, set_weight_attrs
-from aphrodite.quantization.base_config import \
-    QuantizationConfig
+from aphrodite.modeling.layers.linear import LinearBase
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import (QuantizationConfig,
+                                                QuantizeMethodBase)
 
 
 class SqueezeLLMConfig(QuantizationConfig):
@@ -50,14 +51,18 @@ class SqueezeLLMConfig(QuantizationConfig):
         weight_bits = cls.get_from_keys(config, ["wbits"])
         return cls(weight_bits)
 
-    def get_linear_method(self) -> "SqueezeLLMLinearMethod":
-        return SqueezeLLMLinearMethod(self)
+    def get_quant_method(
+            self,
+            layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
+        if isinstance(layer, LinearBase):
+            return SqueezeLLMLinearMethod(self)
+        return
 
     def get_scaled_act_names(self) -> List[str]:
         return []
 
 
-class SqueezeLLMLinearMethod(LinearMethodBase):
+class SqueezeLLMLinearMethod(QuantizeMethodBase):
     """Linear method for SqueezeLLM.
 
     Args:
@@ -111,10 +116,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
         layer.register_parameter("lookup_table", lookup_table)
         set_weight_attrs(lookup_table, extra_weight_attrs)
 
-    def apply_weights(self,
-                      layer: torch.nn.Module,
-                      x: torch.Tensor,
-                      bias: Optional[torch.Tensor] = None) -> torch.Tensor:
+    def apply(self,
+              layer: torch.nn.Module,
+              x: torch.Tensor,
+              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
         qweight = layer.qweight
         lookup_table = layer.lookup_table
         out_shape = x.shape[:-1] + (qweight.shape[-1], )