import argparse
import copy
import itertools
import math
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES

from aphrodite import _custom_ops as ops
from aphrodite.common.utils import FlexibleArgumentParser
from aphrodite.quantization.utils.marlin_utils import (
    GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
from aphrodite.quantization.utils.marlin_utils_test import MarlinWorkspace
from aphrodite.quantization.utils.quant_utils import (gptq_pack, pack_rows,
                                                      quantize_weights)
from aphrodite.scalar_type import ScalarType, scalar_types

DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]


def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
    w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
    w_q = w_q.t().contiguous().t()  # make col major
    return ops.machete_prepack_B(w_q, wtype)


def make_bench_tensors(
    atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
    k: int
) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
                                    torch.tensor]]]:
    assert wtype.is_integer(), "TODO: support floating point weights"

    # we want to make sure that weights don't fit into L2 cache between runs so
    #  we construct enough weights to exceed L2 cache, which is 50mb on a H100
    #  so we target total weight size > 2*50mb
    num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))

    a = torch.randn((m, k), device="cuda", dtype=atype) * 5
    weights = [
        torch.randn((k, n), device="cuda", dtype=atype)
        for _ in range(num_weights)
    ]
    quanitized_weights = [
        quantize_weights(w, wtype, group_size) for w in weights
    ]

    return a, quanitized_weights


# impl


# bench
def bench_fn(label: str, sub_label: str, description: str,
             fn: Callable) -> TMeasurement:

    min_run_time = 1
    return TBenchmark.Timer(
        stmt="fn()",
        globals={
            "fn": fn
        },
        label=label,
        sub_label=sub_label,
        description=description,
    ).blocked_autorange(min_run_time=min_run_time)


def loop_over_weights(
    a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
                                         torch.tensor, torch.tensor]],
    fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
                 None]):
    for w_ref, w_q, w_s, _ in weights:
        fn(a, w_ref, w_q, w_s)


def bench(atype: torch.dtype,
          wtype: ScalarType,
          group_size: int,
          m: int,
          k: int,
          n: int,
          label: str,
          sub_label: str,
          benchmark_marlinv1: bool = True,
          sweep_schedules: bool = True) -> Iterable[TMeasurement]:
    a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
    sub_label += f", L={len(weights)}"

    weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
                       for w_ref, w_q, w_s, w_zp in weights]

    timers = []
    # pytorch impl
    timers.append(
        bench_fn(
            label, sub_label, "torch.matmul", lambda: loop_over_weights(
                a,
                weights,
                lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
            )))

    if benchmark_marlinv1:
        w_ref = weights[0][0]

        w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
        sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
        g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)

        def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
            w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
            return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
                                          wtype.size_bits)

        def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
            return marlin_permute_scales(w_s, *w_ref.shape, group_size)

        weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
                             marlinv1_permute_scales(w_s), w_zp)
                            for w_ref, w_q, w_s, w_zp in weights]

        workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
                                    GPTQ_MARLIN_MAX_PARALLEL)

        # marlinv1
        timers.append(
            bench_fn(
                label, sub_label, "marlin_orig", lambda: loop_over_weights(
                    a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
                    gptq_marlin_gemm(a,
                                     w_q,
                                     w_s,
                                     w_zp_empty,
                                     g_idx,
                                     sort_indices,
                                     workspace.scratch,
                                     wtype,
                                     size_m=a.shape[0],
                                     size_n=w_ref.shape[1],
                                     size_k=w_ref.shape[0],
                                     is_k_full=True))))

    # machete
    timers.append(
        bench_fn(
            label, sub_label, "machete_heuristic", lambda: loop_over_weights(
                a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
                    a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))

    if sweep_schedules:
        print("Finding best schedule for machete")
        best = None
        best_schedule = None
        schedules = ops.machete_supported_schedules(wtype)
        for schedule in reversed(schedules):

            def run(a, _, w_q, w_s, schedule=schedule):
                ops.machete_gemm(a,
                                 w_q,
                                 wtype,
                                 w_s,
                                 b_group_size=group_size,
                                 schedule=schedule)

            res = bench_fn(label, sub_label, "machete_best",
                           lambda: loop_over_weights(a, weights_machete, run))

            print(f"  {res.median:5.5} ", schedule)
            if not best or res.median < best.median:
                best = res
                best_schedule = schedule
        print("Best schedule:", best_schedule)
        timers.append(best)

    return timers


# runner
def print_timers(timers: Iterable[TMeasurement]):
    compare = TBenchmark.Compare(timers)
    compare.print()


def run(dtype: torch.dtype, sweep_schedules: bool,
        MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:

    results = []
    for m, k, n in MKNs:
        timers = bench(dtype,
                       scalar_types.uint4b8,
                       128,
                       m,
                       k,
                       n,
                       f"{dtype}-gemm",
                       f"MKN=({m}x{k}x{n})",
                       sweep_schedules=sweep_schedules)
        print_timers(timers)
        results.extend(timers)

    return results


# output makers
def make_output(
    data: Iterable[TMeasurement],
    MKNs: Iterable[Tuple[int, int, int]],
    base_description: str,
    timestamp=None,
):

    print(f"== All Results {base_description} ====")
    print_timers(data)

    # pickle all the results
    timestamp = int(time.time()) if timestamp is None else timestamp
    with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
        pkl.dump(data, f)


# argparse runners


def run_square_bench(args):
    dim_sizes = list(
        range(args.dim_start, args.dim_end + 1, args.dim_increment))
    MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
    data = run(args.dtype, args.sweep_schedules, MKNs)

    make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
    dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
    n = len(dim_sizes)
    Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
    Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
    Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
    MKNs = list(zip(Ms, Ks, Ns))
    data = run(args.dtype, args.sweep_schedules, MKNs)

    make_output(data, MKNs, f"range_bench-{args.dtype}")


def run_model_bench(args):

    print("Benchmarking models:")
    for i, model in enumerate(args.models):
        print(f"[{i}]  {model}")

    def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
        KNs = []
        for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
            KN[tp_split_dim] = KN[tp_split_dim] // tp_size
            KNs.append(KN)
        return KNs

    model_bench_data = []
    models_tps = list(itertools.product(args.models, args.tp_sizes))
    for model, tp_size in models_tps:
        Ms = args.batch_sizes
        KNs = model_shapes(model, tp_size)
        MKNs = []
        for m in Ms:
            for k, n in KNs:
                MKNs.append((m, k, n))

        data = run(args.dtype, args.sweep_schedules, MKNs)
        model_bench_data.append(data)

    # Print all results
    for data, model_tp in zip(model_bench_data, models_tps):
        model, tp_size = model_tp
        print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
        print_timers(data)

    timestamp = int(time.time())

    all_data = []
    for d in model_bench_data:
        all_data.extend(d)
    # pickle all data
    with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
        pkl.dump(all_data, f)


if __name__ == "__main__":

    def to_torch_dtype(dt):
        if dt == "bfloat16":
            return torch.bfloat16
        if dt == "float16":
            return torch.float16
        raise ValueError("unsupported dtype")

    parser = FlexibleArgumentParser(
        description="""
Benchmark Machete GEMM.
    To run square GEMMs:
        python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
    
    To run constant N and K and sweep M:
        python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
    
    To run dimensions from a model:
        python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
    
    Output:
        - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
            """,  # noqa: E501
        formatter_class=argparse.RawTextHelpFormatter,
    )

    parser.add_argument(
        "--dtype",
        type=to_torch_dtype,
        required=True,
        help="Available options are ['bfloat16', 'float16']",
    )
    parser.add_argument(
        "--sweep-schedules",
        action="store_true",
        help="Run a sweep over all supported schedules",
    )
    subparsers = parser.add_subparsers(dest="cmd", required=True)

    square_parser = subparsers.add_parser("square_bench")
    square_parser.add_argument("--dim-start", type=int, required=True)
    square_parser.add_argument("--dim-end", type=int, required=True)
    square_parser.add_argument("--dim-increment", type=int, required=True)
    square_parser.set_defaults(func=run_square_bench)

    range_parser = subparsers.add_parser("range_bench")
    range_parser.add_argument("--dim-start", type=int, required=True)
    range_parser.add_argument("--dim-end", type=int, required=True)
    range_parser.add_argument("--dim-increment", type=int, required=True)
    range_parser.add_argument("--m-constant", type=int, default=None)
    range_parser.add_argument("--n-constant", type=int, default=None)
    range_parser.add_argument("--k-constant", type=int, default=None)
    range_parser.set_defaults(func=run_range_bench)

    model_parser = subparsers.add_parser("model_bench")
    model_parser.add_argument(
        "--models",
        nargs="+",
        type=str,
        default=DEFAULT_MODELS,
        choices=WEIGHT_SHAPES.keys(),
    )
    model_parser.add_argument("--tp-sizes",
                              nargs="+",
                              type=int,
                              default=DEFAULT_TP_SIZES)
    model_parser.add_argument("--batch-sizes",
                              nargs="+",
                              type=int,
                              default=DEFAULT_BATCH_SIZES)
    model_parser.set_defaults(func=run_model_bench)

    args = parser.parse_args()
    args.func(args)