import math
import pickle
import re
from collections import defaultdict
from typing import List

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement

from aphrodite.common.utils import FlexibleArgumentParser

if __name__ == "__main__":
    parser = FlexibleArgumentParser(
        description='Benchmark the latency of processing a single batch of '
        'requests till completion.')
    parser.add_argument('filename', type=str)

    args = parser.parse_args()

    with open(args.filename, 'rb') as f:
        data: List[TMeasurement] = pickle.load(f)

    results = defaultdict(lambda: list())
    for v in data:
        result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
        if result is not None:
            KN = result.group(1)
        else:
            raise Exception("MKN not found")
        result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
        if result is not None:
            M = result.group(1)
        else:
            raise Exception("MKN not found")

        kernel = v.task_spec.description
        results[KN].append({
            "kernel": kernel,
            "batch_size": M,
            "median": v.median
        })

    rows = int(math.ceil(len(results) / 2))
    fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
    axs = axs.flatten()
    axs_idx = 0
    for shape, data in results.items():
        plt.sca(axs[axs_idx])
        df = pd.DataFrame(data)
        sns.lineplot(data=df,
                     x="batch_size",
                     y="median",
                     hue="kernel",
                     style="kernel",
                     markers=True,
                     dashes=False,
                     palette="Dark2")
        plt.title(f"Shape: {shape}")
        plt.ylabel("time (median, s)")
        axs_idx += 1
    plt.tight_layout()
    plt.savefig("graph_machete_bench.pdf")