Explorar el Código

add awq support

AlpinDale hace 1 año
padre
commit
d9c1d4f6e5

+ 16 - 1
aphrodite/common/config.py

@@ -43,6 +43,8 @@ class ModelConfig:
             version.
         max_model_len: Maximum length of a sequence (including prompt and output).
             If None, will be derived from the model.
+        quantization: Quantization method used to quantize the model weights. if 
+            none, we assume the model weights are not quantized.
     """
 
     def __init__(
@@ -57,6 +59,7 @@ class ModelConfig:
         seed: int,
         revision: Optional[str],
         max_model_len: Optional[int] = None,
+        quantization: Optional[str] = None,
     ) -> None:
         self.model = model
         self.tokenizer = tokenizer
@@ -66,11 +69,13 @@ class ModelConfig:
         self.load_format = load_format
         self.seed = seed
         self.revision = revision
+        self.quantization = quantization
 
         self.hf_config = get_config(model, trust_remote_code, revision)
         self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
         self._verify_load_format()
         self._verify_tokenizer_mode()
+        self._verify_quantization_mode()
         self.max_model_len = None
         if max_model_len is not None:
             derived_max_model_len = self.get_max_model_len()
@@ -99,6 +104,16 @@ class ModelConfig:
                 "either 'auto' or 'slow'.")
         self.tokenizer_mode = tokenizer_mode
 
+    def _verify_quantization(self) -> None:
+        supported_quantization = ["awq"]
+        if self.quantization is None:
+            return
+        quantization = self.quantization.lower()
+        if quantization not in supported_quantization:
+            raise ValueError(
+                f"Unknown quantization: {self.quantization}. Must be one of the following: {supported_quantization}.")
+        self.quantization = quantization
+
     def verify_with_parallel_config(
         self,
         parallel_config: "ParallelConfig",
@@ -319,4 +334,4 @@ def _get_and_verify_dtype(
                 "Bfloat16 is only supported on GPUs with compute capability "
                 f"of at least 8.0. Your {gpu_name} GPU has compute capability "
                 f"{compute_capability[0]}.{compute_capability[1]}.")
-    return torch_dtype
+    return torch_dtype

+ 2 - 1
aphrodite/engine/aphrodite_engine.py

@@ -80,6 +80,7 @@ class AphroditeEngine:
             f"download_dir={model_config.download_dir!r}, "
             f"load_format={model_config.load_format}, "
             f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
+            f"quantization={model_config.quantization}, "
             f"seed={model_config.seed})")
         # TODO: Print more configs in debug mode.
 
@@ -705,4 +706,4 @@ class AphroditeEngine:
         output = all_outputs[0]
         for other_output in all_outputs[1:]:
             assert output == other_output
-        return output
+        return output

+ 8 - 2
aphrodite/engine/args_tools.py

@@ -5,6 +5,7 @@ from typing import Optional, Tuple
 
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
                          SchedulerConfig)
+from torch import quantization
 
 
 @dataclass
@@ -29,6 +30,7 @@ class EngineArgs:
     max_num_seqs: int = 256
     disable_log_stats: bool = False
     revision: Optional[str] = None
+    quantization = Optional[str] = None
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -150,6 +152,10 @@ class EngineArgs:
         parser.add_argument('--disable-log-stats',
                             action='store_true',
                             help='disable logging statistics')
+        parser.add_argument('--quantization', '-q',
+                            type=str, choices=["awq", None],
+                            default=None,
+                            help="Method used for quantization.")
         return parser
 
     @classmethod
@@ -168,7 +174,7 @@ class EngineArgs:
                                    self.tokenizer_mode, self.trust_remote_code,
                                    self.download_dir, self.load_format,
                                    self.dtype, self.seed, self.revision,
-                                   self.max_model_len)
+                                   self.max_model_len, self.quantization)
         cache_config = CacheConfig(self.block_size,
                                    self.gpu_memory_utilization,
                                    self.swap_space)
@@ -205,4 +211,4 @@ class AsyncEngineArgs(EngineArgs):
                             help='max number of prompt characters or prompt '
                             'ID numbers being printed in the long. '
                             'Default: unlimited.')
-        return parser
+        return parser

+ 57 - 8
aphrodite/modeling/hf_downloader.py

@@ -13,6 +13,9 @@ import torch
 from tqdm.auto import tqdm
 
 from aphrodite.common.logger import init_logger
+from aphrodite.modeling.quantization_utils import get_quant_class
+from aphrodite.modeling.quantization_utils import QuantizationConfig
+
 
 logger = init_logger(__name__)
 
@@ -44,7 +47,7 @@ def _shared_pointers(tensors):
 def convert_bin_to_safetensor_file(
     pt_filename: str,
     sf_filename: str,
-):
+) -> None:
     loaded = torch.load(pt_filename, map_location="cpu")
     if "state_dict" in loaded:
         loaded = loaded["state_dict"]
@@ -78,15 +81,55 @@ def convert_bin_to_safetensor_file(
             raise RuntimeError(f"The output tensors do not match for key {k}")
 
 
+# TODO: Move this to other place.
+def get_quant_config(
+    quantization: str,
+    model_name_or_path: str,
+    cache_dir: Optional[str] = None,
+) -> QuantizationConfig:
+    is_local = os.path.isdir(model_name_or_path)
+    if not is_local:
+        # Download the config files.
+        with get_lock(model_name_or_path, cache_dir):
+            hf_folder = snapshot_download(model_name_or_path,
+                                          allow_patterns="*.json",
+                                          cache_dir=cache_dir,
+                                          tqdm_class=Disabledtqdm)
+    else:
+        hf_folder = model_name_or_path
+    config_files = glob.glob(os.path.join(hf_folder, "*.json"))
+
+    quant_cls = get_quant_class(quantization)
+    quant_config_files = [
+        f for f in config_files if any(
+            f.endswith(x) for x in quant_cls.get_config_filenames())
+    ]
+    if len(quant_config_files) == 0:
+        raise ValueError(f"Cannot find the config file for {quantization}")
+    if len(quant_config_files) > 1:
+        raise ValueError(f"Found multiple config files for {quantization}: "
+                         f"{quant_config_files}")
+
+    quant_config_file = quant_config_files[0]
+    with open(quant_config_file, "r") as f:
+        config = json.load(f)
+    return quant_cls.from_config(config)
+
+
 def prepare_hf_model_weights(
     model_name_or_path: str,
     cache_dir: Optional[str] = None,
     use_safetensors: bool = False,
     fall_back_to_pt: bool = True,
-):
+    revision: Optional[str] = None,
+) -> Tuple[str, List[str], bool]:
     # Download model weights from huggingface.
     is_local = os.path.isdir(model_name_or_path)
-    allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
+    if use_safetensors:
+        allow_patterns = ["*.safetensors"]
+    else:
+        # Some quantized models use .pt files for storing the weights.
+        allow_patterns = ["*.bin", "*.pt"]
     if not is_local:
         # Use file lock to prevent multiple processes from
         # downloading the same model weights at the same time.
@@ -94,10 +137,13 @@ def prepare_hf_model_weights(
             hf_folder = snapshot_download(model_name_or_path,
                                           allow_patterns=allow_patterns,
                                           cache_dir=cache_dir,
-                                          tqdm_class=Disabledtqdm)
+                                          tqdm_class=Disabledtqdm,
+                                          revision=revision)
     else:
         hf_folder = model_name_or_path
-    hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
+    hf_weights_files: List[str] = []
+    for pattern in allow_patterns:
+        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
     if not use_safetensors:
         hf_weights_files = [
             x for x in hf_weights_files if not x.endswith("training_args.bin")
@@ -107,7 +153,8 @@ def prepare_hf_model_weights(
         return prepare_hf_model_weights(model_name_or_path,
                                         cache_dir=cache_dir,
                                         use_safetensors=False,
-                                        fall_back_to_pt=False)
+                                        fall_back_to_pt=False,
+                                        revision=revision)
 
     if len(hf_weights_files) == 0:
         raise RuntimeError(
@@ -120,6 +167,7 @@ def hf_model_weights_iterator(
     model_name_or_path: str,
     cache_dir: Optional[str] = None,
     load_format: str = "auto",
+    revision: Optional[str] = None,
 ) -> Iterator[Tuple[str, torch.Tensor]]:
     use_safetensors = False
     use_np_cache = False
@@ -140,7 +188,8 @@ def hf_model_weights_iterator(
         model_name_or_path,
         cache_dir=cache_dir,
         use_safetensors=use_safetensors,
-        fall_back_to_pt=fall_back_to_pt)
+        fall_back_to_pt=fall_back_to_pt,
+        revision=revision)
 
     if use_np_cache:
         # Currently np_cache only support *.bin checkpoints
@@ -260,4 +309,4 @@ def initialize_dummy_weights(
     values between -1e-3 and 1e-3 works well for most models.
     """
     for param in model.state_dict().values():
-        param.data.uniform_(low, high)
+

+ 36 - 0
aphrodite/modeling/layers/quantized_linear/__init__.py

@@ -0,0 +1,36 @@
+from aphrodite.modeling.layers.quantized_linear.awq import (
+    AWQColumnParallelLinear, AWQRowParallelLinear)
+from aphrodite.modeling.megatron.tensor_parallel import (
+    ColumnParallelLinear, RowParallelLinear)
+
+_QUANTIZED_LINEAR_REGISTRY = {
+    "awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
+}
+
+class ParallelLinear:
+
+    @classmethod
+    def column(cls, *args, **kwargs) -> ColumnParallelLinear:
+        quant_config = kwargs.get("quant_config", None)
+        if quant_config is None:
+            return ColumnParallelLinear(*args, **kwargs)
+        
+        name = quant_config.get_name()
+        if name not in _QUANTIZED_LINEAR_REGISTRY:
+            raise ValueError(f"No quantized linear is found for {name}")
+
+        quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
+        return quant_linear_cls(*args, **kwargs)
+
+    @classmethod
+    def row(cls, *args, **kwargs) -> RowParallelLinear:
+        quant_config = kwargs.get("quant_config", None)
+        if quant_config is None:
+            return RowParallelLinear(*args, **kwargs)
+        
+        name = quant_config.get_name()
+        if name not in _QUANTIZED_LINEAR_REGISTRY:
+            raise ValueError(f"No quantized linear is found for {name}")
+
+        quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
+        return quant_linear_cls(*args, **kwargs)

+ 87 - 0
aphrodite/modeling/layers/quantized_linear/awq.py

@@ -0,0 +1,87 @@
+from ast import Param
+from typing import Optional
+
+import torch
+from torch.nn.parameter import Parameter
+
+from aphrodite import quantization_ops
+from aphrodite.modeling.megatron.tensor_parallel.layers import (
+    ColumnParallelLinear, RowParallelLinear)
+
+class AWQColumnParallelLinear(ColumnParallelLinear):
+
+    def create_weights(self, dtype: torch.dtype) -> None:
+        assert self.input_size % self.quant_config.weight_bits == 0
+        assert self.output_size_per_partition % self.quant_config.pack_factor == 0
+        self.qweight = Parameter(
+            torch.empty(
+                self.input_size,
+                self.output_size_per_partition // self.quant_config.pack_factor,
+                device = "cuda",
+                dtype=torch.int32,
+            ),
+            requires_grad=False,
+        )
+        self.scales = Parameter(
+            torch.empty(
+                self.input_size // self.quant_config.group_size,
+                self.output_size_per_partition,
+                device="cuda",
+                dtype=dtype,
+            ),
+            requires_grad=False,
+        )
+
+    def apply_weights(
+        self,
+        x: torch.Tensor,
+        bias: Optional[torch.Tensor],
+        ) -> torch.Tensor:
+            pack_factor = self.quant_config.pack_factor
+            out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
+            reshaped_x = x.reshape(-1, x.reshape[-1])
+            out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, self.qzeros, pack_factor)
+            
+            if bias is not None:
+                out = out + bias
+            return out.reshape(out_shape)
+
+class AWQRowParallelLinear(RowParallelLinear):
+    def create_weights(self, dtype: torch.dtype) -> None:
+        assert self.input_size_per_partition % self.quant_config.weight_bits == 0
+        assert self.output_size % self.quant_config.pack_factor == 0
+        self.qweight = Parameter(
+            torch.empty(
+                self.input_size_per_partition,
+                self.output_size // self.quant_config.pack_factor,
+                device="cuda",
+                dtype=torch.int32,
+            ),
+            requires_grad=False,
+        )
+        self.qzeros = Parameter(
+            torch.empty(
+                self.input_size_per_partition // self.quant_config.group_size,
+                self.output_size // self.quant_config.pack_factor,
+                device="cuda",
+                dtype=torch.int32,
+            ),
+            requires_grad=False,
+        )
+        self.scales = Parameter(
+            torch.empty(
+                self.input_size_per_partition // self.quant_config.group_size,
+                self.output_size,
+                device="cuda",
+                dtype=dtype,
+            ),
+            requires_grad=False,
+        )
+
+    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
+        pack_factor = self.quant_config.pack_factor
+        out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
+        reshaped_x = x.reshape(-1, x.shape[-1])
+        out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, self.qzeros, pack_factor)
+        return out.reshape(out_shape)
+

+ 24 - 2
aphrodite/modeling/loader.py

@@ -1,12 +1,13 @@
 import contextlib
 from typing import Type
 import torch
+from torch.multiprocessing import Value
 import torch.nn as nn
 from transformers import PretrainedConfig
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.modeling.models import LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM
-from aphrodite.modeling.hf_downloader import initialize_dummy_weights
+from aphrodite.modeling.hf_downloader import initialize_dummy_weights, get_quant_config
 
 _MODEL_REGISTRY = {
     "LlamaForCausalLM": LlamaForCausalLM,
@@ -15,6 +16,10 @@ _MODEL_REGISTRY = {
     "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
 }
 
+_QUANT_REGISTRY = {
+    "LlamaForCausalLM",
+}
+
 @contextlib.contextmanager
 def _set_default_torch_dtype(dtype: torch.dtype):
     """Sets the default torch dtype to the given dtype."""
@@ -36,10 +41,27 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
 
 def get_model(model_config: ModelConfig) -> nn.Module:
     model_class = _get_model_architecture(model_config.hf_config)
+    quant_config = None
+    if model_config.quantization is not None:
+        if model_class not in _QUANT_REGISTRY:
+            raise ValueError(
+                f"Quantization is not supported for {model_class}.")
+        quant_config = get_quant_config(model_config.quantization,
+                                        model_config.model,
+                                        model_config.download_dir)
+        supported_dtypes = quant_config.get_supported_act_dtypes()
+        if model_config.dtype not in supported_dtypes:
+            raise ValueError(
+                f"{model_config.dtype} is not supported for quantization method {model_config.quantization}. "
+                f"Supported datatypes: {supported_dtypes}")
+
     with _set_default_torch_dtype(model_config.dtype):
         # Create a model instance.
         # The weights will be initialized as empty tensors.
-        model = model_class(model_config.hf_config)
+        if model_class in _QUANT_REGISTRY:
+            model = model_class(model_config.hf_config, quant_config)
+        else:
+            model = model_class(model_config.hf_config)
         if model_config.load_format == "dummy":
             model = model.cuda()
             # NOTE: For accurate performance evaluation, we assign

+ 61 - 143
aphrodite/modeling/megatron/tensor_parallel/layers.py

@@ -15,16 +15,13 @@ from torch.nn.parameter import Parameter
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank,
     get_tensor_model_parallel_world_size,
-    get_all_reduce_launcher,
 )
 from .mappings import (
-    copy_to_tensor_model_parallel_region,
     gather_from_tensor_model_parallel_region,
     reduce_from_tensor_model_parallel_region,
     scatter_to_tensor_model_parallel_region,
 )
 
-from .random import get_cuda_rng_tracker
 from .utils import (
     divide,
     VocabUtility,
@@ -67,59 +64,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
         maybe_copy(attribute)
 
 
-def _initialize_affine_weight_gpu(weight, init_method,
-                                  partition_dim, stride=1):
-    """Initialize affine weight for model parallel on GPU."""
-
-    set_tensor_model_parallel_attributes(tensor=weight,
-                                         is_parallel=True,
-                                         dim=partition_dim,
-                                         stride=stride)
-
-    with get_cuda_rng_tracker().fork():
-        init_method(weight)
-
-
-def _initialize_affine_weight_cpu(weight, output_size, input_size,
-                                  per_partition_size, partition_dim,
-                                  init_method, stride=1,
-                                  return_master_weight=False,
-                                  *, params_dtype=None):
-    """Initialize affine weight for model parallel.
-
-    Build the master weight on all processes and scatter
-    the relevant chunk."""
-
-    set_tensor_model_parallel_attributes(tensor=weight,
-                                         is_parallel=True,
-                                         dim=partition_dim,
-                                         stride=stride)
-
-    if params_dtype is None:
-        params_dtype = torch.get_default_dtype()
-
-    # Initialize master weight
-    master_weight = torch.empty(output_size, input_size,
-                                dtype=torch.float,
-                                requires_grad=False)
-    init_method(master_weight)
-    master_weight = master_weight.to(dtype=params_dtype)
-
-    # Split and copy
-    per_partition_per_stride_size = divide(per_partition_size, stride)
-    weight_list = torch.split(master_weight, per_partition_per_stride_size,
-                              dim=partition_dim)
-    rank = get_tensor_model_parallel_rank()
-    world_size = get_tensor_model_parallel_world_size()
-    my_weight_list = weight_list[rank::world_size]
-
-    with torch.no_grad():
-        torch.cat(my_weight_list, dim=partition_dim, out=weight)
-    if return_master_weight:
-        return master_weight
-    return None
-
-
 class VocabParallelEmbedding(torch.nn.Module):
     """Embedding parallelized in the vocabulary dimension.
 
@@ -142,6 +86,9 @@ class VocabParallelEmbedding(torch.nn.Module):
                  use_cpu_initialization: bool=False,
                  perform_initialization: bool=True):
         super(VocabParallelEmbedding, self).__init__()
+        assert not perform_initialization
+        assert not use_cpu_initialization
+
         # Keep the input dimensions.
         self.num_embeddings = num_embeddings
         self.embedding_dim = embedding_dim
@@ -164,24 +111,10 @@ class VocabParallelEmbedding(torch.nn.Module):
         self.num_embeddings_per_partition = self.vocab_end_index - \
             self.vocab_start_index
 
-        # Allocate weights and initialize.
-        if use_cpu_initialization:
-            self.weight = Parameter(torch.empty(
-                self.num_embeddings_per_partition, self.embedding_dim,
-                dtype=params_dtype))
-            if perform_initialization:
-                _initialize_affine_weight_cpu(
-                    self.weight, self.num_embeddings, self.embedding_dim,
-                    self.num_embeddings_per_partition, 0, init_method,
-                    params_dtype=params_dtype)
-        else:
-            self.weight = Parameter(torch.empty(
-                self.num_embeddings_per_partition, self.embedding_dim,
-                device=torch.cuda.current_device(), dtype=params_dtype))
-            if perform_initialization:
-                _initialize_affine_weight_gpu(self.weight, init_method,
-                                              partition_dim=0, stride=1)
-
+        self.weight = Parameter(torch.empty(
+            self.num_embeddings_per_partition, self.embedding_dim,
+            device=torch.cuda.current_device(), dtype=params_dtype))
+ 
     def forward(self, input_):
         if self.tensor_model_parallel_size > 1:
             # Build the mask.
@@ -241,17 +174,21 @@ class ColumnParallelLinear(torch.nn.Module):
                  params_dtype=None,
                  use_cpu_initialization=False,
                  perform_initialization=True,
+                 quant_config=None,
                  ):
         super(ColumnParallelLinear, self).__init__()
+        assert not perform_initialization
+        assert not use_cpu_initialization
 
         # Keep input parameters
         self.input_size = input_size
         self.output_size = output_size
         self.gather_output = gather_output
         # Divide the weight matrix along the last dimension.
-        world_size = get_tensor_model_parallel_world_size()
-        self.output_size_per_partition = divide(output_size, world_size)
+        self.world_size = get_tensor_model_parallel_world_size()
+        self.output_size_per_partition = divide(output_size, self.world_size)
         self.skip_bias_add = skip_bias_add
+        self.quant_config = quant_config
 
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
@@ -259,33 +196,13 @@ class ColumnParallelLinear(torch.nn.Module):
         # Parameters.
         # Note: torch.nn.functional.linear performs XA^T + b and as a result
         # we allocate the transpose.
-        # Initialize weight.
-        if use_cpu_initialization:
-            self.weight = Parameter(torch.empty(self.output_size_per_partition,
-                                                self.input_size,
-                                                dtype=params_dtype))
-            if perform_initialization:
-                self.master_weight = _initialize_affine_weight_cpu(
-                    self.weight, self.output_size, self.input_size,
-                    self.output_size_per_partition, 0, init_method,
-                    stride=stride, return_master_weight=keep_master_weight_for_test)
-        else:
-            self.weight = Parameter(torch.empty(
-                self.output_size_per_partition, self.input_size,
-                device=torch.cuda.current_device(), dtype=params_dtype))
-            if perform_initialization:
-                _initialize_affine_weight_gpu(self.weight, init_method,
-                                              partition_dim=0, stride=stride)
+        self.create_weights(params_dtype)
 
         if bias:
-            if use_cpu_initialization:
-                self.bias = Parameter(torch.empty(
-                    self.output_size_per_partition, dtype=params_dtype))
-            else:
-                self.bias = Parameter(torch.empty(
-                    self.output_size_per_partition,
-                    device=torch.cuda.current_device(),
-                    dtype=params_dtype))
+            self.bias = Parameter(torch.empty(
+                self.output_size_per_partition,
+                device=torch.cuda.current_device(),
+                dtype=params_dtype))
             set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
             # Always initialize bias to zero.
             with torch.no_grad():
@@ -293,6 +210,17 @@ class ColumnParallelLinear(torch.nn.Module):
         else:
             self.register_parameter('bias', None)
 
+    def create_weights(self, dtype: torch.dtype) -> None:
+        self.weight = Parameter(torch.empty(
+            self.output_size_per_partition, self.input_size,
+            device=torch.cuda.current_device(), dtype=dtype))
+
+    def apply_weights(
+        self,
+        x: torch.Tensor,
+        bias: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        return F.linear(x, self.weight, bias)
 
     def forward(self, input_):
         """Forward of ColumnParallelLinear
@@ -308,7 +236,7 @@ class ColumnParallelLinear(torch.nn.Module):
 
         input_parallel = input_
         # Matrix multiply.
-        output_parallel = F.linear(input_parallel, self.weight, bias)
+        output_parallel = self.apply_weights(input_parallel, bias)
         if self.gather_output:
             # All-gather across the partitions.
             output = gather_from_tensor_model_parallel_region(output_parallel)
@@ -351,6 +279,7 @@ class RowParallelLinear(torch.nn.Module):
         params_dtype:
         use_cpu_initialization:
         perform_initialization:
+        reduce_results:
     """
 
     def __init__(self, input_size, output_size, *,
@@ -361,57 +290,51 @@ class RowParallelLinear(torch.nn.Module):
                  params_dtype=None,
                  use_cpu_initialization=False,
                  perform_initialization=True,
+                 reduce_results=True,
+                 quant_config=None,
                  ):
         super(RowParallelLinear, self).__init__()
+        assert not perform_initialization
+        assert not use_cpu_initialization
 
         # Keep input parameters
         self.input_size = input_size
         self.output_size = output_size
         self.input_is_parallel = input_is_parallel
+        self.reduce_results = reduce_results
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
 
         # Divide the weight matrix along the last dimension.
-        world_size = get_tensor_model_parallel_world_size()
-        self.input_size_per_partition = divide(input_size, world_size)
+        self.world_size = get_tensor_model_parallel_world_size()
+        self.input_size_per_partition = divide(input_size, self.world_size)
         self.skip_bias_add = skip_bias_add
+        self.quant_config = quant_config
+
+        self.create_weights(params_dtype)
+
+        if not reduce_results and (bias and not skip_bias_add):
+            raise ValueError("When not reduce the results, adding bias to the "
+                             "results can lead to incorrect results")
 
-        # Parameters.
-        # Note: torch.nn.functional.linear performs XA^T + b and as a result
-        # we allocate the transpose.
-        # Initialize weight.
-        if use_cpu_initialization:
-            self.weight = Parameter(torch.empty(self.output_size,
-                                                self.input_size_per_partition,
-                                                dtype=params_dtype))
-            if perform_initialization:
-                self.master_weight = _initialize_affine_weight_cpu(
-                    self.weight, self.output_size, self.input_size,
-                    self.input_size_per_partition, 1, init_method,
-                    stride=stride, return_master_weight=keep_master_weight_for_test,
-                    params_dtype=params_dtype)
-        else:
-            self.weight = Parameter(torch.empty(
-                self.output_size, self.input_size_per_partition,
-                device=torch.cuda.current_device(), dtype=params_dtype))
-            if perform_initialization:
-                _initialize_affine_weight_gpu(self.weight, init_method,
-                                              partition_dim=1, stride=stride)
         if bias:
-            if use_cpu_initialization:
-                self.bias = Parameter(torch.empty(self.output_size,
-                                                  dtype=params_dtype))
-            else:
-                self.bias = Parameter(torch.empty(
-                    self.output_size, device=torch.cuda.current_device(),
-                    dtype=params_dtype))
+            self.bias = Parameter(torch.empty(
+                self.output_size, device=torch.cuda.current_device(),
+                dtype=params_dtype))
 
             # Always initialize bias to zero.
             with torch.no_grad():
                 self.bias.zero_()
         else:
             self.register_parameter('bias', None)
-        self.weight_t = self.weight.t()
+
+    def create_weights(self, dtype: torch.dtype) -> None:
+        self.weight = Parameter(torch.empty(
+                self.output_size, self.input_size_per_partition,
+                device=torch.cuda.current_device(), dtype=dtype))
+
+    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
+        return F.linear(x, self.weight)
 
     def forward(self, input_):
         """Forward of RowParallelLinear
@@ -428,17 +351,12 @@ class RowParallelLinear(torch.nn.Module):
             input_parallel = input_
         else:
             input_parallel = scatter_to_tensor_model_parallel_region(input_)
-        if get_tensor_model_parallel_world_size() == 1:
-            # Matrix multiply.
-            output_ = F.linear(input_parallel, self.weight)
+        # Matrix multiply.
+        output_parallel = self.apply_weights(input_parallel)
+        if self.reduce_results and self.world_size > 1:
+            output_ = reduce_from_tensor_model_parallel_region(output_parallel)
         else:
-            # Matrix multiply.
-            all_reduce_launcher = get_all_reduce_launcher()
-            num_tokens = input_parallel.shape[0]
-            output_buffer = all_reduce_launcher.buffer[:num_tokens]
-            torch.matmul(input_parallel, self.weight_t, out=output_buffer)
-            # All-reduce across all the partitions.
-            output_ = all_reduce_launcher.launch(output_buffer)
+            output_ = output_parallel
 
         if not self.skip_bias_add:
             output = output_ + self.bias if self.bias is not None else output_

+ 90 - 33
aphrodite/modeling/models/llama.py

@@ -35,7 +35,8 @@ from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
-from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.sampler import Sampler 
+from aphrodite.modeling.layers.quantized_linear import ParallelLinear
 from aphrodite.modeling.hf_downloader import load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, hf_model_weights_iterator
 from aphrodite.modeling.megatron.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
 from aphrodite.modeling.megatron.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear
@@ -52,18 +53,21 @@ class LlamaMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-    ):
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
-        self.gate_up_proj = ColumnParallelLinear(hidden_size,
-                                                 2 * intermediate_size,
-                                                 bias=False,
-                                                 gather_output=False,
-                                                 perform_initialization=False)
-        self.down_proj = RowParallelLinear(intermediate_size,
-                                           hidden_size,
-                                           bias=False,
-                                           input_is_parallel=True,
-                                           perform_initialization=False)
+        self.gate_up_proj = ParallelLinear.column(hidden_size,
+                                                  2 * intermediate_size,
+                                                  bias=False,
+                                                  gather_output=False,
+                                                  perform_initialization=False,
+                                                  quant_config=quant_config)
+        self.down_proj = ParallelLinear.row(intermediate_size,
+                                            hidden_size,
+                                            bias=False,
+                                            input_is_parallel=True,
+                                            perform_initialization=False,
+                                            quant_config=quant_config)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -84,7 +88,8 @@ class LlamaAttention(nn.Module):
         num_heads: int,
         num_kv_heads: int,
         rope_theta: float = 10000,
-    ):
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -100,20 +105,22 @@ class LlamaAttention(nn.Module):
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
 
-        self.qkv_proj = ColumnParallelLinear(
+        self.qkv_proj = ParallelLinear.column(
             hidden_size,
             (self.total_num_heads + 2 * self.total_num_kv_heads) *
             self.head_dim,
             bias=False,
             gather_output=False,
             perform_initialization=False,
+            quant_config=quant_config,
         )
-        self.o_proj = RowParallelLinear(
+        self.o_proj = ParallelLinear.row(
             self.total_num_heads * self.head_dim,
             hidden_size,
             bias=False,
             input_is_parallel=True,
             perform_initialization=False,
+            quant_config=quant_config,
         )
         self.attn = PagedAttentionWithRoPE(self.num_heads,
                                            self.head_dim,
@@ -141,7 +148,11 @@ class LlamaAttention(nn.Module):
 
 class LlamaDecoderLayer(nn.Module):
 
-    def __init__(self, config: LlamaConfig):
+    def __init__(
+        self,
+        config: LlamaConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         self.hidden_size = config.hidden_size
         # Requires transformers > 4.32.0
@@ -151,11 +162,13 @@ class LlamaDecoderLayer(nn.Module):
             num_heads=config.num_attention_heads,
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
+            quant_config=quant_config,
         )
         self.mlp = LlamaMLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
+            quant_config=quant_config,
         )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
@@ -192,7 +205,11 @@ class LlamaDecoderLayer(nn.Module):
 
 class LlamaModel(nn.Module):
 
-    def __init__(self, config: LlamaConfig):
+    def __init__(
+        self,
+        config: LlamaConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         self.config = config
         self.padding_idx = config.pad_token_id
@@ -202,7 +219,8 @@ class LlamaModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
             vocab_size, config.hidden_size, perform_initialization=False)
         self.layers = nn.ModuleList([
-            LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
+            LlamaDecoderLayer(config, quant_config)
+            for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
@@ -234,16 +252,23 @@ class LlamaModel(nn.Module):
 
 class LlamaForCausalLM(nn.Module):
 
-    def __init__(self, config):
+    def __init__(
+        self,
+        config: LlamaConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
         super().__init__()
         self.config = config
-        self.model = LlamaModel(config)
+        self.quant_config = quant_config
+        self.model = LlamaModel(config, quant_config)
         vocab_size = ((config.vocab_size + 63) // 64) * 64
-        self.lm_head = ColumnParallelLinear(config.hidden_size,
-                                            vocab_size,
-                                            bias=False,
-                                            gather_output=False,
-                                            perform_initialization=False)
+        # NOTE: The LM head is not quantized.
+        self.lm_head = ParallelLinear.column(config.hidden_size,
+                                             vocab_size,
+                                             bias=False,
+                                             gather_output=False,
+                                             perform_initialization=False,
+                                             quant_config=None)
         self.sampler = Sampler(config.vocab_size)
 
     def forward(
@@ -260,15 +285,28 @@ class LlamaForCausalLM(nn.Module):
                                    input_metadata)
         return next_tokens
 
-    _column_parallel_weights = [
-        "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
-    ]
-    _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
+    _column_parallel_layers = []
+    _row_parallel_layers = ["o_proj", "down_proj"]
 
     def load_weights(self,
                      model_name_or_path: str,
                      cache_dir: Optional[str] = None,
-                     load_format: str = "auto"):
+                     load_format: str = "auto",
+                     revision: Optional[str] = None):
+        if self.quant_config is None:
+            weight_suffixes = ["weight"]
+        else:
+            weight_suffixes = self.quant_config.get_tp_tensor_names()
+
+        column_parallel_weights: List[str] = []
+        for layer in self._column_parallel_layers:
+            for suffix in weight_suffixes:
+                column_parallel_weights.append(f"{layer}.{suffix}")
+        row_parallel_weights: List[str] = []
+        for layer in self._row_parallel_layers:
+            for suffix in weight_suffixes:
+                row_parallel_weights.append(f"{layer}.{suffix}")
+
         tp_size = get_tensor_model_parallel_world_size()
         tensor_model_parallel_rank = get_tensor_model_parallel_rank()
         q_proj_shard_size = (self.config.hidden_size // tp_size)
@@ -285,15 +323,29 @@ class LlamaForCausalLM(nn.Module):
         state_dict = self.state_dict()
 
         for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format):
+                model_name_or_path, cache_dir, load_format, revision):
             if "rotary_emb.inv_freq" in name:
                 continue
 
+            is_packed = False
+            is_transposed = False
+            if self.quant_config is not None:
+                is_packed = self.quant_config.is_packed(name)
+                is_transposed = self.quant_config.is_transposed(name)
+            if is_transposed:
+                loaded_weight = loaded_weight.T
+
             is_attention_weight = False
             for weight_name, shard_size, offset in attention_weight_specs:
                 if weight_name not in name:
                     continue
                 param = state_dict[name.replace(weight_name, "qkv_proj")]
+                if is_transposed:
+                    param = param.T
+
+                if is_packed:
+                    shard_size //= self.quant_config.pack_factor
+                    offset //= self.quant_config.pack_factor
 
                 loaded_weight = loaded_weight[
                     shard_size * tensor_model_parallel_rank:shard_size *
@@ -312,6 +364,9 @@ class LlamaForCausalLM(nn.Module):
                 if weight_name not in name:
                     continue
                 param = state_dict[name.replace(weight_name, "gate_up_proj")]
+                if is_transposed:
+                    param = param.T
+
                 shard_size = param.shape[0] // 2
                 loaded_weight = loaded_weight[
                     shard_size * tensor_model_parallel_rank:shard_size *
@@ -326,6 +381,8 @@ class LlamaForCausalLM(nn.Module):
                 continue
 
             param = state_dict[name]
+            if is_transposed:
+                param = param.T
 
             if "embed_tokens" in name or "lm_head" in name:
                 load_padded_tensor_parallel_vocab(param, loaded_weight,
@@ -333,6 +390,6 @@ class LlamaForCausalLM(nn.Module):
                 continue
 
             load_tensor_parallel_weights(param, loaded_weight, name,
-                                         self._column_parallel_weights,
-                                         self._row_parallel_weights,
+                                         column_parallel_weights,
+                                         row_parallel_weights,
                                          tensor_model_parallel_rank)

+ 20 - 0
aphrodite/modeling/quantization_utils/__init__.py

@@ -0,0 +1,20 @@
+from typing import Type
+
+from aphrodite.modeling.quantization_utils.awq import AWQConfig
+from aphrodite.modeling.quantization_utils.base import QuantizationConfig
+
+_QUANTIZATION_REGISTRY = {
+    "awq": AWQConfig,
+}
+
+
+def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
+    if quantization not in _QUANTIZATION_REGISTRY:
+        raise ValueError(f"Invalid quantization method: {quantization}")
+    return _QUANTIZATION_REGISTRY[quantization]
+
+
+__all__ = [
+    "QuantizationConfig",
+    "get_quant_class",
+]

+ 67 - 0
aphrodite/modeling/quantization_utils/awq.py

@@ -0,0 +1,67 @@
+from typing import Any, Dict, List
+
+import torch
+
+from aphrodite.modeling.quantization_utils.base import QuantizationConfig
+
+
+class AWQConfig(QuantizationConfig):
+    """Config class for AWQ.
+
+    Reference: https://arxiv.org/abs/2306.00978
+    """
+
+    def __init__(
+        self,
+        weight_bits: int,
+        group_size: int,
+        zero_point: bool,
+    ) -> None:
+        self.weight_bits = weight_bits
+        self.group_size = group_size
+        self.zero_point = zero_point
+
+        if self.weight_bits != 4:
+            raise ValueError(
+                "Currently, only 4-bit weight quantization is supported for "
+                f"AWQ, but got {self.weight_bits} bits.")
+        self.pack_factor = 32 // self.weight_bits
+
+    def __repr__(self) -> str:
+        return (f"AWQConfig(weight_bits={self.weight_bits}, "
+                f"group_size={self.group_size}, "
+                f"zero_point={self.zero_point})")
+
+    @classmethod
+    def get_name(cls) -> str:
+        return "awq"
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        return [torch.half]
+
+    @classmethod
+    def get_config_filenames(cls) -> List[str]:
+        return [
+            "quant_config.json",  
+            "quantize_config.json",
+        ]
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
+        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
+        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
+        zero_point = cls.get_from_keys(config, ["zero_point"])
+        return cls(weight_bits, group_size, zero_point)
+
+    @classmethod
+    def get_packed_tensor_names(cls) -> List[str]:
+        return ["qweight", "qzeros"]
+
+    @classmethod
+    def get_transposed_tensor_names(cls) -> List[str]:
+        return ["qweight", "qzeros", "scales"]
+
+    @classmethod
+    def get_tp_tensor_names(cls) -> List[str]:
+        return ["qweight", "qzeros", "scales"]

+ 65 - 0
aphrodite/modeling/quantization_utils/base.py

@@ -0,0 +1,65 @@
+from typing import Any, Dict, List
+
+import torch
+
+
+class QuantizationConfig:
+
+    @classmethod
+    def get_name(cls) -> str:
+        """Name of the quantization method."""
+        raise NotImplementedError
+
+    @classmethod
+    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
+        """List of supported activation dtypes."""
+        raise NotImplementedError
+
+    @classmethod
+    def get_config_filenames(cls) -> List[str]:
+        """List of filenames to search for in the model directory."""
+        raise NotImplementedError
+
+    @classmethod
+    def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
+        """Create a config class from the model's quantization config."""
+        raise NotImplementedError
+
+    @staticmethod
+    def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
+        """Get a value from the model's quantization config."""
+        for key in keys:
+            if key in config:
+                return config[key]
+        raise ValueError(f"Cannot find any of {keys} in the model's "
+                         "quantization config.")
+
+    @classmethod
+    def get_packed_tensor_names(cls) -> List[str]:
+        raise NotImplementedError
+
+    @classmethod
+    def is_packed(cls, tensor_name: str) -> bool:
+        """Returns True if a tensor is packed.
+
+        A tensor is considered packed if each element in the tensor is a
+        packed representation of multiple elements in the original tensor.
+        For example, an INT32 element in the tensor may represent 8 INT4
+        elements in the original tensor.
+        """
+        return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
+
+    @classmethod
+    def get_transposed_tensor_names(cls) -> List[str]:
+        raise NotImplementedError
+
+    @classmethod
+    def is_transposed(cls, tensor_name: str) -> bool:
+        """Returns True if a tensor is transposed relative to nn.Linear.weight.
+        """
+        return any(tag in tensor_name
+                   for tag in cls.get_transposed_tensor_names())
+
+    @classmethod
+    def get_tp_tensor_names(cls) -> List[str]:
+        raise NotImplementedError

+ 15 - 0
kernels/quantization.cpp

@@ -0,0 +1,15 @@
+#include <torch/extension.h>
+
+torch::Tensor awq_gemm(
+  torch::Tensor _in_feats,
+  torch::Tensor _kernel,
+  torch::Tensor _scaling_factors,
+  torch::Tensor _zeros,
+  int split_k_iters);
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+  m.def(
+    "awq_gemm",
+    &awq_gemm,
+    "Quantized GEMM for AWQ");
+}

+ 79 - 0
kernels/quantization/awq/dequantize.cuh

@@ -0,0 +1,79 @@
+/*
+Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
+
+@article{lin2023awq,
+  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
+  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
+  journal={arXiv},
+  year={2023}
+}
+*/
+
+#pragma once
+
+
+__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
+{
+    uint4 result;
+
+    uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
+    uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
+
+    // First, we extract the i4s and construct an intermediate fp16 number.
+    static constexpr uint32_t immLut                = (0xf0 & 0xcc) | 0xaa;
+    static constexpr uint32_t BOTTOM_MASK           = 0x000f000f;
+    static constexpr uint32_t TOP_MASK              = 0x00f000f0;
+    static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
+
+    // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
+    // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
+    // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
+    // elt_67 to fp16 without having to shift them to the bottom bits before hand.
+
+    // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
+    // immediately before required.
+    const uint32_t top_i4s = i4s >> 8;
+    // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
+    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+                    : "=r"(h[0])
+                    : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+    // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
+    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+                    : "=r"(h[1])
+                    : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+    // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
+    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+                    : "=r"(h[2])
+                    : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+    // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
+    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
+                    : "=r"(h[3])
+                    : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
+
+    // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
+    // half2 ctor. In this case, I chose performance reliability over code readability.
+
+    // This is the half2 {1032, 1032} represented as an integer.
+    // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
+    // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
+    static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
+    // This is the half2 {1 / 16, 1 / 16} represented as an integer.
+    static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
+    // This is the half2 {-72, -72} represented as an integer.
+    // static constexpr uint32_t NEG_72 = 0xd480d480;
+    // Haotian: Let's use {-64, -64}.
+    static constexpr uint32_t NEG_64 = 0xd400d400;
+
+    // Finally, we construct the output numbers.
+    // Convert elt_01
+    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
+    // Convert elt_23
+    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
+    // Convert elt_45
+    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
+    // Convert elt_67
+    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
+
+    return result;
+}
+

+ 477 - 0
kernels/quantization/awq/gemm_kernels.cu

@@ -0,0 +1,477 @@
+/*
+
+@article{lin2023awq,
+  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
+  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
+  journal={arXiv},
+  year={2023}
+}
+
+ */
+
+
+#include <torch/extension.h>
+#include "dequantize.cuh"
+#include <cuda_fp16.h>
+#include <c10/cuda/CUDAGuard.h>
+
+
+// Pack two half values.
+static inline __device__ __host__ unsigned
+__pack_half2(const half x, const half y) {
+  unsigned v0 = *((unsigned short *)&x);
+  unsigned v1 = *((unsigned short *)&y);
+  return (v1 << 16) | v0;
+}
+
+__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
+{
+  static constexpr uint32_t ZERO = 0x0;
+  float C_warp[32];
+  __shared__ half A_shared[16 * (32 + 8)];
+  __shared__ half B_shared[32 * (128 + 8)];
+  
+  __shared__ half scaling_factors_shared[128];
+  __shared__ half zeros_shared[128];
+
+  int j_factors1 = ((OC + 128 - 1) / 128);
+  int blockIdx_x = 0;
+  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
+  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
+
+  half A_shared_warp[8];
+  half B_shared_warp[32];
+  for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
+    for (int i = 0; i < 8; ++i) {
+      C_warp[(j_0_4_init * 8) + i] = 0.0;
+    }
+  }
+
+  static constexpr int row_stride_warp = 32 * 8 / 32;
+  static constexpr int row_stride = 2 * 32 * 8 / 128;
+  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
+  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
+  // bool wb_C_flag = (threadIdx.x / 4) < M;
+
+  half* A_ptr = A 
+                + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+                + (((int)threadIdx.x) % (32 / 8)) * 8;
+  
+  int* B_ptr = B
+            + ((int)threadIdx.y) * (OC / 8) * 2
+            + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+            + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+            + (((int)threadIdx.x) % (128 / 8)) * 1;
+// Why * 1 in the above line?
+                        
+  half* A_shared_ptr = A_shared 
+                    + ((int)threadIdx.y) * row_stride_warp * (32 + 8) 
+                    + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
+
+  half* B_shared_ptr = B_shared
+                    + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
+                    + (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
+                    + (((int)threadIdx.x) % (128 / 8)) * 8;
+  
+  int* zeros_ptr = zeros
+                + (((int)blockIdx_y) % j_factors1) * (128 / 8)
+                + ((int)threadIdx.x) % (128 / 8);
+  
+  half* scaling_factors_ptr = scaling_factors
+                            + (((int)blockIdx_y) % j_factors1) * (128) 
+                            + (((int)threadIdx.x) % (128 / 8)) * 8;
+
+  half* C_ptr = C 
+              + blockIdx_z * M * OC        // blockIdz.x -> split_k dim
+              + (((int)blockIdx_y) % j_factors1) * 128
+              + ((int)threadIdx.y) * 64
+              + (((int)threadIdx.x) % 4) * 2;
+
+  // preload s.f. and zeros
+  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
+  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
+  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
+    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
+    __syncthreads();
+    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+    if (ld_A_flag)
+    {
+      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
+    }
+    else
+    {
+      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
+    }
+
+    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
+    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
+    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+    uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
+    /*
+    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
+      printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
+    }
+    */
+    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
+    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
+
+    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
+
+      // B: 32 x 136 (128+8) float16
+      // each warp: 32 x 4
+      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
+      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
+      // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) 
+      uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
+      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+      //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
+
+      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
+      // - zero and * scale
+      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+      /*
+      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
+        printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
+      }
+      */
+
+      // write back
+      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
+    }
+    __syncthreads();
+
+    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
+      {
+        unsigned int addr;
+        __asm__ __volatile__(
+          "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+          : "=r"(addr)
+          : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
+        );
+
+
+        __asm__ __volatile__(
+          "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+          "{%0, %1, %2, %3}, [%4];\n"
+          : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
+          : "r"(addr)
+        );
+      }
+
+      for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
+        {
+          unsigned int addr;
+          __asm__ __volatile__(
+            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+            : "=r"(addr)
+            : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
+          );
+          __asm__ __volatile__(
+            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+            "{%0, %1, %2, %3}, [%4];\n"
+            : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
+            : "r"(addr)
+          );
+        }
+      }
+      for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
+        {
+          __asm__ __volatile__(
+            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+      }
+    }
+  }
+
+// TODO: Shang: Hoist loop invariance.
+  for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
+    for (int local_id = 0; local_id < 8; ++local_id) {
+      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
+      if (row_offset < M)
+      {
+        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
+      }
+    }
+  }
+}
+
+
+__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
+{
+  static constexpr uint32_t ZERO = 0x0;
+  float C_warp[32];
+  __shared__ half A_shared[16 * (32 + 8)];
+  __shared__ half B_shared[32 * (64 + 8)];
+  
+  __shared__ half scaling_factors_shared[64];
+  __shared__ half zeros_shared[64];
+
+  int j_factors1 = ((OC + 64 - 1) / 64);
+
+  int blockIdx_x = 0;
+  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
+  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
+
+  half A_shared_warp[8];
+  half B_shared_warp[16];
+  for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
+    for (int i = 0; i < 8; ++i) {
+      C_warp[(j_0_4_init * 8) + i] = 0.0;
+    }
+  }
+
+  static constexpr int row_stride_warp = 32 * 8 / 32;
+  static constexpr int row_stride = 2 * 32 * 8 / 64;
+  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
+  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
+  // bool wb_C_flag = (threadIdx.x / 4) < M;
+
+  half* A_ptr = A 
+                + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+                + (((int)threadIdx.x) % (32 / 8)) * 8;
+  
+  int* B_ptr = B
+            + ((int)threadIdx.y) * (OC / 8) * 4
+            + (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+            + (((int)blockIdx_y) % j_factors1) * (64 / 8)
+            + (((int)threadIdx.x) % (64 / 8)) * 1;
+// Why * 1 in the above line?
+                        
+  half* A_shared_ptr = A_shared 
+                    + ((int)threadIdx.y) * row_stride_warp * (32 + 8) 
+                    + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
+
+  half* B_shared_ptr = B_shared
+                    + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+                    + (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+                    + (((int)threadIdx.x) % (64 / 8)) * 8;
+  
+  int* zeros_ptr = zeros
+                + (((int)blockIdx_y) % j_factors1) * (64 / 8)
+                + ((int)threadIdx.x) % (64 / 8);
+  
+  half* scaling_factors_ptr = scaling_factors
+                            + (((int)blockIdx_y) % j_factors1) * (64) 
+                            + (((int)threadIdx.x) % (64 / 8)) * 8;
+
+  half* C_ptr = C 
+              + blockIdx_z * M * OC        // blockIdz.x -> split_k dim
+              + (((int)blockIdx_y) % j_factors1) * 64
+              + ((int)threadIdx.y) * 32
+              + (((int)threadIdx.x) % 4) * 2;
+
+  // preload s.f. and zeros
+  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
+  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
+  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
+    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
+    __syncthreads();
+    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
+    if (ld_A_flag)
+    {
+      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
+    }
+    else
+    {
+      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
+    }
+
+    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
+    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
+    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
+    uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
+    /*
+    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
+      printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
+    }
+    */
+    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
+    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
+
+    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
+
+      // B: 32 x 136 (128+8) float16
+      // each warp: 32 x 4
+      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
+      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
+      // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) 
+      uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
+      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
+      //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
+
+      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
+      // - zero and * scale
+      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
+      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
+      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
+      /*
+      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
+        printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
+      }
+      */
+
+      // write back
+      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
+    }
+    __syncthreads();
+
+    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) 
+    {
+      {
+        unsigned int addr;
+        __asm__ __volatile__(
+          "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+          : "=r"(addr)
+          : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
+        );
+        __asm__ __volatile__(
+          "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
+          "{%0, %1, %2, %3}, [%4];\n"
+          : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
+          : "r"(addr)
+        );
+      }
+        
+
+      for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) 
+      {
+        {
+          unsigned int addr;
+          __asm__ __volatile__(
+            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
+            : "=r"(addr)
+            : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
+          );
+          __asm__ __volatile__(
+            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+            "{%0, %1, %2, %3}, [%4];\n"
+            : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
+            : "r"(addr)
+          );
+        }
+      }
+      
+      for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) 
+      {
+
+        {
+          __asm__ __volatile__(
+            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
+            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
+        }
+
+        {
+          __asm__ __volatile__(
+            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
+            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
+            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
+        }
+      }
+    }
+  }
+
+// TODO: Shang: Hoist loop invariance.
+  for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
+    for (int local_id = 0; local_id < 8; ++local_id) {
+      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
+      if (row_offset < M)
+      {
+        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
+      }
+    }
+  }
+}
+
+// in_feats: M, IC [float16]
+// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
+// scaling_factors: IC // G, OC [float16]
+// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
+// assume that batch_size < 16 for now
+
+torch::Tensor gemm_forward_cuda(
+    torch::Tensor _in_feats,
+    torch::Tensor _kernel,
+    torch::Tensor _scaling_factors,
+    torch::Tensor _zeros,
+    int split_k_iters)
+{
+    int num_in_feats = _in_feats.size(0);
+    int num_in_channels = _in_feats.size(1);
+    const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
+
+    auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
+    at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
+    int num_out_feats = _out_feats.size(-2);
+    int num_out_channels = _out_feats.size(-1);
+
+    auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
+    auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
+    auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
+    auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
+    auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
+    int group_size = num_in_channels / _scaling_factors.size(0);
+
+    if (num_out_channels % 64 != 0)
+        throw std::invalid_argument("OC is not multiple of cta_N = 64");
+    if (num_out_channels % 8 != 0)
+        throw std::invalid_argument("OC is not multiple of pack_num = 8");
+    if (group_size % 32 != 0)
+	      throw std::invalid_argument("Group size should be a multiple of 32");
+    if (num_out_channels % group_size != 0)
+        throw std::invalid_argument("OC is not multiple of Group size");
+
+    if (num_out_channels % 128 == 0)
+    {
+        int j_factors1 = num_out_channels / 128 / 1;
+        dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
+        // threadIdx.x: 32
+        // threadIdx.y: i_factors[2] * j_factors[2]
+        dim3 threads_per_block(32, 2);
+        gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
+            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
+    }
+    else if (num_out_channels % 64 == 0)
+    {
+	int j_factors1 = num_out_channels / 64 / 1;
+        dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
+    
+        // threadIdx.x: 32
+        // threadIdx.y: i_factors[2] * j_factors[2]
+        dim3 threads_per_block(32, 2);
+        gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
+            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
+    }
+    return _out_feats.sum(0);
+}

+ 12 - 0
setup.py

@@ -131,6 +131,18 @@ activation_extension = CUDAExtension(
 )
 ext_modules.append(activation_extension)
 
+quantization_extension = CUDAExtension(
+    name="aphrodite.quantization_ops",
+    sources=[
+        "kernels/quantization.cpp",
+        "kernels/quantization/awq/gemm_kernels.cu",
+    ],
+    extra_compile_args={
+        "cxx": CXX_FLAGS,
+        "nvcc": NVCC_FLAGS,
+    },
+)
+ext_modules.append(quantization_extension)
 
 def get_path(*filepath) -> str:
     return os.path.join(ROOT_DIR, *filepath)