Browse Source

add awq support

AlpinDale 1 year ago
parent
commit
d9c1d4f6e5

+ 16 - 1
aphrodite/common/config.py

@@ -43,6 +43,8 @@ class ModelConfig:
             version.
             version.
         max_model_len: Maximum length of a sequence (including prompt and output).
         max_model_len: Maximum length of a sequence (including prompt and output).
             If None, will be derived from the model.
             If None, will be derived from the model.
+        quantization: Quantization method used to quantize the model weights. if 
+            none, we assume the model weights are not quantized.
     """
     """
 
 
     def __init__(
     def __init__(
@@ -57,6 +59,7 @@ class ModelConfig:
         seed: int,
         seed: int,
         revision: Optional[str],
         revision: Optional[str],
         max_model_len: Optional[int] = None,
         max_model_len: Optional[int] = None,
+        quantization: Optional[str] = None,
     ) -> None:
     ) -> None:
         self.model = model
         self.model = model
         self.tokenizer = tokenizer
         self.tokenizer = tokenizer
@@ -66,11 +69,13 @@ class ModelConfig:
         self.load_format = load_format
         self.load_format = load_format
         self.seed = seed
         self.seed = seed
         self.revision = revision
         self.revision = revision
+        self.quantization = quantization
 
 
         self.hf_config = get_config(model, trust_remote_code, revision)
         self.hf_config = get_config(model, trust_remote_code, revision)
         self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
         self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
         self._verify_load_format()
         self._verify_load_format()
         self._verify_tokenizer_mode()
         self._verify_tokenizer_mode()
+        self._verify_quantization_mode()
         self.max_model_len = None
         self.max_model_len = None
         if max_model_len is not None:
         if max_model_len is not None:
             derived_max_model_len = self.get_max_model_len()
             derived_max_model_len = self.get_max_model_len()
@@ -99,6 +104,16 @@ class ModelConfig:
                 "either 'auto' or 'slow'.")
                 "either 'auto' or 'slow'.")
         self.tokenizer_mode = tokenizer_mode
         self.tokenizer_mode = tokenizer_mode
 
 
+    def _verify_quantization(self) -> None:
+        supported_quantization = ["awq"]
+        if self.quantization is None:
+            return
+        quantization = self.quantization.lower()
+        if quantization not in supported_quantization:
+            raise ValueError(
+                f"Unknown quantization: {self.quantization}. Must be one of the following: {supported_quantization}.")
+        self.quantization = quantization
+
     def verify_with_parallel_config(
     def verify_with_parallel_config(
         self,
         self,
         parallel_config: "ParallelConfig",
         parallel_config: "ParallelConfig",
@@ -319,4 +334,4 @@ def _get_and_verify_dtype(
                 "Bfloat16 is only supported on GPUs with compute capability "
                 "Bfloat16 is only supported on GPUs with compute capability "
                 f"of at least 8.0. Your {gpu_name} GPU has compute capability "
                 f"of at least 8.0. Your {gpu_name} GPU has compute capability "
                 f"{compute_capability[0]}.{compute_capability[1]}.")
                 f"{compute_capability[0]}.{compute_capability[1]}.")
-    return torch_dtype
+    return torch_dtype

+ 2 - 1
aphrodite/engine/aphrodite_engine.py

@@ -80,6 +80,7 @@ class AphroditeEngine:
             f"download_dir={model_config.download_dir!r}, "
             f"download_dir={model_config.download_dir!r}, "
             f"load_format={model_config.load_format}, "
             f"load_format={model_config.load_format}, "
             f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
             f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
+            f"quantization={model_config.quantization}, "
             f"seed={model_config.seed})")
             f"seed={model_config.seed})")
         # TODO: Print more configs in debug mode.
         # TODO: Print more configs in debug mode.
 
 
@@ -705,4 +706,4 @@ class AphroditeEngine:
         output = all_outputs[0]
         output = all_outputs[0]
         for other_output in all_outputs[1:]:
         for other_output in all_outputs[1:]:
             assert output == other_output
             assert output == other_output
-        return output
+        return output

+ 8 - 2
aphrodite/engine/args_tools.py

@@ -5,6 +5,7 @@ from typing import Optional, Tuple
 
 
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
                          SchedulerConfig)
                          SchedulerConfig)
+from torch import quantization
 
 
 
 
 @dataclass
 @dataclass
@@ -29,6 +30,7 @@ class EngineArgs:
     max_num_seqs: int = 256
     max_num_seqs: int = 256
     disable_log_stats: bool = False
     disable_log_stats: bool = False
     revision: Optional[str] = None
     revision: Optional[str] = None
+    quantization = Optional[str] = None
 
 
     def __post_init__(self):
     def __post_init__(self):
         if self.tokenizer is None:
         if self.tokenizer is None:
@@ -150,6 +152,10 @@ class EngineArgs:
         parser.add_argument('--disable-log-stats',
         parser.add_argument('--disable-log-stats',
                             action='store_true',
                             action='store_true',
                             help='disable logging statistics')
                             help='disable logging statistics')
+        parser.add_argument('--quantization', '-q',
+                            type=str, choices=["awq", None],
+                            default=None,
+                            help="Method used for quantization.")
         return parser
         return parser
 
 
     @classmethod
     @classmethod
@@ -168,7 +174,7 @@ class EngineArgs:
                                    self.tokenizer_mode, self.trust_remote_code,
                                    self.tokenizer_mode, self.trust_remote_code,
                                    self.download_dir, self.load_format,
                                    self.download_dir, self.load_format,
                                    self.dtype, self.seed, self.revision,
                                    self.dtype, self.seed, self.revision,
-                                   self.max_model_len)
+                                   self.max_model_len, self.quantization)
         cache_config = CacheConfig(self.block_size,
         cache_config = CacheConfig(self.block_size,
                                    self.gpu_memory_utilization,
                                    self.gpu_memory_utilization,
                                    self.swap_space)
                                    self.swap_space)
@@ -205,4 +211,4 @@ class AsyncEngineArgs(EngineArgs):
                             help='max number of prompt characters or prompt '
                             help='max number of prompt characters or prompt '
                             'ID numbers being printed in the long. '
                             'ID numbers being printed in the long. '
                             'Default: unlimited.')
                             'Default: unlimited.')
-        return parser
+        return parser

+ 57 - 8
aphrodite/modeling/hf_downloader.py

@@ -13,6 +13,9 @@ import torch
 from tqdm.auto import tqdm
 from tqdm.auto import tqdm
 
 
 from aphrodite.common.logger import init_logger
 from aphrodite.common.logger import init_logger
+from aphrodite.modeling.quantization_utils import get_quant_class
+from aphrodite.modeling.quantization_utils import QuantizationConfig
+
 
 
 logger = init_logger(__name__)
 logger = init_logger(__name__)
 
 
@@ -44,7 +47,7 @@ def _shared_pointers(tensors):
 def convert_bin_to_safetensor_file(
 def convert_bin_to_safetensor_file(
     pt_filename: str,
     pt_filename: str,
     sf_filename: str,
     sf_filename: str,
-):
+) -> None:
     loaded = torch.load(pt_filename, map_location="cpu")
     loaded = torch.load(pt_filename, map_location="cpu")
     if "state_dict" in loaded:
     if "state_dict" in loaded:
         loaded = loaded["state_dict"]
         loaded = loaded["state_dict"]
@@ -78,15 +81,55 @@ def convert_bin_to_safetensor_file(
             raise RuntimeError(f"The output tensors do not match for key {k}")
             raise RuntimeError(f"The output tensors do not match for key {k}")
 
 
 
 
+# TODO: Move this to other place.
+def get_quant_config(
+    quantization: str,
+    model_name_or_path: str,
+    cache_dir: Optional[str] = None,
+) -> QuantizationConfig:
+    is_local = os.path.isdir(model_name_or_path)
+    if not is_local:
+        # Download the config files.
+        with get_lock(model_name_or_path, cache_dir):
+            hf_folder = snapshot_download(model_name_or_path,
+                                          allow_patterns="*.json",
+                                          cache_dir=cache_dir,
+                                          tqdm_class=Disabledtqdm)
+    else:
+        hf_folder = model_name_or_path
+    config_files = glob.glob(os.path.join(hf_folder, "*.json"))
+
+    quant_cls = get_quant_class(quantization)
+    quant_config_files = [
+        f for f in config_files if any(
+            f.endswith(x) for x in quant_cls.get_config_filenames())
+    ]
+    if len(quant_config_files) == 0:
+        raise ValueError(f"Cannot find the config file for {quantization}")
+    if len(quant_config_files) > 1:
+        raise ValueError(f"Found multiple config files for {quantization}: "
+                         f"{quant_config_files}")
+
+    quant_config_file = quant_config_files[0]
+    with open(quant_config_file, "r") as f:
+        config = json.load(f)
+    return quant_cls.from_config(config)
+
+
 def prepare_hf_model_weights(
 def prepare_hf_model_weights(
     model_name_or_path: str,
     model_name_or_path: str,
     cache_dir: Optional[str] = None,
     cache_dir: Optional[str] = None,
     use_safetensors: bool = False,
     use_safetensors: bool = False,
     fall_back_to_pt: bool = True,
     fall_back_to_pt: bool = True,
-):
+    revision: Optional[str] = None,
+) -> Tuple[str, List[str], bool]:
     # Download model weights from huggingface.
     # Download model weights from huggingface.
     is_local = os.path.isdir(model_name_or_path)
     is_local = os.path.isdir(model_name_or_path)
-    allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
+    if use_safetensors:
+        allow_patterns = ["*.safetensors"]
+    else:
+        # Some quantized models use .pt files for storing the weights.
+        allow_patterns = ["*.bin", "*.pt"]
     if not is_local:
     if not is_local:
         # Use file lock to prevent multiple processes from
         # Use file lock to prevent multiple processes from
         # downloading the same model weights at the same time.
         # downloading the same model weights at the same time.
@@ -94,10 +137,13 @@ def prepare_hf_model_weights(
             hf_folder = snapshot_download(model_name_or_path,
             hf_folder = snapshot_download(model_name_or_path,
                                           allow_patterns=allow_patterns,
                                           allow_patterns=allow_patterns,
                                           cache_dir=cache_dir,
                                           cache_dir=cache_dir,
-                                          tqdm_class=Disabledtqdm)
+                                          tqdm_class=Disabledtqdm,
+                                          revision=revision)
     else:
     else:
         hf_folder = model_name_or_path
         hf_folder = model_name_or_path
-    hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
+    hf_weights_files: List[str] = []
+    for pattern in allow_patterns:
+        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
     if not use_safetensors:
     if not use_safetensors:
         hf_weights_files = [
         hf_weights_files = [
             x for x in hf_weights_files if not x.endswith("training_args.bin")
             x for x in hf_weights_files if not x.endswith("training_args.bin")
@@ -107,7 +153,8 @@ def prepare_hf_model_weights(
         return prepare_hf_model_weights(model_name_or_path,
         return prepare_hf_model_weights(model_name_or_path,
                                         cache_dir=cache_dir,
                                         cache_dir=cache_dir,
                                         use_safetensors=False,
                                         use_safetensors=False,
-                                        fall_back_to_pt=False)
+                                        fall_back_to_pt=False,
+                                        revision=revision)
 
 
     if len(hf_weights_files) == 0:
     if len(hf_weights_files) == 0:
         raise RuntimeError(
         raise RuntimeError(
@@ -120,6 +167,7 @@ def hf_model_weights_iterator(
     model_name_or_path: str,
     model_name_or_path: str,
     cache_dir: Optional[str] = None,
     cache_dir: Optional[str] = None,
     load_format: str = "auto",
     load_format: str = "auto",
+    revision: Optional[str] = None,
 ) -> Iterator[Tuple[str, torch.Tensor]]:
 ) -> Iterator[Tuple[str, torch.Tensor]]:
     use_safetensors = False
     use_safetensors = False
     use_np_cache = False
     use_np_cache = False
@@ -140,7 +188,8 @@ def hf_model_weights_iterator(
         model_name_or_path,
         model_name_or_path,
         cache_dir=cache_dir,
         cache_dir=cache_dir,
         use_safetensors=use_safetensors,
         use_safetensors=use_safetensors,
-        fall_back_to_pt=fall_back_to_pt)
+        fall_back_to_pt=fall_back_to_pt,
+        revision=revision)
 
 
     if use_np_cache:
     if use_np_cache:
         # Currently np_cache only support *.bin checkpoints
         # Currently np_cache only support *.bin checkpoints
@@ -260,4 +309,4 @@ def initialize_dummy_weights(
     values between -1e-3 and 1e-3 works well for most models.
     values between -1e-3 and 1e-3 works well for most models.
     """
     """
     for param in model.state_dict().values():
     for param in model.state_dict().values():
-        param.data.uniform_(low, high)
+

+ 36 - 0
aphrodite/modeling/layers/quantized_linear/__init__.py

@@ -0,0 +1,36 @@
+from aphrodite.modeling.layers.quantized_linear.awq import (
+    AWQColumnParallelLinear, AWQRowParallelLinear)
+from aphrodite.modeling.megatron.tensor_parallel import (
+    ColumnParallelLinear, RowParallelLinear)
+
+_QUANTIZED_LINEAR_REGISTRY = {
+    "awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
+}
+
+class ParallelLinear:
+
+    @classmethod
+    def column(cls, *args, **kwargs) -> ColumnParallelLinear:
+        quant_config = kwargs.get("quant_config", None)
+        if quant_config is None:
+            return ColumnParallelLinear(*args, **kwargs)
+        
+        name = quant_config.get_name()
+        if name not in _QUANTIZED_LINEAR_REGISTRY:
+            raise ValueError(f"No quantized linear is found for {name}")
+
+        quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
+        return quant_linear_cls(*args, **kwargs)
+
+    @classmethod
+    def row(cls, *args, **kwargs) -> RowParallelLinear:
+        quant_config = kwargs.get("quant_config", None)
+        if quant_config is None:
+            return RowParallelLinear(*args, **kwargs)
+        
+        name = quant_config.get_name()
+        if name not in _QUANTIZED_LINEAR_REGISTRY:
+            raise ValueError(f"No quantized linear is found for {name}")
+
+        quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
+        return quant_linear_cls(*args, **kwargs)

+ 87 - 0
aphrodite/modeling/layers/quantized_linear/awq.py

@@ -0,0 +1,87 @@
+from ast import Param
+from typing import Optional
+
+import torch
+from torch.nn.parameter import Parameter
+
+from aphrodite import quantization_ops
+from aphrodite.modeling.megatron.tensor_parallel.layers import (
+    ColumnParallelLinear, RowParallelLinear)
+
+class AWQColumnParallelLinear(ColumnParallelLinear):
+
+    def create_weights(self, dtype: torch.dtype) -> None:
+        assert self.input_size % self.quant_config.weight_bits == 0
+        assert self.output_size_per_partition % self.quant_config.pack_factor == 0
+        self.qweight = Parameter(
+            torch.empty(
+                self.input_size,
+                self.output_size_per_partition // self.quant_config.pack_factor,
+                device = "cuda",
+                dtype=torch.int32,
+            ),
+            requires_grad=False,
+        )
+        self.scales = Parameter(
+            torch.empty(
+                self.input_size // self.quant_config.group_size,
+                self.output_size_per_partition,
+                device="cuda",
+                dtype=dtype,
+            ),
+            requires_grad=False,
+        )
+
+    def apply_weights(
+        self,
+        x: torch.Tensor,
+        bias: Optional[torch.Tensor],
+        ) -> torch.Tensor:
+            pack_factor = self.quant_config.pack_factor
+            out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
+            reshaped_x = x.reshape(-1, x.reshape[-1])
+            out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, self.qzeros, pack_factor)
+            
+            if bias is not None:
+                out = out + bias
+            return out.reshape(out_shape)
+
+class AWQRowParallelLinear(RowParallelLinear):
+    def create_weights(self, dtype: torch.dtype) -> None:
+        assert self.input_size_per_partition % self.quant_config.weight_bits == 0
+        assert self.output_size % self.quant_config.pack_factor == 0
+        self.qweight = Parameter(
+            torch.empty(
+                self.input_size_per_partition,
+                self.output_size // self.quant_config.pack_factor,
+                device="cuda",
+                dtype=torch.int32,
+            ),
+            requires_grad=False,
+        )
+        self.qzeros = Parameter(
+            torch.empty(
+                self.input_size_per_partition // self.quant_config.group_size,
+                self.output_size // self.quant_config.pack_factor,
+                device="cuda",
+                dtype=torch.int32,
+            ),
+            requires_grad=False,
+        )
+        self.scales = Parameter(
+            torch.empty(
+                self.input_size_per_partition // self.quant_config.group_size,
+                self.output_size,
+                device="cuda",
+                dtype=dtype,
+            ),
+            requires_grad=False,
+        )
+
+    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
+        pack_factor = self.quant_config.pack_factor
+        out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
+        reshaped_x = x.reshape(-1, x.shape[-1])
+        out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, self.qzeros, pack_factor)
+        return out.reshape(out_shape)
+

+ 24 - 2
aphrodite/modeling/loader.py

@@ -1,12 +1,13 @@
 import contextlib
 import contextlib
 from typing import Type
 from typing import Type
 import torch
 import torch
+from torch.multiprocessing import Value
 import torch.nn as nn
 import torch.nn as nn
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.common.config import ModelConfig
 from aphrodite.modeling.models import LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM
 from aphrodite.modeling.models import LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM
-from aphrodite.modeling.hf_downloader import initialize_dummy_weights
+from aphrodite.modeling.hf_downloader import initialize_dummy_weights, get_quant_config
 
 
 _MODEL_REGISTRY = {
 _MODEL_REGISTRY = {
     "LlamaForCausalLM": LlamaForCausalLM,
     "LlamaForCausalLM": LlamaForCausalLM,
@@ -15,6 +16,10 @@ _MODEL_REGISTRY = {
     "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
     "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
 }
 }
 
 
+_QUANT_REGISTRY = {
+    "LlamaForCausalLM",
+}
+
 @contextlib.contextmanager
 @contextlib.contextmanager
 def _set_default_torch_dtype(dtype: torch.dtype):
 def _set_default_torch_dtype(dtype: torch.dtype):
     """Sets the default torch dtype to the given dtype."""
     """Sets the default torch dtype to the given dtype."""
@@ -36,10 +41,27 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
 
 
 def get_model(model_config: ModelConfig) -> nn.Module:
 def get_model(model_config: ModelConfig) -> nn.Module:
     model_class = _get_model_architecture(model_config.hf_config)
     model_class = _get_model_architecture(model_config.hf_config)
+    quant_config = None
+    if model_config.quantization is not None:
+        if model_class not in _QUANT_REGISTRY:
+            raise ValueError(
+                f"Quantization is not supported for {model_class}.")
+        quant_config = get_quant_config(model_config.quantization,
+                                        model_config.model,
+                                        model_config.download_dir)
+        supported_dtypes = quant_config.get_supported_act_dtypes()
+        if model_config.dtype not in supported_dtypes:
+            raise ValueError(
+                f"{model_config.dtype} is not supported for quantization method {model_config.quantization}. "
+                f"Supported datatypes: {supported_dtypes}")
+
     with _set_default_torch_dtype(model_config.dtype):
     with _set_default_torch_dtype(model_config.dtype):
         # Create a model instance.
         # Create a model instance.
         # The weights will be initialized as empty tensors.
         # The weights will be initialized as empty tensors.
-        model = model_class(model_config.hf_config)
+        if model_class in _QUANT_REGISTRY:
+            model = model_class(model_config.hf_config, quant_config)
+        else:
+            model = model_class(model_config.hf_config)
         if model_config.load_format == "dummy":
         if model_config.load_format == "dummy":
             model = model.cuda()
             model = model.cuda()
             # NOTE: For accurate performance evaluation, we assign
             # NOTE: For accurate performance evaluation, we assign

+ 61 - 143
aphrodite/modeling/megatron/tensor_parallel/layers.py

@@ -15,16 +15,13 @@ from torch.nn.parameter import Parameter
 from aphrodite.modeling.megatron.parallel_state import (
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank,
     get_tensor_model_parallel_rank,
     get_tensor_model_parallel_world_size,
     get_tensor_model_parallel_world_size,
-    get_all_reduce_launcher,
 )
 )
 from .mappings import (
 from .mappings import (
-    copy_to_tensor_model_parallel_region,
     gather_from_tensor_model_parallel_region,
     gather_from_tensor_model_parallel_region,
     reduce_from_tensor_model_parallel_region,
     reduce_from_tensor_model_parallel_region,
     scatter_to_tensor_model_parallel_region,
     scatter_to_tensor_model_parallel_region,
 )
 )
 
 
-from .random import get_cuda_rng_tracker
 from .utils import (
 from .utils import (
     divide,
     divide,
     VocabUtility,
     VocabUtility,
@@ -67,59 +64,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
         maybe_copy(attribute)
         maybe_copy(attribute)
 
 
 
 
-def _initialize_affine_weight_gpu(weight, init_method,
-                                  partition_dim, stride=1):
-    """Initialize affine weight for model parallel on GPU."""
-
-    set_tensor_model_parallel_attributes(tensor=weight,
-                                         is_parallel=True,
-                                         dim=partition_dim,
-                                         stride=stride)
-
-    with get_cuda_rng_tracker().fork():
-        init_method(weight)
-
-
-def _initialize_affine_weight_cpu(weight, output_size, input_size,
-                                  per_partition_size, partition_dim,
-                                  init_method, stride=1,
-                                  return_master_weight=False,
-                                  *, params_dtype=None):
-    """Initialize affine weight for model parallel.
-
-    Build the master weight on all processes and scatter
-    the relevant chunk."""
-
-    set_tensor_model_parallel_attributes(tensor=weight,
-                                         is_parallel=True,
-                                         dim=partition_dim,
-                                         stride=stride)
-
-    if params_dtype is None:
-        params_dtype = torch.get_default_dtype()
-
-    # Initialize master weight
-    master_weight = torch.empty(output_size, input_size,
-                                dtype=torch.float,
-                                requires_grad=False)
-    init_method(master_weight)
-    master_weight = master_weight.to(dtype=params_dtype)
-
-    # Split and copy
-    per_partition_per_stride_size = divide(per_partition_size, stride)
-    weight_list = torch.split(master_weight, per_partition_per_stride_size,
-                              dim=partition_dim)
-    rank = get_tensor_model_parallel_rank()
-    world_size = get_tensor_model_parallel_world_size()
-    my_weight_list = weight_list[rank::world_size]
-
-    with torch.no_grad():
-        torch.cat(my_weight_list, dim=partition_dim, out=weight)
-    if return_master_weight:
-        return master_weight
-    return None
-
-
 class VocabParallelEmbedding(torch.nn.Module):
 class VocabParallelEmbedding(torch.nn.Module):
     """Embedding parallelized in the vocabulary dimension.
     """Embedding parallelized in the vocabulary dimension.
 
 
@@ -142,6 +86,9 @@ class VocabParallelEmbedding(torch.nn.Module):
                  use_cpu_initialization: bool=False,
                  use_cpu_initialization: bool=False,
                  perform_initialization: bool=True):
                  perform_initialization: bool=True):
         super(VocabParallelEmbedding, self).__init__()
         super(VocabParallelEmbedding, self).__init__()
+        assert not perform_initialization
+        assert not use_cpu_initialization
+
         # Keep the input dimensions.
         # Keep the input dimensions.
         self.num_embeddings = num_embeddings
         self.num_embeddings = num_embeddings
         self.embedding_dim = embedding_dim
         self.embedding_dim = embedding_dim
@@ -164,24 +111,10 @@ class VocabParallelEmbedding(torch.nn.Module):
         self.num_embeddings_per_partition = self.vocab_end_index - \
         self.num_embeddings_per_partition = self.vocab_end_index - \
             self.vocab_start_index
             self.vocab_start_index
 
 
-        # Allocate weights and initialize.
-        if use_cpu_initialization:
-            self.weight = Parameter(torch.empty(
-                self.num_embeddings_per_partition, self.embedding_dim,
-                dtype=params_dtype))
-            if perform_initialization:
-                _initialize_affine_weight_cpu(
-                    self.weight, self.num_embeddings, self.embedding_dim,
-                    self.num_embeddings_per_partition, 0, init_method,
-                    params_dtype=params_dtype)
-        else:
-            self.weight = Parameter(torch.empty(
-                self.num_embeddings_per_partition, self.embedding_dim,
-                device=torch.cuda.current_device(), dtype=params_dtype))
-            if perform_initialization:
-                _initialize_affine_weight_gpu(self.weight, init_method,
-                                              partition_dim=0, stride=1)
-
+        self.weight = Parameter(torch.empty(
+            self.num_embeddings_per_partition, self.embedding_dim,
+            device=torch.cuda.current_device(), dtype=params_dtype))
+ 
     def forward(self, input_):
     def forward(self, input_):
         if self.tensor_model_parallel_size > 1:
         if self.tensor_model_parallel_size > 1:
             # Build the mask.
             # Build the mask.
@@ -241,17 +174,21 @@ class ColumnParallelLinear(torch.nn.Module):
                  params_dtype=None,
                  params_dtype=None,
                  use_cpu_initialization=False,
                  use_cpu_initialization=False,
                  perform_initialization=True,
                  perform_initialization=True,
+                 quant_config=None,
                  ):
                  ):
         super(ColumnParallelLinear, self).__init__()
         super(ColumnParallelLinear, self).__init__()
+        assert not perform_initialization
+        assert not use_cpu_initialization
 
 
         # Keep input parameters
         # Keep input parameters
         self.input_size = input_size
         self.input_size = input_size
         self.output_size = output_size
         self.output_size = output_size
         self.gather_output = gather_output
         self.gather_output = gather_output
         # Divide the weight matrix along the last dimension.
         # Divide the weight matrix along the last dimension.
-        world_size = get_tensor_model_parallel_world_size()
-        self.output_size_per_partition = divide(output_size, world_size)
+        self.world_size = get_tensor_model_parallel_world_size()
+        self.output_size_per_partition = divide(output_size, self.world_size)
         self.skip_bias_add = skip_bias_add
         self.skip_bias_add = skip_bias_add
+        self.quant_config = quant_config
 
 
         if params_dtype is None:
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
             params_dtype = torch.get_default_dtype()
@@ -259,33 +196,13 @@ class ColumnParallelLinear(torch.nn.Module):
         # Parameters.
         # Parameters.
         # Note: torch.nn.functional.linear performs XA^T + b and as a result
         # Note: torch.nn.functional.linear performs XA^T + b and as a result
         # we allocate the transpose.
         # we allocate the transpose.
-        # Initialize weight.
-        if use_cpu_initialization:
-            self.weight = Parameter(torch.empty(self.output_size_per_partition,
-                                                self.input_size,
-                                                dtype=params_dtype))
-            if perform_initialization:
-                self.master_weight = _initialize_affine_weight_cpu(
-                    self.weight, self.output_size, self.input_size,
-                    self.output_size_per_partition, 0, init_method,
-                    stride=stride, return_master_weight=keep_master_weight_for_test)
-        else:
-            self.weight = Parameter(torch.empty(
-                self.output_size_per_partition, self.input_size,
-                device=torch.cuda.current_device(), dtype=params_dtype))
-            if perform_initialization:
-                _initialize_affine_weight_gpu(self.weight, init_method,
-                                              partition_dim=0, stride=stride)
+        self.create_weights(params_dtype)
 
 
         if bias:
         if bias:
-            if use_cpu_initialization:
-                self.bias = Parameter(torch.empty(
-                    self.output_size_per_partition, dtype=params_dtype))
-            else:
-                self.bias = Parameter(torch.empty(
-                    self.output_size_per_partition,
-                    device=torch.cuda.current_device(),
-                    dtype=params_dtype))
+            self.bias = Parameter(torch.empty(
+                self.output_size_per_partition,
+                device=torch.cuda.current_device(),
+                dtype=params_dtype))
             set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
             set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
             # Always initialize bias to zero.
             # Always initialize bias to zero.
             with torch.no_grad():
             with torch.no_grad():
@@ -293,6 +210,17 @@ class ColumnParallelLinear(torch.nn.Module):
         else:
         else:
             self.register_parameter('bias', None)
             self.register_parameter('bias', None)
 
 
+    def create_weights(self, dtype: torch.dtype) -> None:
+        self.weight = Parameter(torch.empty(
+            self.output_size_per_partition, self.input_size,
+            device=torch.cuda.current_device(), dtype=dtype))
+
+    def apply_weights(
+        self,
+        x: torch.Tensor,
+        bias: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        return F.linear(x, self.weight, bias)
 
 
     def forward(self, input_):
     def forward(self, input_):
         """Forward of ColumnParallelLinear
         """Forward of ColumnParallelLinear
@@ -308,7 +236,7 @@ class ColumnParallelLinear(torch.nn.Module):
 
 
         input_parallel = input_
         input_parallel = input_
         # Matrix multiply.
         # Matrix multiply.
-        output_parallel = F.linear(input_parallel, self.weight, bias)
+        output_parallel = self.apply_weights(input_parallel, bias)
         if self.gather_output:
         if self.gather_output:
             # All-gather across the partitions.
             # All-gather across the partitions.
             output = gather_from_tensor_model_parallel_region(output_parallel)
             output = gather_from_tensor_model_parallel_region(output_parallel)
@@ -351,6 +279,7 @@ class RowParallelLinear(torch.nn.Module):
         params_dtype:
         params_dtype:
         use_cpu_initialization:
         use_cpu_initialization:
         perform_initialization:
         perform_initialization:
+        reduce_results:
     """
     """
 
 
     def __init__(self, input_size, output_size, *,
     def __init__(self, input_size, output_size, *,
@@ -361,57 +290,51 @@ class RowParallelLinear(torch.nn.Module):
                  params_dtype=None,
                  params_dtype=None,
                  use_cpu_initialization=False,
                  use_cpu_initialization=False,
                  perform_initialization=True,
                  perform_initialization=True,
+                 reduce_results=True,
+                 quant_config=None,
                  ):
                  ):
         super(RowParallelLinear, self).__init__()
         super(RowParallelLinear, self).__init__()
+        assert not perform_initialization
+        assert not use_cpu_initialization
 
 
         # Keep input parameters
         # Keep input parameters
         self.input_size = input_size
         self.input_size = input_size
         self.output_size = output_size
         self.output_size = output_size
         self.input_is_parallel = input_is_parallel
         self.input_is_parallel = input_is_parallel
+        self.reduce_results = reduce_results
         if params_dtype is None:
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
             params_dtype = torch.get_default_dtype()
 
 
         # Divide the weight matrix along the last dimension.
         # Divide the weight matrix along the last dimension.
-        world_size = get_tensor_model_parallel_world_size()
-        self.input_size_per_partition = divide(input_size, world_size)
+        self.world_size = get_tensor_model_parallel_world_size()
+        self.input_size_per_partition = divide(input_size, self.world_size)
         self.skip_bias_add = skip_bias_add
         self.skip_bias_add = skip_bias_add
+        self.quant_config = quant_config
+
+        self.create_weights(params_dtype)
+
+        if not reduce_results and (bias and not skip_bias_add):
+            raise ValueError("When not reduce the results, adding bias to the "
+                             "results can lead to incorrect results")
 
 
-        # Parameters.
-        # Note: torch.nn.functional.linear performs XA^T + b and as a result
-        # we allocate the transpose.
-        # Initialize weight.
-        if use_cpu_initialization:
-            self.weight = Parameter(torch.empty(self.output_size,
-                                                self.input_size_per_partition,
-                                                dtype=params_dtype))
-            if perform_initialization:
-                self.master_weight = _initialize_affine_weight_cpu(
-                    self.weight, self.output_size, self.input_size,
-                    self.input_size_per_partition, 1, init_method,
-                    stride=stride, return_master_weight=keep_master_weight_for_test,
-                    params_dtype=params_dtype)
-        else:
-            self.weight = Parameter(torch.empty(
-                self.output_size, self.input_size_per_partition,
-                device=torch.cuda.current_device(), dtype=params_dtype))
-            if perform_initialization:
-                _initialize_affine_weight_gpu(self.weight, init_method,
-                                              partition_dim=1, stride=stride)
         if bias:
         if bias:
-            if use_cpu_initialization:
-                self.bias = Parameter(torch.empty(self.output_size,
-                                                  dtype=params_dtype))
-            else:
-                self.bias = Parameter(torch.empty(
-                    self.output_size, device=torch.cuda.current_device(),
-                    dtype=params_dtype))
+            self.bias = Parameter(torch.empty(
+                self.output_size, device=torch.cuda.current_device(),
+                dtype=params_dtype))
 
 
             # Always initialize bias to zero.
             # Always initialize bias to zero.
             with torch.no_grad():
             with torch.no_grad():
                 self.bias.zero_()
                 self.bias.zero_()
         else:
         else:
             self.register_parameter('bias', None)
             self.register_parameter('bias', None)
-        self.weight_t = self.weight.t()
+
+    def create_weights(self, dtype: torch.dtype) -> None:
+        self.weight = Parameter(torch.empty(
+                self.output_size, self.input_size_per_partition,
+                device=torch.cuda.current_device(), dtype=dtype))
+
+    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
+        return F.linear(x, self.weight)
 
 
     def forward(self, input_):
     def forward(self, input_):
         """Forward of RowParallelLinear
         """Forward of RowParallelLinear
@@ -428,17 +351,12 @@ class RowParallelLinear(torch.nn.Module):
             input_parallel = input_
             input_parallel = input_
         else:
         else:
             input_parallel = scatter_to_tensor_model_parallel_region(input_)
             input_parallel = scatter_to_tensor_model_parallel_region(input_)
-        if get_tensor_model_parallel_world_size() == 1:
-            # Matrix multiply.
-            output_ = F.linear(input_parallel, self.weight)
+        # Matrix multiply.
+        output_parallel = self.apply_weights(input_parallel)
+        if self.reduce_results and self.world_size > 1:
+            output_ = reduce_from_tensor_model_parallel_region(output_parallel)
         else:
         else:
-            # Matrix multiply.
-            all_reduce_launcher = get_all_reduce_launcher()
-            num_tokens = input_parallel.shape[0]
-            output_buffer = all_reduce_launcher.buffer[:num_tokens]
-            torch.matmul(input_parallel, self.weight_t, out=output_buffer)
-            # All-reduce across all the partitions.
-            output_ = all_reduce_launcher.launch(output_buffer)
+            output_ = output_parallel
 
 
         if not self.skip_bias_add:
         if not self.skip_bias_add:
             output = output_ + self.bias if self.bias is not None else output_
             output = output_ + self.bias if self.bias is not None else output_

+ 90 - 33
aphrodite/modeling/models/llama.py

@@ -35,7 +35,8 @@ from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
 from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler 
+from aphrodite.modeling.layers.quantized_linear import ParallelLinear
 from aphrodite.modeling.hf_downloader import load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, hf_model_weights_iterator
 from aphrodite.modeling.hf_downloader import load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, hf_model_weights_iterator
 from aphrodite.modeling.megatron.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
 from aphrodite.modeling.megatron.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
 from aphrodite.modeling.megatron.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear
 from aphrodite.modeling.megatron.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear
@@ -52,18 +53,21 @@ class LlamaMLP(nn.Module):
         hidden_size: int,
         hidden_size: int,
         intermediate_size: int,
         intermediate_size: int,
         hidden_act: str,
         hidden_act: str,
-    ):
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         super().__init__()
-        self.gate_up_proj = ColumnParallelLinear(hidden_size,
-                                                 2 * intermediate_size,
-                                                 bias=False,
-                                                 gather_output=False,
-                                                 perform_initialization=False)
-        self.down_proj = RowParallelLinear(intermediate_size,
-                                           hidden_size,
-                                           bias=False,
-                                           input_is_parallel=True,
-                                           perform_initialization=False)
+        self.gate_up_proj = ParallelLinear.column(hidden_size,
+                                                  2 * intermediate_size,
+                                                  bias=False,
+                                                  gather_output=False,
+                                                  perform_initialization=False,
+                                                  quant_config=quant_config)
+        self.down_proj = ParallelLinear.row(intermediate_size,
+                                            hidden_size,
+                                            bias=False,
+                                            input_is_parallel=True,
+                                            perform_initialization=False,
+                                            quant_config=quant_config)
         if hidden_act != "silu":
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
                              "Only silu is supported for now.")
@@ -84,7 +88,8 @@ class LlamaAttention(nn.Module):
         num_heads: int,
         num_heads: int,
         num_kv_heads: int,
         num_kv_heads: int,
         rope_theta: float = 10000,
         rope_theta: float = 10000,
-    ):
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         super().__init__()
         self.hidden_size = hidden_size
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
         tp_size = get_tensor_model_parallel_world_size()
@@ -100,20 +105,22 @@ class LlamaAttention(nn.Module):
         self.scaling = self.head_dim**-0.5
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
         self.rope_theta = rope_theta
 
 
-        self.qkv_proj = ColumnParallelLinear(
+        self.qkv_proj = ParallelLinear.column(
             hidden_size,
             hidden_size,
             (self.total_num_heads + 2 * self.total_num_kv_heads) *
             (self.total_num_heads + 2 * self.total_num_kv_heads) *
             self.head_dim,
             self.head_dim,
             bias=False,
             bias=False,
             gather_output=False,
             gather_output=False,
             perform_initialization=False,
             perform_initialization=False,
+            quant_config=quant_config,
         )
         )
-        self.o_proj = RowParallelLinear(
+        self.o_proj = ParallelLinear.row(
             self.total_num_heads * self.head_dim,
             self.total_num_heads * self.head_dim,
             hidden_size,
             hidden_size,
             bias=False,
             bias=False,
             input_is_parallel=True,
             input_is_parallel=True,
             perform_initialization=False,
             perform_initialization=False,
+            quant_config=quant_config,
         )
         )
         self.attn = PagedAttentionWithRoPE(self.num_heads,
         self.attn = PagedAttentionWithRoPE(self.num_heads,
                                            self.head_dim,
                                            self.head_dim,
@@ -141,7 +148,11 @@ class LlamaAttention(nn.Module):
 
 
 class LlamaDecoderLayer(nn.Module):
 class LlamaDecoderLayer(nn.Module):
 
 
-    def __init__(self, config: LlamaConfig):
+    def __init__(
+        self,
+        config: LlamaConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         super().__init__()
         self.hidden_size = config.hidden_size
         self.hidden_size = config.hidden_size
         # Requires transformers > 4.32.0
         # Requires transformers > 4.32.0
@@ -151,11 +162,13 @@ class LlamaDecoderLayer(nn.Module):
             num_heads=config.num_attention_heads,
             num_heads=config.num_attention_heads,
             num_kv_heads=config.num_key_value_heads,
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             rope_theta=rope_theta,
+            quant_config=quant_config,
         )
         )
         self.mlp = LlamaMLP(
         self.mlp = LlamaMLP(
             hidden_size=self.hidden_size,
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
             hidden_act=config.hidden_act,
+            quant_config=quant_config,
         )
         )
         self.input_layernorm = RMSNorm(config.hidden_size,
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
                                        eps=config.rms_norm_eps)
@@ -192,7 +205,11 @@ class LlamaDecoderLayer(nn.Module):
 
 
 class LlamaModel(nn.Module):
 class LlamaModel(nn.Module):
 
 
-    def __init__(self, config: LlamaConfig):
+    def __init__(
+        self,
+        config: LlamaConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
         self.padding_idx = config.pad_token_id
         self.padding_idx = config.pad_token_id
@@ -202,7 +219,8 @@ class LlamaModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
         self.embed_tokens = VocabParallelEmbedding(
             vocab_size, config.hidden_size, perform_initialization=False)
             vocab_size, config.hidden_size, perform_initialization=False)
         self.layers = nn.ModuleList([
         self.layers = nn.ModuleList([
-            LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
+            LlamaDecoderLayer(config, quant_config)
+            for _ in range(config.num_hidden_layers)
         ])
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
@@ -234,16 +252,23 @@ class LlamaModel(nn.Module):
 
 
 class LlamaForCausalLM(nn.Module):
 class LlamaForCausalLM(nn.Module):
 
 
-    def __init__(self, config):
+    def __init__(
+        self,
+        config: LlamaConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         super().__init__()
         self.config = config
         self.config = config
-        self.model = LlamaModel(config)
+        self.quant_config = quant_config
+        self.model = LlamaModel(config, quant_config)
         vocab_size = ((config.vocab_size + 63) // 64) * 64
         vocab_size = ((config.vocab_size + 63) // 64) * 64
-        self.lm_head = ColumnParallelLinear(config.hidden_size,
-                                            vocab_size,
-                                            bias=False,
-                                            gather_output=False,
-                                            perform_initialization=False)
+        # NOTE: The LM head is not quantized.
+        self.lm_head = ParallelLinear.column(config.hidden_size,
+                                             vocab_size,
+                                             bias=False,
+                                             gather_output=False,
+                                             perform_initialization=False,
+                                             quant_config=None)
         self.sampler = Sampler(config.vocab_size)
         self.sampler = Sampler(config.vocab_size)
 
 
     def forward(
     def forward(
@@ -260,15 +285,28 @@ class LlamaForCausalLM(nn.Module):
                                    input_metadata)
                                    input_metadata)
         return next_tokens
         return next_tokens
 
 
-    _column_parallel_weights = [
-        "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
-    ]
-    _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
+    _column_parallel_layers = []
+    _row_parallel_layers = ["o_proj", "down_proj"]
 
 
     def load_weights(self,
     def load_weights(self,
                      model_name_or_path: str,
                      model_name_or_path: str,
                      cache_dir: Optional[str] = None,
                      cache_dir: Optional[str] = None,
-                     load_format: str = "auto"):
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        if self.quant_config is None:
+            weight_suffixes = ["weight"]
+        else:
+            weight_suffixes = self.quant_config.get_tp_tensor_names()
+
+        column_parallel_weights: List[str] = []
+        for layer in self._column_parallel_layers:
+            for suffix in weight_suffixes:
+                column_parallel_weights.append(f"{layer}.{suffix}")
+        row_parallel_weights: List[str] = []
+        for layer in self._row_parallel_layers:
+            for suffix in weight_suffixes:
+                row_parallel_weights.append(f"{layer}.{suffix}")
+
         tp_size = get_tensor_model_parallel_world_size()
         tp_size = get_tensor_model_parallel_world_size()
         tensor_model_parallel_rank = get_tensor_model_parallel_rank()
         tensor_model_parallel_rank = get_tensor_model_parallel_rank()
         q_proj_shard_size = (self.config.hidden_size // tp_size)
         q_proj_shard_size = (self.config.hidden_size // tp_size)
@@ -285,15 +323,29 @@ class LlamaForCausalLM(nn.Module):
         state_dict = self.state_dict()
         state_dict = self.state_dict()
 
 
         for name, loaded_weight in hf_model_weights_iterator(
         for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format):
+                model_name_or_path, cache_dir, load_format, revision):
             if "rotary_emb.inv_freq" in name:
             if "rotary_emb.inv_freq" in name:
                 continue
                 continue
 
 
+            is_packed = False
+            is_transposed = False
+            if self.quant_config is not None:
+                is_packed = self.quant_config.is_packed(name)
+                is_transposed = self.quant_config.is_transposed(name)
+            if is_transposed:
+                loaded_weight = loaded_weight.T
+
             is_attention_weight = False
             is_attention_weight = False
             for weight_name, shard_size, offset in attention_weight_specs:
             for weight_name, shard_size, offset in attention_weight_specs:
                 if weight_name not in name:
                 if weight_name not in name:
                     continue
                     continue
                 param = state_dict[name.replace(weight_name, "qkv_proj")]
                 param = state_dict[name.replace(weight_name, "qkv_proj")]
+                if is_transposed:
+                    param = param.T
+
+                if is_packed:
+                    shard_size //= self.quant_config.pack_factor
+                    offset //= self.quant_config.pack_factor
 
 
                 loaded_weight = loaded_weight[
                 loaded_weight = loaded_weight[
                     shard_size * tensor_model_parallel_rank:shard_size *
                     shard_size * tensor_model_parallel_rank:shard_size *
@@ -312,6 +364,9 @@ class LlamaForCausalLM(nn.Module):
                 if weight_name not in name:
                 if weight_name not in name:
                     continue
                     continue
                 param = state_dict[name.replace(weight_name, "gate_up_proj")]
                 param = state_dict[name.replace(weight_name, "gate_up_proj")]
+                if is_transposed:
+                    param = param.T
+
                 shard_size = param.shape[0] // 2
                 shard_size = param.shape[0] // 2
                 loaded_weight = loaded_weight[
                 loaded_weight = loaded_weight[
                     shard_size * tensor_model_parallel_rank:shard_size *
                     shard_size * tensor_model_parallel_rank:shard_size *
@@ -326,6 +381,8 @@ class LlamaForCausalLM(nn.Module):
                 continue
                 continue
 
 
             param = state_dict[name]
             param = state_dict[name]
+            if is_transposed:
+                param = param.T
 
 
             if "embed_tokens" in name or "lm_head" in name:
             if "embed_tokens" in name or "lm_head" in name:
                 load_padded_tensor_parallel_vocab(param, loaded_weight,
                 load_padded_tensor_parallel_vocab(param, loaded_weight,
@@ -333,6 +390,6 @@ class LlamaForCausalLM(nn.Module):
                 continue
                 continue
 
 
             load_tensor_parallel_weights(param, loaded_weight, name,
             load_tensor_parallel_weights(param, loaded_weight, name,
-                                         self._column_parallel_weights,
-                                         self._row_parallel_weights,
+                                         column_parallel_weights,
+                                         row_parallel_weights,
                                          tensor_model_parallel_rank)
                                          tensor_model_parallel_rank)

+ 20 - 0
aphrodite/modeling/quantization_utils/__init__.py

@@ -0,0 +1,20 @@
+from typing import Type
+
+from aphrodite.modeling.quantization_utils.awq import AWQConfig
+from aphrodite.modeling.quantization_utils.base import QuantizationConfig
+
+_QUANTIZATION_REGISTRY = {
+    "awq": AWQConfig,
+}
+
+
+def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
+    if quantization not in _QUANTIZATION_REGISTRY:
+        raise ValueError(f"Invalid quantization method: {quantization}")
+    return _QUANTIZATION_REGISTRY[quantization]
+
+
+__all__ = [
+    "QuantizationConfig",
+    "get_quant_class",
+]

+ 67 - 0
aphrodite/modeling/quantization_utils/awq.py

@@ -0,0 +1,67 @@
+from typing import Any, Dict, List
+
+import torch
+
+from aphrodite.modeling.quantization_utils.base import QuantizationConfig
+
+
+class AWQConfig(QuantizationConfig):
+    """Config class for AWQ.
+
+    Reference: https://arxiv.org/abs/2306.00978
+    """
+
+    def __init__(
+        self,
+        weight_bits: int,
+        group_size: int,
+        zero_point: bool,
+    ) -> None:
+        self.weight_bits = weight_bits
+        self.group_size = group_size
+        self.zero_point = zero_point
+
+        if self.weight_bits != 4:
+            raise ValueError(
+                "Currently, only 4-bit weight quantization is supported for "
+                f"AWQ, but got {self.weight_bits} bits.")
+        self.pack_factor = 32 // self.weight_bits
+
+    def __repr__(self) -> str:
+        return (f"AWQConfig(weight_bits={self.weight_bits}, "
+                f"group_size={self.group_size}, "
+                f"zero_point={self.zero_point})")
+
+    @classmethod
+    def get_name(cls) -> str:
+        return "awq"
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        return [torch.half]
+
+    @classmethod
+    def get_config_filenames(cls) -> List[str]:
+        return [
+            "quant_config.json",  
+            "quantize_config.json",
+        ]
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
+        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
+        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
+        zero_point = cls.get_from_keys(config, ["zero_point"])
+        return cls(weight_bits, group_size, zero_point)
+
+    @classmethod
+    def get_packed_tensor_names(cls) -> List[str]:
+        return ["qweight", "qzeros"]
+
+    @classmethod
+    def get_transposed_tensor_names(cls) -> List[str]:
+        return ["qweight", "qzeros", "scales"]
+
+    @classmethod
+    def get_tp_tensor_names(cls) -> List[str]:
+        return ["qweight", "qzeros", "scales"]

+ 65 - 0
aphrodite/modeling/quantization_utils/base.py

@@ -0,0 +1,65 @@
+from typing import Any, Dict, List
+
+import torch
+
+
+class QuantizationConfig:
+
+    @classmethod
+    def get_name(cls) -> str:
+        """Name of the quantization method."""
+        raise NotImplementedError
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        """List of supported activation dtypes."""
+        raise NotImplementedError
+
+    @classmethod
+    def get_config_filenames(cls) -> List[str]:
+        """List of filenames to search for in the model directory."""
+        raise NotImplementedError
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
+        """Create a config class from the model's quantization config."""
+        raise NotImplementedError
+
+    @staticmethod
+    def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
+        """Get a value from the model's quantization config."""
+        for key in keys:
+            if key in config:
+                return config[key]
+        raise ValueError(f"Cannot find any of {keys} in the model's "
+                         "quantization config.")
+
+    @classmethod
+    def get_packed_tensor_names(cls) -> List[str]:
+        raise NotImplementedError
+
+    @classmethod
+    def is_packed(cls, tensor_name: str) -> bool:
+        """Returns True if a tensor is packed.
+
+        A tensor is considered packed if each element in the tensor is a
+        packed representation of multiple elements in the original tensor.
+        For example, an INT32 element in the tensor may represent 8 INT4
+        elements in the original tensor.
+        """
+        return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
+
+    @classmethod
+    def get_transposed_tensor_names(cls) -> List[str]:
+        raise NotImplementedError
+
+    @classmethod
+    def is_transposed(cls, tensor_name: str) -> bool:
+        """Returns True if a tensor is transposed relative to nn.Linear.weight.
+        """
+        return any(tag in tensor_name
+                   for tag in cls.get_transposed_tensor_names())
+
+    @classmethod
+    def get_tp_tensor_names(cls) -> List[str]:
+        raise NotImplementedError

+ 15 - 0
kernels/quantization.cpp

@@ -0,0 +1,15 @@
+#include <torch/extension.h>
+
+torch::Tensor awq_gemm(
+  torch::Tensor _in_feats,
+  torch::Tensor _kernel,
+  torch::Tensor _scaling_factors,
+  torch::Tensor _zeros,
+  int split_k_iters);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def(
+    "awq_gemm",
+    &awq_gemm,
+    "Quantized GEMM for AWQ");
+}

+ 79 - 0
kernels/quantization/awq/dequantize.cuh

@@ -0,0 +1,79 @@
+/*
+Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
+
+@article{lin2023awq,
+  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
+  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
+  journal={arXiv},
+  year={2023}
+}
+*/
+
+#pragma once
+
+
+__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
+{
+    uint4 result;
+
+    uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
+    uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
+
+    // First, we extract the i4s and construct an intermediate fp16 number.
+    static constexpr uint32_t immLut                = (0xf0 & 0xcc) | 0xaa;
+    static constexpr uint32_t BOTTOM_MASK           = 0x000f000f;
+    static constexpr uint32_t TOP_MASK              = 0x00f000f0;
+    static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
+
+    // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
+    // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
+    // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
+    // elt_67 to fp16 without having to shift them to the bottom bits before hand.
+
+    // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
+    // immediately before required.
+    const uint32_t top_i4s = i4s >> 8;
+    // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
+    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+                    : "=r"(h[0])
+                    : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+    // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
+    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+                    : "=r"(h[1])
+                    : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+    // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
+    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+                    : "=r"(h[2])
+                    : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+    // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
+    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+                    : "=r"(h[3])
+                    : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+
+    // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
+    // half2 ctor. In this case, I chose performance reliability over code readability.
+
+    // This is the half2 {1032, 1032} represented as an integer.
+    // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
+    // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
+    static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
+    // This is the half2 {1 / 16, 1 / 16} represented as an integer.
+    static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
+    // This is the half2 {-72, -72} represented as an integer.
+    // static constexpr uint32_t NEG_72 = 0xd480d480;
+    // Haotian: Let's use {-64, -64}.
+    static constexpr uint32_t NEG_64 = 0xd400d400;
+
+    // Finally, we construct the output numbers.
+    // Convert elt_01
+    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
+    // Convert elt_23
+    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
+    // Convert elt_45
+    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
+    // Convert elt_67
+    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
+
+    return result;
+}
+

+ 477 - 0
kernels/quantization/awq/gemm_kernels.cu

@@ -0,0 +1,477 @@
+/*
+
+@article{lin2023awq,
+  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
+  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
+  journal={arXiv},
+  year={2023}
+}
+
+ */
+
+
+#include <torch/extension.h>
+#include "dequantize.cuh"
+#include <cuda_fp16.h>
+#include <c10/cuda/CUDAGuard.h>
+
+
+// Pack two half values.
+static inline __device__ __host__ unsigned
+__pack_half2(const half x, const half y) {
+  unsigned v0 = *((unsigned short *)&x);
+  unsigned v1 = *((unsigned short *)&y);
+  return (v1 << 16) | v0;
+}
+
+__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
+{
+  static constexpr uint32_t ZERO = 0x0;
+  float C_warp[32];
+  __shared__ half A_shared[16 * (32 + 8)];
+  __shared__ half B_shared[32 * (128 + 8)];
+  
+  __shared__ half scaling_factors_shared[128];
+  __shared__ half zeros_shared[128];
+
+  int j_factors1 = ((OC + 128 - 1) / 128);
+  int blockIdx_x = 0;
+  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
+  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
+
+  half A_shared_warp[8];
+  half B_shared_warp[32];
+  for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
+    for (int i = 0; i < 8; ++i) {
+      C_warp[(j_0_4_init * 8) + i] = 0.0;
+    }
+  }
+
+  static constexpr int row_stride_warp = 32 * 8 / 32;
+  static constexpr int row_stride = 2 * 32 * 8 / 128;
+  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
+  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
+  // bool wb_C_flag = (threadIdx.x / 4) < M;
+
+  half* A_ptr = A 
+                + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+                + (((int)threadIdx.x) % (32 / 8)) * 8;
+  
+  int* B_ptr = B
+            + ((int)threadIdx.y) * (OC / 8) * 2
+            + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+            + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+            + (((int)threadIdx.x) % (128 / 8)) * 1;
+// Why * 1 in the above line?
+                        
+  half* A_shared_ptr = A_shared 
+                    + ((int)threadIdx.y) * row_stride_warp * (32 + 8) 
+                    + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
+
+  half* B_shared_ptr = B_shared
+                    + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
+                    + (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
+                    + (((int)threadIdx.x) % (128 / 8)) * 8;
+  
+  int* zeros_ptr = zeros
+                + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+                + ((int)threadIdx.x) % (128 / 8);
+  
+  half* scaling_factors_ptr = scaling_factors
+                            + (((int)blockIdx_y) % j_factors1) * (128) 
+                            + (((int)threadIdx.x) % (128 / 8)) * 8;
+
+  half* C_ptr = C 
+              + blockIdx_z * M * OC        // blockIdz.x -> split_k dim
+              + (((int)blockIdx_y) % j_factors1) * 128
+              + ((int)threadIdx.y) * 64
+              + (((int)threadIdx.x) % 4) * 2;
+
+  // preload s.f. and zeros
+  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
+  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
+  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
+    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
+    __syncthreads();
+    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+    if (ld_A_flag)
+    {
+      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
+    }
+    else
+    {
+      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
+    }
+
+    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
+    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
+    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+    uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
+    /*
+    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
+      printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
+    }
+    */
+    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
+    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
+
+    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
+
+      // B: 32 x 136 (128+8) float16
+      // each warp: 32 x 4
+      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
+      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
+      // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) 
+      uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
+      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+      //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
+
+      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
+      // - zero and * scale
+      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+      /*
+      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
+        printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
+      }
+      */
+
+      // write back
+      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
+    }
+    __syncthreads();
+
+    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
+      {
+        unsigned int addr;
+        __asm__ __volatile__(
+          "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+          : "=r"(addr)
+          : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
+        );
+
+
+        __asm__ __volatile__(
+          "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+          "{%0, %1, %2, %3}, [%4];\n"
+          : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
+          : "r"(addr)
+        );
+      }
+
+      for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
+        {
+          unsigned int addr;
+          __asm__ __volatile__(
+            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+            : "=r"(addr)
+            : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
+          );
+          __asm__ __volatile__(
+            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+            "{%0, %1, %2, %3}, [%4];\n"
+            : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
+            : "r"(addr)
+          );
+        }
+      }
+      for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
+        {
+          __asm__ __volatile__(
+            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+      }
+    }
+  }
+
+// TODO: Shang: Hoist loop invariance.
+  for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
+    for (int local_id = 0; local_id < 8; ++local_id) {
+      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
+      if (row_offset < M)
+      {
+        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
+      }
+    }
+  }
+}
+
+
+__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
+{
+  static constexpr uint32_t ZERO = 0x0;
+  float C_warp[32];
+  __shared__ half A_shared[16 * (32 + 8)];
+  __shared__ half B_shared[32 * (64 + 8)];
+  
+  __shared__ half scaling_factors_shared[64];
+  __shared__ half zeros_shared[64];
+
+  int j_factors1 = ((OC + 64 - 1) / 64);
+
+  int blockIdx_x = 0;
+  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
+  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
+
+  half A_shared_warp[8];
+  half B_shared_warp[16];
+  for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
+    for (int i = 0; i < 8; ++i) {
+      C_warp[(j_0_4_init * 8) + i] = 0.0;
+    }
+  }
+
+  static constexpr int row_stride_warp = 32 * 8 / 32;
+  static constexpr int row_stride = 2 * 32 * 8 / 64;
+  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
+  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
+  // bool wb_C_flag = (threadIdx.x / 4) < M;
+
+  half* A_ptr = A 
+                + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+                + (((int)threadIdx.x) % (32 / 8)) * 8;
+  
+  int* B_ptr = B
+            + ((int)threadIdx.y) * (OC / 8) * 4
+            + (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+            + (((int)blockIdx_y) % j_factors1) * (64 / 8)
+            + (((int)threadIdx.x) % (64 / 8)) * 1;
+// Why * 1 in the above line?
+                        
+  half* A_shared_ptr = A_shared 
+                    + ((int)threadIdx.y) * row_stride_warp * (32 + 8) 
+                    + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
+
+  half* B_shared_ptr = B_shared
+                    + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+                    + (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+                    + (((int)threadIdx.x) % (64 / 8)) * 8;
+  
+  int* zeros_ptr = zeros
+                + (((int)blockIdx_y) % j_factors1) * (64 / 8)
+                + ((int)threadIdx.x) % (64 / 8);
+  
+  half* scaling_factors_ptr = scaling_factors
+                            + (((int)blockIdx_y) % j_factors1) * (64) 
+                            + (((int)threadIdx.x) % (64 / 8)) * 8;
+
+  half* C_ptr = C 
+              + blockIdx_z * M * OC        // blockIdz.x -> split_k dim
+              + (((int)blockIdx_y) % j_factors1) * 64
+              + ((int)threadIdx.y) * 32
+              + (((int)threadIdx.x) % 4) * 2;
+
+  // preload s.f. and zeros
+  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
+  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
+  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
+    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
+    __syncthreads();
+    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+    if (ld_A_flag)
+    {
+      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
+    }
+    else
+    {
+      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
+    }
+
+    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
+    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
+    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+    uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
+    /*
+    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
+      printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
+    }
+    */
+    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
+    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
+
+    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
+
+      // B: 32 x 136 (128+8) float16
+      // each warp: 32 x 4
+      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
+      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
+      // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) 
+      uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
+      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+      //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
+
+      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
+      // - zero and * scale
+      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+      /*
+      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
+        printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
+      }
+      */
+
+      // write back
+      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
+    }
+    __syncthreads();
+
+    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) 
+    {
+      {
+        unsigned int addr;
+        __asm__ __volatile__(
+          "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+          : "=r"(addr)
+          : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
+        );
+        __asm__ __volatile__(
+          "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+          "{%0, %1, %2, %3}, [%4];\n"
+          : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
+          : "r"(addr)
+        );
+      }
+        
+
+      for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) 
+      {
+        {
+          unsigned int addr;
+          __asm__ __volatile__(
+            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+            : "=r"(addr)
+            : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
+          );
+          __asm__ __volatile__(
+            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+            "{%0, %1, %2, %3}, [%4];\n"
+            : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
+            : "r"(addr)
+          );
+        }
+      }
+      
+      for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) 
+      {
+
+        {
+          __asm__ __volatile__(
+            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+      }
+    }
+  }
+
+// TODO: Shang: Hoist loop invariance.
+  for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
+    for (int local_id = 0; local_id < 8; ++local_id) {
+      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
+      if (row_offset < M)
+      {
+        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
+      }
+    }
+  }
+}
+
+// in_feats: M, IC [float16]
+// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
+// scaling_factors: IC // G, OC [float16]
+// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
+// assume that batch_size < 16 for now
+
+torch::Tensor gemm_forward_cuda(
+    torch::Tensor _in_feats,
+    torch::Tensor _kernel,
+    torch::Tensor _scaling_factors,
+    torch::Tensor _zeros,
+    int split_k_iters)
+{
+    int num_in_feats = _in_feats.size(0);
+    int num_in_channels = _in_feats.size(1);
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
+
+    auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
+    at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
+    int num_out_feats = _out_feats.size(-2);
+    int num_out_channels = _out_feats.size(-1);
+
+    auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
+    auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
+    auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
+    auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
+    auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
+    int group_size = num_in_channels / _scaling_factors.size(0);
+
+    if (num_out_channels % 64 != 0)
+        throw std::invalid_argument("OC is not multiple of cta_N = 64");
+    if (num_out_channels % 8 != 0)
+        throw std::invalid_argument("OC is not multiple of pack_num = 8");
+    if (group_size % 32 != 0)
+	      throw std::invalid_argument("Group size should be a multiple of 32");
+    if (num_out_channels % group_size != 0)
+        throw std::invalid_argument("OC is not multiple of Group size");
+
+    if (num_out_channels % 128 == 0)
+    {
+        int j_factors1 = num_out_channels / 128 / 1;
+        dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
+        // threadIdx.x: 32
+        // threadIdx.y: i_factors[2] * j_factors[2]
+        dim3 threads_per_block(32, 2);
+        gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
+            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
+    }
+    else if (num_out_channels % 64 == 0)
+    {
+	int j_factors1 = num_out_channels / 64 / 1;
+        dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
+    
+        // threadIdx.x: 32
+        // threadIdx.y: i_factors[2] * j_factors[2]
+        dim3 threads_per_block(32, 2);
+        gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
+            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
+    }
+    return _out_feats.sum(0);
+}

+ 12 - 0
setup.py

@@ -131,6 +131,18 @@ activation_extension = CUDAExtension(
 )
 )
 ext_modules.append(activation_extension)
 ext_modules.append(activation_extension)
 
 
+quantization_extension = CUDAExtension(
+    name="aphrodite.quantization_ops",
+    sources=[
+        "kernels/quantization.cpp",
+        "kernels/quantization/awq/gemm_kernels.cu",
+    ],
+    extra_compile_args={
+        "cxx": CXX_FLAGS,
+        "nvcc": NVCC_FLAGS,
+    },
+)
+ext_modules.append(quantization_extension)
 
 
 def get_path(*filepath) -> str:
 def get_path(*filepath) -> str:
     return os.path.join(ROOT_DIR, *filepath)
     return os.path.join(ROOT_DIR, *filepath)