123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- # Copyright (c) 2023, Tri Dao.
- """ Useful functions for writing test code. """
- import torch
- import torch.utils.benchmark as benchmark
- def benchmark_forward(
- fn, *inputs, repeats=10, desc="", verbose=True, amp=False, amp_dtype=torch.float16, **kwinputs
- ):
- """Use Pytorch Benchmark on the forward pass of an arbitrary function."""
- if verbose:
- print(desc, "- Forward pass")
- def amp_wrapper(*inputs, **kwinputs):
- with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
- fn(*inputs, **kwinputs)
- t = benchmark.Timer(
- stmt="fn_amp(*inputs, **kwinputs)",
- globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
- num_threads=torch.get_num_threads(),
- )
- m = t.timeit(repeats)
- if verbose:
- print(m)
- return t, m
- def benchmark_backward(
- fn,
- *inputs,
- grad=None,
- repeats=10,
- desc="",
- verbose=True,
- amp=False,
- amp_dtype=torch.float16,
- **kwinputs,
- ):
- """Use Pytorch Benchmark on the backward pass of an arbitrary function."""
- if verbose:
- print(desc, "- Backward pass")
- with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
- y = fn(*inputs, **kwinputs)
- if type(y) is tuple:
- y = y[0]
- if grad is None:
- grad = torch.randn_like(y)
- else:
- if grad.shape != y.shape:
- raise RuntimeError("Grad shape does not match output shape")
- def f(*inputs, y, grad):
- # Set .grad to None to avoid extra operation of gradient accumulation
- for x in inputs:
- if isinstance(x, torch.Tensor):
- x.grad = None
- y.backward(grad, retain_graph=True)
- t = benchmark.Timer(
- stmt="f(*inputs, y=y, grad=grad)",
- globals={"f": f, "inputs": inputs, "y": y, "grad": grad},
- num_threads=torch.get_num_threads(),
- )
- m = t.timeit(repeats)
- if verbose:
- print(m)
- return t, m
- def benchmark_combined(
- fn,
- *inputs,
- grad=None,
- repeats=10,
- desc="",
- verbose=True,
- amp=False,
- amp_dtype=torch.float16,
- **kwinputs,
- ):
- """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
- if verbose:
- print(desc, "- Forward + Backward pass")
- with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
- y = fn(*inputs, **kwinputs)
- if type(y) is tuple:
- y = y[0]
- if grad is None:
- grad = torch.randn_like(y)
- else:
- if grad.shape != y.shape:
- raise RuntimeError("Grad shape does not match output shape")
- def f(grad, *inputs, **kwinputs):
- for x in inputs:
- if isinstance(x, torch.Tensor):
- x.grad = None
- with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
- y = fn(*inputs, **kwinputs)
- if type(y) is tuple:
- y = y[0]
- y.backward(grad, retain_graph=True)
- t = benchmark.Timer(
- stmt="f(grad, *inputs, **kwinputs)",
- globals={"f": f, "fn": fn, "inputs": inputs, "grad": grad, "kwinputs": kwinputs},
- num_threads=torch.get_num_threads(),
- )
- m = t.timeit(repeats)
- if verbose:
- print(m)
- return t, m
- def benchmark_fwd_bwd(
- fn,
- *inputs,
- grad=None,
- repeats=10,
- desc="",
- verbose=True,
- amp=False,
- amp_dtype=torch.float16,
- **kwinputs,
- ):
- """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
- return (
- benchmark_forward(
- fn,
- *inputs,
- repeats=repeats,
- desc=desc,
- verbose=verbose,
- amp=amp,
- amp_dtype=amp_dtype,
- **kwinputs,
- ),
- benchmark_backward(
- fn,
- *inputs,
- grad=grad,
- repeats=repeats,
- desc=desc,
- verbose=verbose,
- amp=amp,
- amp_dtype=amp_dtype,
- **kwinputs,
- ),
- )
- def benchmark_all(
- fn,
- *inputs,
- grad=None,
- repeats=10,
- desc="",
- verbose=True,
- amp=False,
- amp_dtype=torch.float16,
- **kwinputs,
- ):
- """Use Pytorch Benchmark on the forward+backward pass of an arbitrary function."""
- return (
- benchmark_forward(
- fn,
- *inputs,
- repeats=repeats,
- desc=desc,
- verbose=verbose,
- amp=amp,
- amp_dtype=amp_dtype,
- **kwinputs,
- ),
- benchmark_backward(
- fn,
- *inputs,
- grad=grad,
- repeats=repeats,
- desc=desc,
- verbose=verbose,
- amp=amp,
- amp_dtype=amp_dtype,
- **kwinputs,
- ),
- benchmark_combined(
- fn,
- *inputs,
- grad=grad,
- repeats=repeats,
- desc=desc,
- verbose=verbose,
- amp=amp,
- amp_dtype=amp_dtype,
- **kwinputs,
- ),
- )
- def pytorch_profiler(
- fn,
- *inputs,
- trace_filename=None,
- backward=False,
- amp=False,
- amp_dtype=torch.float16,
- cpu=False,
- verbose=True,
- **kwinputs,
- ):
- """Wrap benchmark functions in Pytorch profiler to see CUDA information."""
- if backward:
- with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
- out = fn(*inputs, **kwinputs)
- if type(out) is tuple:
- out = out[0]
- g = torch.randn_like(out)
- for _ in range(30): # Warm up
- if backward:
- for x in inputs:
- if isinstance(x, torch.Tensor):
- x.grad = None
- with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
- out = fn(*inputs, **kwinputs)
- if type(out) is tuple:
- out = out[0]
- # Backward should be done outside autocast
- if backward:
- out.backward(g, retain_graph=True)
- activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [
- torch.profiler.ProfilerActivity.CUDA
- ]
- with torch.profiler.profile(
- activities=activities,
- record_shapes=True,
- # profile_memory=True,
- with_stack=True,
- ) as prof:
- if backward:
- for x in inputs:
- if isinstance(x, torch.Tensor):
- x.grad = None
- with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
- out = fn(*inputs, **kwinputs)
- if type(out) is tuple:
- out = out[0]
- if backward:
- out.backward(g, retain_graph=True)
- if verbose:
- # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
- print(prof.key_averages().table(row_limit=50))
- if trace_filename is not None:
- prof.export_chrome_trace(trace_filename)
- def benchmark_memory(fn, *inputs, desc="", verbose=True, **kwinputs):
- torch.cuda.empty_cache()
- torch.cuda.reset_peak_memory_stats()
- torch.cuda.synchronize()
- fn(*inputs, **kwinputs)
- torch.cuda.synchronize()
- mem = torch.cuda.max_memory_allocated() / ((2**20) * 1000)
- if verbose:
- print(f"{desc} max memory: {mem}GB")
- torch.cuda.empty_cache()
- return mem
|