benchmark_machete.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import argparse
  2. import copy
  3. import itertools
  4. import math
  5. import pickle as pkl
  6. import time
  7. from typing import Callable, Iterable, List, Tuple
  8. import torch
  9. import torch.utils.benchmark as TBenchmark
  10. from torch.utils.benchmark import Measurement as TMeasurement
  11. from weight_shapes import WEIGHT_SHAPES
  12. from aphrodite import _custom_ops as ops
  13. from aphrodite.common.utils import FlexibleArgumentParser
  14. from aphrodite.quantization.utils.marlin_utils import (
  15. GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
  16. from aphrodite.quantization.utils.marlin_utils_test import MarlinWorkspace
  17. from aphrodite.quantization.utils.quant_utils import (gptq_pack, pack_rows,
  18. quantize_weights)
  19. from aphrodite.scalar_type import ScalarType, scalar_types
  20. DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
  21. DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
  22. DEFAULT_TP_SIZES = [1]
  23. def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
  24. w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
  25. w_q = w_q.t().contiguous().t() # make col major
  26. return ops.machete_prepack_B(w_q, wtype)
  27. def make_bench_tensors(
  28. atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
  29. k: int
  30. ) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
  31. torch.tensor]]]:
  32. assert wtype.is_integer(), "TODO: support floating point weights"
  33. # we want to make sure that weights don't fit into L2 cache between runs so
  34. # we construct enough weights to exceed L2 cache, which is 50mb on a H100
  35. # so we target total weight size > 2*50mb
  36. num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
  37. a = torch.randn((m, k), device="cuda", dtype=atype) * 5
  38. weights = [
  39. torch.randn((k, n), device="cuda", dtype=atype)
  40. for _ in range(num_weights)
  41. ]
  42. quanitized_weights = [
  43. quantize_weights(w, wtype, group_size) for w in weights
  44. ]
  45. return a, quanitized_weights
  46. # impl
  47. # bench
  48. def bench_fn(label: str, sub_label: str, description: str,
  49. fn: Callable) -> TMeasurement:
  50. min_run_time = 1
  51. return TBenchmark.Timer(
  52. stmt="fn()",
  53. globals={
  54. "fn": fn
  55. },
  56. label=label,
  57. sub_label=sub_label,
  58. description=description,
  59. ).blocked_autorange(min_run_time=min_run_time)
  60. def loop_over_weights(
  61. a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
  62. torch.tensor, torch.tensor]],
  63. fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
  64. None]):
  65. for w_ref, w_q, w_s, _ in weights:
  66. fn(a, w_ref, w_q, w_s)
  67. def bench(atype: torch.dtype,
  68. wtype: ScalarType,
  69. group_size: int,
  70. m: int,
  71. k: int,
  72. n: int,
  73. label: str,
  74. sub_label: str,
  75. benchmark_marlinv1: bool = True,
  76. sweep_schedules: bool = True) -> Iterable[TMeasurement]:
  77. a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
  78. sub_label += f", L={len(weights)}"
  79. weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
  80. for w_ref, w_q, w_s, w_zp in weights]
  81. timers = []
  82. # pytorch impl
  83. timers.append(
  84. bench_fn(
  85. label, sub_label, "torch.matmul", lambda: loop_over_weights(
  86. a,
  87. weights,
  88. lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
  89. )))
  90. if benchmark_marlinv1:
  91. w_ref = weights[0][0]
  92. w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
  93. sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
  94. g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
  95. def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
  96. w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
  97. return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
  98. wtype.size_bits)
  99. def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
  100. return marlin_permute_scales(w_s, *w_ref.shape, group_size)
  101. weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
  102. marlinv1_permute_scales(w_s), w_zp)
  103. for w_ref, w_q, w_s, w_zp in weights]
  104. workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
  105. GPTQ_MARLIN_MAX_PARALLEL)
  106. # marlinv1
  107. timers.append(
  108. bench_fn(
  109. label, sub_label, "marlin_orig", lambda: loop_over_weights(
  110. a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
  111. gptq_marlin_gemm(a,
  112. w_q,
  113. w_s,
  114. w_zp_empty,
  115. g_idx,
  116. sort_indices,
  117. workspace.scratch,
  118. wtype,
  119. size_m=a.shape[0],
  120. size_n=w_ref.shape[1],
  121. size_k=w_ref.shape[0],
  122. is_k_full=True))))
  123. # machete
  124. timers.append(
  125. bench_fn(
  126. label, sub_label, "machete_heuristic", lambda: loop_over_weights(
  127. a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
  128. a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
  129. if sweep_schedules:
  130. print("Finding best schedule for machete")
  131. best = None
  132. best_schedule = None
  133. schedules = ops.machete_supported_schedules(wtype)
  134. for schedule in reversed(schedules):
  135. def run(a, _, w_q, w_s, schedule=schedule):
  136. ops.machete_gemm(a,
  137. w_q,
  138. wtype,
  139. w_s,
  140. b_group_size=group_size,
  141. schedule=schedule)
  142. res = bench_fn(label, sub_label, "machete_best",
  143. lambda: loop_over_weights(a, weights_machete, run))
  144. print(f" {res.median:5.5} ", schedule)
  145. if not best or res.median < best.median:
  146. best = res
  147. best_schedule = schedule
  148. print("Best schedule:", best_schedule)
  149. timers.append(best)
  150. return timers
  151. # runner
  152. def print_timers(timers: Iterable[TMeasurement]):
  153. compare = TBenchmark.Compare(timers)
  154. compare.print()
  155. def run(dtype: torch.dtype, sweep_schedules: bool,
  156. MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
  157. results = []
  158. for m, k, n in MKNs:
  159. timers = bench(dtype,
  160. scalar_types.uint4b8,
  161. 128,
  162. m,
  163. k,
  164. n,
  165. f"{dtype}-gemm",
  166. f"MKN=({m}x{k}x{n})",
  167. sweep_schedules=sweep_schedules)
  168. print_timers(timers)
  169. results.extend(timers)
  170. return results
  171. # output makers
  172. def make_output(
  173. data: Iterable[TMeasurement],
  174. MKNs: Iterable[Tuple[int, int, int]],
  175. base_description: str,
  176. timestamp=None,
  177. ):
  178. print(f"== All Results {base_description} ====")
  179. print_timers(data)
  180. # pickle all the results
  181. timestamp = int(time.time()) if timestamp is None else timestamp
  182. with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
  183. pkl.dump(data, f)
  184. # argparse runners
  185. def run_square_bench(args):
  186. dim_sizes = list(
  187. range(args.dim_start, args.dim_end + 1, args.dim_increment))
  188. MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
  189. data = run(args.dtype, args.sweep_schedules, MKNs)
  190. make_output(data, MKNs, f"square_bench-{args.dtype}")
  191. def run_range_bench(args):
  192. dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
  193. n = len(dim_sizes)
  194. Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
  195. Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
  196. Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
  197. MKNs = list(zip(Ms, Ks, Ns))
  198. data = run(args.dtype, args.sweep_schedules, MKNs)
  199. make_output(data, MKNs, f"range_bench-{args.dtype}")
  200. def run_model_bench(args):
  201. print("Benchmarking models:")
  202. for i, model in enumerate(args.models):
  203. print(f"[{i}] {model}")
  204. def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
  205. KNs = []
  206. for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
  207. KN[tp_split_dim] = KN[tp_split_dim] // tp_size
  208. KNs.append(KN)
  209. return KNs
  210. model_bench_data = []
  211. models_tps = list(itertools.product(args.models, args.tp_sizes))
  212. for model, tp_size in models_tps:
  213. Ms = args.batch_sizes
  214. KNs = model_shapes(model, tp_size)
  215. MKNs = []
  216. for m in Ms:
  217. for k, n in KNs:
  218. MKNs.append((m, k, n))
  219. data = run(args.dtype, args.sweep_schedules, MKNs)
  220. model_bench_data.append(data)
  221. # Print all results
  222. for data, model_tp in zip(model_bench_data, models_tps):
  223. model, tp_size = model_tp
  224. print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
  225. print_timers(data)
  226. timestamp = int(time.time())
  227. all_data = []
  228. for d in model_bench_data:
  229. all_data.extend(d)
  230. # pickle all data
  231. with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
  232. pkl.dump(all_data, f)
  233. if __name__ == "__main__":
  234. def to_torch_dtype(dt):
  235. if dt == "bfloat16":
  236. return torch.bfloat16
  237. if dt == "float16":
  238. return torch.float16
  239. raise ValueError("unsupported dtype")
  240. parser = FlexibleArgumentParser(
  241. description="""
  242. Benchmark Machete GEMM.
  243. To run square GEMMs:
  244. python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
  245. To run constant N and K and sweep M:
  246. python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
  247. To run dimensions from a model:
  248. python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
  249. Output:
  250. - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
  251. """, # noqa: E501
  252. formatter_class=argparse.RawTextHelpFormatter,
  253. )
  254. parser.add_argument(
  255. "--dtype",
  256. type=to_torch_dtype,
  257. required=True,
  258. help="Available options are ['bfloat16', 'float16']",
  259. )
  260. parser.add_argument(
  261. "--sweep-schedules",
  262. action="store_true",
  263. help="Run a sweep over all supported schedules",
  264. )
  265. subparsers = parser.add_subparsers(dest="cmd", required=True)
  266. square_parser = subparsers.add_parser("square_bench")
  267. square_parser.add_argument("--dim-start", type=int, required=True)
  268. square_parser.add_argument("--dim-end", type=int, required=True)
  269. square_parser.add_argument("--dim-increment", type=int, required=True)
  270. square_parser.set_defaults(func=run_square_bench)
  271. range_parser = subparsers.add_parser("range_bench")
  272. range_parser.add_argument("--dim-start", type=int, required=True)
  273. range_parser.add_argument("--dim-end", type=int, required=True)
  274. range_parser.add_argument("--dim-increment", type=int, required=True)
  275. range_parser.add_argument("--m-constant", type=int, default=None)
  276. range_parser.add_argument("--n-constant", type=int, default=None)
  277. range_parser.add_argument("--k-constant", type=int, default=None)
  278. range_parser.set_defaults(func=run_range_bench)
  279. model_parser = subparsers.add_parser("model_bench")
  280. model_parser.add_argument(
  281. "--models",
  282. nargs="+",
  283. type=str,
  284. default=DEFAULT_MODELS,
  285. choices=WEIGHT_SHAPES.keys(),
  286. )
  287. model_parser.add_argument("--tp-sizes",
  288. nargs="+",
  289. type=int,
  290. default=DEFAULT_TP_SIZES)
  291. model_parser.add_argument("--batch-sizes",
  292. nargs="+",
  293. type=int,
  294. default=DEFAULT_BATCH_SIZES)
  295. model_parser.set_defaults(func=run_model_bench)
  296. args = parser.parse_args()
  297. args.func(args)