graph_machete_bench.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import math
  2. import pickle
  3. import re
  4. from collections import defaultdict
  5. from typing import List
  6. import matplotlib.pyplot as plt
  7. import pandas as pd
  8. import seaborn as sns
  9. from torch.utils.benchmark import Measurement as TMeasurement
  10. from aphrodite.common.utils import FlexibleArgumentParser
  11. if __name__ == "__main__":
  12. parser = FlexibleArgumentParser(
  13. description='Benchmark the latency of processing a single batch of '
  14. 'requests till completion.')
  15. parser.add_argument('filename', type=str)
  16. args = parser.parse_args()
  17. with open(args.filename, 'rb') as f:
  18. data: List[TMeasurement] = pickle.load(f)
  19. results = defaultdict(lambda: list())
  20. for v in data:
  21. result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
  22. if result is not None:
  23. KN = result.group(1)
  24. else:
  25. raise Exception("MKN not found")
  26. result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
  27. if result is not None:
  28. M = result.group(1)
  29. else:
  30. raise Exception("MKN not found")
  31. kernel = v.task_spec.description
  32. results[KN].append({
  33. "kernel": kernel,
  34. "batch_size": M,
  35. "median": v.median
  36. })
  37. rows = int(math.ceil(len(results) / 2))
  38. fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
  39. axs = axs.flatten()
  40. axs_idx = 0
  41. for shape, data in results.items():
  42. plt.sca(axs[axs_idx])
  43. df = pd.DataFrame(data)
  44. sns.lineplot(data=df,
  45. x="batch_size",
  46. y="median",
  47. hue="kernel",
  48. style="kernel",
  49. markers=True,
  50. dashes=False,
  51. palette="Dark2")
  52. plt.title(f"Shape: {shape}")
  53. plt.ylabel("time (median, s)")
  54. axs_idx += 1
  55. plt.tight_layout()
  56. plt.savefig("graph_machete_bench.pdf")