import math
from contextlib import suppress
from pathlib import Path

import scipy
import torch
from safetensors.torch import load_file

with suppress(ImportError):
    import aphrodite._hadamard_C as hadamard_C

HADA_TENSORS = load_file(
    Path(__file__).resolve().parent / "hadamard.safetensors")


class HadamardTransformFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, scale=1.0):
        ctx._hadamard_transform_scale = scale  # pylint: disable=protected-access
        return hadamard_C.fast_hadamard_transform(x, scale)


def hadamard_transform(x, scale=1.0):
    return HadamardTransformFn.apply(x, scale)


def int2mask(i, int_map):
    return ((i & int_map) > 0).int()


def mask2int(mask, int_map):
    return (int_map.unsqueeze(0) * mask.int()).sum(dim=-1)


def get_norm12():
    # 29 elements of norm 12 in E8 + 1/4
    return torch.tensor([
        [3, 1, 1, 1, 3, 3, 3, 3],
        [1, 3, 1, 1, 3, 3, 3, 3],
        [1, 1, 3, 1, 3, 3, 3, 3],
        [1, 1, 1, 3, 3, 3, 3, 3],
        [3, 3, 3, 1, 3, 3, 1, 1],
        [3, 3, 3, 1, 3, 1, 3, 1],
        [3, 3, 3, 1, 1, 3, 3, 1],
        [3, 3, 3, 1, 3, 1, 1, 3],
        [3, 3, 3, 1, 1, 3, 1, 3],
        [3, 3, 3, 1, 1, 1, 3, 3],
        [3, 3, 1, 3, 3, 3, 1, 1],
        [3, 3, 1, 3, 3, 1, 3, 1],
        [3, 3, 1, 3, 1, 3, 3, 1],
        [3, 3, 1, 3, 3, 1, 1, 3],
        [3, 3, 1, 3, 1, 3, 1, 3],
        [3, 3, 1, 3, 1, 1, 3, 3],
        [3, 1, 3, 3, 3, 3, 1, 1],
        [3, 1, 3, 3, 3, 1, 3, 1],
        [3, 1, 3, 3, 1, 3, 3, 1],
        [3, 1, 3, 3, 3, 1, 1, 3],
        [3, 1, 3, 3, 1, 3, 1, 3],
        [1, 3, 3, 3, 1, 1, 3, 3],
        [1, 3, 3, 3, 3, 3, 1, 1],
        [1, 3, 3, 3, 3, 1, 3, 1],
        [1, 3, 3, 3, 1, 3, 3, 1],
        [1, 3, 3, 3, 3, 1, 1, 3],
        [1, 3, 3, 3, 1, 3, 1, 3],
        [1, 1, 3, 3, 1, 3, 3, 3],
        [3, 3, 1, 1, 3, 3, 3, 1],
    ]) / 2


def get_packed_abs_grid():
    intr = torch.arange(-4, 4)
    d8 = torch.cartesian_prod(*[intr] * 8).float() + 1 / 2
    d8m2 = d8.sum(dim=-1) % 2 == 0
    d8n = d8.norm(dim=-1)**2 <= 10
    d8abs = torch.unique(d8[sorted(torch.where(d8m2 * d8n)[0])].abs(), dim=0)
    norm12 = get_norm12()
    cba = torch.concat([d8abs, norm12], dim=0)
    cba = cba[:, [0, 2, 1, 3, 4, 6, 5, 7]]
    cba[:, 7] *= (1 - 2 * (cba.sum(1) % 2))
    cba = cba * 4
    cba = cba.to(torch.int64)
    acc = cba[:, 0]
    for i in range(7):
        acc = acc | (cba[:, (i + 1)] << ((i + 1) * 8))
    return acc


def next_power_of_2(n):
    if n == 0:
        return 1
    return 2**math.ceil(math.log(n, 2))


def get_power_of_2(n):
    """Returns the highest power of 2 that divides n."""
    k = 0
    while n % 2 == 0:
        n //= 2
        k += 1
    return k, n


def get_hadK(n, use_rand=True):
    exp, base = get_power_of_2(n)
    if base == 1:
        return None, 1, n
    if use_rand:
        rand_mat = torch.tensor(scipy.stats.special_ortho_group.rvs(base)).to(
            torch.float32)
        return rand_mat, base, n

    # Use hadamad only and add padding if cannot find one
    pad_n = next_power_of_2(n)
    if exp < 2 or str(base * 4) not in HADA_TENSORS:
        return None, 1, pad_n
    base_mat = HADA_TENSORS[str(base * 4)] / math.sqrt(base * 4)
    return base_mat, base * 4, n


def matmul_hadU_cuda(X, hadK, K, n, scale=None, transpose=False):
    if n != X.shape[-1]:
        X = torch.nn.functional.pad(X, (0, n - X.shape[-1]))

    had_scale = 1 / math.sqrt(n // K) if scale is None else scale / math.sqrt(
        n // K)
    if K == 1:
        return hadamard_transform(X.contiguous(), scale=had_scale)

    if transpose:
        hadK = hadK.T.contiguous()
    input = X.view(-1, K, n // K)  # pylint: disable=redefined-builtin
    input = hadamard_transform(input.contiguous(), scale=had_scale)
    input = hadK @ input
    return input.reshape(X.shape)


def matmul_hadUt_cuda(X, hadK, K, n, scale=None):
    return matmul_hadU_cuda(X, hadK, K, n, scale=scale, transpose=True)