Bladeren bron

Revert "Refactor AWQ support."

This reverts commit d09e27f5d4de222d8b3873659492b821e8e5c85d.
AlpinDale 1 jaar geleden
bovenliggende
commit
39beed0b87

+ 0 - 16
aphrodite/common/config.py

@@ -43,8 +43,6 @@ class ModelConfig:
             version.
         max_model_len: Maximum length of a sequence (including prompt and output).
             If None, will be derived from the model.
-        quantization: Quantization method that was used to quantize the model
-            weights. If None, we assume the model weights are not quantized.
     """
 
     def __init__(
@@ -59,7 +57,6 @@ class ModelConfig:
         seed: int,
         revision: Optional[str],
         max_model_len: Optional[int] = None,
-        quantization: Optional[str] = None,
     ) -> None:
         self.model = model
         self.tokenizer = tokenizer
@@ -69,13 +66,11 @@ class ModelConfig:
         self.load_format = load_format
         self.seed = seed
         self.revision = revision
-        self.quantization = quantization
 
         self.hf_config = get_config(model, trust_remote_code, revision)
         self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
         self._verify_load_format()
         self._verify_tokenizer_mode()
-        self._verify_quantization()
         self.max_model_len = None
         if max_model_len is not None:
             derived_max_model_len = self.get_max_model_len()
@@ -104,17 +99,6 @@ class ModelConfig:
                 "either 'auto' or 'slow'.")
         self.tokenizer_mode = tokenizer_mode
 
-    def _verify_quantization(self) -> None:
-        supported_quantization = ["awq"]
-        if self.quantization is None:
-            return
-        quantization = self.quantization.lower()
-        if quantization not in supported_quantization:
-            raise ValueError(
-                f"Unknown quantization: {self.quantization}. Must be one of "
-                f"{supported_quantization}.")
-        self.quantization = quantization
-        
     def verify_with_parallel_config(
         self,
         parallel_config: "ParallelConfig",

+ 0 - 1
aphrodite/engine/aphrodite_engine.py

@@ -80,7 +80,6 @@ class AphroditeEngine:
             f"download_dir={model_config.download_dir!r}, "
             f"load_format={model_config.load_format}, "
             f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
-            f"quantization={model_config.quantization}, "
             f"seed={model_config.seed})")
         # TODO: Print more configs in debug mode.
 

+ 1 - 7
aphrodite/engine/args_tools.py

@@ -29,7 +29,6 @@ class EngineArgs:
     max_num_seqs: int = 256
     disable_log_stats: bool = False
     revision: Optional[str] = None
-    quantization = Optional[str] = None
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -151,11 +150,6 @@ class EngineArgs:
         parser.add_argument('--disable-log-stats',
                             action='store_true',
                             help='disable logging statistics')
-        parser.add_argument('--quantization', '-q',
-                            type=str,
-                            choices=['awq', None],
-                            default=None,
-                            help="Method used to quantize the weights.")
         return parser
 
     @classmethod
@@ -174,7 +168,7 @@ class EngineArgs:
                                    self.tokenizer_mode, self.trust_remote_code,
                                    self.download_dir, self.load_format,
                                    self.dtype, self.seed, self.revision,
-                                   self.max_model_len, self.quantization)
+                                   self.max_model_len)
         cache_config = CacheConfig(self.block_size,
                                    self.gpu_memory_utilization,
                                    self.swap_space)

+ 5 - 46
aphrodite/modeling/hf_downloader.py

@@ -13,8 +13,6 @@ import torch
 from tqdm.auto import tqdm
 
 from aphrodite.common.logger import init_logger
-from aphrodite.modeling.quantization_utils import get_quant_class
-from aphrodite.modeling.quantization_utils.base import QuantizationConfig
 
 logger = init_logger(__name__)
 
@@ -46,7 +44,7 @@ def _shared_pointers(tensors):
 def convert_bin_to_safetensor_file(
     pt_filename: str,
     sf_filename: str,
-) -> None:
+):
     loaded = torch.load(pt_filename, map_location="cpu")
     if "state_dict" in loaded:
         loaded = loaded["state_dict"]
@@ -79,52 +77,16 @@ def convert_bin_to_safetensor_file(
         if not torch.equal(pt_tensor, sf_tensor):
             raise RuntimeError(f"The output tensors do not match for key {k}")
 
-def get_quant_config(
-        quantization: str,
-        model_name_or_path: str,
-        cache_dir: Optional[str] = None,
-) -> QuantizationConfig:
-    is_local = os.path.isdir(model_name_or_path)
-    if not is_local:
-        with get_lock(model_name_or_path, cache_dir):
-            hf_folder = snapshot_download(model_name_or_path,
-                                          allow_patterns="*.json",
-                                          cache_dir=cache_dir,
-                                          tqdm_class=Disabledtqdm)
-    else:
-        hf_folder = model_name_or_path
-    config_files = glob.glob(os.path.join(hf_folder, "*.json"))
-
-    quant_cls = get_quant_class(quantization)
-    quant_config_files = [
-        f for f in config_files if any(
-            f.endswith(x) for x in quant_cls.get_config_filenames())
-    ]
-    if len(quant_config_files) == 0:
-        raise ValueError(f"Cannot find the config file for {quantization}")
-    if len(quant_config_files) > 1:
-        raise ValueError(f"Found multiple config files for {quantization}: "
-                         f"{quant_config_files}")
-    
-    quant_config_file = quant_config_files[0]
-    with open(quant_config_file, "r") as f:
-        config = json.load(f)
-    return quant_cls.from_config(config)
-
 
 def prepare_hf_model_weights(
     model_name_or_path: str,
     cache_dir: Optional[str] = None,
     use_safetensors: bool = False,
     fall_back_to_pt: bool = True,
-    revision: Optional[str] = None,
-) -> Tuple[str, List[str], bool]:
+):
     # Download model weights from huggingface.
     is_local = os.path.isdir(model_name_or_path)
-    if use_safetensors:
-        allow_patterns = ["*.safetensors"]
-    else:
-        allow_patterns = ["*.bin", "*.pt"]
+    allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
     if not is_local:
         # Use file lock to prevent multiple processes from
         # downloading the same model weights at the same time.
@@ -132,13 +94,10 @@ def prepare_hf_model_weights(
             hf_folder = snapshot_download(model_name_or_path,
                                           allow_patterns=allow_patterns,
                                           cache_dir=cache_dir,
-                                          tqdm_class=Disabledtqdm,
-                                          revision=revision)
+                                          tqdm_class=Disabledtqdm)
     else:
         hf_folder = model_name_or_path
-    hf_weights_files: List[str] = []
-    for pattern in allow_patterns:
-        hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
+    hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
     if not use_safetensors:
         hf_weights_files = [
             x for x in hf_weights_files if not x.endswith("training_args.bin")

+ 0 - 36
aphrodite/modeling/layers/quantized_linear/__init__.py

@@ -1,36 +0,0 @@
-from aphrodite.modeling.layers.quantized_linear.awq import (
-    AWQColumnParallelLinear, AWQRowParallelLinear)
-from aphrodite.modeling.megatron.tensor_parallel import (
-    ColumnParallelLinear, RowParallelLinear)
-
-_QUANTIZED_LINEAR_REGISTRY = {
-    "awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
-}
-
-class ParallelLinear:
-
-    @classmethod
-    def column(cls, *args, **kwargs) -> ColumnParallelLinear:
-        quant_config = kwargs.get("quant_config", None)
-        if quant_config is None:
-            return ColumnParallelLinear(*args, **kwargs)
-        
-        name = quant_config.get_name()
-        if name not in _QUANTIZED_LINEAR_REGISTRY:
-            raise ValueError(f"No quantized linear is found for {name}")
-        
-        quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
-        return quant_linear_cls(*args, **kwargs)
-    
-    @classmethod
-    def row(cls, **args, **kwargs) -> RowParallelLinear:
-        quant_config = kwargs.get("quant_config", None)
-        if quant_config is None:
-            return RowParallelLinear(*args, **kwargs)
-        
-        name = quant_config.get_name()
-        if name not in _QUANTIZED_LINEAR_REGISTRY:
-            raise ValueError(f"No quantized linear is found for {name}")
-        
-        quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
-        return quant_linear_cls(*args, **kwargs)

+ 0 - 100
aphrodite/modeling/layers/quantized_linear/awq.py

@@ -1,100 +0,0 @@
-from typing import Optional
-
-import torch
-from torch.nn.parameter import Parameter
-
-from aphrodite import quantization_ops
-from aphrodite.modeling.megatron.tensor_parallel.layers import (
-    ColumnParallelLinear, RowParallelLinear)
-
-class AWQColumnParallelLinear(ColumnParallelLinear):
-
-    def create_weights(self, dtype: torch.dtype) -> None:
-        assert self.input_size % self.quant_config.weight_bits == 0
-        assert (self.output_size_per_partition %
-                self.quant_config.pack_factor == 0)
-        self.qweight = Parameter(
-            torch.empty(
-                self.input_size,
-                self.output_size_per_partition //
-                self.quant_config.pack_factor,
-                device="cuda",
-                dtype=torch.int32,
-            ),
-            requires_grad=False,
-        )
-        self.qzeros = Parameter(
-            torch.empty(
-                self.input_size // self.quant_config.group_size,
-                self.output_size_per_partition //
-                self.quant_config.pack_factor,
-                device="cuda",
-                dtype=torch.int32,
-            ),
-            requires_grad=False,
-        )
-        self.scales = Parameter(
-            torch.empty(
-                self.input_size // self.quant_config.group_size,
-                self.output_size_per_partition,
-                device="cuda",
-                dtype=dtype,
-            ),
-            requires_grad=False,
-        )
-
-    def apply_weights(self, x: torch.Tensor,
-                      bias: Optional[torch.Tensor]
-                      ) -> torch.Tensor:
-        pack_factor = self.quant_config.pack_factor
-        out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
-        reshaped_x = x.reshape(-1, x.shape[-1])
-        out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
-                                        self.qzeros, pack_factor)
-        
-        if bias is not None:
-            out = out + bias
-        return out.reshape(out_shape)
-    
-
-class AWQRowParallelLinear(RowParallelLinear):
-
-    def create_weights(self, dtype: torch.dtype) -> None:
-        assert (self.input_size_per_partition %
-                self.quant_config.weight_bits == 0)
-        assert self.output_size % self.quant_config.pack_factor == 0
-        self.qweight = Parameter(
-            torch.empty(
-                self.input_size_per_partition,
-                self.output_size // self.quant_config.pack_factor,
-                device="cuda",
-                dtype=torch.int32,
-            ),
-            requires_grad=False,
-        )
-        self.qzeros = Parameter(
-            torch.empty(
-                self.input_size_per_partition // self.quant_config.group_size,
-                self.output_size // self.quant_config.pack_factor,
-                device="cuda",
-                dtype=torch.int32,
-            ),
-            requires_grad=False,
-        )
-        self.scales = Parameter(
-            torch.empty(
-                self.input_size_per_partition // self.quant_config.group_size,
-                self.output_size,
-                device="cuda",
-                dtype=dtype,
-            ),
-            requires_grad=False,
-        )
-
-    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
-        pack_factor = self.quant_config.pack_factor
-        out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
-        reshaped_x = x.reshape(-1, x.shape[-1])
-        out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
-                                        self.qzeros, pack_factor)
-        return out.reshape(out_shape)

+ 3 - 30
aphrodite/modeling/loader.py

@@ -5,10 +5,8 @@ import torch.nn as nn
 from transformers import PretrainedConfig
 
 from aphrodite.common.config import ModelConfig
-from aphrodite.modeling.models import (
-    LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM)
-from aphrodite.modeling.hf_downloader import (
-    initialize_dummy_weights, get_quant_config)
+from aphrodite.modeling.models import LlamaForCausalLM, GPTJForCausalLM, GPTNeoXForCausalLM
+from aphrodite.modeling.hf_downloader import initialize_dummy_weights
 
 _MODEL_REGISTRY = {
     "LlamaForCausalLM": LlamaForCausalLM,
@@ -17,11 +15,6 @@ _MODEL_REGISTRY = {
     "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
 }
 
-# TODO: Remove this once all models support quant.
-_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
-    LlamaForCausalLM,
-]
-
 @contextlib.contextmanager
 def _set_default_torch_dtype(dtype: torch.dtype):
     """Sets the default torch dtype to the given dtype."""
@@ -43,30 +36,10 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
 
 def get_model(model_config: ModelConfig) -> nn.Module:
     model_class = _get_model_architecture(model_config.hf_config)
-
-    # get the quant config
-    quant_config = None
-    if model_config.quantization is not None:
-        if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
-            raise ValueError(
-                f"Quantization is not supported for {model_class}.")
-        quant_config = get_quant_config(model_config.quantization,
-                                        model_config.model,
-                                        model_config.download_dir)
-        supported_dtypes = quant_config.get_supported_act_dtypes()
-        if model_config.dtype not in supported_dtypes:
-            raise ValueError(
-                f"{model_config.dtype} is not supported for quantization "
-                f"method {model_config.quantization}. Supported dtypes: "
-                f"{supported_dtypes}"
-            )
     with _set_default_torch_dtype(model_config.dtype):
         # Create a model instance.
         # The weights will be initialized as empty tensors.
-        if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
-            model = model_class(model_config.hf_config, quant_config)
-        else:
-            model = model_class(model_config.hf_config)
+        model = model_class(model_config.hf_config)
         if model_config.load_format == "dummy":
             model = model.cuda()
             # NOTE: For accurate performance evaluation, we assign

+ 0 - 1
aphrodite/modeling/megatron/tensor_parallel/__init__.py

@@ -12,7 +12,6 @@ from .mappings import (
     copy_to_tensor_model_parallel_region,
     gather_from_tensor_model_parallel_region,
     gather_from_sequence_parallel_region,
-    reduce_from_tensor_model_parallel_region,
     scatter_to_tensor_model_parallel_region,
     scatter_to_sequence_parallel_region,
 )

+ 146 - 65
aphrodite/modeling/megatron/tensor_parallel/layers.py

@@ -6,7 +6,6 @@
 # Parts of the code here are adapted from PyTorch
 # repo: https://github.com/pytorch/pytorch
 
-from typing import Optional
 
 import torch
 import torch.nn.functional as F
@@ -16,13 +15,16 @@ from torch.nn.parameter import Parameter
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank,
     get_tensor_model_parallel_world_size,
+    get_all_reduce_launcher,
 )
 from .mappings import (
+    copy_to_tensor_model_parallel_region,
     gather_from_tensor_model_parallel_region,
     reduce_from_tensor_model_parallel_region,
     scatter_to_tensor_model_parallel_region,
 )
 
+from .random import get_cuda_rng_tracker
 from .utils import (
     divide,
     VocabUtility,
@@ -65,6 +67,59 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
         maybe_copy(attribute)
 
 
+def _initialize_affine_weight_gpu(weight, init_method,
+                                  partition_dim, stride=1):
+    """Initialize affine weight for model parallel on GPU."""
+
+    set_tensor_model_parallel_attributes(tensor=weight,
+                                         is_parallel=True,
+                                         dim=partition_dim,
+                                         stride=stride)
+
+    with get_cuda_rng_tracker().fork():
+        init_method(weight)
+
+
+def _initialize_affine_weight_cpu(weight, output_size, input_size,
+                                  per_partition_size, partition_dim,
+                                  init_method, stride=1,
+                                  return_master_weight=False,
+                                  *, params_dtype=None):
+    """Initialize affine weight for model parallel.
+
+    Build the master weight on all processes and scatter
+    the relevant chunk."""
+
+    set_tensor_model_parallel_attributes(tensor=weight,
+                                         is_parallel=True,
+                                         dim=partition_dim,
+                                         stride=stride)
+
+    if params_dtype is None:
+        params_dtype = torch.get_default_dtype()
+
+    # Initialize master weight
+    master_weight = torch.empty(output_size, input_size,
+                                dtype=torch.float,
+                                requires_grad=False)
+    init_method(master_weight)
+    master_weight = master_weight.to(dtype=params_dtype)
+
+    # Split and copy
+    per_partition_per_stride_size = divide(per_partition_size, stride)
+    weight_list = torch.split(master_weight, per_partition_per_stride_size,
+                              dim=partition_dim)
+    rank = get_tensor_model_parallel_rank()
+    world_size = get_tensor_model_parallel_world_size()
+    my_weight_list = weight_list[rank::world_size]
+
+    with torch.no_grad():
+        torch.cat(my_weight_list, dim=partition_dim, out=weight)
+    if return_master_weight:
+        return master_weight
+    return None
+
+
 class VocabParallelEmbedding(torch.nn.Module):
     """Embedding parallelized in the vocabulary dimension.
 
@@ -85,11 +140,8 @@ class VocabParallelEmbedding(torch.nn.Module):
                  init_method=init.xavier_normal_,
                  params_dtype: torch.dtype=None,
                  use_cpu_initialization: bool=False,
-                 perform_initialization: bool=False):
+                 perform_initialization: bool=True):
         super(VocabParallelEmbedding, self).__init__()
-        assert not perform_initialization
-        assert not use_cpu_initialization
-
         # Keep the input dimensions.
         self.num_embeddings = num_embeddings
         self.embedding_dim = embedding_dim
@@ -112,9 +164,23 @@ class VocabParallelEmbedding(torch.nn.Module):
         self.num_embeddings_per_partition = self.vocab_end_index - \
             self.vocab_start_index
 
-        self.weight = Parameter(torch.empty(
-            self.num_embeddings_per_partition, self.embedding_dim,
-            device=torch.cuda.current_device(), dtype=params_dtype))
+        # Allocate weights and initialize.
+        if use_cpu_initialization:
+            self.weight = Parameter(torch.empty(
+                self.num_embeddings_per_partition, self.embedding_dim,
+                dtype=params_dtype))
+            if perform_initialization:
+                _initialize_affine_weight_cpu(
+                    self.weight, self.num_embeddings, self.embedding_dim,
+                    self.num_embeddings_per_partition, 0, init_method,
+                    params_dtype=params_dtype)
+        else:
+            self.weight = Parameter(torch.empty(
+                self.num_embeddings_per_partition, self.embedding_dim,
+                device=torch.cuda.current_device(), dtype=params_dtype))
+            if perform_initialization:
+                _initialize_affine_weight_gpu(self.weight, init_method,
+                                              partition_dim=0, stride=1)
 
     def forward(self, input_):
         if self.tensor_model_parallel_size > 1:
@@ -174,22 +240,18 @@ class ColumnParallelLinear(torch.nn.Module):
                  skip_bias_add=False,
                  params_dtype=None,
                  use_cpu_initialization=False,
-                 perform_initialization=False,
-                 quant_config=None,
+                 perform_initialization=True,
                  ):
         super(ColumnParallelLinear, self).__init__()
-        assert not perform_initialization
-        assert not use_cpu_initialization
 
         # Keep input parameters
         self.input_size = input_size
         self.output_size = output_size
         self.gather_output = gather_output
         # Divide the weight matrix along the last dimension.
-        self.world_size = get_tensor_model_parallel_world_size()
-        self.output_size_per_partition = divide(output_size, self.world_size)
+        world_size = get_tensor_model_parallel_world_size()
+        self.output_size_per_partition = divide(output_size, world_size)
         self.skip_bias_add = skip_bias_add
-        self.quant_config = quant_config
 
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
@@ -197,13 +259,33 @@ class ColumnParallelLinear(torch.nn.Module):
         # Parameters.
         # Note: torch.nn.functional.linear performs XA^T + b and as a result
         # we allocate the transpose.
-        self.create_weights(params_dtype)
+        # Initialize weight.
+        if use_cpu_initialization:
+            self.weight = Parameter(torch.empty(self.output_size_per_partition,
+                                                self.input_size,
+                                                dtype=params_dtype))
+            if perform_initialization:
+                self.master_weight = _initialize_affine_weight_cpu(
+                    self.weight, self.output_size, self.input_size,
+                    self.output_size_per_partition, 0, init_method,
+                    stride=stride, return_master_weight=keep_master_weight_for_test)
+        else:
+            self.weight = Parameter(torch.empty(
+                self.output_size_per_partition, self.input_size,
+                device=torch.cuda.current_device(), dtype=params_dtype))
+            if perform_initialization:
+                _initialize_affine_weight_gpu(self.weight, init_method,
+                                              partition_dim=0, stride=stride)
 
         if bias:
-            self.bias = Parameter(torch.empty(
-                self.output_size_per_partition,
-                device=torch.cuda.current_device(),
-                dtype=params_dtype))
+            if use_cpu_initialization:
+                self.bias = Parameter(torch.empty(
+                    self.output_size_per_partition, dtype=params_dtype))
+            else:
+                self.bias = Parameter(torch.empty(
+                    self.output_size_per_partition,
+                    device=torch.cuda.current_device(),
+                    dtype=params_dtype))
             set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
             # Always initialize bias to zero.
             with torch.no_grad():
@@ -211,17 +293,6 @@ class ColumnParallelLinear(torch.nn.Module):
         else:
             self.register_parameter('bias', None)
 
-    def create_weights(self, dtype: torch.dtype) -> None:
-        self.weight = Parameter(torch.empty(
-            self.output_size_per_partition, self.input_size,
-            device=torch.cuda.current_device(), dtype=dtype))
-
-    def apply_weights(
-        self,
-        x: torch.Tensor,
-        bias: Optional[torch.Tensor],
-    ) -> torch.Tensor:
-        return F.linear(x, self.weight, bias)
 
     def forward(self, input_):
         """Forward of ColumnParallelLinear
@@ -237,7 +308,7 @@ class ColumnParallelLinear(torch.nn.Module):
 
         input_parallel = input_
         # Matrix multiply.
-        output_parallel = self.apply_weights(input_parallel, bias)
+        output_parallel = F.linear(input_parallel, self.weight, bias)
         if self.gather_output:
             # All-gather across the partitions.
             output = gather_from_tensor_model_parallel_region(output_parallel)
@@ -280,7 +351,6 @@ class RowParallelLinear(torch.nn.Module):
         params_dtype:
         use_cpu_initialization:
         perform_initialization:
-        reduce_results:
     """
 
     def __init__(self, input_size, output_size, *,
@@ -290,52 +360,58 @@ class RowParallelLinear(torch.nn.Module):
                  skip_bias_add=False,
                  params_dtype=None,
                  use_cpu_initialization=False,
-                 perform_initialization=False,
-                 reduce_results=True,
-                 quant_config=None,
+                 perform_initialization=True,
                  ):
         super(RowParallelLinear, self).__init__()
-        assert not perform_initialization
-        assert not use_cpu_initialization
 
         # Keep input parameters
         self.input_size = input_size
         self.output_size = output_size
         self.input_is_parallel = input_is_parallel
-        self.reduce_results = reduce_results
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
 
         # Divide the weight matrix along the last dimension.
-        self.world_size = get_tensor_model_parallel_world_size()
-        self.input_size_per_partition = divide(input_size, self.world_size)
+        world_size = get_tensor_model_parallel_world_size()
+        self.input_size_per_partition = divide(input_size, world_size)
         self.skip_bias_add = skip_bias_add
-        self.quant_config = quant_config
-
-        self.create_weights(params_dtype)
-
-        if not reduce_results and (bias and not skip_bias_add):
-            raise ValueError("When not reduce the results, adding bias to the "
-                             "results can lead to incorrect results")
 
+        # Parameters.
+        # Note: torch.nn.functional.linear performs XA^T + b and as a result
+        # we allocate the transpose.
+        # Initialize weight.
+        if use_cpu_initialization:
+            self.weight = Parameter(torch.empty(self.output_size,
+                                                self.input_size_per_partition,
+                                                dtype=params_dtype))
+            if perform_initialization:
+                self.master_weight = _initialize_affine_weight_cpu(
+                    self.weight, self.output_size, self.input_size,
+                    self.input_size_per_partition, 1, init_method,
+                    stride=stride, return_master_weight=keep_master_weight_for_test,
+                    params_dtype=params_dtype)
+        else:
+            self.weight = Parameter(torch.empty(
+                self.output_size, self.input_size_per_partition,
+                device=torch.cuda.current_device(), dtype=params_dtype))
+            if perform_initialization:
+                _initialize_affine_weight_gpu(self.weight, init_method,
+                                              partition_dim=1, stride=stride)
         if bias:
-            self.bias = Parameter(torch.empty(
-                self.output_size, device=torch.cuda.current_device(),
-                dtype=params_dtype))
+            if use_cpu_initialization:
+                self.bias = Parameter(torch.empty(self.output_size,
+                                                  dtype=params_dtype))
+            else:
+                self.bias = Parameter(torch.empty(
+                    self.output_size, device=torch.cuda.current_device(),
+                    dtype=params_dtype))
 
             # Always initialize bias to zero.
             with torch.no_grad():
                 self.bias.zero_()
         else:
             self.register_parameter('bias', None)
-
-    def create_weights(self, dtype: torch.dtype) -> None:
-        self.weight = Parameter(torch.empty(
-                self.output_size, self.input_size_per_partition,
-                device=torch.cuda.current_device(), dtype=dtype))
-
-    def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
-        return F.linear(x, self.weight)
+        self.weight_t = self.weight.t()
 
     def forward(self, input_):
         """Forward of RowParallelLinear
@@ -352,12 +428,17 @@ class RowParallelLinear(torch.nn.Module):
             input_parallel = input_
         else:
             input_parallel = scatter_to_tensor_model_parallel_region(input_)
-        # Matrix multiply.
-        output_parallel = self.apply_weights(input_parallel)
-        if self.reduce_results and self.world_size > 1:
-            output_ = reduce_from_tensor_model_parallel_region(output_parallel)
+        if get_tensor_model_parallel_world_size() == 1:
+            # Matrix multiply.
+            output_ = F.linear(input_parallel, self.weight)
         else:
-            output_ = output_parallel
+            # Matrix multiply.
+            all_reduce_launcher = get_all_reduce_launcher()
+            num_tokens = input_parallel.shape[0]
+            output_buffer = all_reduce_launcher.buffer[:num_tokens]
+            torch.matmul(input_parallel, self.weight_t, out=output_buffer)
+            # All-reduce across all the partitions.
+            output_ = all_reduce_launcher.launch(output_buffer)
 
         if not self.skip_bias_add:
             output = output_ + self.bias if self.bias is not None else output_
@@ -365,4 +446,4 @@ class RowParallelLinear(torch.nn.Module):
         else:
             output = output_
             output_bias = self.bias
-        return output, output_bias
+        return output, output_bias

+ 29 - 81
aphrodite/modeling/models/llama.py

@@ -36,16 +36,10 @@ from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.attention import PagedAttentionWithRoPE
 from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.hf_downloader import (
-    load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
-    hf_model_weights_iterator)
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from aphrodite.modeling.megatron.tensor_parallel import (
-    VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
+from aphrodite.modeling.hf_downloader import load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, hf_model_weights_iterator
+from aphrodite.modeling.megatron.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
+from aphrodite.modeling.megatron.tensor_parallel import VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear
 from aphrodite.common.sequence import SamplerOutput
-from aphrodite.modeling.layers.quantized_linear import ParallelLinear
-from aphrodite.modeling.quantization_utils import QuantizationConfig
 
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
@@ -58,21 +52,18 @@ class LlamaMLP(nn.Module):
         hidden_size: int,
         intermediate_size: int,
         hidden_act: str,
-        quant_config: Optional[QuantizationConfig] = None,
-    ) -> None:
+    ):
         super().__init__()
-       self.gate_up_proj = ParallelLinear.column(hidden_size,
-                                                  2 * intermediate_size,
-                                                  bias=False,
-                                                  gather_output=False,
-                                                  perform_initialization=False,
-                                                  quant_config=quant_config)
-        self.down_proj = ParallelLinear.row(intermediate_size,
-                                            hidden_size,
-                                            bias=False,
-                                            input_is_parallel=True,
-                                            perform_initialization=False,
-                                            quant_config=quant_config)
+        self.gate_up_proj = ColumnParallelLinear(hidden_size,
+                                                 2 * intermediate_size,
+                                                 bias=False,
+                                                 gather_output=False,
+                                                 perform_initialization=False)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           input_is_parallel=True,
+                                           perform_initialization=False)
         if hidden_act != "silu":
             raise ValueError(f"Unsupported activation: {hidden_act}. "
                              "Only silu is supported for now.")
@@ -93,8 +84,7 @@ class LlamaAttention(nn.Module):
         num_heads: int,
         num_kv_heads: int,
         rope_theta: float = 10000,
-        quant_config: Optional[QuantizationConfig] = None,
-    ) -> None:
+    ):
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -110,14 +100,13 @@ class LlamaAttention(nn.Module):
         self.scaling = self.head_dim**-0.5
         self.rope_theta = rope_theta
 
-        self.qkv_proj = ParallelLinear.column(
+        self.qkv_proj = ColumnParallelLinear(
             hidden_size,
             (self.total_num_heads + 2 * self.total_num_kv_heads) *
             self.head_dim,
             bias=False,
             gather_output=False,
             perform_initialization=False,
-            quant_config=quant_config,
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
@@ -152,10 +141,7 @@ class LlamaAttention(nn.Module):
 
 class LlamaDecoderLayer(nn.Module):
 
-    def __init__(self,
-                 config: LlamaConfig,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 ) -> None:
+    def __init__(self, config: LlamaConfig):
         super().__init__()
         self.hidden_size = config.hidden_size
         # Requires transformers > 4.32.0
@@ -165,13 +151,11 @@ class LlamaDecoderLayer(nn.Module):
             num_heads=config.num_attention_heads,
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
-            quant_config=quant_config,
         )
         self.mlp = LlamaMLP(
             hidden_size=self.hidden_size,
             intermediate_size=config.intermediate_size,
             hidden_act=config.hidden_act,
-            quant_config=quant_config,
         )
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
@@ -208,10 +192,7 @@ class LlamaDecoderLayer(nn.Module):
 
 class LlamaModel(nn.Module):
 
-    def __init__(self,
-                 config: LlamaConfig,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 ) -> None:
+    def __init__(self, config: LlamaConfig):
         super().__init__()
         self.config = config
         self.padding_idx = config.pad_token_id
@@ -221,7 +202,7 @@ class LlamaModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(
             vocab_size, config.hidden_size, perform_initialization=False)
         self.layers = nn.ModuleList([
-            LlamaDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers)
+            LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
@@ -253,21 +234,16 @@ class LlamaModel(nn.Module):
 
 class LlamaForCausalLM(nn.Module):
 
-    def __init__(self, config: LlamaConfig,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 ) -> None:
+    def __init__(self, config):
         super().__init__()
         self.config = config
-        self.quant_config = quant_config
-        self.model = LlamaModel(config, quant_config)
+        self.model = LlamaModel(config)
         vocab_size = ((config.vocab_size + 63) // 64) * 64
-
-        self.lm_head = ParallelLinear.column(config.hidden_size,
+        self.lm_head = ColumnParallelLinear(config.hidden_size,
                                             vocab_size,
                                             bias=False,
                                             gather_output=False,
-                                            perform_initialization=False,
-                                            quant_config=None) # NOTE: the lm_head is not quantized.
+                                            perform_initialization=False)
         self.sampler = Sampler(config.vocab_size)
 
     def forward(
@@ -284,27 +260,16 @@ class LlamaForCausalLM(nn.Module):
                                    input_metadata)
         return next_tokens
 
-    _column_parallel_layers = []
-    _row_parallel_layers = ["o_proj", "down_proj"]
+    _column_parallel_weights = [
+        "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
+    ]
+    _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
 
     def load_weights(self,
                      model_name_or_path: str,
                      cache_dir: Optional[str] = None,
                      load_format: str = "auto",
                      revision: Optional[str] = None):
-        if self.quant_config is None:
-            weight_suffixes = ["weight"]
-        else:
-            weight_suffixes = self.quant_config.get_tp_tensor_names()
-
-        column_parallel_weights: List[str] = []
-        for layer in self._column_parallel_layers:
-            for suffix in weight_suffixes:
-                column_parallel_weights.append(f"{layer}.{suffix}")
-        row_parallel_weights: List[str] = []
-        for layer in self._row_parallel_layers:
-            for suffix in weight_suffixes:
-                row_parallel_weights.append(f"{layer}.{suffix}")
         tp_size = get_tensor_model_parallel_world_size()
         tensor_model_parallel_rank = get_tensor_model_parallel_rank()
         q_proj_shard_size = (self.config.hidden_size // tp_size)
@@ -325,24 +290,11 @@ class LlamaForCausalLM(nn.Module):
             if "rotary_emb.inv_freq" in name:
                 continue
 
-            is_packed = False
-            is_transposed = False
-            if self.quant_config is not None:
-                is_packed = self.quant_config.is_packed(name)
-                is_transposed = self.quant_config.is_transposed(name)
-            if is_transposed:
-                loaded_weight = loaded_weight.T
-
             is_attention_weight = False
             for weight_name, shard_size, offset in attention_weight_specs:
                 if weight_name not in name:
                     continue
                 param = state_dict[name.replace(weight_name, "qkv_proj")]
-                if is_transposed:
-                    param = param.T
-                if is_packed:
-                    shard_size //= self.quant_config.pack_factor
-                    offset //= self.quant_config.pack_factor
 
                 loaded_weight = loaded_weight[
                     shard_size * tensor_model_parallel_rank:shard_size *
@@ -361,8 +313,6 @@ class LlamaForCausalLM(nn.Module):
                 if weight_name not in name:
                     continue
                 param = state_dict[name.replace(weight_name, "gate_up_proj")]
-                if is_transposed:
-                    param = param.T
                 shard_size = param.shape[0] // 2
                 loaded_weight = loaded_weight[
                     shard_size * tensor_model_parallel_rank:shard_size *
@@ -377,8 +327,6 @@ class LlamaForCausalLM(nn.Module):
                 continue
 
             param = state_dict[name]
-            if is_transposed:
-                param = param.T
 
             if "embed_tokens" in name or "lm_head" in name:
                 load_padded_tensor_parallel_vocab(param, loaded_weight,
@@ -386,6 +334,6 @@ class LlamaForCausalLM(nn.Module):
                 continue
 
             load_tensor_parallel_weights(param, loaded_weight, name,
-                                         column_parallel_weights,
-                                         row_parallel_weights,
+                                         self._column_parallel_weights,
+                                         self._row_parallel_weights,
                                          tensor_model_parallel_rank)

+ 0 - 18
aphrodite/modeling/quantization_utils/__init__.py

@@ -1,18 +0,0 @@
-from typing import Type
-
-from aphrodite.modeling.quantization_utils.awq import AWQConfig
-from aphrodite.modeling.quantization_utils.base import QuantizationConfig
-
-_QUANTIZATION_REGISTRY = {
-    "awq": AWQConfig,
-}
-
-def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
-    if quantization not in _QUANTIZATION_REGISTRY:
-        raise ValueError(f"Invalid quantization method: {quantization}")
-    return _QUANTIZATION_REGISTRY[quantization]
-
-__all__ = [
-    "QuantizationConfig",
-    "get_quant_class",
-]

+ 0 - 68
aphrodite/modeling/quantization_utils/awq.py

@@ -1,68 +0,0 @@
-from typing import Any, Dict, List
-
-import torch
-
-from aphrodite.modeling.quantization_utils.base import QuantizationConfig
-
-class AWQConfig(QuantizationConfig):
-    """Config class for AWQ.
-    Reference: https://arxiv.org/abs/2306.00978
-    """
-    def __init__(
-            self,
-            weight_bits: int,
-            group_size: int,
-            zero_point: bool,
-    ) -> None:
-        self.weight_bits = weight_bits
-        self.group_size = group_size
-        self.zero_point = zero_point
-
-        if self.weight_bits != 4:
-            raise ValueError(
-                "Currently, only 4-bit weight quantization is supported for "
-                f"AWQ, but got {self.weight_bits} bits instead.")
-        self.pack_factor = 32 // self.weight_bits
-
-    def __repr__(self) -> str:
-        return (f"AWQConfig(weight_bits={self.weight_bits}, "
-                f"group_size={self.group_size}, "
-                f"zero_point={self.zero_point})")
-    
-    @classmethod
-    def get_name(cls) -> str:
-        return "awq"
-    
-    @classmethod
-    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
-        return [torch.half]
-    
-    @classmethod
-    def get_min_capability(cls) -> int:
-        return 80
-    
-    @classmethod
-    def get_config_filenames(cls) -> List[str]:
-        return [
-            "quant_config.json",
-            "quantize_config.json",]
-    
-    @classmethod
-    def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
-        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
-        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
-        zero_point = cls.get_from_keys(config, ["zero_point"])
-        return cls(weight_bits, group_size, zero_point)
-    
-    @classmethod
-    def get_packed_tensor_names(cls) -> List[str]:
-        return ["qweight", "qzeros"]
-    
-    @classmethod
-    def get_transposed_tensor_names(cls) -> List[str]:
-        return ["qweight", "qzeros", "scales"]
-    
-    @classmethod
-    def get_tp_tensor_names(cls) -> List[str]:
-        return ["qweight", "qzeros", "scales"]
-    

+ 0 - 62
aphrodite/modeling/quantization_utils/base.py

@@ -1,62 +0,0 @@
-from typing import Any, Dict, List
-
-import torch
-
-class QuantizationConfig:
-
-    @classmethod
-    def get_name(cls) -> str:
-        """Name of the quant method."""
-        raise NotImplementedError
-    
-    @classmethod
-    def get_supported_act_dtypes(cls) -> List[torch.dtype]:
-        """List of supported activation datatypes."""
-        raise NotImplementedError
-    
-    @classmethod
-    def get_config_filenames(cls) -> List[str]:
-        """List of filenames to search for in the model directory."""
-        raise NotImplementedError
-    
-    @classmethod
-    def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
-        """Create a config class from the model's quant config."""
-        raise NotImplementedError
-    
-    @staticmethod
-    def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
-        """Get a value from the model's quant config."""
-        for key in keys:
-            if key in config:
-                return config[key]
-        raise ValueError(f"Cannot find any of the {keys} in the model's quant config.")
-    
-    @classmethod
-    def get_packed_tensor_names(cls) -> List[str]:
-        raise NotImplementedError
-    
-    @classmethod
-    def is_packed(cls, tensor_name: str) -> bool:
-        """Returns True if a tensor is packed.
-        
-        A tensor is considered packed if each element in the tensor is a 
-        packed representation of multiple elements in the original tensor.
-        For example, an INT32 element in the tensor may represent 8 INT4
-        elements in the original tensor.
-        """
-        return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
-    
-    @classmethod
-    def get_transposed_tensor_names(cls) -> List[str]:
-        raise NotImplementedError
-    
-    @classmethod
-    def is_transposed(cls, tensor_name: str) -> bool:
-        """Returns True if a tensor is transposed relative to nn.Linear.weight."""
-        return any(tag in tensor_name
-                   for tag in cls.get_transposed_tensor_names())
-    
-    @classmethod
-    def get_tp_tensor_names(cls) -> List[str]:
-        raise NotImplementedError

+ 0 - 15
kernels/quantization.cpp

@@ -1,15 +0,0 @@
-#include <torch/extension.h>
-
-torch::Tensor awq_gemm(
-    torch::Tensor _in_feats,
-    torch::Tensor _kernel,
-    torch::Tensor _scaling_factors,
-    torch::Tensor _zeros,
-    int split_k_iters);
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
-    m.def(
-        "awq_gemm",
-        &awq_gemm,
-        "Quantized GEMM for AWQ");
-}

+ 0 - 56
kernels/quantization/awq/dequantize.cuh

@@ -1,56 +0,0 @@
-#pragma once
-
-__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
-{
-    uint4 result;
-
-    uint32_t* h = reinterpret_cast<u_int32_t*>(&result);
-    uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
-
-    static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
-    static constexpr uint32_t BOTTOM_MASK = 0x00f000f;
-    static constexpr uint32_t TOP_MASK = 0x00f000f0;
-    static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
-
-    // Note that the entire sequence only requires 1 shift instruction. This is thanks to
-    // the register packing format and the fact that we force out integers to be unsigned,
-    // and account for this in the fp16 subtractions.
-    // We exploit the fact that sub and fma have the same throughput in order to convert
-    // elt_32 and elt_67 to fp16 without having to shift them to the bottom bits beforehand.
-
-    // shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency
-    // immediately before required.
-    const uin32_t top_i4s = i4s >> 8;
-    // extract elt_01 - (i4s & 0x00f00f) | 0x64006400
-    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
-                    : "=r"(h[0])
-                    : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut);
-    // extract elt_32 (i4s & 0x00f000f0) | 0x64006400
-    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
-                    :"=r"(h[1])
-                    : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
-    // extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
-    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
-                    :"=r"(h[2])
-                    : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
-    // extract elt_64 (top_i4s & 0x00f000f0) | 0x64006400
-    asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
-                    :"=r"(h[3])
-                    : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
-
-    static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
-    static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
-    static constexpr uint32_t NEG_64 = 0xd400d400;
-
-    // convert elt_01
-    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
-    // convert elt_23
-    asm volatile("sub.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
-    // convert elt_45
-    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
-    // convert elt_67
-    asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
-
-    return result;
-
-}

+ 0 - 476
kernels/quantization/awq/gemm_kernels.cu

@@ -1,476 +0,0 @@
-/*
-Adapted from https://github.com/mit-han-lab/llm-awq
-@article{lin2023awq,
-  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
-  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
-  journal={arXiv},
-  year={2023}
-}
-*/
-
-#include <torch/extension.h>
-#include <c10/cuda/CUDAGuard.h>
-
-#include "dequantize.cuh"
-
-#include <cuda_fp16.h>
-
-// Pack two half values.
-static inline __device__ __host__ unsigned
-__pack_half2(const half x, const half y) {
-  unsigned v0 = *((unsigned short *)&x);
-  unsigned v1 = *((unsigned short *)&y);
-  return (v1 << 16) | v0;
-}
-
-__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
-{
-  static constexpr uint32_t ZERO = 0x0;
-  float C_warp[32];
-  __shared__ half A_shared[16 * (32 + 8)];
-  __shared__ half B_shared[32 * (128 + 8)];
-  
-  __shared__ half scaling_factors_shared[128];
-  __shared__ half zeros_shared[128];
-
-  int j_factors1 = ((OC + 128 - 1) / 128);
-  int blockIdx_x = 0;
-  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
-  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
-
-  half A_shared_warp[8];
-  half B_shared_warp[32];
-  for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
-    for (int i = 0; i < 8; ++i) {
-      C_warp[(j_0_4_init * 8) + i] = 0.0;
-    }
-  }
-
-  static constexpr int row_stride_warp = 32 * 8 / 32;
-  static constexpr int row_stride = 2 * 32 * 8 / 128;
-  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
-  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
-  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
-  // bool wb_C_flag = (threadIdx.x / 4) < M;
-
-  half* A_ptr = A 
-                + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
-                + (((int)threadIdx.x) % (32 / 8)) * 8;
-  
-  int* B_ptr = B
-            + ((int)threadIdx.y) * (OC / 8) * 2
-            + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
-            + (((int)blockIdx_y) % j_factors1) * (128 / 8)
-            + (((int)threadIdx.x) % (128 / 8)) * 1;
-// Why * 1 in the above line?
-                        
-  half* A_shared_ptr = A_shared 
-                    + ((int)threadIdx.y) * row_stride_warp * (32 + 8) 
-                    + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
-                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
-
-  half* B_shared_ptr = B_shared
-                    + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
-                    + (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
-                    + (((int)threadIdx.x) % (128 / 8)) * 8;
-  
-  int* zeros_ptr = zeros
-                + (((int)blockIdx_y) % j_factors1) * (128 / 8)
-                + ((int)threadIdx.x) % (128 / 8);
-  
-  half* scaling_factors_ptr = scaling_factors
-                            + (((int)blockIdx_y) % j_factors1) * (128) 
-                            + (((int)threadIdx.x) % (128 / 8)) * 8;
-
-  half* C_ptr = C 
-              + blockIdx_z * M * OC        // blockIdz.x -> split_k dim
-              + (((int)blockIdx_y) % j_factors1) * 128
-              + ((int)threadIdx.y) * 64
-              + (((int)threadIdx.x) % 4) * 2;
-
-  // preload s.f. and zeros
-  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
-  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
-  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
-    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
-    __syncthreads();
-    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
-    if (ld_A_flag)
-    {
-      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
-    }
-    else
-    {
-      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
-    }
-
-    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
-    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
-    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
-    uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
-    /*
-    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
-      printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
-    }
-    */
-    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
-    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
-
-    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
-
-      // B: 32 x 136 (128+8) float16
-      // each warp: 32 x 4
-      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
-      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
-      // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) 
-      uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
-      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
-      //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
-
-      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
-      // - zero and * scale
-      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
-      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
-      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
-      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
-      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
-      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
-      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
-      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
-      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
-      /*
-      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
-        printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
-      }
-      */
-
-      // write back
-      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
-    }
-    __syncthreads();
-
-    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
-      {
-        unsigned int addr;
-        __asm__ __volatile__(
-          "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
-          : "=r"(addr)
-          : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
-        );
-
-
-        __asm__ __volatile__(
-          "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
-          "{%0, %1, %2, %3}, [%4];\n"
-          : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
-          : "r"(addr)
-        );
-      }
-
-      for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
-        {
-          unsigned int addr;
-          __asm__ __volatile__(
-            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
-            : "=r"(addr)
-            : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
-          );
-          __asm__ __volatile__(
-            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
-            "{%0, %1, %2, %3}, [%4];\n"
-            : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
-            : "r"(addr)
-          );
-        }
-      }
-      for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
-        {
-          __asm__ __volatile__(
-            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
-            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
-            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
-            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
-        }
-
-        {
-          __asm__ __volatile__(
-            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
-            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
-            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
-            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
-        }
-      }
-    }
-  }
-
-// TODO: Shang: Hoist loop invariance.
-  for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
-    for (int local_id = 0; local_id < 8; ++local_id) {
-      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
-      if (row_offset < M)
-      {
-        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
-      }
-    }
-  }
-}
-
-
-__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
-{
-  static constexpr uint32_t ZERO = 0x0;
-  float C_warp[32];
-  __shared__ half A_shared[16 * (32 + 8)];
-  __shared__ half B_shared[32 * (64 + 8)];
-  
-  __shared__ half scaling_factors_shared[64];
-  __shared__ half zeros_shared[64];
-
-  int j_factors1 = ((OC + 64 - 1) / 64);
-
-  int blockIdx_x = 0;
-  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
-  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
-
-  half A_shared_warp[8];
-  half B_shared_warp[16];
-  for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
-    for (int i = 0; i < 8; ++i) {
-      C_warp[(j_0_4_init * 8) + i] = 0.0;
-    }
-  }
-
-  static constexpr int row_stride_warp = 32 * 8 / 32;
-  static constexpr int row_stride = 2 * 32 * 8 / 64;
-  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
-  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
-  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
-  // bool wb_C_flag = (threadIdx.x / 4) < M;
-
-  half* A_ptr = A 
-                + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
-                + (((int)threadIdx.x) % (32 / 8)) * 8;
-  
-  int* B_ptr = B
-            + ((int)threadIdx.y) * (OC / 8) * 4
-            + (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
-            + (((int)blockIdx_y) % j_factors1) * (64 / 8)
-            + (((int)threadIdx.x) % (64 / 8)) * 1;
-// Why * 1 in the above line?
-                        
-  half* A_shared_ptr = A_shared 
-                    + ((int)threadIdx.y) * row_stride_warp * (32 + 8) 
-                    + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
-                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
-
-  half* B_shared_ptr = B_shared
-                    + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
-                    + (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
-                    + (((int)threadIdx.x) % (64 / 8)) * 8;
-  
-  int* zeros_ptr = zeros
-                + (((int)blockIdx_y) % j_factors1) * (64 / 8)
-                + ((int)threadIdx.x) % (64 / 8);
-  
-  half* scaling_factors_ptr = scaling_factors
-                            + (((int)blockIdx_y) % j_factors1) * (64) 
-                            + (((int)threadIdx.x) % (64 / 8)) * 8;
-
-  half* C_ptr = C 
-              + blockIdx_z * M * OC        // blockIdz.x -> split_k dim
-              + (((int)blockIdx_y) % j_factors1) * 64
-              + ((int)threadIdx.y) * 32
-              + (((int)threadIdx.x) % 4) * 2;
-
-  // preload s.f. and zeros
-  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
-  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
-  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
-    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
-    __syncthreads();
-    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
-    if (ld_A_flag)
-    {
-      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
-    }
-    else
-    {
-      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
-    }
-
-    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
-    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
-    uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
-    uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
-    /*
-    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
-      printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
-    }
-    */
-    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
-    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
-
-    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
-
-      // B: 32 x 136 (128+8) float16
-      // each warp: 32 x 4
-      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
-      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
-      // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) 
-      uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
-      uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
-      //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
-
-      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
-      // - zero and * scale
-      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
-      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
-      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
-      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
-      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
-      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
-      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
-      asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
-      asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
-      /*
-      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
-        printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
-      }
-      */
-
-      // write back
-      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
-    }
-    __syncthreads();
-
-    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) 
-    {
-      {
-        unsigned int addr;
-        __asm__ __volatile__(
-          "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
-          : "=r"(addr)
-          : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
-        );
-        __asm__ __volatile__(
-          "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
-          "{%0, %1, %2, %3}, [%4];\n"
-          : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
-          : "r"(addr)
-        );
-      }
-        
-
-      for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) 
-      {
-        {
-          unsigned int addr;
-          __asm__ __volatile__(
-            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
-            : "=r"(addr)
-            : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
-          );
-          __asm__ __volatile__(
-            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
-            "{%0, %1, %2, %3}, [%4];\n"
-            : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
-            : "r"(addr)
-          );
-        }
-      }
-      
-      for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) 
-      {
-
-        {
-          __asm__ __volatile__(
-            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
-            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
-            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
-            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
-        }
-
-        {
-          __asm__ __volatile__(
-            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
-            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
-            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
-            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
-        }
-      }
-    }
-  }
-
-// TODO: Shang: Hoist loop invariance.
-  for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
-    for (int local_id = 0; local_id < 8; ++local_id) {
-      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
-      if (row_offset < M)
-      {
-        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
-      }
-    }
-  }
-}
-
-// in_feats: M, IC [float16]
-// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
-// scaling_factors: IC // G, OC [float16]
-// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
-// assume that batch_size < 16 for now
-
-torch::Tensor awq_gemm(
-    torch::Tensor _in_feats,
-    torch::Tensor _kernel,
-    torch::Tensor _scaling_factors,
-    torch::Tensor _zeros,
-    int split_k_iters)
-{
-    int num_in_feats = _in_feats.size(0);
-    int num_in_channels = _in_feats.size(1);
-    const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
-
-    auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
-    at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
-    int num_out_feats = _out_feats.size(-2);
-    int num_out_channels = _out_feats.size(-1);
-
-    auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
-    auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
-    auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
-    auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
-    auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
-    int group_size = num_in_channels / _scaling_factors.size(0);
-
-    if (num_out_channels % 64 != 0)
-        throw std::invalid_argument("OC is not multiple of cta_N = 64");
-    if (num_out_channels % 8 != 0)
-        throw std::invalid_argument("OC is not multiple of pack_num = 8");
-    if (group_size % 32 != 0)
-	      throw std::invalid_argument("Group size should be a multiple of 32");
-    if (num_out_channels % group_size != 0)
-        throw std::invalid_argument("OC is not multiple of Group size");
-
-    if (num_out_channels % 128 == 0)
-    {
-        int j_factors1 = num_out_channels / 128 / 1;
-        dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
-        // threadIdx.x: 32
-        // threadIdx.y: i_factors[2] * j_factors[2]
-        dim3 threads_per_block(32, 2);
-        gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
-            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
-    }
-    else if (num_out_channels % 64 == 0)
-    {
-	int j_factors1 = num_out_channels / 64 / 1;
-        dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
-    
-        // threadIdx.x: 32
-        // threadIdx.y: i_factors[2] * j_factors[2]
-        dim3 threads_per_block(32, 2);
-        gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
-            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
-    }
-    return _out_feats.sum(0);
-}

+ 0 - 13
setup.py

@@ -131,19 +131,6 @@ activation_extension = CUDAExtension(
 )
 ext_modules.append(activation_extension)
 
-# Quant kernels
-quantization_extension = CUDAExtension(
-    name="aphrodite.quantization_ops",
-    sources=[
-        "kernels/quantization.cpp",
-        "kernels/quantization/awq/gemm_kernels.cu",
-    ],
-    extra_compile_args={
-        "cxx": CXX_FLAGS,
-        "nvcc": NVCC_FLAGS,
-    },
-)
-ext_modules.append(quantization_extension)
 
 def get_path(*filepath) -> str:
     return os.path.join(ROOT_DIR, *filepath)