Răsfoiți Sursa

[LayerNorm] Implement residual + LayerNorm/RMSNorm in Triton

Tri Dao 1 an în urmă
părinte
comite
79bd1a2d5d

+ 395 - 0
flash_attn/ops/triton/layernorm.py

@@ -0,0 +1,395 @@
+# Copyright (c) 2023, Tri Dao.
+# Implement residual + layer_norm / rms_norm.
+
+# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
+# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
+# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
+
+import math
+
+import torch
+import torch.nn.functional as F
+
+import triton
+import triton.language as tl
+
+
+def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
+    dtype = x.dtype
+    if upcast:
+        weight = weight.float()
+        bias = bias.float() if bias is not None else None
+    if upcast:
+        x = x.float()
+        residual = residual.float() if residual is not None else residual
+    if residual is not None:
+        x = (x + residual).to(x.dtype)
+    out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(dtype)
+    return out if residual is None else (out, x)
+
+
+def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
+    dtype = x.dtype
+    if upcast:
+        weight = weight.float()
+        bias = bias.float() if bias is not None else None
+    if upcast:
+        x = x.float()
+        residual = residual.float() if residual is not None else residual
+    if residual is not None:
+        x = (x + residual).to(x.dtype)
+    rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
+    out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
+    out = out.to(dtype)
+    return out if residual is None else (out, x)
+
+
+@triton.autotune(
+    configs=[
+        triton.Config({}, num_warps=1),
+        triton.Config({}, num_warps=2),
+        triton.Config({}, num_warps=4),
+        triton.Config({}, num_warps=8),
+        triton.Config({}, num_warps=16),
+        triton.Config({}, num_warps=32),
+    ],
+    key=["N", "HAS_RESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
+)
+# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
+# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
+@triton.jit
+def _layer_norm_fwd_1pass_kernel(
+    X,  # pointer to the input
+    Y,  # pointer to the output
+    W,  # pointer to the weights
+    B,  # pointer to the biases
+    RESIDUAL,  # pointer to the residual
+    RESIDUAL_OUT,  # pointer to the residual
+    Mean,  # pointer to the mean
+    Rstd,  # pointer to the 1/std
+    stride_x_row,  # how much to increase the pointer when moving by 1 row
+    stride_y_row,
+    stride_res_row,
+    stride_res_out_row,
+    N,  # number of columns in X
+    eps,  # epsilon to avoid division by zero
+    IS_RMS_NORM: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    HAS_RESIDUAL: tl.constexpr,
+    HAS_BIAS: tl.constexpr,
+):
+    # Map the program id to the row of X and Y it should compute.
+    row = tl.program_id(0)
+    X += row * stride_x_row
+    Y += row * stride_y_row
+    if HAS_RESIDUAL:
+        RESIDUAL += row * stride_res_row
+        RESIDUAL_OUT += row * stride_res_out_row
+    # Compute mean and variance
+    cols = tl.arange(0, BLOCK_N)
+    x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
+    if HAS_RESIDUAL:
+        residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.).to(tl.float32)
+        x += residual
+        tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
+    if not IS_RMS_NORM:
+        mean = tl.sum(x, axis=0) / N
+        tl.store(Mean + row, mean)
+        xbar = tl.where(cols < N, x - mean, 0.)
+        var = tl.sum(xbar * xbar, axis=0) / N
+    else:
+        xbar = tl.where(cols < N, x, 0.)
+        var = tl.sum(xbar * xbar, axis=0) / N
+    rstd = 1 / tl.sqrt(var + eps)
+    tl.store(Rstd + row, rstd)
+    # Normalize and apply linear transformation
+    mask = cols < N
+    w = tl.load(W + cols, mask=mask).to(tl.float32)
+    if HAS_BIAS:
+        b = tl.load(B + cols, mask=mask).to(tl.float32)
+    x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
+    y = x_hat * w + b if HAS_BIAS else x_hat * w
+    # Write output
+    tl.store(Y + cols, y, mask=mask)
+
+
+def _layer_norm_fwd(x, weight, bias, eps, residual=None, is_rms_norm=False):
+    M, N = x.shape
+    assert x.stride(-1) == 1
+    if residual is not None:
+        assert residual.stride(-1) == 1
+        assert residual.shape == (M, N)
+    assert weight.shape == (N,)
+    assert weight.stride(-1) == 1
+    if bias is not None:
+        assert bias.stride(-1) == 1
+        assert bias.shape == (N,)
+    # allocate output
+    y = torch.empty_like(x)
+    assert y.stride(-1) == 1
+    if residual is not None:
+        residual_out = torch.empty_like(residual)
+        assert residual_out.stride(-1) == 1
+    else:
+        residual_out = None
+    mean = torch.empty((M, ), dtype=torch.float32, device='cuda') if not is_rms_norm else None
+    rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
+    # Less than 64KB per feature: enqueue fused kernel
+    MAX_FUSED_SIZE = 65536 // x.element_size()
+    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+    if N > BLOCK_N:
+        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+    # heuristics for number of warps
+    with torch.cuda.device(x.device.index):
+        _layer_norm_fwd_1pass_kernel[(M,)](x, y, weight, bias, residual, residual_out,
+                                           mean, rstd,
+                                           x.stride(0), y.stride(0),
+                                           residual.stride(0) if residual is not None else 0,
+                                           residual_out.stride(0) if residual is not None else 0,
+                                           N, eps,
+                                           is_rms_norm,
+                                           BLOCK_N,
+                                           residual is not None,
+                                           bias is not None,
+                                           )
+    return y, mean, rstd, residual_out
+
+
+@triton.autotune(
+    configs=[
+        triton.Config({}, num_warps=1),
+        triton.Config({}, num_warps=2),
+        triton.Config({}, num_warps=4),
+        triton.Config({}, num_warps=8),
+        triton.Config({}, num_warps=16),
+        triton.Config({}, num_warps=32),
+    ],
+    key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
+)
+# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
+# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
+# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
+@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
+@triton.jit
+def _layer_norm_bwd_kernel(
+    X,   # pointer to the input
+    W,   # pointer to the weights
+    B,   # pointer to the biases
+    Y,   # pointer to the output to be recomputed
+    DY,  # pointer to the output gradient
+    DX,  # pointer to the input gradient
+    DW,  # pointer to the partial sum of weights gradient
+    DB,  # pointer to the partial sum of biases gradient
+    DRESIDUAL,
+    DRESIDUAL_IN,
+    Mean,   # pointer to the mean
+    Rstd,   # pointer to the 1/std
+    stride_x_row,  # how much to increase the pointer when moving by 1 row
+    stride_y_row,
+    stride_dy_row,
+    stride_dx_row,
+    stride_dres_row,
+    stride_dres_in_row,
+    M,  # number of rows in X
+    N,  # number of columns in X
+    eps,  # epsilon to avoid division by zero
+    rows_per_program,
+    IS_RMS_NORM: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    HAS_DRESIDUAL: tl.constexpr,
+    STORE_DRESIDUAL: tl.constexpr,
+    HAS_BIAS: tl.constexpr,
+    RECOMPUTE_OUTPUT: tl.constexpr,
+):
+    # Map the program id to the elements of X, DX, and DY it should compute.
+    row_block_id = tl.program_id(0)
+    row_start = row_block_id * rows_per_program
+    cols = tl.arange(0, BLOCK_N)
+    mask = cols < N
+    X += row_start * stride_x_row
+    if HAS_DRESIDUAL:
+        DRESIDUAL += row_start * stride_dres_row
+    if STORE_DRESIDUAL:
+        DRESIDUAL_IN += row_start * stride_dres_in_row
+    DY += row_start * stride_dy_row
+    DX += row_start * stride_dx_row
+    if RECOMPUTE_OUTPUT:
+        Y += row_start * stride_y_row
+    w = tl.load(W + cols, mask=mask).to(tl.float32)
+    if RECOMPUTE_OUTPUT and HAS_BIAS:
+        b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
+    dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
+    if HAS_BIAS:
+        db = tl.zeros((BLOCK_N,), dtype=tl.float32)
+    row_end = min((row_block_id + 1) * rows_per_program, M)
+    for row in range(row_start, row_end):
+        # Load data to SRAM
+        x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
+        dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
+        if not IS_RMS_NORM:
+            mean = tl.load(Mean + row)
+        rstd = tl.load(Rstd + row)
+        # Compute dx
+        xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
+        xhat = tl.where(mask, xhat, 0.)
+        if RECOMPUTE_OUTPUT:
+            y = xhat * w + b if HAS_BIAS else xhat * w
+            tl.store(Y + cols, y, mask=mask)
+        wdy = w * dy
+        dw += dy * xhat
+        if HAS_BIAS:
+            db += dy
+        if not IS_RMS_NORM:
+            c1 = tl.sum(xhat * wdy, axis=0) / N
+            c2 = tl.sum(wdy, axis=0) / N
+            dx = (wdy - (xhat * c1 + c2)) * rstd
+        else:
+            c1 = tl.sum(xhat * wdy, axis=0) / N
+            dx = (wdy - xhat * c1) * rstd
+        if HAS_DRESIDUAL:
+            dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
+            dx += dres
+        # Write dx
+        if STORE_DRESIDUAL:
+            tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
+        tl.store(DX + cols, dx, mask=mask)
+
+        X += stride_x_row
+        if HAS_DRESIDUAL:
+            DRESIDUAL += stride_dres_row
+        if STORE_DRESIDUAL:
+            DRESIDUAL_IN += stride_dres_in_row
+        if RECOMPUTE_OUTPUT:
+            Y += stride_y_row
+        DY += stride_dy_row
+        DX += stride_dx_row
+    tl.store(DW + row_block_id * N + cols, dw, mask=mask)
+    if HAS_BIAS:
+        tl.store(DB + row_block_id * N + cols, db, mask=mask)
+
+
+def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, dresidual=None, is_rms_norm=False, x_dtype=None,
+                    recompute_output=False):
+    M, N = x.shape
+    assert x.stride(-1) == 1
+    assert dy.stride(-1) == 1
+    assert dy.shape == (M, N)
+    if dresidual is not None:
+        assert dresidual.stride(-1) == 1
+        assert dresidual.shape == (M, N)
+    assert weight.shape == (N,)
+    assert weight.stride(-1) == 1
+    if bias is not None:
+        assert bias.stride(-1) == 1
+        assert bias.shape == (N,)
+    # allocate output
+    dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
+    dresidual_in = torch.empty_like(dresidual) if dresidual is not None and dx.dtype != dresidual.dtype else None
+    y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
+
+    # Less than 64KB per feature: enqueue fused kernel
+    MAX_FUSED_SIZE = 65536 // x.element_size()
+    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+    if N > BLOCK_N:
+        raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+    sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
+    _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
+    _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None
+    rows_per_program = math.ceil(M / sm_count)
+    grid = (sm_count,)
+    with torch.cuda.device(x.device.index):
+        _layer_norm_bwd_kernel[grid](x, weight, bias, y,
+                                     dy, dx, _dw, _db, dresidual, dresidual_in,
+                                     mean, rstd,
+                                     x.stride(0),
+                                     0 if not recompute_output else y.stride(0),
+                                     dy.stride(0), dx.stride(0),
+                                     dresidual.stride(0) if dresidual is not None else 0,
+                                     dresidual_in.stride(0) if dresidual_in is not None else 0,
+                                     M, N, eps,
+                                     rows_per_program,
+                                     is_rms_norm,
+                                     BLOCK_N,
+                                     dresidual is not None,
+                                     dresidual_in is not None,
+                                     bias is not None)
+    dw = _dw.sum(0).to(weight.dtype)
+    db = _db.sum(0).to(bias.dtype) if bias is not None else None
+    # Don't need to compute dresidual_in separately in this case
+    if dresidual is not None and dx.dtype == dresidual.dtype:
+        dresidual_in = dx
+    return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
+
+
+class LayerNormFn(torch.autograd.Function):
+
+    @staticmethod
+    def forward(ctx, x, weight, bias, residual=None, eps=1e-6, is_rms_norm=False):
+        x_shape_og = x.shape
+        # reshape input data into 2D tensor
+        x = x.reshape(-1, x.shape[-1])
+        if x.stride(-1) != 1:
+            x = x.contiguous()
+        if residual is not None:
+            assert residual.shape == x_shape_og
+            residual = residual.reshape(-1, residual.shape[-1])
+            if residual.stride(-1) != 1:
+                residual = residual.contiguous()
+        weight = weight.contiguous()
+        if bias is not None:
+            bias = bias.contiguous()
+        y, mean, rstd, *rest = _layer_norm_fwd(x, weight, bias, eps, residual, is_rms_norm)
+        if residual is not None:
+            residual_out = rest[0]
+        ctx.save_for_backward(x if residual is None else residual_out, weight, bias, mean, rstd)
+        ctx.x_shape_og = x_shape_og
+        ctx.eps = eps
+        ctx.is_rms_norm = is_rms_norm
+        ctx.has_residual = residual is not None
+        ctx.x_dtype = x.dtype
+        y = y.reshape(x_shape_og)
+        return y if residual is None else (y, residual_out.reshape(x_shape_og))
+
+    @staticmethod
+    def backward(ctx, dy, *args):
+        x, weight, bias, mean, rstd = ctx.saved_tensors
+        dy = dy.reshape(-1, dy.shape[-1])
+        if dy.stride(-1) != 1:
+            dy = dy.contiguous()
+        assert dy.shape == x.shape
+        if ctx.has_residual:
+            dresidual = args[0]
+            dresidual = dresidual.reshape(-1, dresidual.shape[-1])
+            if dresidual.stride(-1) != 1:
+                dresidual = dresidual.contiguous()
+            assert dresidual.shape == x.shape
+        else:
+            dresidual = None
+        dx, dw, db, dresidual_in = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, dresidual,
+                                                   ctx.is_rms_norm, x_dtype=ctx.x_dtype)
+        return dx.reshape(ctx.x_shape_og), dw, db, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, None, None
+
+
+def layer_norm_fn(x, weight, bias, residual=None, eps=1e-6, is_rms_norm=False):
+    return LayerNormFn.apply(x, weight, bias, residual, eps, is_rms_norm)
+
+
+def rms_norm_fn(x, weight, bias, residual=None, eps=1e-6):
+    return LayerNormFn.apply(x, weight, bias, residual, eps, True)
+
+
+class RMSNorm(torch.nn.Module):
+    def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
+        factory_kwargs = {"device": device, "dtype": dtype}
+        super().__init__()
+        self.eps = eps
+        self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
+        self.register_parameter("bias", None)
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        torch.nn.init.ones_(self.weight)
+
+    def forward(self, x, residual=None):
+        return layer_norm_fn(x, self.weight, self.bias, residual=residual, eps=self.eps, is_rms_norm=True)

+ 8 - 1
flash_attn/utils/benchmark.py

@@ -213,7 +213,10 @@ def pytorch_profiler(
     """Wrap benchmark functions in Pytorch profiler to see CUDA information."""
     if backward:
         with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
-            g = torch.randn_like(fn(*inputs, **kwinputs))
+            out = fn(*inputs, **kwinputs)
+            if type(out) is tuple:
+                out = out[0]
+            g = torch.randn_like(out)
     for _ in range(30):  # Warm up
         if backward:
             for x in inputs:
@@ -221,6 +224,8 @@ def pytorch_profiler(
                     x.grad = None
         with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
             out = fn(*inputs, **kwinputs)
+            if type(out) is tuple:
+                out = out[0]
         # Backward should be done outside autocast
         if backward:
             out.backward(g, retain_graph=True)
@@ -239,6 +244,8 @@ def pytorch_profiler(
                     x.grad = None
         with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
             out = fn(*inputs, **kwinputs)
+            if type(out) is tuple:
+                out = out[0]
         if backward:
             out.backward(g, retain_graph=True)
     if verbose:

+ 103 - 0
tests/ops/triton/test_layer_norm.py

@@ -0,0 +1,103 @@
+import math
+from functools import partial
+
+import pytest
+import torch
+import torch.nn.functional as F
+from einops import rearrange, repeat
+from flash_attn.ops.triton.layernorm import layer_norm_fn, layer_norm_ref, rms_norm_ref
+
+
+is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
+
+
+@pytest.mark.parametrize("is_rms_norm", [False, True])
+# @pytest.mark.parametrize("is_rms_norm", [True])
+@pytest.mark.parametrize("has_residual", [True, False])
+# @pytest.mark.parametrize("has_residual", [True])
+@pytest.mark.parametrize(
+    "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
+)
+# @pytest.mark.parametrize("weight_dtype", [torch.float32])
+@pytest.mark.parametrize(
+    "input_dtype,residual_dtype",
+    [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
+    + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
+)
+# @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
+@pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 8192])
+# @pytest.mark.parametrize("hidden_size", [256])
+def test_layer_norm(
+    hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm
+):
+    device = "cuda"
+    if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
+        atol = 5e-2
+    elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
+        atol = 5e-3
+    else:
+        atol = 1e-4
+    # set seed
+    torch.random.manual_seed(0)
+    batch_size = 8
+    seqlen = 512
+    # batch_size = 1
+    # seqlen = 1
+    layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
+    allclose = (
+        lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
+        <= 2 * (x_pt - x_ref).abs().max() + atol
+    )
+    x0 = torch.randn(
+        batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
+    )
+    x0_pt = x0.detach().clone().requires_grad_()
+    x0_ref = x0.detach().clone().requires_grad_()
+    if has_residual:
+        res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
+        res_pt = res.detach().clone().requires_grad_()
+        res_ref = res.detach().clone().requires_grad_()
+    else:
+        res, res_pt, res_ref = None, None, None
+    weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+    if not is_rms_norm:
+        bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
+    else:
+        bias = None
+    weight_pt = weight.detach().clone().requires_grad_()
+    weight_ref = weight.detach().clone().requires_grad_()
+    bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
+    bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
+    residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
+
+    out, *rest = layer_norm_fn(x0, weight, bias, residual=res, eps=1e-6, is_rms_norm=is_rms_norm)
+    out_pt, *rest_pt = layer_norm_ref_fn(x0_pt, weight_pt, bias_pt, residual=res_pt, eps=1e-6)
+    out_ref, *rest_ref = layer_norm_ref_fn(
+        x0_ref, weight_ref, bias_ref, residual=res_ref, eps=1e-6, upcast=True
+    )
+    if has_residual:
+        residual = rest[0]
+        residual_pt = rest_pt[0]
+        residual_ref = rest_ref[0]
+        residual_ref = x0_ref + res_ref
+    assert out.dtype == input_dtype
+    if has_residual:
+        assert residual.dtype == residual_dtype
+        assert allclose(residual, residual_pt, residual_ref)
+    assert allclose(out, out_pt, out_ref)
+
+    g = torch.randn_like(out) / batch_size
+    if not has_residual:
+        out.backward(g)
+        out_pt.backward(g)
+        out_ref.backward(g)
+    else:
+        (out * F.sigmoid(residual)).backward(g)
+        (out_pt * F.sigmoid(residual_pt)).backward(g)
+        (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
+    assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
+    if has_residual:
+        assert allclose(res.grad, res_pt.grad, res_ref.grad)
+    assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
+    if bias is not None:
+        assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)