benchmark.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # Copyright (c) 2023, Tri Dao.
  2. """ Useful functions for writing test code. """
  3. import torch
  4. import torch.utils.benchmark as benchmark
  5. def benchmark_forward(
  6. fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
  7. ):
  8. """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
  9. if verbose:
  10. print(desc, "- Forward pass")
  11. def amp_wrapper(*inputs, **kwinputs):
  12. with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
  13. fn(*inputs, **kwinputs)
  14. t = benchmark.Timer(
  15. stmt="fn_amp(*inputs, **kwinputs)",
  16. globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
  17. num_threads=torch.get_num_threads(),
  18. )
  19. m = t.timeit(repeats)
  20. if verbose:
  21. print(m)
  22. return t, m
  23. def benchmark_backward(
  24. fn,
  25. *inputs,
  26. grad=None,
  27. repeats=10,
  28. desc="",
  29. verbose=True,
  30. amp=False,
  31. amp_dtype=torch.float16,
  32. **kwinputs,
  33. ):
  34. """Use Pytorch Benchmark on the backward pass of an arbitrary function."""
  35. if verbose:
  36. print(desc, "- Backward pass")
  37. with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
  38. y = fn(*inputs, **kwinputs)
  39. if type(y) is tuple:
  40. y = y[0]
  41. if grad is None:
  42. grad = torch.randn_like(y)
  43. else:
  44. if grad.shape != y.shape:
  45. raise RuntimeError("Grad shape does not match output shape")
  46. def f(*inputs, y, grad):
  47. # Set .grad to None to avoid extra operation of gradient accumulation
  48. for x in inputs:
  49. if isinstance(x, torch.Tensor):
  50. x.grad = None
  51. y.backward(grad, retain_graph=True)
  52. t = benchmark.Timer(
  53. stmt="f(*inputs, y=y, grad=grad)",
  54. globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
  55. num_threads=torch.get_num_threads(),
  56. )
  57. m = t.timeit(repeats)
  58. if verbose:
  59. print(m)
  60. return t, m
  61. def benchmark_combined(
  62. fn,
  63. *inputs,
  64. grad=None,
  65. repeats=10,
  66. desc="",
  67. verbose=True,
  68. amp=False,
  69. amp_dtype=torch.float16,
  70. **kwinputs,
  71. ):
  72. """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
  73. if verbose:
  74. print(desc, "- Forward + Backward pass")
  75. with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
  76. y = fn(*inputs, **kwinputs)
  77. if type(y) is tuple:
  78. y = y[0]
  79. if grad is None:
  80. grad = torch.randn_like(y)
  81. else:
  82. if grad.shape != y.shape:
  83. raise RuntimeError("Grad shape does not match output shape")
  84. def f(grad, *inputs, **kwinputs):
  85. for x in inputs:
  86. if isinstance(x, torch.Tensor):
  87. x.grad = None
  88. with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
  89. y = fn(*inputs, **kwinputs)
  90. if type(y) is tuple:
  91. y = y[0]
  92. y.backward(grad, retain_graph=True)
  93. t = benchmark.Timer(
  94. stmt="f(grad, *inputs, **kwinputs)",
  95. globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
  96. num_threads=torch.get_num_threads(),
  97. )
  98. m = t.timeit(repeats)
  99. if verbose:
  100. print(m)
  101. return t, m
  102. def benchmark_fwd_bwd(
  103. fn,
  104. *inputs,
  105. grad=None,
  106. repeats=10,
  107. desc="",
  108. verbose=True,
  109. amp=False,
  110. amp_dtype=torch.float16,
  111. **kwinputs,
  112. ):
  113. """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
  114. return (
  115. benchmark_forward(
  116. fn,
  117. *inputs,
  118. repeats=repeats,
  119. desc=desc,
  120. verbose=verbose,
  121. amp=amp,
  122. amp_dtype=amp_dtype,
  123. **kwinputs,
  124. ),
  125. benchmark_backward(
  126. fn,
  127. *inputs,
  128. grad=grad,
  129. repeats=repeats,
  130. desc=desc,
  131. verbose=verbose,
  132. amp=amp,
  133. amp_dtype=amp_dtype,
  134. **kwinputs,
  135. ),
  136. )
  137. def benchmark_all(
  138. fn,
  139. *inputs,
  140. grad=None,
  141. repeats=10,
  142. desc="",
  143. verbose=True,
  144. amp=False,
  145. amp_dtype=torch.float16,
  146. **kwinputs,
  147. ):
  148. """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
  149. return (
  150. benchmark_forward(
  151. fn,
  152. *inputs,
  153. repeats=repeats,
  154. desc=desc,
  155. verbose=verbose,
  156. amp=amp,
  157. amp_dtype=amp_dtype,
  158. **kwinputs,
  159. ),
  160. benchmark_backward(
  161. fn,
  162. *inputs,
  163. grad=grad,
  164. repeats=repeats,
  165. desc=desc,
  166. verbose=verbose,
  167. amp=amp,
  168. amp_dtype=amp_dtype,
  169. **kwinputs,
  170. ),
  171. benchmark_combined(
  172. fn,
  173. *inputs,
  174. grad=grad,
  175. repeats=repeats,
  176. desc=desc,
  177. verbose=verbose,
  178. amp=amp,
  179. amp_dtype=amp_dtype,
  180. **kwinputs,
  181. ),
  182. )
  183. def pytorch_profiler(
  184. fn,
  185. *inputs,
  186. trace_filename=None,
  187. backward=False,
  188. amp=False,
  189. amp_dtype=torch.float16,
  190. cpu=False,
  191. verbose=True,
  192. **kwinputs,
  193. ):
  194. """Wrap benchmark functions in Pytorch profiler to see CUDA information."""
  195. if backward:
  196. with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
  197. out = fn(*inputs, **kwinputs)
  198. if type(out) is tuple:
  199. out = out[0]
  200. g = torch.randn_like(out)
  201. for _ in range(30): # Warm up
  202. if backward:
  203. for x in inputs:
  204. if isinstance(x, torch.Tensor):
  205. x.grad = None
  206. with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
  207. out = fn(*inputs, **kwinputs)
  208. if type(out) is tuple:
  209. out = out[0]
  210. # Backward should be done outside autocast
  211. if backward:
  212. out.backward(g, retain_graph=True)
  213. activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
  214. torch.profiler.ProfilerActivity.CUDA
  215. ]
  216. with torch.profiler.profile(
  217. activities=activities,
  218. record_shapes=True,
  219. # profile_memory=True,
  220. with_stack=True,
  221. ) as prof:
  222. if backward:
  223. for x in inputs:
  224. if isinstance(x, torch.Tensor):
  225. x.grad = None
  226. with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
  227. out = fn(*inputs, **kwinputs)
  228. if type(out) is tuple:
  229. out = out[0]
  230. if backward:
  231. out.backward(g, retain_graph=True)
  232. if verbose:
  233. # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
  234. print(prof.key_averages().table(row_limit=50))
  235. if trace_filename is not None:
  236. prof.export_chrome_trace(trace_filename)
  237. def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
  238. torch.cuda.empty_cache()
  239. torch.cuda.reset_peak_memory_stats()
  240. torch.cuda.synchronize()
  241. fn(*inputs, **kwinputs)
  242. torch.cuda.synchronize()
  243. mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
  244. if verbose:
  245. print(f"{desc} max memory: {mem}GB")
  246. torch.cuda.empty_cache()
  247. return mem