1
0
Эх сурвалжийг харах

feat: INT8 KV Cache (#298)

* add INT8 KV cache kernels

Co-authored-by: aniz1905@gmail.com <zhangpeng156@meituan.com>

* add kv quant utils for exporting scales

Co-authored-by: aniz1905@gmail.com <zhangpeng156@meituan.com>

* add args

* integrate into attention and model

* formatting

Co-authored-by: aniz1905@gmail.com <zhangpeng156@meituan.com>

---------

Co-authored-by: aniz1905@gmail.com <zhangpeng156@meituan.com>
AlpinDale 1 жил өмнө
parent
commit
9810daa699

+ 4 - 1
.gitignore

@@ -195,4 +195,7 @@ _build/
 
 # HIP files generated by PyTorch
 *.hip
-*_hip*
+*_hip*
+
+kv_cache_states/*
+quant_params/*

+ 6 - 2
aphrodite/common/config.py

@@ -73,7 +73,7 @@ class ModelConfig:
         trust_remote_code: bool,
         download_dir: Optional[str],
         load_format: str,
-        dtype: Union[str, torch.dtype],
+        dtype: str,
         seed: int,
         revision: Optional[str] = None,
         tokenizer_revision: Optional[str] = None,
@@ -362,6 +362,8 @@ class CacheConfig:
             Aphrodite execution.
         swap_space: Size of the CPU swap space per GPU (in GiB).
         cache_dtype: Data Type for KV cache storage.
+        cache_quant_params_path: Path to the scales and zero points
+            of KV cache quantization when cache_dtype is int8.
     """
 
     def __init__(
@@ -370,6 +372,7 @@ class CacheConfig:
         gpu_memory_utilization: float,
         swap_space: int,
         cache_dtype: str,
+        cache_quant_params_path: Optional[str] = None,
         sliding_window: Optional[int] = None,
         context_shift: bool = False,
     ) -> None:
@@ -378,6 +381,7 @@ class CacheConfig:
         self.swap_space_bytes = swap_space * _GB
         self.cache_dtype = cache_dtype
         self.sliding_window = sliding_window
+        self.cache_quant_params_path = cache_quant_params_path
         self.context_shift = context_shift
         self._verify_args()
         self._verify_cache_dtype()
@@ -393,7 +397,7 @@ class CacheConfig:
                 f"{self.gpu_memory_utilization}.")
 
     def _verify_cache_dtype(self) -> None:
-        if self.cache_dtype == "auto":
+        if self.cache_dtype in ["auto", "int8"]:
             pass
         elif self.cache_dtype == "fp8_e5m2":
             nvcc_cuda_version = get_nvcc_cuda_version()

+ 2 - 1
aphrodite/common/sequence.py

@@ -175,7 +175,8 @@ class Sequence:
         # TODO: The current hashing function is O(L^2). We should optimize
         # this in the future.
         num_tokens = self.num_hashed_tokens_of_block(logical_idx)
-        return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
+        return hash(
+            (tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
 
     def num_hashed_tokens_of_block(self, logical_idx: int):
         return logical_idx * self.block_size + self.block_size

+ 30 - 23
aphrodite/common/utils.py

@@ -24,6 +24,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
     "bfloat16": torch.bfloat16,
     "float": torch.float,
     "fp8_e5m2": torch.uint8,
+    "int8": torch.int8,
 }
 
 
@@ -168,24 +169,20 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
     os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
 
 
-def get_nvcc_cuda_version() -> Version:
-    cuda_home = os.environ.get("CUDA_HOME")
+def get_nvcc_cuda_version() -> Optional[Version]:
+    cuda_home = os.environ.get('CUDA_HOME')
     if not cuda_home:
-        cuda_home = "/usr/local/cuda"
-        logger.info(
-            f"CUDA_HOME is not found in the environment. Using {cuda_home} as "
-            "CUDA_HOME.")
-    try:
-        nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
-                                              universal_newlines=True)
-    except FileNotFoundError:
-        print("nvcc is not found. Please make sure to export CUDA_HOME.")
-        return Version("0.0.0")  # return a default Version object
-    except subprocess.CalledProcessError:
-        print("An error occurred while trying to get nvcc output. Please "
-              "make sure to export CUDA_HOME.")
-        return Version("0.0.0")
-
+        cuda_home = '/usr/local/cuda'
+        if os.path.isfile(cuda_home + '/bin/nvcc'):
+            logger.info(
+                f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.'
+            )
+        else:
+            logger.warning(
+                f'Not found nvcc in {cuda_home}. Skipping cuda version check!')
+            return None
+    nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
+                                          universal_newlines=True)
     output = nvcc_output.split()
     release_idx = output.index("release") + 1
     nvcc_cuda_version = parse(output[release_idx].split(",")[0])
@@ -254,10 +251,15 @@ def create_kv_caches_with_random(
         key_cache = torch.empty(size=key_cache_shape,
                                 dtype=torch_dtype,
                                 device=device)
-        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
-            key_cache.uniform_(-scale, scale)
-        elif cache_dtype == "fp8_e5m2":
+        if cache_dtype == 'fp8_e5m2':
             _generate_random_fp8_e5m2(key_cache, -scale, scale)
+        elif cache_dtype == 'int8':
+            torch.randint(-128, 127, key_cache.size(), out=key_cache)
+        elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
+            key_cache.uniform_(-scale, scale)
+        else:
+            raise ValueError(
+                f"Does not support key cache of type {cache_dtype}")
         key_caches.append(key_cache)
 
     value_cache_shape = (num_blocks, num_heads, head_size, block_size)
@@ -266,9 +268,14 @@ def create_kv_caches_with_random(
         value_cache = torch.empty(size=value_cache_shape,
                                   dtype=torch_dtype,
                                   device=device)
-        if cache_dtype in ["auto", "half", "bfloat16", "float"]:
-            value_cache.uniform_(-scale, scale)
-        elif cache_dtype == "fp8_e5m2":
+        if cache_dtype == 'fp8_e5m2':
             _generate_random_fp8_e5m2(value_cache, -scale, scale)
+        elif cache_dtype == 'int8':
+            torch.randint(-128, 127, value_cache.size(), out=value_cache)
+        elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
+            value_cache.uniform_(-scale, scale)
+        else:
+            raise ValueError(
+                f"Does not support value cache of type {cache_dtype}")
         value_caches.append(value_cache)
     return key_caches, value_caches

+ 5 - 0
aphrodite/engine/aphrodite_engine.py

@@ -87,6 +87,7 @@ class AphroditeEngine:
             f"Context Length = {model_config.max_model_len}\n"
             f"Enforce Eager Mode = {model_config.enforce_eager}\n"
             f"KV Cache Data Type = {cache_config.cache_dtype}\n"
+            f"KV Cache Params Path = {cache_config.cache_quant_params_path}\n"
             f"Device = {device_config.device}")
         # TODO: Print more configs in debug mode.
 
@@ -148,6 +149,7 @@ class AphroditeEngine:
             distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
             kv_cache_dtype=self.cache_config.cache_dtype,
+            kv_quant_params_path=(self.cache_config.cache_quant_params_path),
             is_driver_worker=True,
         )
         self._run_workers("init_model")
@@ -256,6 +258,8 @@ class AphroditeEngine:
                     distributed_init_method,
                     lora_config=self.lora_config,
                     kv_cache_dtype=self.cache_config.cache_dtype,
+                    kv_quant_params_path=
+                    (self.cache_config.cache_quant_params_path),
                 ))
 
         driver_rank = 0
@@ -270,6 +274,7 @@ class AphroditeEngine:
             distributed_init_method,
             lora_config=self.lora_config,
             kv_cache_dtype=self.cache_config.cache_dtype,
+            kv_quant_params_path=(self.cache_config.cache_quant_params_path),
             is_driver_worker=True,
         )
 

+ 10 - 1
aphrodite/engine/args_tools.py

@@ -18,6 +18,7 @@ class EngineArgs:
     load_format: str = 'auto'
     dtype: str = 'auto'
     kv_cache_dtype: str = 'auto'
+    kv_quant_params_path: str = None
     seed: int = 0
     max_model_len: Optional[int] = None
     worker_use_ray: bool = False
@@ -132,11 +133,18 @@ class EngineArgs:
         parser.add_argument(
             '--kv-cache-dtype',
             type=str,
-            choices=['auto', 'fp8_e5m2'],
+            choices=['auto', 'fp8_e5m2', 'int8'],
             default=EngineArgs.kv_cache_dtype,
             help='Data type for kv cache storage. If "auto", will use model '
             'data type. Note FP8 is not supported when cuda version is '
             'lower than 11.8.')
+        parser.add_argument(
+            '--kv-quant-params-path',
+            type=str,
+            default=EngineArgs.kv_quant_params_path,
+            help='Path to scales and zero points of KV cache '
+            'quantization. Only applicable when kv-cache-dtype '
+            'is int8.')
         parser.add_argument('--max-model-len',
                             type=int,
                             default=EngineArgs.max_model_len,
@@ -317,6 +325,7 @@ class EngineArgs:
         cache_config = CacheConfig(self.block_size,
                                    self.gpu_memory_utilization,
                                    self.swap_space, self.kv_cache_dtype,
+                                   self.kv_quant_params_path,
                                    model_config.get_sliding_window(),
                                    self.context_shift)
         parallel_config = ParallelConfig(self.pipeline_parallel_size,

+ 0 - 0
aphrodite/kv_quant/__init__.py


+ 303 - 0
aphrodite/kv_quant/calib_dataloader.py

@@ -0,0 +1,303 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import numpy as np
+import torch
+
+
+def set_seed(seed):
+    np.random.seed(seed)
+    torch.random.manual_seed(seed)
+
+
+def get_wikitext2(tokenizer, nsamples, seed, seqlen, path=None):
+    """Load Wikitext-2 train and test datasets and tokenize.
+    Args:
+        tokenizer: Tokenizer to encode text.
+        nsamples: Number of samples to take from train set.
+        seed: Random seed for sampling.
+        seqlen: Maximum sequence length.
+    Returns:
+        train_loader: List of sampled and tokenized training examples.
+        test_enc: Full tokenized Wikitext-2 test set.
+    """
+    from datasets import load_dataset
+    traindata = load_dataset(path if path else 'wikitext',
+                             'wikitext-2-raw-v1',
+                             split='train')
+    testdata = load_dataset(path if path else 'wikitext',
+                            'wikitext-2-raw-v1',
+                            split='test')
+
+    trainenc = tokenizer('\n\n'.join(traindata['text']), return_tensors='pt')
+    testenc = tokenizer('\n\n'.join(testdata['text']), return_tensors='pt')
+
+    import random
+    random.seed(seed)
+    trainloader = []
+    for _ in range(nsamples):
+        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
+        j = i + seqlen
+        inp = trainenc.input_ids[:, i:j]
+        tar = inp.clone()
+        tar[:, :-1] = -100
+        trainloader.append((inp, tar))
+    return trainloader, testenc
+
+
+def get_ptb(tokenizer, nsamples, seed, seqlen):
+    """Load PTB train and validation datasets and tokenize.
+    Args:
+        tokenizer: Tokenizer to encode text.
+        nsamples: Number of samples to take from train set.
+        seed: Random seed for sampling.
+        seqlen: Maximum sequence length.
+    Returns:
+        train_loader: List of sampled and tokenized training examples.
+        test_enc: Full tokenized PTB validation set.
+    """
+    from datasets import load_dataset
+    traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
+    valdata = load_dataset('ptb_text_only',
+                           'penn_treebank',
+                           split='validation')
+
+    trainenc = tokenizer('\n\n'.join(traindata['sentence']),
+                         return_tensors='pt')
+    testenc = tokenizer('\n\n'.join(valdata['sentence']), return_tensors='pt')
+
+    import random
+    random.seed(seed)
+    trainloader = []
+    for _ in range(nsamples):
+        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
+        j = i + seqlen
+        inp = trainenc.input_ids[:, i:j]
+        tar = inp.clone()
+        tar[:, :-1] = -100
+        trainloader.append((inp, tar))
+    return trainloader, testenc
+
+
+def get_c4(tokenizer, nsamples, seed, seqlen, path=None):
+    """Load C4 train and validation datasets and tokenize.
+    Args:
+        tokenizer: Tokenizer to encode text.
+        nsamples: Number of samples to take from train set.
+        seed: Random seed for sampling.
+        seqlen: Maximum sequence length.
+    Returns:
+        train_loader: List of sampled and tokenized training examples.
+        test_enc: Full tokenized PTB validation set.
+    """
+    from datasets import load_dataset
+    traindata = load_dataset(
+        path if path else 'allenai/c4',
+        'allenai--c4',
+        data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
+        split='train',
+        use_auth_token=False)
+    valdata = load_dataset(
+        path if path else 'allenai/c4',
+        'allenai--c4',
+        data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
+        split='validation',
+        use_auth_token=False)
+
+    import random
+    random.seed(seed)
+    trainloader = []
+    for _ in range(nsamples):
+        while True:
+            i = random.randint(0, len(traindata) - 1)
+            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
+            if trainenc.input_ids.shape[1] >= seqlen:
+                break
+        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
+        j = i + seqlen
+        inp = trainenc.input_ids[:, i:j]
+        tar = inp.clone()
+        tar[:, :-1] = -100
+        trainloader.append((inp, tar))
+
+    valenc = []
+    for _ in range(256):
+        while True:
+            i = random.randint(0, len(valdata) - 1)
+            tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
+            if tmp.input_ids.shape[1] >= seqlen:
+                break
+        i = random.randint(0, tmp.input_ids.shape[1] - seqlen)
+        j = i + seqlen
+        valenc.append(tmp.input_ids[:, i:j])
+    valenc = torch.hstack(valenc)
+
+    class TokenizerWrapper:
+
+        def __init__(self, input_ids):
+            self.input_ids = input_ids
+
+    valenc = TokenizerWrapper(valenc)
+
+    return trainloader, valenc
+
+
+def get_ptb_new(tokenizer, nsamples, seed, seqlen):
+    """Load PTB New train and validation datasets and tokenize.
+    Args:
+        tokenizer: Tokenizer to encode text.
+        nsamples: Number of samples to take from train set.
+        seed: Random seed for sampling.
+        seqlen: Maximum sequence length.
+    Returns:
+        train_loader: List of sampled and tokenized training examples.
+        test_enc: Full tokenized PTB validation set.
+    """
+    from datasets import load_dataset
+    traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
+    testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
+
+    trainenc = tokenizer(' '.join(traindata['sentence']), return_tensors='pt')
+    testenc = tokenizer(' '.join(testdata['sentence']), return_tensors='pt')
+
+    import random
+    random.seed(seed)
+    trainloader = []
+    for _ in range(nsamples):
+        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
+        j = i + seqlen
+        inp = trainenc.input_ids[:, i:j]
+        tar = inp.clone()
+        tar[:, :-1] = -100
+        trainloader.append((inp, tar))
+    return trainloader, testenc
+
+
+def get_c4_new(tokenizer, nsamples, seed, seqlen):
+    """Load C4 New train and validation datasets and tokenize.
+    Args:
+        tokenizer: Tokenizer to encode text.
+        nsamples: Number of samples to take from train set.
+        seed: Random seed for sampling.
+        seqlen: Maximum sequence length.
+    Returns:
+        train_loader: List of sampled and tokenized training examples.
+        test_enc: Full tokenized PTB validation set.
+    """
+    from datasets import load_dataset
+    traindata = load_dataset(
+        'allenai/c4',
+        'allenai--c4',
+        data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
+        split='train')
+    valdata = load_dataset(
+        'allenai/c4',
+        'allenai--c4',
+        data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'},
+        split='validation')
+
+    import random
+    random.seed(seed)
+    trainloader = []
+    for _ in range(nsamples):
+        while True:
+            i = random.randint(0, len(traindata) - 1)
+            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
+            if trainenc.input_ids.shape[1] >= seqlen:
+                break
+        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen)
+        j = i + seqlen
+        inp = trainenc.input_ids[:, i:j]
+        tar = inp.clone()
+        tar[:, :-1] = -100
+        trainloader.append((inp, tar))
+
+    valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
+    valenc = valenc.input_ids[:, :(256 * seqlen)]
+
+    class TokenizerWrapper:
+
+        def __init__(self, input_ids):
+            self.input_ids = input_ids
+
+    valenc = TokenizerWrapper(valenc)
+
+    return trainloader, valenc
+
+
+def get_pileval(tokenizer, nsamples, seed, path, seqlen=512):
+    """Load pileval train dataset and tokenize.
+    Args:
+        tokenizer: Tokenizer to encode text.
+        nsamples: Number of samples to take from train set.
+        seed: Random seed for sampling.
+        seqlen: Maximum sequence length.
+    Returns:
+        train_loader: List of sampled and tokenized training examples.
+        test_enc: Full tokenized PTB validation set.
+    """
+    from datasets import load_dataset
+    from datasets.builder import DatasetGenerationError
+    try:
+        dataset = load_dataset('json', data_files=path, split='train')
+    except DatasetGenerationError as err:
+        raise InterruptedError('There have been some issues when generating '
+                               'the dataset, you could try to download it '
+                               'locally first, and replace the `data_files`'
+                               'with local addresses or use other datasets '
+                               '(c4, wiki, ptb).') from err
+    dataset = dataset.shuffle(seed=seed)
+    samples = []
+    n_run = 0
+    for data in dataset:
+        line = data['text']
+        line = line.strip()
+        line_encoded = tokenizer.encode(line)
+        if len(line_encoded) > 512:
+            continue
+        sample = torch.tensor([line_encoded])
+        if sample.numel() == 0:
+            continue
+        samples.append(sample)
+        n_run += 1
+        if n_run == nsamples:
+            break
+    # now concatenate all samples and split according to block size
+    cat_samples = torch.cat(samples, dim=1)
+    n_split = cat_samples.shape[1] // seqlen
+    print(f' * Split into {n_split} blocks')
+    return [
+        cat_samples[:, i * seqlen:(i + 1) * seqlen] for i in range(n_split)
+    ], None
+
+
+def get_calib_loaders(name,
+                      tokenizer,
+                      nsamples=128,
+                      seed=0,
+                      seqlen=2048,
+                      path=None):
+    """Get calibration data loaders for a dataset.
+    Args:
+      name: Dataset name ('wikitext2', 'ptb', 'c4', etc).
+      tokenizer: Tokenizer to encode text.
+      nsamples: Number of samples to take from train set.
+      seed: Random seed for sampling.
+      seqlen: Maximum sequence length.
+    Returns:
+      train_loader: List of sampled and tokenized training examples.
+      test_data: Full tokenized validation set.
+    """
+    if 'wikitext2' in name:
+        return get_wikitext2(tokenizer, nsamples, seed, seqlen, path)
+    if 'ptb' in name:
+        if 'new' in name:
+            return get_ptb_new(tokenizer, nsamples, seed, seqlen)
+        return get_ptb(tokenizer, nsamples, seed, seqlen)
+    if 'c4' in name:
+        if 'new' in name:
+            return get_c4_new(tokenizer, nsamples, seed, seqlen)
+        return get_c4(tokenizer, nsamples, seed, seqlen, path)
+
+    if 'pileval' in name:
+        if path is None:
+            path = 'https://the-eye.eu/public/AI/pile/val.jsonl.zst'
+        return get_pileval(tokenizer, nsamples, seed, path, seqlen)

+ 112 - 0
aphrodite/kv_quant/calibrate.py

@@ -0,0 +1,112 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+from pathlib import Path
+
+import fire
+import torch
+from accelerate import (infer_auto_device_map, init_empty_weights,
+                        load_checkpoint_in_model)
+from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
+
+from aphrodite.kv_quant.calibration import CalibrationContext
+from aphrodite.kv_quant.utils import collect_target_modules
+from aphrodite.kv_quant.calib_dataloader import get_calib_loaders
+
+LAYER_TYPE_MAP = {
+    'InternLMForCausalLM': 'InternLMDecoderLayer',
+    'QWenLMHeadModel': 'QWenBlock',
+    'BaiChuanForCausalLM': 'DecoderLayer',
+    'LlamaForCausalLM': 'LlamaDecoderLayer',
+}
+NORM_TYPE_MAP = {
+    'InternLMForCausalLM': 'InternLMRMSNorm',
+    'QWenLMHeadModel': 'RMSNorm',
+    'BaiChuanForCausalLM': 'RMSNorm',
+    'LlamaForCausalLM': 'LlamaRMSNorm',
+}
+
+
+def calibrate(model: str,
+              calib_dataset: str = 'c4',
+              calib_samples: int = 128,
+              calib_seqlen: int = 2048,
+              work_dir: str = './work_dir',
+              device: str = 'cuda',
+              dataset_path: str = None) -> None:
+    """The main function for loading the model and performing calibration on a
+    given dataset.
+    Args:
+        model (str): The model to be loaded.
+        calib_dataset (str, optional): The calibration dataset name.
+            Defaults to 'c4'.
+        calib_samples (int, optional): The number of samples for calibration.
+            Defaults to 128.
+        calib_seqlen (int, optional): The sequence length for calibration.
+            Defaults to 2048.
+        work_dir (str): The working directory for outputs.
+            Defaults to './work_dir'.
+        device (str, optional): The device to be used for calculation.
+            Defaults to 'cuda'.
+    """
+
+    assert calib_dataset in ['c4', 'ptb', 'wikitext2', 'pileval'], \
+        'Support only `c4`, `ptb`, `wikitext2` or `pileval`.'
+
+    # Load tokenizer and configuration
+    tokenizer = AutoTokenizer.from_pretrained(model,
+                                              use_fast=False,
+                                              trust_remote_code=True)
+    hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
+    checkpoint = hf_config._name_or_path
+
+    with init_empty_weights():
+        # Load model
+        model = AutoModelForCausalLM.from_pretrained(model,
+                                                     torch_dtype=torch.float16,
+                                                     trust_remote_code=True)
+        model.config.use_cache = False
+
+    layer_type = LAYER_TYPE_MAP[type(model).__name__]
+    norm_type = NORM_TYPE_MAP[type(model).__name__]
+
+    decoder_layers = collect_target_modules(model, layer_type)
+
+    # Infer device map
+    device_map = infer_auto_device_map(model,
+                                       no_split_module_classes=[layer_type])
+    for name in device_map:
+        if name in decoder_layers or 'lm_head' in name:
+            device_map[name] = 'cpu'
+        else:
+            device_map[name] = 0
+    load_checkpoint_in_model(model, checkpoint, device_map)
+
+    print('Loading calibrate dataset ...')
+    calib_loader, _ = get_calib_loaders(calib_dataset,
+                                        tokenizer,
+                                        nsamples=calib_samples,
+                                        seqlen=calib_seqlen,
+                                        path=dataset_path)
+
+    # Initialize calibration context
+    calib_ctx = CalibrationContext(model,
+                                   tokenizer,
+                                   layer_type=layer_type,
+                                   norm_type=norm_type,
+                                   device=device)
+
+    with calib_ctx:
+        all_data = torch.cat([
+            data if isinstance(data, torch.Tensor) else data[0]
+            for data in calib_loader
+        ]).to(device)
+        calib_ctx.calibrate(all_data)
+
+    # Create work directory if not exists
+    work_dir = Path(work_dir)
+    work_dir.mkdir(parents=True, exist_ok=True)
+    calib_ctx.export(work_dir)
+
+
+if __name__ == '__main__':
+    fire.Fire(calibrate)

+ 323 - 0
aphrodite/kv_quant/calibration.py

@@ -0,0 +1,323 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from functools import partial
+from typing import Union
+
+import torch
+from torch import nn
+import transformers
+from transformers import PreTrainedTokenizer
+from pkg_resources import parse_version
+from aphrodite.kv_quant.utils import (bimap_name_mod, collect_target_modules,
+                                      concat_decoder_layer_outputs,
+                                      split_decoder_layer_inputs)
+from aphrodite.kv_quant.observer import ActivationObserver, KVCacheObserver
+
+
+class CalibrationContext():
+    """Calibration context manager for model quantization.
+    Parameters:
+      - model: The target model to be calibrated and quantized
+      - tokenizer: The tokenizer used in the model training
+      - layer_type: Layer type to be targeted for calibration
+      - norm_type: Normalization type used for calibration
+      - device: Device on which model is to be calibrated ('cpu' or 'cuda')
+    """
+
+    inp_obs_group = 'inputs'
+    out_obs_group = 'outputs'
+    key_obs_group = 'keys'
+    value_obs_group = 'values'
+
+    def __init__(self,
+                 model: nn.Module,
+                 tokenizer: PreTrainedTokenizer,
+                 layer_type: Union[str, type],
+                 norm_type: Union[str, type],
+                 device: str = 'cuda') -> None:
+        """Initiate calibration context.
+        Args:
+            model (nn.Module): Model to be calibrated.
+            tokenizer (PreTrainedTokenizer): Tokenizer of the given model.
+            layer_type (Union[str, type]): Type of the layers to be observed.
+            norm_type (Union[str, type]): Norm type used in the model.
+            device (str, optional): Device where the model should run.
+                Defaults to 'cuda'.
+        """
+
+        self.layer_type = layer_type
+        self.norm_type = norm_type
+
+        num_kv_heads, num_attn_heads = self._guess_num_heads(model)
+        self.num_kv_heads = num_kv_heads
+        self.head_dim = model.config.hidden_size // num_attn_heads
+        self.model = model
+        del self.model.lm_head
+
+        self.tokenizer = tokenizer
+
+        # Collect modules to observe
+        self.name2layer = collect_target_modules(self.model, layer_type)
+        self.name2fc = {}
+        for l_name, layer in self.name2layer.items():
+            name2fc = collect_target_modules(layer, nn.Linear, prefix=l_name)
+            self.name2fc.update(name2fc)
+        self.name2norm = collect_target_modules(self.model, norm_type)
+
+        maps = bimap_name_mod([self.name2layer, self.name2fc, self.name2norm])
+        self.name2mod, self.mod2name = maps
+
+        # Initialize observers
+        self._init_input_observers(self.name2fc)
+        self._init_output_observers(self.name2norm)
+        self._init_output_observers(self.name2fc)
+        self._init_kv_observers(self.name2layer)
+
+        self.device = device
+
+    def _guess_num_heads(self, model):
+
+        if hasattr(model.config, 'num_key_value_heads'):
+            num_kv_heads = model.config.num_key_value_heads
+        else:
+            num_kv_heads = model.config.num_attention_heads
+
+        num_attn_heads = model.config.num_attention_heads
+
+        return num_kv_heads, num_attn_heads
+
+    def _init_input_observers(self, name2mod):
+        """Initialize input observers for given modules."""
+        for name, mod in name2mod.items():
+            obs = ActivationObserver(mod.weight.size(-1))
+            obs.global_available(name, group=self.inp_obs_group)
+
+    def _init_output_observers(self, name2mod):
+        """Initialize output observers for given modules."""
+        for name, mod in name2mod.items():
+            obs = ActivationObserver(mod.weight.size(0))
+            obs.global_available(name, group=self.out_obs_group)
+
+    def _init_kv_observers(self, name2mod):
+        """Initialize KV observers for given modules."""
+        for name in name2mod:
+            k_obs = KVCacheObserver(self.num_kv_heads, self.head_dim)
+            v_obs = KVCacheObserver(self.num_kv_heads, self.head_dim)
+            k_obs.global_available(name, group=self.key_obs_group)
+            v_obs.global_available(name, group=self.value_obs_group)
+
+    def _insert_input_observers(self):
+        """Insert input observers into the target modules.
+        This function registers a forward pre-hook on each target module to
+        observe the inputs.
+        """
+
+        def _input_hook(mod: nn.Module, inp: torch.Tensor):
+            m_name = self.mod2name[mod]
+            obs = ActivationObserver.find(m_name, group=self.inp_obs_group)
+            obs.observe(inp[0])
+
+        group = ActivationObserver.find_group(self.inp_obs_group)
+        for name in group:
+            mod = self.name2mod[name]
+            hook_fn = mod.register_forward_pre_hook(_input_hook)
+            self._hooks.append(hook_fn)
+
+    def _insert_output_observers(self):
+        """Insert output observers into the target modules.
+        This function registers a forward hook on each target module to observe
+        the outputs.
+        """
+
+        def _output_hook(mod: nn.Module, inp: torch.Tensor, out: torch.Tensor):
+            m_name = self.mod2name[mod]
+            obs = ActivationObserver.find(m_name, group=self.out_obs_group)
+            obs.observe(out)
+
+        group = ActivationObserver.find_group(self.out_obs_group)
+        for name in group:
+            mod = self.name2mod[name]
+            hook_fn = mod.register_forward_hook(_output_hook)
+            self._hooks.append(hook_fn)
+
+    def _wrap_decoder_layers(self):
+        """Method to wrap the decoder layers' forward functions for observing
+        their key/value cache during batched forward passes."""
+
+        def _forward(mod, *args, **kwargs):
+
+            mod.to(self.device)
+            batch_args, batch_kwargs = split_decoder_layer_inputs(
+                *args, **kwargs)
+            batch_outputs = []
+            samples = len(batch_args)
+
+            m_name = self.mod2name[mod]
+            k_obs = KVCacheObserver.find(m_name, group=self.key_obs_group)
+            v_obs = KVCacheObserver.find(m_name, group=self.value_obs_group)
+
+            for i in range(len(batch_args)):
+
+                if k_obs and v_obs:
+                    batch_kwargs[i]['use_cache'] = True
+                    version = parse_version(transformers.__version__)
+                    use_new_cache = type(mod).__name__ == 'LlamaDecoderLayer'
+                    if version > parse_version('4.36.0') and use_new_cache:
+                        from transformers.cache_utils import DynamicCache
+                        batch_kwargs[i]['past_key_value'] = DynamicCache()
+
+                        ori_idx = mod.self_attn.layer_idx
+                        mod.self_attn.layer_idx = 0
+
+                        out = self._ori_forwards[mod](*batch_args[i],
+                                                      **batch_kwargs[i])
+                        mod.self_attn.layer_idx = ori_idx
+
+                        out = list(out)
+                        cache = out.pop(-1)
+
+                        key = cache.key_cache.pop(-1)
+                        value = cache.value_cache.pop(-1)
+
+                        k_obs.observe(key)
+                        v_obs.observe(value)
+                    else:
+                        out = self._ori_forwards[mod](*batch_args[i],
+                                                      **batch_kwargs[i])
+                        out = list(out)
+                        key, value = out.pop(-1)
+                        k_obs.observe(key)
+                        v_obs.observe(value)
+
+                    del key, value
+                    torch.cuda.empty_cache()
+                    batch_outputs.append(tuple(out))
+                else:
+                    batch_outputs.append(self._ori_forwards[mod](
+                        *batch_args[i], **batch_kwargs[i]))
+
+            outputs = concat_decoder_layer_outputs(batch_outputs)
+
+            del batch_outputs, batch_args, batch_kwargs, args
+            mod.to('cpu')
+            torch.cuda.empty_cache()
+            max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
+            print(f'{m_name}, samples: {samples}, '
+                  f'max gpu memory: {max_memory:.2f} GB')
+            return outputs
+
+        for layer in self.name2layer.values():
+            self._ori_forwards[layer] = layer.forward
+            layer.forward = partial(_forward, layer)
+
+    def collect_inputs_stats(self):
+        """Collect statistics (min, max, absmax values) of the observed inputs.
+        Returns a dictionary with these collected stats.
+        """
+        inputs_stats = {
+            'max': {},
+            'min': {},
+            'mean': {},
+            'absmax': {},
+            'absmean': {}
+        }
+        obs_group = ActivationObserver.find_group(self.inp_obs_group)
+        for name, obs in obs_group.items():
+            inputs_stats['max'][name] = obs.max_val
+            inputs_stats['min'][name] = obs.min_val
+            inputs_stats['mean'][name] = obs.mean_val
+            inputs_stats['absmax'][name] = obs.absmax_val
+            inputs_stats['absmean'][name] = obs.absmean_val
+        return inputs_stats
+
+    def collect_outputs_stats(self):
+        """Collect statistics (min, max, absmax values) of the observed
+        outputs.
+        Returns a dictionary with these collected stats.
+        """
+        outputs_stats = {
+            'max': {},
+            'min': {},
+            'mean': {},
+            'absmax': {},
+            'absmean': {}
+        }
+        obs_group = ActivationObserver.find_group(self.out_obs_group)
+        for name, obs in obs_group.items():
+            outputs_stats['max'][name] = obs.max_val
+            outputs_stats['min'][name] = obs.min_val
+            outputs_stats['mean'][name] = obs.mean_val
+            outputs_stats['absmax'][name] = obs.absmax_val
+            outputs_stats['absmean'][name] = obs.absmean_val
+        return outputs_stats
+
+    def collect_kv_stats(self):
+        """Collect statistics (min, max, absmax values) of the observed keys
+        and values.
+        Returns a tuple of two dictionaries with these collected stats.
+        """
+        key_stats = {'max': {}, 'min': {}, 'absmax': {}}
+        obs_group = KVCacheObserver.find_group(self.key_obs_group)
+        for name, obs in obs_group.items():
+            key_stats['max'][name] = obs.max_val
+            key_stats['min'][name] = obs.min_val
+            key_stats['absmax'][name] = obs.absmax_val
+
+        value_stats = {'max': {}, 'min': {}, 'absmax': {}}
+        obs_group = KVCacheObserver.find_group(self.value_obs_group)
+        for name, obs in obs_group.items():
+            value_stats['max'][name] = obs.max_val
+            value_stats['min'][name] = obs.min_val
+            value_stats['absmax'][name] = obs.absmax_val
+        return key_stats, value_stats
+
+    def export(self, out_dir):
+        """Export the calibration statistics (inputs, outputs, keys and values)
+        to specified directory.
+        Args:
+            out_dir (Union[str, Path]): The directory path where the stats
+                will be saved.
+        """
+
+        inp_stats = self.collect_inputs_stats()
+        torch.save(inp_stats, out_dir / 'inputs_stats.pth')
+
+        out_stats = self.collect_outputs_stats()
+        torch.save(out_stats, out_dir / 'outputs_stats.pth')
+
+        key_stats, value_stats = self.collect_kv_stats()
+        torch.save(key_stats, out_dir / 'key_stats.pth')
+        torch.save(value_stats, out_dir / 'value_stats.pth')
+
+    def calibrate(self, data):
+        """Forward pass through the model in inference mode with given data."""
+
+        if type(self.model).__name__ == 'QWenLMHeadModel':
+            model = self.model.transformer
+        else:
+            model = self.model.model
+        with torch.inference_mode():
+            _ = model(data.to(self.device))
+
+    def __enter__(self):
+        """Prepares the Calibration object for a 'with' statement by
+        registering hooks and wrapping layer forward methods."""
+
+        self._hooks = list()
+
+        self._ori_forwards = {}
+        for layer in self.name2layer.values():
+            self._ori_forwards[layer] = layer.forward
+
+        self._insert_input_observers()
+        self._insert_output_observers()
+        self._wrap_decoder_layers()
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        """Clean up after a 'with' statement by removing registered hooks,
+        restoring original forward methods, and if no exception occurred,
+        collecting all gathered statistics and saving them."""
+        for h in self._hooks:
+            h.remove()
+
+        for layer in self.name2layer.values():
+            layer.forward = self._ori_forwards[layer]

+ 122 - 0
aphrodite/kv_quant/export_kv_params.py

@@ -0,0 +1,122 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from pathlib import Path
+from typing import Union
+
+import numpy as np
+import torch
+import fire
+
+
+def _export_sym(key_stats: dict,
+                value_stats: dict,
+                bits: int,
+                out_dir: Union[str, Path],
+                tp: int = 1) -> None:
+    """Export symmetric quantization parameters to specified directory."""
+    keys_absmax = key_stats['absmax']
+    values_absmax = value_stats['absmax']
+    for layer_idx, name in enumerate(keys_absmax.keys()):
+        k_absmax = keys_absmax[name]
+        v_absmax = values_absmax[name]
+
+        heads, _ = k_absmax.shape
+        assert heads % tp == 0
+
+        mp_k_absmax = torch.chunk(k_absmax, tp)
+        mp_v_absmax = torch.chunk(v_absmax, tp)
+        for i in range(tp):
+            # quant: q = f / scale
+            # dequant: f = q * scale
+            k_s = mp_k_absmax[i].max() / (2**(bits - 1) - 1)
+            v_s = mp_v_absmax[i].max() / (2**(bits - 1) - 1)
+
+            kv_qparams = np.array([k_s, v_s], dtype=np.float32)
+            out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight'  # noqa: E501
+            kv_qparams.tofile(out_path)
+            print(f'Layer {layer_idx} MP {i} qparam: {k_s} \t{v_s}')
+
+
+def _export_asym(key_stats: dict,
+                 value_stats: dict,
+                 bits: int,
+                 out_dir: Union[str, Path],
+                 tp: int = 1) -> None:
+    """Export asymmetric quantization parameters to specified directory."""
+    keys_min = key_stats['min']
+    values_min = value_stats['min']
+
+    keys_max = key_stats['max']
+    values_max = value_stats['max']
+    for layer_idx, name in enumerate(keys_min.keys()):
+        k_max = keys_max[name]
+        v_max = values_max[name]
+
+        k_min = keys_min[name]
+        v_min = values_min[name]
+
+        heads, _ = k_min.shape
+        assert heads % tp == 0
+
+        tp_k_min = torch.chunk(k_min, tp)
+        tp_v_min = torch.chunk(v_min, tp)
+
+        tp_k_max = torch.chunk(k_max, tp)
+        tp_v_max = torch.chunk(v_max, tp)
+        for i in range(tp):
+            # zp = (min+max) / 2
+            # scale = (max-min) / 255
+            # quant: q = (f-zp) / scale
+            # dequant: f = q * scale + zp
+            k_min = tp_k_min[i].min()
+            v_min = tp_v_min[i].min()
+
+            k_max = tp_k_max[i].max()
+            v_max = tp_v_max[i].max()
+
+            k_scale = (k_max - k_min) / (2**bits - 1)
+            v_scale = (v_max - v_min) / (2**bits - 1)
+
+            k_zp = (k_max + k_min) / 2
+            v_zp = (v_max + v_min) / 2
+
+            kv_qparams = np.array([k_scale, k_zp, v_scale, v_zp],
+                                  dtype=np.float32)
+            out_path = out_dir / f'layers.{layer_idx}.past_kv_scale.{i}.weight'
+            kv_qparams.tofile(out_path)
+            print(f'Layer {layer_idx} MP {i} qparam: '
+                  f'\t{k_scale} \t{k_zp} \t{v_scale} \t{v_zp}')
+
+
+def main(work_dir: str,
+         kv_params_dir: str,
+         kv_bits: int = 8,
+         kv_sym: bool = False,
+         num_tp: int = 1) -> None:
+    """Main function to export key and value stats.
+    Args:
+        work_dir (Union[str, Path]): Directory path where the stats are saved.
+        kv_params_dir (Union[str, Path]): Directory path where to
+            save the results.
+        kv_bits (int, optional): Number of bits for quantization.
+            Defaults to 8.
+        kv_sym (bool, optional): Whether to use symmetric quantizaiton.
+            Defaults to False.
+        num_tp (int, optional): Number of tensor parallelism. Defaults to 1.
+    """
+
+    work_dir = Path(work_dir)
+
+    tm_dir = Path(kv_params_dir)
+    tm_dir.mkdir(parents=True, exist_ok=True)
+
+    key_stats = torch.load(work_dir / 'key_stats.pth')
+    value_stats = torch.load(work_dir / 'value_stats.pth')
+
+    if kv_sym:
+        _export_sym(key_stats, value_stats, kv_bits, tm_dir, num_tp)
+    else:
+        _export_asym(key_stats, value_stats, kv_bits, tm_dir, num_tp)
+
+
+if __name__ == '__main__':
+    fire.Fire(main)

+ 180 - 0
aphrodite/kv_quant/observer.py

@@ -0,0 +1,180 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Dict, Union
+import torch
+from torch import nn
+
+
+class GlobalAvailMixin:
+    """Mixin class to make instances globally available."""
+
+    _instances: Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']] = {
+        'default': {}
+    }
+
+    def global_available(self,
+                         key: Union[str, nn.Module] = 'default',
+                         group: str = 'default') -> None:
+        """Make the instance globally available.
+        Args:
+            key (Union[str, nn.Module], optional): Key to save the instance.
+                Defaults to 'default'.
+            group (str, optional): Group to save the instance.
+                Defaults to 'default'.
+        """
+        self._save_instance(self, key, group)
+
+    @classmethod
+    def _save_instance(cls,
+                       instance: 'GlobalAvailMixin',
+                       key: Union[str, nn.Module] = 'default',
+                       group: str = 'default') -> None:
+        """Save the instance.
+        Args:
+            instance (GlobalAvailMixin): Instance to save.
+            key (Union[str, nn.Module], optional): Key to save the instance.
+                Defaults to 'default'.
+            group (str, optional): Group to save the instance.
+                Defaults to 'default'.
+        """
+        if group not in cls._instances:
+            assert isinstance(group, str)
+            cls._instances[group] = {}
+
+        cls._instances[group][key] = instance
+
+    @classmethod
+    def find(cls,
+             key: Union[str, nn.Module] = 'default',
+             group: str = 'default') -> Union[None, 'GlobalAvailMixin']:
+        """Find an instance by its key and group.
+        Args:
+            key (Union[str, nn.Module], optional): Key of the instance.
+                Defaults to 'default'.
+            group (str, optional): Group of the instance.
+                Defaults to 'default'.
+        Returns:
+            Union[None, GlobalAvailMixin]: The found instance, or None if
+                it does not exist.
+        """
+        return cls._instances.get(group, {}).get(key)
+
+    @classmethod
+    def find_group(
+            cls,
+            group: str) -> Dict[Union[str, nn.Module], 'GlobalAvailMixin']:
+        """Find all instances in a group.
+        Args:
+            group (str): Group of the instances.
+        Returns:
+            Dict[Union[str, nn.Module], GlobalAvailMixin]: All instances in
+                the group.
+        """
+        return cls._instances.get(group, {})
+
+    @classmethod
+    def instances(
+            cls) -> Dict[str, Dict[Union[str, nn.Module], 'GlobalAvailMixin']]:
+        """Get all instances."""
+        return cls._instances
+
+
+class KVCacheObserver(GlobalAvailMixin):
+    """A class to observe and record the max, min, and absolute max value of
+    given tensor."""
+
+    def __init__(self, num_head: int, head_dim: int) -> None:
+        """Constructor for KVCacheObserver.
+        Args:
+            num_head : Number of heads
+            head_dim : Dimension of each head
+        """
+        self.num_head = num_head
+        self.head_dim = head_dim
+        self.max_val = torch.full((num_head, head_dim),
+                                  -torch.inf,
+                                  dtype=torch.float16)
+        self.min_val = torch.full((num_head, head_dim),
+                                  torch.inf,
+                                  dtype=torch.float16)
+        self.absmax_val = torch.full((num_head, head_dim),
+                                     0,
+                                     dtype=torch.float16)
+
+    @torch.no_grad()
+    def observe(self, x: torch.Tensor) -> None:
+        """Function to observe the input tensor and update the max, min, and
+        absolute max values.
+        Args:
+            x : Input tensor
+        """
+        assert len(x.shape) == 4
+
+        if x.size(1) == self.num_head and x.size(3) == self.head_dim:
+            # layout: (bs, heads, seqlen, dims)
+            x = x.transpose(1, 2)
+        elif x.size(2) != self.num_head or x.size(3) != self.head_dim:
+            raise RuntimeError(
+                'Unexpected dimensions for x, expected (bs, num_head, seqlen, head_dim) or (bs, seqlen, num_head, head_dim)'
+            )
+
+        cur_max = x.flatten(0, 1).max(0)[0].cpu()
+        cur_min = x.flatten(0, 1).min(0)[0].cpu()
+        cur_absmax = x.flatten(0, 1).abs().max(0)[0].cpu()
+
+        self.max_val = torch.maximum(self.max_val, cur_max)
+        self.min_val = torch.minimum(self.min_val, cur_min)
+        self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)
+
+
+class ActivationObserver(GlobalAvailMixin):
+    """A class to observe and record the max, min, mean, absolute max, and
+    absolute mean value of a given tensor.
+    Also keeps track of the number of batches observed.
+    """
+
+    def __init__(self, dim: int) -> None:
+        """Constructor for ActivationObserver.
+        Args:
+            dim : Dimension of the tensor
+        """
+        self.dim = dim
+        self.max_val = torch.full((dim, ), -torch.inf, dtype=torch.float16)
+        self.min_val = torch.full((dim, ), torch.inf, dtype=torch.float16)
+        self.absmax_val = torch.full((dim, ), 0, dtype=torch.float16)
+        self.absmean_val = torch.full((dim, ), 0, dtype=torch.float16)
+        self.mean_val = torch.full((dim, ), 0, dtype=torch.float16)
+        self.num_batches_tracked = 0
+
+    @torch.no_grad()
+    def observe(self, x: torch.Tensor) -> None:
+        """Function to observe the input tensor and update the max, min, mean,
+        absolute max, absolute mean values and number of batches tracked.
+        Args:
+            x : Input tensor
+        """
+        assert len(x.shape) == 3
+        assert x.size(2) == self.dim
+        cur_val = x.flatten(0, 1)
+        cur_max = cur_val.max(0)[0].cpu()
+        cur_min = cur_val.min(0)[0].cpu()
+        cur_mean = cur_val.mean(0).cpu()
+
+        cur_abs = cur_val.abs()
+        cur_absmax = cur_abs.max(0)[0].cpu()
+        cur_absmean = cur_abs.mean(0).cpu()
+
+        self.max_val = torch.maximum(self.max_val, cur_max)
+        self.min_val = torch.minimum(self.min_val, cur_min)
+        self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)
+
+        # Update mean and absmean value with accumulated sum divided
+        # by total number of batches
+        self.mean_val = (
+            (self.mean_val * self.num_batches_tracked + cur_mean) /
+            (self.num_batches_tracked + 1))
+        self.absmean_val = (
+            (self.absmean_val * self.num_batches_tracked + cur_absmean) /
+            (self.num_batches_tracked + 1))
+
+        # Increment the count of batches tracked
+        self.num_batches_tracked += 1

+ 156 - 0
aphrodite/kv_quant/utils.py

@@ -0,0 +1,156 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Dict, List, Tuple, Union
+import torch
+from torch import nn
+
+
+def split_decoder_layer_inputs(
+    *args: Union[torch.Tensor, Any], **kwargs: Union[torch.Tensor, Any]
+) -> Tuple[List[List[Any]], List[Dict[str, Any]]]:
+    """This function splits batched decoder layer inputs into individual
+    elements.
+    Args:
+        *args (Union[torch.Tensor, Any]): Positional arguments which could
+            be a mix of tensors and other types.
+        **kwargs (Union[torch.Tensor, Any]): Keyword arguments which could
+            be a mix of tensors and other types.
+    Returns:
+        Tuple[List[List[Any]], List[Dict[str, Any]]]: A tuple containing two
+            lists, one for positional arguments, one for keyword arguments.
+            Each list contains individual elements from the batch.
+    """
+
+    if not isinstance(args[0], torch.Tensor):
+        raise ValueError('The first argument must be a Tensor')
+
+    bs = args[0].size(0)
+
+    batch_args = []
+    batch_kwargs = []
+    for i in range(bs):
+        new_args = []
+        # Iterate over each argument. If it's a torch.Tensor and its first
+        # dimension equals the batch size, then get the value corresponding
+        # to the current index, else directly add the whole value.
+        for val in args:
+            if isinstance(val, torch.Tensor) and val.size(0) == bs:
+                new_args.append(val[i:i + 1])
+            else:
+                new_args.append(val)
+
+        new_kwargs = {}
+        # Execute the same operation for the keyword arguments.
+        for name, val in kwargs.items():
+            if isinstance(val, torch.Tensor) and val.size(0) == bs:
+                new_kwargs[name] = val[i:i + 1]
+            else:
+                new_kwargs[name] = val
+
+        batch_args.append(new_args)
+        batch_kwargs.append(new_kwargs)
+
+    return batch_args, batch_kwargs
+
+
+def concat_decoder_layer_outputs(
+        batch_outputs: List[Tuple[Any]]) -> Tuple[Any]:
+    """This function concatenates individual decoder layer outputs into a
+    batched output.
+    Args:
+        batch_outputs (List[Tuple[Any]]): A list of tuples, where each tuple
+            represents the output from an individual element in the batch.
+    Returns:
+        Tuple[Any]: A tuple representing the batched output.
+    """
+
+    num_returns = len(batch_outputs[0])
+
+    def is_past_key_value(data: Any) -> bool:
+        """Check whether data is a past key-value pair.
+        Args:
+            data (Any): The data to check.
+        Returns:
+            bool: True if data is a past key-value pair, False otherwise.
+        """
+        flag = isinstance(data, tuple)
+        flag = flag and len(data) == 2
+        flag = flag and isinstance(data[0], torch.Tensor)
+        flag = flag and isinstance(data[1], torch.Tensor)
+        return flag
+
+    new_outputs = []
+
+    # Iterate over all types of return values.
+    for i in range(num_returns):
+        # Check if the current element is a past key-value pair.
+        flag = is_past_key_value(batch_outputs[0][i])
+        if flag:
+            # Concatenate the keys and values separately.
+            key = torch.cat([out[i][0] for out in batch_outputs])
+            value = torch.cat([out[i][1] for out in batch_outputs])
+            out_i = (key, value)
+        else:
+            # If it's not a past key-value pair, concatenate directly.
+            out_i = torch.cat([out[i] for out in batch_outputs])
+        new_outputs.append(out_i)
+
+    return tuple(new_outputs)
+
+
+def collect_target_modules(
+        model: nn.Module,
+        #    target: Union[str, type],
+        target: str,
+        skip_names: List[str] = None,
+        prefix: str = '') -> Dict[str, nn.Module]:
+    """Collects the specific target modules from the model.
+    Args:
+        model : The PyTorch module from which to collect the target modules.
+        target : The specific target to be collected. It can be a class of a
+            module or the name of a module.
+        skip_names : List of names of modules to be skipped during collection.
+        prefix : A string to be added as a prefix to the module names.
+    Returns:
+        A dictionary mapping from module names to module instances.
+    """
+
+    # if isinstance(target, LazyAttr):
+    #     target = target.build()
+    if skip_names is None:
+        skip_names = []
+    if not isinstance(target, (type, str)):
+        raise TypeError('Target must be a string (name of the module) '
+                        'or a type (class of the module)')
+
+    def _is_target(n, m):
+        if isinstance(target, str):
+            return target == type(m).__name__ and n not in skip_names
+        return isinstance(m, target) and n not in skip_names
+
+    name2mod = {}
+    for name, mod in model.named_modules():
+        m_name = f'{prefix}.{name}' if prefix else name
+        if _is_target(name, mod):
+            name2mod[m_name] = mod
+    return name2mod
+
+
+def bimap_name_mod(
+    name2mod_mappings: List[Dict[str, nn.Module]]
+) -> Tuple[Dict[str, nn.Module], Dict[nn.Module, str]]:
+    """Generates bidirectional maps from module names to module instances and
+    vice versa.
+    Args:
+        name2mod_mappings : List of dictionaries each mapping from module
+            names to module instances.
+    Returns:
+        Two dictionaries providing bidirectional mappings between module
+            names and module instances.
+    """
+
+    name2mod = {}
+    mod2name = {}
+    for mapping in name2mod_mappings:
+        mod2name.update({v: k for k, v in mapping.items()})
+        name2mod.update(mapping)
+    return name2mod, mod2name

+ 9 - 0
aphrodite/modeling/layers/attention.py

@@ -101,6 +101,7 @@ class PagedAttention(nn.Module):
         key_cache: Optional[torch.Tensor],
         value_cache: Optional[torch.Tensor],
         input_metadata: InputMetadata,
+        kv_quant_param: List[float] = None,
     ) -> torch.Tensor:
         """PagedAttention forward pass.
 
@@ -121,6 +122,9 @@ class PagedAttention(nn.Module):
         query = query.view(-1, self.num_heads, self.head_size)
         key = key.view(-1, self.num_kv_heads, self.head_size)
         value = value.view(-1, self.num_kv_heads, self.head_size)
+        # FIXME: Remove this when all models support int8 kv cache
+        kv_quant_param = [1.0, 0.0, 1.0, 0.0
+                          ] if kv_quant_param is None else kv_quant_param
 
         # Reshape the keys and values and store them in the cache.
         # If key_cache and value_cache are not provided, the new key and value
@@ -134,6 +138,7 @@ class PagedAttention(nn.Module):
                 value_cache,
                 input_metadata.slot_mapping.flatten(),
                 input_metadata.kv_cache_dtype,
+                *kv_quant_param,
             )
 
         if input_metadata.is_prompt:
@@ -230,6 +235,7 @@ class PagedAttention(nn.Module):
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
+                kv_quant_param,
             )
 
         # Reshape the output tensor.
@@ -278,6 +284,7 @@ def _paged_attention(
     num_kv_heads: int,
     scale: float,
     alibi_slopes: Optional[torch.Tensor],
+    kv_quant_param: List[float],
 ) -> torch.Tensor:
     output = torch.empty_like(query)
 
@@ -310,6 +317,7 @@ def _paged_attention(
             input_metadata.max_context_len,
             alibi_slopes,
             input_metadata.kv_cache_dtype,
+            *kv_quant_param,
         )
     else:
         # Run PagedAttention V2.
@@ -341,5 +349,6 @@ def _paged_attention(
             input_metadata.max_context_len,
             alibi_slopes,
             input_metadata.kv_cache_dtype,
+            *kv_quant_param,
         )
     return output

+ 6 - 2
aphrodite/modeling/metadata.py

@@ -1,4 +1,4 @@
-from typing import Optional
+from typing import Optional, List
 
 import torch
 
@@ -13,6 +13,7 @@ class InputMetadata:
         context_lens: the length of attention context for each sequence.
         block_tables: The block tables. (Seq id -> list of physical block)
         kv_cache_dtype: Data Type to store KV cache.
+        kv_quant_params: KV quant scales and zero points for int8 kv cache.
     """
 
     def __init__(
@@ -27,6 +28,7 @@ class InputMetadata:
         block_tables: Optional[torch.Tensor],
         use_cuda_graph: bool,
         kv_cache_dtype: str,
+        kv_quant_params: List[List[float]],
     ) -> None:
         self.is_prompt = is_prompt
         self.prompt_lens = prompt_lens
@@ -38,6 +40,7 @@ class InputMetadata:
         self.block_tables = block_tables
         self.use_cuda_graph = use_cuda_graph
         self.kv_cache_dtype = kv_cache_dtype
+        self.kv_quant_params = kv_quant_params
 
         # Set during the execution of the first attention op.
         # FIXME: This is a hack.
@@ -51,4 +54,5 @@ class InputMetadata:
                 f"context_lens={self.context_lens}, "
                 f"block_tables={self.block_tables}, "
                 f"use_cuda_graph={self.use_cuda_graph}, "
-                f"kv_cache_dtype={self.kv_cache_dtype})")
+                f"kv_cache_dtype={self.kv_cache_dtype}, "
+                f"kv_quant_params={self.kv_quant_params})")

+ 7 - 1
aphrodite/modeling/models/llama.py

@@ -190,6 +190,7 @@ class LlamaAttention(nn.Module):
         hidden_states: torch.Tensor,
         kv_cache: KVCache,
         input_metadata: InputMetadata,
+        kv_quant_param: List[float],
     ) -> torch.Tensor:
         if self.merge_weight:
             qkv, _ = self.qkv_proj(hidden_states)
@@ -201,7 +202,8 @@ class LlamaAttention(nn.Module):
             v, _ = self.v_proj(hidden_states)
         q, k = self.rotary_emb(positions, q, k)
         k_cache, v_cache = kv_cache
-        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
+        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
+                                kv_quant_param)
         output, _ = self.o_proj(attn_output)
         return output
 
@@ -250,6 +252,7 @@ class LlamaDecoderLayer(nn.Module):
         kv_cache: KVCache,
         input_metadata: InputMetadata,
         residual: Optional[torch.Tensor],
+        kv_quant_param: List[float],
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
         if residual is None:
@@ -263,6 +266,7 @@ class LlamaDecoderLayer(nn.Module):
             hidden_states=hidden_states,
             kv_cache=kv_cache,
             input_metadata=input_metadata,
+            kv_quant_param=kv_quant_param,
         )
 
         # Fully Connected
@@ -316,6 +320,8 @@ class LlamaModel(nn.Module):
                 kv_caches[i],
                 input_metadata,
                 residual,
+                input_metadata.kv_quant_params[i]
+                if input_metadata.kv_quant_params is not None else None,
             )
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states

+ 30 - 0
aphrodite/task_handler/model_runner.py

@@ -46,6 +46,7 @@ class ModelRunner:
         device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
         kv_cache_dtype: Optional[str] = "auto",
+        kv_quant_params_path: Optional[str] = None,
         is_driver_worker: bool = False,
     ):
         self.model_config = model_config
@@ -81,6 +82,30 @@ class ModelRunner:
         # cache in_wsl result
         self.in_wsl = in_wsl()
         self.kv_cache_dtype = kv_cache_dtype
+        self.kv_quant_params = self.load_kv_quant_params(
+            model_config,
+            kv_quant_params_path) if self.kv_cache_dtype == "int8" else None
+
+    def load_kv_quant_params(self, model_config: ModelConfig,
+                             kv_quant_params_path: str) -> List[List[float]]:
+        if model_config is None:
+            return None
+        # Remove it when all models support kv cache int8.
+        architectures = model_config.hf_config.architectures
+        for arch in architectures:
+            if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
+                raise ValueError(
+                    f"KV CACHE INT8 is not supported for model architectures {arch} for now. "
+                    f"Supported architectures: LlamaForCausalLM and LLaMAForCausalLM."
+                )
+        num_layers = model_config.hf_config.num_hidden_layers
+        kv_quant_params = []
+        for i in range(num_layers):
+            if kv_quant_params_path is not None:
+                path = kv_quant_params_path + f"/layers.{i}.past_kv_scale.0.weight"
+                kv_quant_param = list(np.fromfile(path, dtype=np.float32))
+            kv_quant_params.append(kv_quant_param)
+        return kv_quant_params
 
     def load_model(self) -> None:
         self.model = get_model(self.model_config, self.device_config,
@@ -255,6 +280,7 @@ class ModelRunner:
             block_tables=block_tables,
             use_cuda_graph=False,
             kv_cache_dtype=self.kv_cache_dtype,
+            kv_quant_params=self.kv_quant_params,
         )
         return (input_tokens, input_positions, input_metadata, prompt_lens,
                 subquery_lens, lora_index_mapping, lora_prompt_mapping,
@@ -383,6 +409,7 @@ class ModelRunner:
             block_tables=block_tables,
             use_cuda_graph=use_captured_graph,
             kv_cache_dtype=self.kv_cache_dtype,
+            kv_quant_params=self.kv_quant_params,
         )
         return (input_tokens, input_positions, input_metadata,
                 lora_index_mapping, lora_prompt_mapping, lora_requests)
@@ -525,6 +552,7 @@ class ModelRunner:
                 "block_tables": input_metadata.block_tables,
                 "use_cuda_graph": input_metadata.use_cuda_graph,
                 "kv_cache_dtype": input_metadata.kv_cache_dtype,
+                "kv_quant_params": input_metadata.kv_quant_params,
                 "selected_token_indices":
                 sampling_metadata.selected_token_indices,
                 "lora_requests": lora_requests,
@@ -548,6 +576,7 @@ class ModelRunner:
                 block_tables=metadata_dict["block_tables"],
                 use_cuda_graph=metadata_dict["use_cuda_graph"],
                 kv_cache_dtype=metadata_dict["kv_cache_dtype"],
+                kv_quant_params=metadata_dict["kv_quant_params"],
             )
             sampling_metadata = SamplingMetadata(
                 seq_groups=None,
@@ -739,6 +768,7 @@ class ModelRunner:
                         block_tables=block_tables[:batch_size],
                         use_cuda_graph=True,
                         kv_cache_dtype=self.kv_cache_dtype,
+                        kv_quant_params=self.kv_quant_params,
                     )
 
                     if self.lora_config:

+ 10 - 7
aphrodite/task_handler/worker.py

@@ -42,6 +42,7 @@ class Worker:
         distributed_init_method: str,
         lora_config: Optional[LoRAConfig] = None,
         kv_cache_dtype: Optional[str] = "auto",
+        kv_quant_params_path: Optional[str] = None,
         is_driver_worker: bool = False,
     ) -> None:
         self.model_config = model_config
@@ -56,13 +57,15 @@ class Worker:
         if self.is_driver_worker:
             assert self.rank == 0, "The driver worker must have rank 0."
 
-        self.model_runner = ModelRunner(model_config,
-                                        parallel_config,
-                                        scheduler_config,
-                                        device_config,
-                                        lora_config=self.lora_config,
-                                        kv_cache_dtype=kv_cache_dtype,
-                                        is_driver_worker=is_driver_worker)
+        self.model_runner = ModelRunner(
+            model_config,
+            parallel_config,
+            scheduler_config,
+            device_config,
+            lora_config=self.lora_config,
+            kv_cache_dtype=kv_cache_dtype,
+            kv_quant_params_path=kv_quant_params_path,
+            is_driver_worker=is_driver_worker)
         # Uninitialized cache engine. Will be initialized by
         # self.init_cache_engine().
         self.cache_config = None

+ 2 - 1
kernels/attention/attention_dtypes.h

@@ -4,4 +4,5 @@
 #include "dtype_float16.cuh"
 #include "dtype_float32.cuh"
 #include "dtype_bfloat16.cuh"
-#include "dtype_fp8_e5m2.cuh"
+#include "dtype_fp8_e5m2.cuh"
+#include "dtype_int8.cuh"

+ 153 - 74
kernels/attention/attention_kernels.cu

@@ -26,6 +26,7 @@
 
 #include "attention_dtypes.h"
 #include "attention_utils.cuh"
+#include "../quantization/int8_kvcache/quant_utils.cuh"
 #ifdef ENABLE_FP8_E5M2
 #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
 #endif
@@ -41,6 +42,13 @@
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
 
+enum kv_cache_dtype {
+  AUTO,
+#ifdef ENABLE_FP8_E5M2
+  FP8_E5M2,
+#endif
+  INT8};
+
 namespace aphrodite {
 
 // Utility function for attention softmax.
@@ -87,7 +95,7 @@ template<
   int HEAD_SIZE,
   int BLOCK_SIZE,
   int NUM_THREADS,
-  bool IS_FP8_E5M2_KV_CACHE,
+  kv_cache_dtype KV_CACHE_DTYPE,
   int PARTITION_SIZE = 0> // Zero means no partitioning.
 __device__ void paged_attention_kernel(
   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
@@ -104,7 +112,11 @@ __device__ void paged_attention_kernel(
   const float* __restrict__ alibi_slopes, // [num_heads]
   const int q_stride,
   const int kv_block_stride,
-  const int kv_head_stride) {
+  const int kv_head_stride,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f) {
   const int seq_idx = blockIdx.y;
   const int partition_idx = blockIdx.z;
   const int max_num_partitions = gridDim.z;
@@ -151,9 +163,7 @@ __device__ void paged_attention_kernel(
   constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
   using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
   using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
-#ifdef ENABLE_FP8_E5M2
   using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
-#endif
 
   constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
   constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
@@ -217,13 +227,16 @@ __device__ void paged_attention_kernel(
         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
         const int offset1 = (vec_idx * VEC_SIZE) / x;
         const int offset2 = (vec_idx * VEC_SIZE) % x;
-        if constexpr (IS_FP8_E5M2_KV_CACHE) {
+        if constexpr (KV_CACHE_DTYPE == INT8) {
+          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          using Dequant_vec = typename FloatVec<Quant_vec>::Type;
+          Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
+          k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
 #ifdef ENABLE_FP8_E5M2
+        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
           Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
           // Vector conversion from Quant_vec to K_vec.
           k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
-#else
-          assert(false);
 #endif
         } else {
           k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
@@ -301,9 +314,7 @@ __device__ void paged_attention_kernel(
   constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
   using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
   using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
-#ifdef ENABLE_FP8_E5M2
   using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
-#endif
   using Float_L_vec = typename FloatVec<L_vec>::Type;
 
   constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
@@ -337,13 +348,17 @@ __device__ void paged_attention_kernel(
       if (row_idx < HEAD_SIZE) {
         const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
         V_vec v_vec;
-        if constexpr (IS_FP8_E5M2_KV_CACHE) {
+        if constexpr (KV_CACHE_DTYPE == INT8) {
+          // dequant and conversion
+          V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+          using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
+          V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
+          v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
 #ifdef ENABLE_FP8_E5M2
+        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
           V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
           // Vector conversion from V_quant_vec to V_vec.
           v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
-#else
-          assert(false);
 #endif
         } else {
           v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
@@ -431,7 +446,7 @@ template<
   int HEAD_SIZE,
   int BLOCK_SIZE,
   int NUM_THREADS,
-  bool IS_FP8_E5M2_KV_CACHE>
+  kv_cache_dtype KV_CACHE_DTYPE>
 __global__ void paged_attention_v1_kernel(
   scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
@@ -445,11 +460,15 @@ __global__ void paged_attention_v1_kernel(
   const float* __restrict__ alibi_slopes, // [num_heads]
   const int q_stride,
   const int kv_block_stride,
-  const int kv_head_stride) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
+  const int kv_head_stride,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
     /* exp_sums */ nullptr, /* max_logits */ nullptr,
     out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
-    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
+    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
 }
 
 // Grid: (num_heads, num_seqs, max_num_partitions).
@@ -459,7 +478,7 @@ template<
   int HEAD_SIZE,
   int BLOCK_SIZE,
   int NUM_THREADS,
-  bool IS_FP8_E5M2_KV_CACHE,
+  kv_cache_dtype KV_CACHE_DTYPE,
   int PARTITION_SIZE>
 __global__ void paged_attention_v2_kernel(
   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
@@ -476,11 +495,15 @@ __global__ void paged_attention_v2_kernel(
   const float* __restrict__ alibi_slopes, // [num_heads]
   const int q_stride,
   const int kv_block_stride,
-  const int kv_head_stride) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
+  const int kv_head_stride,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
     exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
     block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
-    q_stride, kv_block_stride, kv_head_stride);
+    q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
 }
 
 // Grid: (num_heads, num_seqs).
@@ -584,32 +607,36 @@ __global__ void paged_attention_v2_reduce_kernel(
 
 } // namespace aphrodite
 
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \
-  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                       \
-    ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,   \
-      IS_FP8_E5M2_KV_CACHE>), shared_mem_size);                                               \
-  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
-  IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>(                            \
-    out_ptr,                                                                                  \
-    query_ptr,                                                                                \
-    key_cache_ptr,                                                                            \
-    value_cache_ptr,                                                                          \
-    num_kv_heads,                                                                             \
-    scale,                                                                                    \
-    block_tables_ptr,                                                                         \
-    context_lens_ptr,                                                                         \
-    max_num_blocks_per_seq,                                                                   \
-    alibi_slopes_ptr,                                                                         \
-    q_stride,                                                                                 \
-    kv_block_stride,                                                                          \
-    kv_head_stride);
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                        \
+  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                        \
+    ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,    \
+      KV_CACHE_DTYPE>), shared_mem_size);                                                           \
+  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,              \
+  KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>(                                        \
+    out_ptr,                                                                                        \
+    query_ptr,                                                                                      \
+    key_cache_ptr,                                                                                  \
+    value_cache_ptr,                                                                                \
+    num_kv_heads,                                                                                   \
+    scale,                                                                                          \
+    block_tables_ptr,                                                                               \
+    context_lens_ptr,                                                                               \
+    max_num_blocks_per_seq,                                                                         \
+    alibi_slopes_ptr,                                                                               \
+    q_stride,                                                                                       \
+    kv_block_stride,                                                                                \
+    kv_head_stride,                                                                                 \
+    k_scale,                                                                                        \
+    k_zp,                                                                                           \
+    v_scale,                                                                                        \
+    v_zp);
 
 // TODO: Tune NUM_THREADS.
 template<
   typename T,
   typename CACHE_T,
   int BLOCK_SIZE,
-  bool IS_FP8_E5M2_KV_CACHE,
+  kv_cache_dtype KV_CACHE_DTYPE,
   int NUM_THREADS = 128>
 void paged_attention_v1_launcher(
   torch::Tensor& out,
@@ -621,7 +648,11 @@ void paged_attention_v1_launcher(
   torch::Tensor& block_tables,
   torch::Tensor& context_lens,
   int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes) {
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int head_size = query.size(2);
@@ -685,8 +716,8 @@ void paged_attention_v1_launcher(
   }
 }
 
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)       \
-  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)             \
+  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(       \
     out,                                                                     \
     query,                                                                   \
     key_cache,                                                               \
@@ -696,20 +727,24 @@ void paged_attention_v1_launcher(
     block_tables,                                                            \
     context_lens,                                                            \
     max_context_len,                                                         \
-    alibi_slopes);
+    alibi_slopes,                                                            \
+    k_scale,                                                                 \
+    k_zp,                                                                    \
+    v_scale,                                                                 \
+    v_zp);
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)       \
   switch (block_size) {                                               \
     case 8:                                                           \
-      CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);          \
+      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                \
       break;                                                          \
     case 16:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);         \
+      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);               \
       break;                                                          \
     case 32:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);         \
+      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);               \
       break;                                                          \
     default:                                                          \
       TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
@@ -728,24 +763,40 @@ void paged_attention_v1(
   int block_size,
   int max_context_len,
   const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype) {
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f) {
   if (kv_cache_dtype == "auto") {
     if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
+      CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
     } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
     } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
+      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
     } else {
       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
     }
+#ifdef ENABLE_FP8_E5M2
   } else if (kv_cache_dtype == "fp8_e5m2") {
     if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+      CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#endif
+  } else if (kv_cache_dtype == "int8") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
     } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
     } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
+      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
     } else {
       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
     }
@@ -755,8 +806,8 @@ void paged_attention_v1(
 }
 
 #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
-  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
-  IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>                                                       \
+  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,        \
+  KV_CACHE_DTYPE, PARTITION_SIZE>                                                             \
   <<<grid, block, shared_mem_size, stream>>>(                                                 \
     exp_sums_ptr,                                                                             \
     max_logits_ptr,                                                                           \
@@ -772,7 +823,11 @@ void paged_attention_v1(
     alibi_slopes_ptr,                                                                         \
     q_stride,                                                                                 \
     kv_block_stride,                                                                          \
-    kv_head_stride);                                                                          \
+    kv_head_stride,                                                                           \
+    k_scale,                                                                                  \
+    k_zp,                                                                                     \
+    v_scale,                                                                                  \
+    v_zp);                                                                                    \
   aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
   <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
     out_ptr,                                                                                  \
@@ -786,7 +841,7 @@ template<
   typename T,
   typename CACHE_T,
   int BLOCK_SIZE,
-  bool IS_FP8_E5M2_KV_CACHE,
+  kv_cache_dtype KV_CACHE_DTYPE,
   int NUM_THREADS = 128,
   int PARTITION_SIZE = 512>
 void paged_attention_v2_launcher(
@@ -802,7 +857,11 @@ void paged_attention_v2_launcher(
   torch::Tensor& block_tables,
   torch::Tensor& context_lens,
   int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes) {
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int head_size = query.size(2);
@@ -872,8 +931,8 @@ void paged_attention_v2_launcher(
   }
 }
 
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)           \
-  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>(     \
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)                 \
+  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(           \
     out,                                                                         \
     exp_sums,                                                                    \
     max_logits,                                                                  \
@@ -886,20 +945,24 @@ void paged_attention_v2_launcher(
     block_tables,                                                                \
     context_lens,                                                                \
     max_context_len,                                                             \
-    alibi_slopes);
+    alibi_slopes,                                                                \
+    k_scale,                                                                     \
+    k_zp,                                                                        \
+    v_scale,                                                                     \
+    v_zp);
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE)       \
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)             \
   switch (block_size) {                                                     \
     case 8:                                                                 \
-      CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);                \
+      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                      \
       break;                                                                \
     case 16:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);               \
+      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);                     \
       break;                                                                \
     case 32:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);               \
+      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);                     \
       break;                                                                \
     default:                                                                \
       TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
@@ -921,24 +984,40 @@ void paged_attention_v2(
   int block_size,
   int max_context_len,
   const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype) {
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f) {
   if (kv_cache_dtype == "auto") {
     if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
+      CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
     } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
     } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
+      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
     } else {
       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
     }
+#ifdef ENABLE_FP8_E5M2
   } else if (kv_cache_dtype == "fp8_e5m2") {
     if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+      CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#endif
+  } else if (kv_cache_dtype == "int8") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
     } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
     } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
+      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
     } else {
       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
     }

+ 7 - 0
kernels/attention/dtype_float32.cuh

@@ -86,6 +86,13 @@ inline __device__ float4 add(float4 a, float4 b) {
   return c;
 }
 
+inline __device__ Float4_ add(Float4_ a, Float4_ b) {
+  Float4_ c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  return c;
+}
+
 // Vector multiplication.
 template<>
 inline __device__ float mul<float, float>(float a, float b) {

+ 49 - 0
kernels/attention/dtype_int8.cuh

@@ -0,0 +1,49 @@
+#pragma once
+
+#include <stdint.h>
+#include "attention_generic.cuh"
+#include "dtype_float32.cuh"
+
+namespace aphrodite {
+// int8 vector types for quantization of kv cache
+
+template<>
+struct Vec<int8_t, 1> {
+    using Type = int8_t;
+};
+
+template<>
+struct Vec<int8_t, 2> {
+    using Type = int16_t;
+};
+
+template<>
+struct Vec<int8_t, 4> {
+    using Type = int32_t;
+};
+
+template<>
+struct Vec<int8_t, 8> {
+    using Type = int64_t;
+};
+
+template<>
+struct FloatVec<int8_t> {
+    using Type = float;
+};
+
+template<>
+struct FloatVec<int16_t> {
+    using Type = float2;
+};
+
+template<>
+struct FloatVec<int32_t> {
+    using Type = Float4_;
+};
+
+template<>
+struct FloatVec<int64_t> {
+    using Type = Float8_;
+};
+}

+ 10 - 6
kernels/cache.h

@@ -16,12 +16,16 @@ void copy_blocks(
   const std::map<int64_t, std::vector<int64_t>>& block_mapping);
 
 void reshape_and_cache(
-  torch::Tensor& key,
-  torch::Tensor& value,
-  torch::Tensor& key_cache,
-  torch::Tensor& value_cache,
-  torch::Tensor& slot_mapping,
-  const std::string& kv_cache_dtype);
+  torch::Tensor& key,   
+  torch::Tensor& value, 
+  torch::Tensor& key_cache, 
+  torch::Tensor& value_cache, 
+  torch::Tensor& slot_mapping, 
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f);
 
 void gather_cached_kv(
   torch::Tensor& key,

+ 46 - 15
kernels/cache_kernels.cu

@@ -4,6 +4,7 @@
 
 #include "cuda_compat.h"
 #include "dispatch_utils.h"
+#include "quantization/int8_kvcache/quant_utils.cuh"
 #ifdef ENABLE_FP8_E5M2
 #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
 #endif
@@ -13,6 +14,13 @@
 #include <map>
 #include <vector>
 
+enum kv_cache_dtype {
+  AUTO,
+#ifdef ENABLE_FP8_E5M2
+  FP8_E5M2,
+#endif
+  INT8};
+
 #ifdef USE_ROCM
   #include <hip/hip_bf16.h>
   typedef __hip_bfloat16 __nv_bfloat16;
@@ -151,7 +159,7 @@ void copy_blocks(
 
 namespace aphrodite {
 
-template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
+template<typename scalar_t, typename cache_t, kv_cache_dtype KV_CACHE_DTYPE>
 __global__ void reshape_and_cache_kernel(
   const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
   const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
@@ -163,7 +171,11 @@ __global__ void reshape_and_cache_kernel(
   const int num_heads,
   const int head_size,
   const int block_size,
-  const int x) {
+  const int x,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
   const int64_t token_idx = blockIdx.x;
   const int64_t slot_idx = slot_mapping[token_idx];
   if (slot_idx < 0) {
@@ -195,12 +207,13 @@ __global__ void reshape_and_cache_kernel(
                                   + block_offset;
     scalar_t tgt_key = key[src_key_idx];
     scalar_t tgt_value = value[src_value_idx];
-    if constexpr (is_fp8_e5m2_kv_cache) {
+    if constexpr (KV_CACHE_DTYPE == INT8) {
+      key_cache[tgt_key_idx] = int8::quant(tgt_key, k_scale, k_zp);
+      value_cache[tgt_value_idx] = int8::quant(tgt_value, v_scale, v_zp);
 #ifdef ENABLE_FP8_E5M2
+    } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
       key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
       value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
-#else
-      assert(false);
 #endif
     } else {
       key_cache[tgt_key_idx] = tgt_key;
@@ -211,8 +224,8 @@ __global__ void reshape_and_cache_kernel(
 
 } // namespace aphrodite
 
-#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE)                                \
-  aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
+#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_CACHE_DTYPE)                                      \
+  aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, KV_CACHE_DTYPE><<<grid, block, 0, stream>>>(  \
     reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \
     reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \
     reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \
@@ -223,7 +236,11 @@ __global__ void reshape_and_cache_kernel(
     num_heads,                                                                                     \
     head_size,                                                                                     \
     block_size,                                                                                    \
-    x);
+    x,                                                                                             \
+    k_scale,                                                                                       \
+    k_zp,                                                                                          \
+    v_scale,                                                                                       \
+    v_zp);
 
 void reshape_and_cache(
   torch::Tensor& key,           // [num_tokens, num_heads, head_size]
@@ -231,7 +248,11 @@ void reshape_and_cache(
   torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
   torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
   torch::Tensor& slot_mapping,  // [num_tokens]
-  const std::string& kv_cache_dtype)
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f)
 {
   int num_tokens = key.size(0);
   int num_heads = key.size(1);
@@ -248,19 +269,29 @@ void reshape_and_cache(
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   if (kv_cache_dtype == "auto") {
     if (key.dtype() == at::ScalarType::Float) {
-      CALL_RESHAPE_AND_CACHE(float, float, false);
+      CALL_RESHAPE_AND_CACHE(float, float, AUTO);
     } else if (key.dtype() == at::ScalarType::Half) {
-      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
+      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, AUTO);
     } else if (key.dtype() == at::ScalarType::BFloat16) {
-      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
+      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, AUTO);
     }
+#ifdef ENABLE_FP8_E5M2
   } else if (kv_cache_dtype == "fp8_e5m2") {
     if (key.dtype() == at::ScalarType::Float) {
-      CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
+      CALL_RESHAPE_AND_CACHE(float, uint8_t, FP8_E5M2);
+    } else if (key.dtype() == at::ScalarType::Half) {
+      CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, FP8_E5M2);
+    } else if (key.dtype() == at::ScalarType::BFloat16) {
+      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, FP8_E5M2);
+    }
+#endif
+  } else if (kv_cache_dtype == "int8") {
+    if (key.dtype() == at::ScalarType::Float) {
+      CALL_RESHAPE_AND_CACHE(float, int8_t, INT8);
     } else if (key.dtype() == at::ScalarType::Half) {
-      CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
+      CALL_RESHAPE_AND_CACHE(uint16_t, int8_t, INT8);
     } else if (key.dtype() == at::ScalarType::BFloat16) {
-      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
+      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, int8_t, INT8);
     }
   } else {
     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);

+ 2 - 1
kernels/dispatch_utils.h

@@ -20,7 +20,8 @@
   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
   AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
   AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)   \
-  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
+  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
 
 #define APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...)           \
   AT_DISPATCH_SWITCH(                                                    \

+ 10 - 2
kernels/ops.h

@@ -14,7 +14,11 @@ void paged_attention_v1(
   int block_size,
   int max_context_len,
   const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype);
+  const std::string& kv_cache_dtype,
+  float k_scale = 1.0f,
+  float k_zp = 0.0f,
+  float v_scale = 1.0f,
+  float v_zp = 0.0f);
 
 void paged_attention_v2(
   torch::Tensor& out,
@@ -31,7 +35,11 @@ void paged_attention_v2(
   int block_size,
   int max_context_len,
   const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype);
+  const std::string& kv_cache_dtype,
+  float k_scale = 1.0f,
+  float k_zp = 0.0f,
+  float v_scale = 1.0f,
+  float v_zp = 0.0f);
 
 void rms_norm(
   torch::Tensor& out,

+ 284 - 0
kernels/quantization/int8_kvcache/quant_utils.cuh

@@ -0,0 +1,284 @@
+// Adated from FasterTransformer, https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+#pragma once
+
+#include <assert.h>
+#include <stdint.h>
+#include <float.h>
+#include <type_traits>
+#include "../../attention/attention_dtypes.h"
+
+namespace aphrodite {
+namespace int8 {
+// float32 to int8
+inline __device__ int8_t quant(float a, const float scale, const float zp)
+{
+    int8_t int8;
+    int8 = round(max(-128.f, min(127.f, (a - zp) / scale)));
+    return int8;
+}
+
+// float32x2 to int8x2
+inline __device__ short quant(float2 a, const float scale, const float zp)
+{
+    union {
+        int8_t int8[2];
+        short  int16;
+    };
+
+    int8[0] = quant(a.x, scale, zp);
+    int8[1] = quant(a.y, scale, zp);
+    return int16;
+}
+
+// float32x4 to int8x4
+inline __device__ int32_t quant(float4 a, const float scale, const float zp)
+{
+    union {
+        int8_t  int8[4];
+        int32_t int32;
+    };
+
+    int8[0] = quant(a.x, scale, zp);
+    int8[1] = quant(a.y, scale, zp);
+    int8[2] = quant(a.z, scale, zp);
+    int8[3] = quant(a.w, scale, zp);
+    return int32;
+}
+
+// float16 to int8
+inline __device__ int8_t quant(uint16_t a, const float scale, const float zp)
+{
+    int8_t int8;
+    float  b = half_to_float(a);
+    int8     = quant(b, scale, zp);
+    return int8;
+}
+
+// float16x2 to int8x2
+inline __device__ int16_t quant(uint32_t a, const float scale, const float zp)
+{
+    union {
+        int8_t int8[2];
+        short  int16;
+    };
+    float2 b = half2_to_float2(a);
+
+    int8[0] = quant(b.x, scale, zp);
+    int8[1] = quant(b.y, scale, zp);
+    return int16;
+}
+
+// float16x4 to int8x4
+inline __device__ int32_t quant(uint2 a, const float scale, const float zp)
+{
+    union {
+        int16_t int16[2];
+        int32_t int32;
+    };
+
+    int16[0] = quant(a.x, scale, zp);
+    int16[1] = quant(a.y, scale, zp);
+    return int32;
+}
+
+// float16x8 to int8x8
+inline __device__ int64_t quant(uint4 a, const float scale, const float zp)
+{
+    union {
+        int16_t int16[4];
+        int64_t int64;
+    };
+
+    int16[0] = quant(a.x, scale, zp);
+    int16[1] = quant(a.y, scale, zp);
+    int16[2] = quant(a.z, scale, zp);
+    int16[3] = quant(a.w, scale, zp);
+    return int64;
+}
+
+// bf16 to int8
+inline __device__ int8_t quant(__nv_bfloat16 a, const float scale, const float zp)
+{
+    int8_t int8;
+    float  b = to_float(a);
+    int8     = quant(b, scale, zp);
+    return int8;
+}
+
+//bf16x2 to int8x2
+inline __device__ int16_t quant(__nv_bfloat162 a, const float scale, const float zp)
+{
+    union {
+        int8_t int8[2];
+        short  int16;
+    };
+    float2 b = bf1622float2(a);
+
+    int8[0] = quant(b.x, scale, zp);
+    int8[1] = quant(b.y, scale, zp);
+    return int16;
+}
+
+// bf16x4 to int8x4
+inline __device__ int32_t quant(bf16_4_t a, const float scale, const float zp)
+{
+    union {
+        int16_t int16[2];
+        int32_t int32;
+    };
+
+    int16[0] = quant(a.x, scale, zp);
+    int16[1] = quant(a.y, scale, zp);
+    return int32;
+}
+
+// bf16x8 to int8x8
+inline __device__ int64_t quant(bf16_8_t a, const float scale, const float zp)
+{
+    union {
+        int16_t int16[4];
+        int64_t int64;
+    };
+
+    int16[0] = quant(a.x, scale, zp);
+    int16[1] = quant(a.y, scale, zp);
+    int16[2] = quant(a.z, scale, zp);
+    int16[3] = quant(a.w, scale, zp);
+    return int64;
+}
+
+// int8 to float32, then `vec_conversion` to target format
+inline __device__ float dequant(int8_t a, const float scale, const float zp)
+{
+    float b = a * scale + zp;
+    return b;
+}
+
+// int8x2 to float32x2
+inline __device__ float2 dequant(int16_t a, const float scale, const float zp)
+{
+    union {
+        int8_t  int8[2];
+        int16_t int16;
+    };
+    int16 = a;
+
+    float2 b;
+    b.x = int8[0] * scale + zp;
+    b.y = int8[1] * scale + zp;
+    return b;
+}
+
+// int8x4 to float32x4
+inline __device__ Float4_ dequant(int32_t a, const float scale, const float zp)
+{
+    union {
+        int8_t  int8[4];
+        int32_t int32;
+    };
+    int32 = a;
+
+    Float4_ b;
+    b.x.x = (int8[0] * scale) + zp;
+    b.x.y = (int8[1] * scale) + zp;
+    b.y.x = (int8[2] * scale) + zp;
+    b.y.y = (int8[3] * scale) + zp;
+    return b;
+}
+
+// int8x8 to float32x8
+inline __device__ Float8_ dequant(int64_t a, const float scale, const float zp)
+{
+    union {
+        int16_t int16[4];
+        int64_t int64;
+    };
+    int64 = a;
+
+    Float8_ b;
+    b.x = dequant(int16[0], scale, zp);
+    b.y = dequant(int16[1], scale, zp);
+    b.z = dequant(int16[2], scale, zp);
+    b.w = dequant(int16[3], scale, zp);
+    return b;
+}
+
+template<typename Tout, typename Tin>
+__inline__ __device__ Tout vec_conversion(const Tin& x)
+{
+    return x;
+}
+
+template<>
+__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
+{
+    union {
+        half2    float16;
+        uint32_t uint32;
+    };
+
+    float16 = __float22half2_rn(a);
+    return uint32;
+}
+
+template<>
+__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
+{
+    uint2  b;
+    float2 val;
+    val.x = a.x.x;
+    val.y = a.x.y;
+    b.x   = vec_conversion<uint32_t, float2>(val);
+
+    val.x = a.y.x;
+    val.y = a.y.y;
+    b.y   = vec_conversion<uint32_t, float2>(val);
+
+    return b;
+}
+
+template<>
+__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
+{
+    float4 b;
+    b.x = a.x.x;
+    b.y = a.x.y;
+    b.z = a.y.x;
+    b.w = a.y.y;
+    return b;
+}
+
+template<>
+__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
+{
+    uint4 b;
+    b.x = vec_conversion<uint32_t, float2>(a.x);
+    b.y = vec_conversion<uint32_t, float2>(a.y);
+    b.z = vec_conversion<uint32_t, float2>(a.z);
+    b.w = vec_conversion<uint32_t, float2>(a.w);
+    return b;
+}
+
+template<>
+__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
+    __nv_bfloat162 b;
+    from_float(b, a);
+    return b;
+}
+
+template<>
+__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
+    bf16_4_t b;
+    from_float(b, a);
+    return b;
+}
+
+template<>
+__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
+    bf16_8_t b;
+    from_float(b, a);
+    return b;
+}
+
+} // namespace int8
+} // namespace aphrodite