123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425 |
- from collections import namedtuple
- from functools import partial
- import math
- from typing import NamedTuple
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import time
- try:
- import cudnn
- except ImportError:
- cudnn = None
- Timing = NamedTuple('timing', [('mean', float)])
- from einops import rearrange, repeat
- from flash_attn.utils.benchmark import benchmark_forward, benchmark_backward, benchmark_combined, benchmark_all, benchmark_fwd_bwd, pytorch_profiler
- from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
- from flash_attn_interface import flash_attn_func as flash_attn_func_v3
- from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
- from triton.testing import do_bench
- try:
- from triton_fused_attention import attention as triton_attention
- except ImportError:
- triton_attention = None
- triton_attention = None
- def time_fwd(func, *args, repeats=30, verbose=True, desc="", **kwargs):
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- return Timing(do_bench(lambda: func(*args, **kwargs), warmup=5, rep=repeats) * 1e-3)
- def flops(batch, nheads, seqlen_q, seqlen_k, headdim, causal=False, window_size=(-1, -1)):
- if causal:
- avg_seqlen = (max(0, seqlen_k - seqlen_q) + seqlen_k) / 2
- else:
- if window_size == (-1, -1):
- avg_seqlen = seqlen_k
- else:
- row_idx = torch.arange(seqlen_q, device='cuda')
- col_left = torch.maximum(row_idx + seqlen_k - seqlen_q - window_size[0], torch.tensor(0))
- col_right = torch.minimum(row_idx + seqlen_k - seqlen_q - window_size[1], torch.tensor(seqlen_k - 1))
- avg_seqlen = (col_right - col_left + 1).float().mean().item()
- return batch * nheads * 2 * seqlen_q * avg_seqlen * headdim * 2
- def convert_to_cudnn_type(torch_type):
- if torch_type == torch.float16:
- return cudnn.data_type.HALF
- elif torch_type == torch.bfloat16:
- return cudnn.data_type.BFLOAT16
- elif torch_type == torch.float32:
- return cudnn.data_type.FLOAT
- elif torch_type == torch.int32:
- return cudnn.data_type.INT32
- elif torch_type == torch.int64:
- return cudnn.data_type.INT64
- else:
- raise ValueError("Unsupported tensor data type.")
- def cudnn_spda_setup(q, k, v, causal=False, window_size_left=-1):
- b, nheads, seqlen_q, headdim = q.shape
- _, nheads_k, seqlen_k, _ = k.shape
- assert v.shape == (b, nheads_k, seqlen_k, headdim)
- assert cudnn is not None, 'CUDNN is not available'
- q_gpu, k_gpu, v_gpu = q, k, v
- o_gpu = torch.empty_like(q_gpu)
- stats_gpu = torch.empty(b, nheads, seqlen_q, 1, dtype=torch.float32, device=q.device)
- graph = cudnn.pygraph(
- io_data_type=convert_to_cudnn_type(q.dtype),
- intermediate_data_type=cudnn.data_type.FLOAT,
- compute_data_type=cudnn.data_type.FLOAT,
- )
- q = graph.tensor_like(q_gpu.detach())
- k = graph.tensor_like(k_gpu.detach())
- v = graph.tensor_like(v_gpu.detach())
- o, stats = graph.sdpa(
- name="sdpa",
- q=q,
- k=k,
- v=v,
- is_inference=False,
- attn_scale=1.0 / math.sqrt(headdim),
-
- use_causal_mask=causal or window_size_left >= 0,
- sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None,
- )
- o.set_output(True).set_dim(o_gpu.shape).set_stride(o_gpu.stride())
- stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
- graph.validate()
- graph.build_operation_graph()
- graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
- graph.check_support()
- graph.build_plans()
- variant_pack = {
- q: q_gpu,
- k: k_gpu,
- v: v_gpu,
- o: o_gpu,
- stats: stats_gpu,
- }
- workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
- def run(*args, **kwargs):
- graph.execute(variant_pack, workspace)
- return o_gpu
- return run
- def cudnn_spda_bwd_setup(q, k, v, o, g, lse, causal=False, window_size_left=-1):
- b, nheads, seqlen_q, headdim = q.shape
- _, nheads_k, seqlen_k, _ = k.shape
- assert v.shape == (b, nheads_k, seqlen_k, headdim)
- assert g.shape == (b, nheads, seqlen_q, headdim)
- assert o.shape == (b, nheads, seqlen_q, headdim)
- assert lse.shape == (b, nheads, seqlen_q, 1)
- assert cudnn is not None, 'CUDNN is not available'
- q_gpu, k_gpu, v_gpu, o_gpu, g_gpu = q, k, v, o, g
- dq_gpu = torch.empty_like(q_gpu)
- dk_gpu = torch.empty_like(k_gpu)
- dv_gpu = torch.empty_like(v_gpu)
- graph = cudnn.pygraph(
- io_data_type=convert_to_cudnn_type(q.dtype),
- intermediate_data_type=cudnn.data_type.FLOAT,
- compute_data_type=cudnn.data_type.FLOAT,
- )
- q = graph.tensor_like(q_gpu.detach())
- k = graph.tensor_like(k_gpu.detach())
- v = graph.tensor_like(v_gpu.detach())
- o = graph.tensor_like(o_gpu.detach())
- g = graph.tensor_like(g_gpu.detach())
- stats = graph.tensor_like(lse.detach())
- dq, dk, dv = graph.sdpa_backward(
- name="sdpa_backward",
- q=q,
- k=k,
- v=v,
- o=o,
- dO=g,
- stats=stats,
- attn_scale=1.0 / math.sqrt(headdim),
-
- use_causal_mask=causal or window_size_left >= 0,
- sliding_window_length=window_size_left if window_size_left >= 0 and not causal else None,
- )
- dq.set_output(True).set_dim(dq_gpu.shape).set_stride(dq_gpu.stride())
- dk.set_output(True).set_dim(dk_gpu.shape).set_stride(dk_gpu.stride())
- dv.set_output(True).set_dim(dv_gpu.shape).set_stride(dv_gpu.stride())
- graph.validate()
- graph.build_operation_graph()
- graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
- graph.check_support()
- graph.build_plans()
- variant_pack = {
- q: q_gpu,
- k: k_gpu,
- v: v_gpu,
- o: o_gpu,
- g: g_gpu,
- stats: lse,
- dq: dq_gpu,
- dk: dk_gpu,
- dv: dv_gpu,
- }
- workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8)
- def run(*args, **kwargs):
- graph.execute(variant_pack, workspace)
- return dq_gpu, dk_gpu, dv_gpu
- return run
- torch.manual_seed(0)
- repeats = 10
- dropout_p = 0.0
- causal = False
- dtype = torch.bfloat16
- dtype_gen = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
- device = 'cuda'
- verbose = True
- varlen = False
- page_size = None
- softcap = 0.0
- V_colmajor = False
- deterministic = False
- batch_size = 2
- seqlen = 8192
- dim = 2048
- headdim = 256
- bs_seqlen_vals = [(2, 8192)]
- time_f = {}
- time_b = {}
- for headdim in [128]:
- nheads = dim // headdim
-
-
-
-
-
- nheads_kv = nheads
-
- for batch_size, seqlen in bs_seqlen_vals:
- num_splits = 1
- window_size = (-1, -1)
-
- sink_token_length = 0
- pack_gqa = None
-
- seqlen_q = seqlen
- leftpad_k = None
-
- q = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
- k = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
- v = torch.randn(batch_size, seqlen, nheads_kv, headdim, device=device, dtype=dtype_gen, requires_grad=True)
- q, k, v = [x.detach().to(dtype).requires_grad_() for x in [q, k, v]]
- v_colmajor = v.detach().transpose(-1, -3).contiguous().transpose(-1, -3).requires_grad_()
- v_fa3 = v if not V_colmajor else v_colmajor
-
-
-
- g = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
- o = torch.randn(batch_size, seqlen_q, nheads, headdim, device=device, dtype=dtype_gen, requires_grad=True)
- stats = torch.randn(batch_size, seqlen_q, nheads, 1, device=device, dtype=torch.float32)
- a = torch.randn(batch_size, seqlen, seqlen, device=device, dtype=dtype_gen)
- b = torch.randn(batch_size, dim * 2, seqlen, device=device, dtype=dtype_gen).transpose(-1, -2)
-
-
- if varlen:
- q_unpad, k_unpad, v_unpad = [rearrange(x.detach(), "b s h d -> (b s) h d").requires_grad_() for x in [q, k, v]]
- cu_seqlens_q = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen_q
- cu_seqlens_k = torch.arange(batch_size + 1, device=device, dtype=torch.int32) * seqlen
-
-
-
-
-
-
- if page_size is not None:
- assert seqlen % page_size == 0
- k_paged, v_paged = [rearrange(x, "b (n p) h d -> (b n) p h d", p=page_size) for x in [k, v]]
- page_table = rearrange(torch.arange(batch_size * seqlen // page_size, device=device, dtype=torch.int32),
- "(b s) -> b s", s=seqlen // page_size)
- else:
- page_table = None
- for causal in [False, True]:
-
- print(f"\n### {headdim = }, {causal = }, {seqlen = } ###")
- nFLOPS = flops(batch_size, nheads, seqlen_q, seqlen, headdim, causal=causal, window_size=window_size)
- if cudnn is not None:
-
- if headdim <= 256 and dtype != torch.float8_e4m3fn:
- cudnn_spda = cudnn_spda_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), causal=causal, window_size_left=window_size[0])
- cudnn_spda_bwd = cudnn_spda_bwd_setup(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), o.transpose(1, 2), g.transpose(1, 2), stats.transpose(1, 2), causal=causal, window_size_left=window_size[0])
-
- if dtype != torch.float8_e4m3fn:
-
- if not varlen:
- m0 = time_fwd(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
- else:
- m0 = time_fwd(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, repeats=repeats, verbose=verbose, desc='Fav2')
- time_f[(causal, headdim, batch_size, seqlen), "Flash2"] = m0.mean
- time.sleep(1)
- if not varlen:
- _, m0b = benchmark_backward(flash_attn_func, q, k, v, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
- repeats=repeats, verbose=False, desc='Fav2')
- else:
- _, m0b = benchmark_backward(flash_attn_varlen_func, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, seqlen_q, seqlen, dropout_p, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
- repeats=repeats, verbose=False, desc='Fav2')
- time_b[(causal, headdim, batch_size, seqlen), "Flash2"] = m0b.mean
-
- if headdim <= 256 and dtype != torch.float8_e4m3fn:
- if triton_attention is not None:
- qt, kt, vt = [x.detach().transpose(1, 2).contiguous().requires_grad_() for x in [q, k, v]]
- time.sleep(1)
- m3 = time_fwd(triton_attention, qt, kt, vt, causal, 1 / math.sqrt(headdim), repeats=repeats, verbose=verbose, desc='Triton')
- time_f[(causal, headdim, batch_size, seqlen), "Triton"] = m3.mean
-
-
-
-
-
- if cudnn is not None:
-
- if headdim <= 256 and dtype != torch.float8_e4m3fn:
- time.sleep(1)
- m2 = time_fwd(cudnn_spda, repeats=repeats, verbose=verbose, desc='CuDNN')
- time_f[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2.mean
- time.sleep(1)
- m2b = time_fwd(cudnn_spda_bwd, repeats=repeats, verbose=verbose, desc='CuDNN')
- time_b[(causal, headdim, batch_size, seqlen), "cuDNN"] = m2b.mean
-
-
- time.sleep(1)
- if not varlen:
-
- m1 = time_fwd(flash_attn_func_v3, q, k if page_size is None else k_paged, v_fa3 if page_size is None else v_paged, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
-
- else:
- m1 = time_fwd(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, repeats=repeats, verbose=verbose, desc='Fav3')
-
- time_f[(causal, headdim, batch_size, seqlen), "Flash3"] = m1.mean
-
-
-
-
-
- if dtype != torch.float8_e4m3fn:
- time.sleep(1)
- if not varlen:
- _, m1b = benchmark_backward(flash_attn_func_v3, q, k, v, causal=causal, window_size=window_size, sink_token_length=sink_token_length, softcap=softcap, deterministic=deterministic,
- repeats=repeats, verbose=False, desc='Fav3')
- else:
- _, m1b = benchmark_backward(flash_attn_varlen_func_v3, q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, None, None, seqlen_q, seqlen, causal=causal, window_size=window_size, softcap=softcap, deterministic=deterministic,
- repeats=repeats, verbose=False, desc='Fav3')
- time_b[(causal, headdim, batch_size, seqlen), "Flash3"] = m1b.mean
-
-
-
-
-
-
- if dtype != torch.float8_e4m3fn:
-
- print(f'Fav2 fwd: {m0.mean * 1e3:.3f}ms, {(nFLOPS / m0.mean * 1e-12):.1f} TFLOPS')
- print(f'Fav2 bwd: {m0b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m0b.mean * 1e-12):.1f} TFLOPS')
- if headdim <= 256 and dtype != torch.float8_e4m3fn:
- if triton_attention is not None:
- print(f'Triton fwd: {m3.mean * 1e3:.3f}ms, {(nFLOPS / m3.mean * 1e-12):.1f} TFLOPS')
-
-
- if cudnn is not None:
- print(f'CuDNN fwd: {m2.mean * 1e3:.3f}ms, {(nFLOPS / m2.mean * 1e-12):.1f} TFLOPS')
- print(f'CuDNN bwd: {m2b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m2b.mean * 1e-12):.1f} TFLOPS')
- print(f'Fav3 fwd: {m1.mean * 1e3:.3f}ms, {(nFLOPS / m1.mean * 1e-12):.1f} TFLOPS')
- if dtype != torch.float8_e4m3fn:
- print(f'Fav3 bwd: {m1b.mean * 1e3:.3f}ms, {(2.5 * nFLOPS / m1b.mean * 1e-12):.1f} TFLOPS')
-
-
-
-
-
-
-
-
-
-
-
|