Browse Source

Add Triton implementation for benchmarking

Tri Dao 2 năm trước cách đây
mục cha
commit
50ca23488d
2 tập tin đã thay đổi với 443 bổ sung0 xóa
  1. 79 0
      benchmarks/benchmark_causal.py
  2. 364 0
      flash_attn/triton/fused_attention.py

+ 79 - 0
benchmarks/benchmark_causal.py

@@ -0,0 +1,79 @@
+from functools import partial
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange, repeat
+
+from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
+from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
+from flash_attn.triton.fused_attention import attention as attention
+
+
+def attention_pytorch(qkv, dropout_p=0.0, causal=False):
+    """
+    Arguments:
+        qkv: (batch_size, seqlen, 3, nheads, head_dim)
+        dropout_p: float
+    Output:
+        output: (batch_size, seqlen, nheads, head_dim)
+    """
+    batch_size, seqlen, _, nheads, d = qkv.shape
+    q, k, v = qkv.unbind(dim=2)
+    q = rearrange(q, 'b t h d -> (b h) t d')
+    k = rearrange(k, 'b s h d -> (b h) d s')
+    softmax_scale = 1.0 / math.sqrt(d)
+    # Preallocate attn_weights for `baddbmm`
+    scores = torch.empty(batch_size * nheads, seqlen, seqlen, dtype=qkv.dtype, device=qkv.device)
+    scores = rearrange(torch.baddbmm(scores, q, k, beta=0, alpha=softmax_scale),
+                       '(b h) t s -> b h t s', h=nheads)
+    if causal:
+        # "triu_tril_cuda_template" not implemented for 'BFloat16'
+        # So we have to construct the mask in float
+        causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
+        # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
+        scores = scores + causal_mask.to(dtype=scores.dtype)
+    attention = torch.softmax(scores, dim=-1)
+    attention_drop = F.dropout(attention, dropout_p)
+    output = torch.einsum('bhts,bshd->bthd', attention_drop , v)
+    return output.to(dtype=qkv.dtype)
+
+
+def attention_triton(q, k, v):
+    """
+    No dropout and only support causal=True.
+    Triton implementation seems to require q, k, v being contiguous?
+    Arguments:
+        q, k, v: (batch_size, nheads, seqlen, head_dim)
+    Output:
+        output: (batch_size, nheads, seqlen, head_dim)
+    """
+    softmax_scale = 1.0 / math.sqrt(q.shape[-1])
+    return attention(q, k, v, softmax_scale)
+
+
+torch.manual_seed(0)
+repeats = 30
+batch_size = 2
+seqlen = 2048
+nheads = 12
+headdim = 128
+dropout_p = 0.0
+causal = True
+dtype = torch.bfloat16
+device = 'cuda'
+
+qkv = torch.randn(batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype,
+                  requires_grad=True)
+cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
+                          device=qkv.device)
+
+benchmark_all(flash_attn_unpadded_qkvpacked_func, rearrange(qkv, 'b s ... -> (b s) ...'),
+              cu_seqlens, seqlen, dropout_p, causal=causal, repeats=repeats, desc='FlashAttention')
+benchmark_all(attention_pytorch, qkv, dropout_p, causal=causal,
+              repeats=repeats, desc='PyTorch Attention')
+
+q, k, v = [torch.randn(batch_size, nheads, seqlen, headdim, device=device, dtype=dtype,
+                       requires_grad=True) for _ in range(3)]
+benchmark_all(attention_triton, q, k, v, repeats=repeats, desc='FlashAttention Triton')

+ 364 - 0
flash_attn/triton/fused_attention.py

@@ -0,0 +1,364 @@
+# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
+# for benchmarking.
+# Fixing some dtype casting to make it work for bfloat16
+
+"""
+Fused Attention
+===============
+This is a Triton implementation of the Flash Attention algorithm
+(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
+"""
+
+import pytest
+import torch
+
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def _fwd_kernel(
+    Q, K, V, sm_scale,
+    TMP, L, M,  # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
+    Out,
+    stride_qz, stride_qh, stride_qm, stride_qk,
+    stride_kz, stride_kh, stride_kn, stride_kk,
+    stride_vz, stride_vh, stride_vk, stride_vn,
+    stride_oz, stride_oh, stride_om, stride_on,
+    Z, H, N_CTX,
+    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+):
+    start_m = tl.program_id(0)
+    off_hz = tl.program_id(1)
+    # initialize offsets
+    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    offs_n = tl.arange(0, BLOCK_N)
+    offs_d = tl.arange(0, BLOCK_DMODEL)
+    off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
+    off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
+    off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
+    # Initialize pointers to Q, K, V
+    q_ptrs = Q + off_q
+    k_ptrs = K + off_k
+    v_ptrs = V + off_v
+    # initialize pointer to m and l
+    t_ptrs = TMP + off_hz * N_CTX + offs_m
+    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+    # load q: it will stay in SRAM throughout
+    q = tl.load(q_ptrs)
+    # loop over k, v and update accumulator
+    for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
+        start_n = tl.multiple_of(start_n, BLOCK_N)
+        # -- compute qk ----
+        k = tl.load(k_ptrs + start_n * stride_kn)
+        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+        qk += tl.dot(q, k, trans_b=True)
+        qk *= sm_scale
+        qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
+        # -- compute m_ij, p, l_ij
+        m_ij = tl.max(qk, 1)
+        p = tl.exp(qk - m_ij[:, None])
+        l_ij = tl.sum(p, 1)
+        # -- update m_i and l_i
+        m_i_new = tl.maximum(m_i, m_ij)
+        alpha = tl.exp(m_i - m_i_new)
+        beta = tl.exp(m_ij - m_i_new)
+        l_i_new = alpha * l_i + beta * l_ij
+        # -- update output accumulator --
+        # scale p
+        p_scale = beta / l_i_new
+        p = p * p_scale[:, None]
+        # scale acc
+        acc_scale = l_i / l_i_new * alpha
+        tl.store(t_ptrs, acc_scale)
+        acc_scale = tl.load(t_ptrs)  # BUG: have to store and immediately load
+        acc = acc * acc_scale[:, None]
+        # update acc
+        v = tl.load(v_ptrs + start_n * stride_vk)
+        p = p.to(q.dtype)
+        acc += tl.dot(p, v)
+        # update m_i and l_i
+        l_i = l_i_new
+        m_i = m_i_new
+    # rematerialize offsets to save registers
+    start_m = tl.program_id(0)
+    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+    # write back l and m
+    l_ptrs = L + off_hz * N_CTX + offs_m
+    m_ptrs = M + off_hz * N_CTX + offs_m
+    tl.store(l_ptrs, l_i)
+    tl.store(m_ptrs, m_i)
+    # initialize pointers to output
+    offs_n = tl.arange(0, BLOCK_DMODEL)
+    off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
+    out_ptrs = Out + off_o
+    tl.store(out_ptrs, acc)
+
+
+@triton.jit
+def _bwd_preprocess(
+    Out, DO, L,
+    NewDO, Delta,
+    BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
+):
+    off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
+    off_n = tl.arange(0, D_HEAD)
+    # load
+    o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
+    do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
+    denom = tl.load(L + off_m).to(tl.float32)
+    # compute
+    do = do / denom[:, None]
+    delta = tl.sum(o * do, axis=1)
+    # write-back
+    tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
+    tl.store(Delta + off_m, delta)
+
+
+@triton.jit
+def _bwd_kernel(
+    Q, K, V, sm_scale, Out, DO,
+    DQ, DK, DV,
+    L, M,
+    D,
+    stride_qz, stride_qh, stride_qm, stride_qk,
+    stride_kz, stride_kh, stride_kn, stride_kk,
+    stride_vz, stride_vh, stride_vk, stride_vn,
+    Z, H, N_CTX,
+    num_block,
+    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+):
+    off_hz = tl.program_id(0)
+    off_z = off_hz // H
+    off_h = off_hz % H
+    # offset pointers for batch/head
+    Q += off_z * stride_qz + off_h * stride_qh
+    K += off_z * stride_qz + off_h * stride_qh
+    V += off_z * stride_qz + off_h * stride_qh
+    DO += off_z * stride_qz + off_h * stride_qh
+    DQ += off_z * stride_qz + off_h * stride_qh
+    DK += off_z * stride_qz + off_h * stride_qh
+    DV += off_z * stride_qz + off_h * stride_qh
+    for start_n in range(0, num_block):
+        lo = start_n * BLOCK_M
+        # initialize row/col offsets
+        offs_qm = lo + tl.arange(0, BLOCK_M)
+        offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
+        offs_m = tl.arange(0, BLOCK_N)
+        offs_k = tl.arange(0, BLOCK_DMODEL)
+        # initialize pointers to value-like data
+        q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+        k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
+        v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+        do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+        dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+        # pointer to row-wise quantities in value-like data
+        D_ptrs = D + off_hz * N_CTX
+        m_ptrs = M + off_hz * N_CTX
+        # initialize dv amd dk
+        dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+        dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+        # k and v stay in SRAM throughout
+        k = tl.load(k_ptrs)
+        v = tl.load(v_ptrs)
+        # loop over rows
+        for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
+            offs_m_curr = start_m + offs_m
+            # load q, k, v, do on-chip
+            q = tl.load(q_ptrs)
+            # recompute p = softmax(qk, dim=-1).T
+            # NOTE: `do` is pre-divided by `l`; no normalization here
+            qk = tl.dot(q, k, trans_b=True)
+            qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
+            m = tl.load(m_ptrs + offs_m_curr)
+            p = tl.exp(qk * sm_scale - m[:, None])
+            # compute dv
+            do = tl.load(do_ptrs)
+            dv += tl.dot(p.to(q.dtype), do, trans_a=True)
+            # compute dp = dot(v, do)
+            Di = tl.load(D_ptrs + offs_m_curr)
+            dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
+            dp += tl.dot(do, v, trans_b=True)
+            # compute ds = p * (dp - delta[:, None])
+            ds = p * dp * sm_scale
+            # compute dk = dot(ds.T, q)
+            dk += tl.dot(ds.to(q.dtype), q, trans_a=True)
+            # # compute dq
+            dq = tl.load(dq_ptrs, eviction_policy="evict_last")
+            dq += tl.dot(ds.to(q.dtype), k)
+            tl.store(dq_ptrs, dq, eviction_policy="evict_last")
+            # # increment pointers
+            dq_ptrs += BLOCK_M * stride_qm
+            q_ptrs += BLOCK_M * stride_qm
+            do_ptrs += BLOCK_M * stride_qm
+        # write-back
+        dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
+        dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
+        tl.store(dv_ptrs, dv)
+        tl.store(dk_ptrs, dk)
+
+
+class _attention(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, q, k, v, sm_scale):
+        BLOCK = 128
+        # shape constraints
+        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+        assert Lq == Lk and Lk == Lv
+        assert Lk in {16, 32, 64, 128}
+        o = torch.empty_like(q)
+        grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1])
+        tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
+        L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
+        m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
+        num_warps = 4 if Lk <= 64 else 8
+
+        _fwd_kernel[grid](
+            q, k, v, sm_scale,
+            tmp, L, m,
+            o,
+            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
+            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
+            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
+            o.stride(0), o.stride(1), o.stride(2), o.stride(3),
+            q.shape[0], q.shape[1], q.shape[2],
+            BLOCK_M=BLOCK, BLOCK_N=BLOCK,
+            BLOCK_DMODEL=Lk, num_warps=num_warps,
+            num_stages=1,
+        )
+        ctx.save_for_backward(q, k, v, o, L, m)
+        ctx.BLOCK = BLOCK
+        ctx.grid = grid
+        ctx.sm_scale = sm_scale
+        ctx.BLOCK_DMODEL = Lk
+        return o
+
+    @staticmethod
+    def backward(ctx, do):
+        q, k, v, o, l, m = ctx.saved_tensors
+        do = do.contiguous()
+        dq = torch.zeros_like(q, dtype=torch.float32)
+        dk = torch.empty_like(k)
+        dv = torch.empty_like(v)
+        do_scaled = torch.empty_like(do)
+        delta = torch.empty_like(l)
+        _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
+            o, do, l,
+            do_scaled, delta,
+            BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
+        )
+
+        # NOTE: kernel currently buggy for other values of `num_warps`
+        num_warps = 8
+        _bwd_kernel[(ctx.grid[1],)](
+            q, k, v, ctx.sm_scale,
+            o, do_scaled,
+            dq, dk, dv,
+            l, m,
+            delta,
+            q.stride(0), q.stride(1), q.stride(2), q.stride(3),
+            k.stride(0), k.stride(1), k.stride(2), k.stride(3),
+            v.stride(0), v.stride(1), v.stride(2), v.stride(3),
+            q.shape[0], q.shape[1], q.shape[2],
+            ctx.grid[0],
+            BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK,
+            BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps,
+            num_stages=1,
+        )
+        return dq, dk, dv, None
+
+
+attention = _attention.apply
+
+
+@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(3, 2, 2048, 64)])
+def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
+    torch.manual_seed(20)
+    q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+    k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+    v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5).requires_grad_()
+    sm_scale = 0.3
+    dout = torch.randn_like(q)
+    # reference implementation
+    M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
+    p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
+    for z in range(Z):
+        for h in range(H):
+            p[:, :, M == 0] = float("-inf")
+    p = torch.softmax(p.float(), dim=-1).half()
+    ref_out = torch.matmul(p, v)
+    ref_out.backward(dout)
+    ref_dv, v.grad = v.grad.clone(), None
+    ref_dk, k.grad = k.grad.clone(), None
+    ref_dq, q.grad = q.grad.clone(), None
+    # triton implementation
+    tri_out = attention(q, k, v, sm_scale)
+    tri_out.backward(dout)
+    tri_dv, v.grad = v.grad.clone(), None
+    tri_dk, k.grad = k.grad.clone(), None
+    tri_dq, q.grad = q.grad.clone(), None
+    # compare
+    triton.testing.assert_almost_equal(ref_out, tri_out)
+    triton.testing.assert_almost_equal(ref_dv, tri_dv)
+    triton.testing.assert_almost_equal(ref_dk, tri_dk)
+    triton.testing.assert_almost_equal(ref_dq, tri_dq)
+
+
+try:
+    from flash_attn.flash_attn_interface import flash_attn_func
+    HAS_FLASH = True
+except BaseException:
+    HAS_FLASH = False
+
+BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
+# vary seq length for fixed head and batch=4
+configs = [triton.testing.Benchmark(
+    x_names=['N_CTX'],
+    x_vals=[2**i for i in range(10, 16)],
+    line_arg='provider',
+    line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
+    line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
+    styles=[('red', '-'), ('blue', '-')],
+    ylabel='ms',
+    plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
+    args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
+) for mode in ['bwd']]
+
+
+@triton.testing.perf_report(configs)
+def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.float16, device="cuda"):
+    assert mode in ['fwd', 'bwd']
+    warmup = 25
+    rep = 100
+    if provider == "triton":
+        q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+        k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+        v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
+        sm_scale = 1.3
+        fn = lambda: attention(q, k, v, sm_scale)
+        if mode == 'bwd':
+            o = fn()
+            do = torch.randn_like(o)
+            fn = lambda: o.backward(do, retain_graph=True)
+        ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
+        return ms
+    if provider == "flash":
+        lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
+        cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
+        cu_seqlens[1:] = lengths.cumsum(0)
+        qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
+        fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
+        if mode == 'bwd':
+            o = fn()
+            do = torch.randn_like(o)
+            fn = lambda: o.backward(do, retain_graph=True)
+        ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
+        return ms
+
+# only works on A100 at the moment
+# bench_flash_attention.run(save_path='.', print_data=True)