w8a8_benchmarks.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import argparse
  2. import copy
  3. import itertools
  4. import pickle as pkl
  5. import time
  6. from typing import Callable, Iterable, List, Tuple
  7. import torch
  8. import torch.utils.benchmark as TBenchmark
  9. from torch.utils.benchmark import Measurement as TMeasurement
  10. from weight_shapes import WEIGHT_SHAPES
  11. from aphrodite import _custom_ops as ops
  12. from aphrodite.common.utils import FlexibleArgumentParser
  13. DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
  14. DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
  15. DEFAULT_TP_SIZES = [1]
  16. # helpers
  17. def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
  18. finfo = torch.finfo(torch.float8_e4m3fn)
  19. return torch.round(tensor.clamp(
  20. min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
  21. def to_int8(tensor: torch.Tensor) -> torch.Tensor:
  22. return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
  23. def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
  24. k: int) -> Tuple[torch.Tensor, torch.Tensor]:
  25. a = torch.randn((m, k), device='cuda') * 5
  26. b = torch.randn((n, k), device='cuda').t() * 5
  27. if dtype == torch.int8:
  28. return to_int8(a), to_int8(b)
  29. if dtype == torch.float8_e4m3fn:
  30. return to_fp8(a), to_fp8(b)
  31. raise ValueError("unsupported dtype")
  32. # impl
  33. def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
  34. scale_b: torch.Tensor,
  35. out_dtype: torch.dtype) -> torch.Tensor:
  36. return torch.mm(a, b)
  37. def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
  38. scale_b: torch.Tensor,
  39. out_dtype: torch.dtype) -> torch.Tensor:
  40. return torch._scaled_mm(a,
  41. b,
  42. scale_a=scale_a,
  43. scale_b=scale_b,
  44. out_dtype=out_dtype)
  45. def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
  46. scale_a: torch.Tensor, scale_b: torch.Tensor,
  47. out_dtype: torch.dtype) -> torch.Tensor:
  48. return torch._scaled_mm(a,
  49. b,
  50. scale_a=scale_a,
  51. scale_b=scale_b,
  52. out_dtype=out_dtype,
  53. use_fast_accum=True)
  54. def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
  55. scale_b: torch.Tensor,
  56. out_dtype: torch.dtype) -> torch.Tensor:
  57. return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
  58. # bench
  59. def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
  60. scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
  61. sub_label: str, fn: Callable, description: str) -> TMeasurement:
  62. min_run_time = 1
  63. globals = {
  64. "a": a,
  65. "b": b,
  66. "scale_a": scale_a,
  67. "scale_b": scale_b,
  68. "out_dtype": out_dtype,
  69. "fn": fn,
  70. }
  71. return TBenchmark.Timer(
  72. stmt="fn(a, b, scale_a, scale_b, out_dtype)",
  73. globals=globals,
  74. label=label,
  75. sub_label=sub_label,
  76. description=description,
  77. ).blocked_autorange(min_run_time=min_run_time)
  78. def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
  79. sub_label: str) -> Iterable[TMeasurement]:
  80. assert dtype == torch.int8
  81. a, b = make_rand_tensors(torch.int8, m, n, k)
  82. scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
  83. scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
  84. timers = []
  85. # pytorch impl
  86. timers.append(
  87. bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
  88. b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
  89. torch.bfloat16, label, sub_label, pytorch_mm_impl,
  90. "pytorch_bf16_bf16_bf16_matmul-no-scales"))
  91. # cutlass impl
  92. timers.append(
  93. bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
  94. cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
  95. return timers
  96. def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
  97. sub_label: str) -> Iterable[TMeasurement]:
  98. assert dtype == torch.float8_e4m3fn
  99. a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
  100. scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
  101. scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
  102. timers = []
  103. # pytorch impl w. bf16
  104. timers.append(
  105. bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
  106. b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
  107. torch.bfloat16, label, sub_label, pytorch_mm_impl,
  108. "pytorch_bf16_bf16_bf16_matmul-no-scales"))
  109. # pytorch impl: bf16 output, without fp8 fast accum
  110. timers.append(
  111. bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
  112. pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
  113. # pytorch impl: bf16 output, with fp8 fast accum
  114. timers.append(
  115. bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
  116. pytorch_fp8_impl_fast_accum,
  117. "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
  118. # pytorch impl: fp16 output, without fp8 fast accum
  119. timers.append(
  120. bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
  121. pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
  122. # pytorch impl: fp16 output, with fp8 fast accum
  123. timers.append(
  124. bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
  125. pytorch_fp8_impl_fast_accum,
  126. "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
  127. # cutlass impl: bf16 output
  128. timers.append(
  129. bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
  130. cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
  131. # cutlass impl: fp16 output
  132. timers.append(
  133. bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
  134. cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
  135. return timers
  136. def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str,
  137. sub_label: str) -> Iterable[TMeasurement]:
  138. if dtype == torch.int8:
  139. return bench_int8(dtype, m, k, n, label, sub_label)
  140. if dtype == torch.float8_e4m3fn:
  141. return bench_fp8(dtype, m, k, n, label, sub_label)
  142. raise ValueError("unsupported type")
  143. # runner
  144. def print_timers(timers: Iterable[TMeasurement]):
  145. compare = TBenchmark.Compare(timers)
  146. compare.print()
  147. def run(dtype: torch.dtype,
  148. MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
  149. results = []
  150. for m, k, n in MKNs:
  151. timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
  152. f"MKN=({m}x{k}x{n})")
  153. print_timers(timers)
  154. results.extend(timers)
  155. return results
  156. # output makers
  157. def make_output(data: Iterable[TMeasurement],
  158. MKNs: Iterable[Tuple[int, int, int]],
  159. base_description: str,
  160. timestamp=None):
  161. print(f"== All Results {base_description} ====")
  162. print_timers(data)
  163. # pickle all the results
  164. timestamp = int(time.time()) if timestamp is None else timestamp
  165. with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
  166. pkl.dump(data, f)
  167. # argparse runners
  168. def run_square_bench(args):
  169. dim_sizes = list(
  170. range(args.dim_start, args.dim_end + 1, args.dim_increment))
  171. MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
  172. data = run(args.dtype, MKNs)
  173. make_output(data, MKNs, f"square_bench-{args.dtype}")
  174. def run_range_bench(args):
  175. dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
  176. n = len(dim_sizes)
  177. Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
  178. Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
  179. Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
  180. MKNs = list(zip(Ms, Ks, Ns))
  181. data = run(args.dtype, MKNs)
  182. make_output(data, MKNs, f"range_bench-{args.dtype}")
  183. def run_model_bench(args):
  184. print("Benchmarking models:")
  185. for i, model in enumerate(args.models):
  186. print(f"[{i}] {model}")
  187. def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
  188. KNs = []
  189. for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
  190. KN[tp_split_dim] = KN[tp_split_dim] // tp_size
  191. KNs.append(KN)
  192. return KNs
  193. model_bench_data = []
  194. models_tps = list(itertools.product(args.models, args.tp_sizes))
  195. for model, tp_size in models_tps:
  196. Ms = args.batch_sizes
  197. KNs = model_shapes(model, tp_size)
  198. MKNs = []
  199. for m in Ms:
  200. for k, n in KNs:
  201. MKNs.append((m, k, n))
  202. data = run(args.dtype, MKNs)
  203. model_bench_data.append(data)
  204. # Print all results
  205. for data, model_tp in zip(model_bench_data, models_tps):
  206. model, tp_size = model_tp
  207. print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
  208. print_timers(data)
  209. timestamp = int(time.time())
  210. all_data = []
  211. for d in model_bench_data:
  212. all_data.extend(d)
  213. # pickle all data
  214. with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
  215. pkl.dump(all_data, f)
  216. if __name__ == '__main__':
  217. def to_torch_dtype(dt):
  218. if dt == "int8":
  219. return torch.int8
  220. if dt == "fp8":
  221. return torch.float8_e4m3fn
  222. raise ValueError("unsupported dtype")
  223. parser = FlexibleArgumentParser(
  224. description="""
  225. Benchmark Cutlass GEMM.
  226. To run square GEMMs:
  227. python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
  228. To run constant N and K and sweep M:
  229. python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
  230. To run dimensions from a model:
  231. python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
  232. Output:
  233. - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
  234. """, # noqa: E501
  235. formatter_class=argparse.RawTextHelpFormatter)
  236. parser.add_argument("--dtype",
  237. type=to_torch_dtype,
  238. required=True,
  239. help="Available options are ['int8', 'fp8']")
  240. subparsers = parser.add_subparsers(dest="cmd")
  241. square_parser = subparsers.add_parser("square_bench")
  242. square_parser.add_argument("--dim-start", type=int, required=True)
  243. square_parser.add_argument("--dim-end", type=int, required=True)
  244. square_parser.add_argument("--dim-increment", type=int, required=True)
  245. square_parser.set_defaults(func=run_square_bench)
  246. range_parser = subparsers.add_parser("range_bench")
  247. range_parser.add_argument("--dim-start", type=int, required=True)
  248. range_parser.add_argument("--dim-end", type=int, required=True)
  249. range_parser.add_argument("--dim-increment", type=int, required=True)
  250. range_parser.add_argument("--m-constant", type=int, default=None)
  251. range_parser.add_argument("--n-constant", type=int, default=None)
  252. range_parser.add_argument("--k-constant", type=int, default=None)
  253. range_parser.set_defaults(func=run_range_bench)
  254. model_parser = subparsers.add_parser("model_bench")
  255. model_parser.add_argument("--models",
  256. nargs="+",
  257. type=str,
  258. default=DEFAULT_MODELS,
  259. choices=WEIGHT_SHAPES.keys())
  260. model_parser.add_argument("--tp-sizes",
  261. nargs="+",
  262. type=int,
  263. default=DEFAULT_TP_SIZES)
  264. model_parser.add_argument("--batch-sizes",
  265. nargs="+",
  266. type=int,
  267. default=DEFAULT_BATCH_SIZES)
  268. model_parser.set_defaults(func=run_model_bench)
  269. args = parser.parse_args()
  270. args.func(args)