from typing import Any, Dict, List, NamedTuple, Optional, TypeVar

import torch
from torch.nn.parameter import Parameter

from aphrodite import _custom_ops as ops
from aphrodite.modeling.layers.linear import (ColumnParallelLinear, LinearBase,
                                              LinearMethodBase,
                                              QKVParallelLinear,
                                              RowParallelLinear)
from aphrodite.modeling.utils import set_weight_attrs
from aphrodite.quantization.base_config import QuantizationConfig


class AutoQuantConfig(QuantizationConfig):
    """Config class for AutoQuant.
    Reference: https://arxiv.org/abs/2208.07339
    """

    def __init__(
            self,
            weight_bits: int,
            group_size: int,
            zero_point: bool,
            from_float: bool,
            quant_mode: str,  # llm_int8, smoothquant, weight_only
    ) -> None:
        self.weight_bits = weight_bits
        self.group_size = group_size
        self.zero_point = zero_point
        self.from_float = from_float
        self.quant_mode = quant_mode

        if quant_mode == "weight_only" and self.weight_bits != 4:
            raise ValueError(
                "Currently, only 4-bit weight quantization is supported for "
                f"AutoQuant weight_only, but got {self.weight_bits} bits.")
        if quant_mode in ["llm_int8", "smoothquant"] and self.weight_bits != 8:
            raise ValueError(
                "Currently, only 8-bit weight quantization is supported for "
                "AutoQuant llm_int8 or smoothquant, "
                f"but got {self.weight_bits} bits.")
        self.pack_factor = 32 // self.weight_bits

    def __repr__(self) -> str:
        return (f"AutoQuantConfig(weight_bits={self.weight_bits}, "
                f"group_size={self.group_size}, "
                f"zero_point={self.zero_point}, "
                f"from_float={self.from_float}, "
                f"quant_mode={self.quant_mode})")

    def get_name(self) -> str:
        return "autoquant"

    def get_supported_act_dtypes(self) -> List[torch.dtype]:
        return [torch.half, torch.bfloat16]

    def get_min_capability(self) -> int:
        # The AutoQuant kernel only supports Ampere or newer GPUs.
        return 75

    @staticmethod
    def get_config_filenames() -> List[str]:
        return [
            "quant_config.json",
            "quantize_config.json",
        ]

    @classmethod
    def from_config(cls, config: Dict[str, Any]) -> "AutoQuantConfig":
        weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
        group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
        zero_point = cls.get_from_keys(config, ["zero_point"])
        try:
            from_float = cls.get_from_keys(config, ["from_float"])
        except Exception:
            from_float = False
        try:
            quant_mode = cls.get_from_keys(config, ["quant_mode"])
        except Exception:
            quant_mode = "weight_only"
        return cls(weight_bits, group_size, zero_point, from_float, quant_mode)

    def get_quant_method(
            self, layer: torch.nn.Module) -> Optional["AutoQuantLinearMethod"]:
        if isinstance(layer, LinearBase):
            return AutoQuantLinearMethod(self)
        return None

    def get_scaled_act_names(self) -> List[str]:
        return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]


class AutoQuantLinearMethod(LinearMethodBase):
    """Linear method for AutoQuant.
    Args:
        quant_config: The AutoQuant quantization config.
    """

    def __init__(self, quant_config: AutoQuantConfig):
        self.quant_config = quant_config

    def create_weights(self, layer: torch.nn.Module,
                       input_size_per_partition: int,
                       output_partition_sizes: List[int], input_size: int,
                       output_size: int, params_dtype: torch.dtype,
                       **extra_weight_attrs):
        if self.quant_config.quant_mode == "weight_only" and \
                input_size_per_partition % self.quant_config.group_size != 0:
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")
        output_size_per_partition = sum(output_partition_sizes)
        if self.quant_config.quant_mode == "weight_only" and \
                output_size_per_partition % self.quant_config.pack_factor != 0:
            raise ValueError(
                "The output size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size.")
        if self.quant_config.quant_mode == "weight_only" and \
                not self.quant_config.from_float:
            qweight = Parameter(
                torch.empty(
                    input_size_per_partition,
                    output_size_per_partition // self.quant_config.pack_factor,
                    dtype=torch.int32,
                ),
                requires_grad=False,
            )
            set_weight_attrs(
                qweight, {
                    "input_dim": 0,
                    "output_dim": 1,
                    "packed_dim": 1,
                    "pack_factor": self.quant_config.pack_factor,
                })
            qzeros = Parameter(
                torch.empty(
                    input_size_per_partition // self.quant_config.group_size,
                    output_size_per_partition // self.quant_config.pack_factor,
                    dtype=torch.int32,
                ),
                requires_grad=False,
            )
            set_weight_attrs(
                qzeros, {
                    "input_dim": 0,
                    "output_dim": 1,
                    "packed_dim": 1,
                    "pack_factor": self.quant_config.pack_factor,
                })
            scales = Parameter(
                torch.empty(
                    input_size_per_partition // self.quant_config.group_size,
                    output_size_per_partition,
                    dtype=params_dtype,
                ),
                requires_grad=False,
            )
            set_weight_attrs(scales, {
                "input_dim": 0,
                "output_dim": 1,
            })
            layer.register_parameter("qweight", qweight)
            set_weight_attrs(qweight, extra_weight_attrs)
            layer.register_parameter("qzeros", qzeros)
            set_weight_attrs(qzeros, extra_weight_attrs)
            layer.register_parameter("scales", scales)
            set_weight_attrs(scales, extra_weight_attrs)
        else:
            weight = Parameter(torch.empty(output_size_per_partition,
                                           input_size_per_partition,
                                           dtype=params_dtype),
                               requires_grad=False)
            set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
            layer.register_parameter("weight", weight)
            set_weight_attrs(weight, extra_weight_attrs)

    def apply(self,
              layer: torch.nn.Module,
              x: torch.Tensor,
              bias: Optional[torch.Tensor] = None) -> torch.Tensor:
        if self.quant_config.quant_mode == "weight_only":
            qweight = layer.qweight
            scales_zeros = layer.scales_zeros
            pack_factor = self.quant_config.pack_factor
            out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
            reshaped_x = x.reshape(-1, x.shape[-1])
            out = ops.autoquant_s4_f16_gemm(reshaped_x, qweight, scales_zeros)
            if bias is not None:
                out = out + bias
            return out.reshape(out_shape)
        else:
            weight = layer.weight
            state = layer.state
            if weight.CB is not None:
                state.CB = weight.CB
                state.SCB = weight.SCB
                weight.CB = None
                weight.SCB = None
            import bitsandbytes as bnb
            out = bnb.matmul(x, weight, bias=bias, state=state)
            if not state.has_fp16_weights and \
                    state.CB is not None and state.CxB is not None:
                # we converted 8-bit row major to turing/ampere format
                # in the first inference pass
                # we no longer need the row-major weight
                del state.CB
                weight.data = state.CxB
            return out


T = TypeVar("T", bound="torch.nn.Module")


class QParams(NamedTuple):
    """A class to hold the quantization parameters."""

    scales: torch.Tensor
    zero_points: Optional[torch.Tensor]


@torch.no_grad()
def cal_qparams_per_group_minmax(w: torch.Tensor,
                                 n_bits: int = 4,
                                 group_size: int = 128):
    """Calculate quantization parameters for each group using min and max
    values."""

    outc, inc = w.shape
    assert inc >= group_size, \
        'Input channels should be greater than or equal to group_size.'
    assert inc % group_size == 0, \
        'Input channels should be divisible by group_size.'
    w_group_wise = w.reshape(outc, -1, group_size)
    w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
    w_max = w_group_wise.max(dim=-1, keepdim=True)[0]

    q_max = 2**n_bits - 1
    q_min = 0
    scales = (w_max - w_min)
    scales = scales.clamp_(min=1e-5).div_(q_max)
    # zero_points = (-w_min / scales).round().clamp(q_min, q_max)
    zero_points = (-torch.round(w_min / scales)).clamp_(q_min, q_max)
    return QParams(scales=scales, zero_points=zero_points)


def convert_s4(qw: torch.Tensor,
               qz: torch.Tensor,
               s: torch.Tensor,
               group_size: int = 128):
    assert qw.is_contiguous()
    assert qz.is_contiguous()
    assert s.is_contiguous()
    _qw = torch.zeros_like(qw)
    _sz = torch.zeros_like(s, dtype=torch.int32)  # half2
    _ws = torch.zeros_like(s)
    ops.autoquant_convert_s4_k_m8(_qw, _sz, _ws, qw, s, qz,
                                  qw.size(-1) * 8, qw.size(0), group_size)
    return _qw, _sz


def tp_m_s4(x: torch.Tensor, tp: int = 1):
    return x.view(x.size(0) // 32, tp, -1, 128).permute(0, 2, 3,
                                                        1).contiguous()


def quant(weight: torch.Tensor,
          qparams: Optional[QParams] = None) -> torch.Tensor:
    """Perform fake quantization on the given weight tensor.
    Args:
        weight (torch.Tensor): The weight tensor with shape
            (out_features, in_features).
        qparams (Optional[QParams]): A namedtuple containing 'scales'
            and 'zero_points'.
    Returns:
        torch.Tensor: The fake quantized weight tensor.
    """
    if qparams is None:
        qparams = cal_qparams_per_group_minmax(weight)
    scales = qparams.scales
    zero_points = qparams.zero_points
    out_c, in_c = weight.shape
    # Reshape the weights if using per_group quantization
    # per tensor scales shape: [1]
    # per channel scales shape: [out_c, 1]
    # per group scales shape: [out_c, in_c//group_size, 1]
    if len(scales.shape) > 2:
        # scales shape: [out_c, in_c//group_size, 1]
        weight = weight.reshape(out_c, scales.shape[1], -1)
    if zero_points is None:
        real_qweight = (weight / scales).round()
    else:
        real_qweight = ((weight + (scales * zero_points)) / scales).round()
    if len(scales.shape) > 2:
        real_qweight = real_qweight.reshape(out_c, in_c)
    return real_qweight.to(torch.int32)


# core quantization method (simulated quantization)
def quantize_tensor(
    weight,
    n_bits=4,
    group_size=128,
):
    pack_num = 32 // n_bits
    pack_order = [0, 2, 4, 6, 1, 3, 5, 7]
    org_weight_shape = weight.shape
    out_features = org_weight_shape[0]
    in_features = org_weight_shape[1]
    qparams = cal_qparams_per_group_minmax(weight, n_bits)
    i32_w = quant(weight, qparams)
    i32_w = i32_w.t().contiguous()
    w_pack_oc = out_features // (32 // n_bits)
    w_inc = in_features
    pack_int_w = torch.zeros((w_inc, w_pack_oc),
                             dtype=torch.int32,
                             device=weight.device)
    for col in range(pack_int_w.shape[1]):
        for i in range(pack_num):
            pack_int_w_col = i32_w[:, col * pack_num + pack_order[i]]
            pack_int_w[:, col] |= pack_int_w_col << (i * n_bits)
    qweight = pack_int_w
    scales = qparams.scales.squeeze(-1).t().contiguous()
    if qparams.zero_points is not None:
        zeros = qparams.zero_points.to(torch.int32)
        zeros = zeros.squeeze(-1).t().contiguous()
        z_inc = in_features // group_size
        z_oc = out_features // (32 // n_bits)
        pack_int_zeros = torch.zeros((z_inc, z_oc),
                                     dtype=torch.int32,
                                     device=weight.device)
        for col in range(pack_int_zeros.shape[1]):
            for i in range(pack_num):
                qzero_col = zeros[:, col * pack_num + pack_order[i]]
                pack_int_zeros[:, col] |= qzero_col << (i * n_bits)
        qzeros = pack_int_zeros
    return qweight, scales, qzeros


def replace_quant_params(model,
                         quant_config,
                         modules_to_not_convert="lm_head"):
    """
    modules_to_not_convert (`str`, *optional*, defaults to `lm_head`):
            Name of the module to not convert in `Linear8bitLt`.
            In practice we keep the `lm_head` in full precision
            for numerical stability reasons.
    """
    if not isinstance(modules_to_not_convert, list):
        modules_to_not_convert = [modules_to_not_convert]
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_quant_params(module, quant_config, modules_to_not_convert)
        if isinstance(
            module,
                (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear)) \
                and name not in modules_to_not_convert:
            if quant_config.from_float:
                module.linear_weights.pop("weight")
                param = module._parameters["weight"]
                if quant_config.quant_mode in ("llm_int8", "smoothquant"):
                    import bitsandbytes as bnb
                    new_value = bnb.nn.Int8Params(param.data,
                                                  requires_grad=False,
                                                  has_fp16_weights=False)
                    state = bnb.MatmulLtState()
                    if quant_config.quant_mode == "smoothquant":
                        state.threshold = 0.0
                    else:
                        state.threshold = 6.0
                    state.has_fp16_weights = False
                    state.memory_efficient_backward = False
                    state.use_pool = True
                    module._parameters["weight"] = new_value
                    module.linear_weights["weight"] = new_value
                    module.linear_weights["state"] = state
                    set_weight_attrs(
                        new_value, {
                            "input_dim": 0,
                            "output_dim": 1,
                            "packed_dim": 1,
                            "pack_factor": quant_config.pack_factor,
                        })
                    del param
                    torch.cuda.empty_cache()

                elif quant_config.quant_mode == "weight_only":
                    data_fp = param.cuda()
                    _qweight, _scales, _qzeros = quantize_tensor(
                        data_fp, n_bits=4, group_size=128)
                    qweight, scales_zeros = convert_s4(_qweight, _qzeros,
                                                       _scales)
                    torch.cuda.synchronize()
                    param_qweight = Parameter(qweight, requires_grad=False)
                    param_scales_zeros = Parameter(scales_zeros,
                                                   requires_grad=False)
                    module.register_parameter("qweight", param_qweight)
                    module.register_parameter("scales_zeros",
                                              param_scales_zeros)
                    set_weight_attrs(
                        param_qweight, {
                            "input_dim": 0,
                            "output_dim": 1,
                            "packed_dim": 1,
                            "pack_factor": quant_config.pack_factor,
                        })
                    set_weight_attrs(param_scales_zeros, {
                        "input_dim": 0,
                        "output_dim": 1,
                    })
                    module.linear_weights["qweight"] = param_qweight
                    module.linear_weights["scales_zeros"] = param_scales_zeros
                    del _qzeros
                    del _scales
                    del param
                    delattr(module, "weight")
                    torch.cuda.empty_cache()

            else:  # load packed int4 weight
                module.linear_weights.pop("qweight")
                module.linear_weights.pop("qzeros")
                module.linear_weights.pop("scales")
                _qweight = module._parameters["qweight"]
                _qzeros = module._parameters["qzeros"]
                _scales = module._parameters["scales"]
                qweight, scales_zeros = convert_s4(_qweight.data, _qzeros.data,
                                                   _scales.data)
                param_qweight = Parameter(qweight, requires_grad=False)
                param_scales_zeros = Parameter(scales_zeros,
                                               requires_grad=False)
                del _qweight
                del _qzeros
                del _scales
                delattr(module, "qweight")
                delattr(module, "qzeros")
                delattr(module, "scales")
                module.register_parameter("qweight", param_qweight)
                module.register_parameter("scales_zeros", param_scales_zeros)
                set_weight_attrs(
                    param_qweight, {
                        "input_dim": 0,
                        "output_dim": 1,
                        "packed_dim": 1,
                        "pack_factor": quant_config.pack_factor,
                    })
                set_weight_attrs(param_scales_zeros, {
                    "input_dim": 0,
                    "output_dim": 1,
                })
                module.linear_weights["qweight"] = param_qweight
                module.linear_weights["scales_zeros"] = param_scales_zeros
                torch.cuda.synchronize()
                torch.cuda.empty_cache()