benchmark_gemm.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import time
  2. import torch
  3. import torch.utils.benchmark as benchmark
  4. from triton.testing import do_bench
  5. if torch.version.cuda:
  6. backendBLAS = "cuBLAS"
  7. elif torch.version.hip:
  8. backendBLAS = "hipBLAS"
  9. def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, **kwinputs):
  10. """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
  11. if verbose:
  12. print(desc, '- Forward pass')
  13. t = benchmark.Timer(
  14. stmt='fn(*inputs, **kwinputs)',
  15. globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs},
  16. num_threads=torch.get_num_threads(),
  17. )
  18. m = t.timeit(repeats)
  19. if verbose:
  20. print(m)
  21. return t, m
  22. torch.manual_seed(0)
  23. repeats = 30
  24. dtype = torch.float16
  25. device = 'cuda'
  26. verbose = False
  27. m, n = 8192, 8192
  28. tflops_matmul = {}
  29. tflops_matmul1 = {}
  30. for k in [512, 1024, 1536, 2048, 2560, 3072, 3584, 4096, 4608, 5120, 5632, 6144, 6656, 7168, 7680, 8192]:
  31. a = torch.randn(m, k, device=device, dtype=dtype)
  32. b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
  33. nFLOPS_matmul = 2 * m * n * k
  34. time.sleep(2) # to reduce power throttling
  35. timing = benchmark_forward(torch.matmul, a, b, desc=backendBLAS, verbose=verbose, repeats=repeats)[1]
  36. tflops_matmul[k] = nFLOPS_matmul / timing.mean * 1e-12
  37. print(f'[torch.utils.benchmark] {backendBLAS}, {m = }, {n = }, {k = }: {timing.mean * 1e3:.3f}ms, {tflops_matmul[k]:.1f} TFLOPS')
  38. time.sleep(2) # to reduce power throttling
  39. ms = do_bench(lambda: torch.matmul(a, b), warmup=10, rep=repeats)
  40. tflops_matmul1[k] = nFLOPS_matmul / ms * 1e-9
  41. print(f'[triton.test.do_bench] {backendBLAS}, {m = }, {n = }, {k = }: {ms:.3f}ms, {tflops_matmul1[k]:.1f} TFLOPS')