12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364 |
- import math
- import pickle
- import re
- from collections import defaultdict
- from typing import List
- import matplotlib.pyplot as plt
- import pandas as pd
- import seaborn as sns
- from torch.utils.benchmark import Measurement as TMeasurement
- from aphrodite.common.utils import FlexibleArgumentParser
- if __name__ == "__main__":
- parser = FlexibleArgumentParser(
- description='Benchmark the latency of processing a single batch of '
- 'requests till completion.')
- parser.add_argument('filename', type=str)
- args = parser.parse_args()
- with open(args.filename, 'rb') as f:
- data: List[TMeasurement] = pickle.load(f)
- results = defaultdict(lambda: list())
- for v in data:
- result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
- if result is not None:
- KN = result.group(1)
- else:
- raise Exception("MKN not found")
- result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
- if result is not None:
- M = result.group(1)
- else:
- raise Exception("MKN not found")
- kernel = v.task_spec.description
- results[KN].append({
- "kernel": kernel,
- "batch_size": M,
- "median": v.median
- })
- rows = int(math.ceil(len(results) / 2))
- fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
- axs = axs.flatten()
- axs_idx = 0
- for shape, data in results.items():
- plt.sca(axs[axs_idx])
- df = pd.DataFrame(data)
- sns.lineplot(data=df,
- x="batch_size",
- y="median",
- hue="kernel",
- style="kernel",
- markers=True,
- dashes=False,
- palette="Dark2")
- plt.title(f"Shape: {shape}")
- plt.ylabel("time (median, s)")
- axs_idx += 1
- plt.tight_layout()
- plt.savefig("graph_machete_bench.pdf")
|