|
@@ -0,0 +1,395 @@
|
|
|
+# Copyright (c) 2023, Tri Dao.
|
|
|
+# Implement residual + layer_norm / rms_norm.
|
|
|
+
|
|
|
+# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
|
|
+# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
|
|
+# This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
|
|
+# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
|
|
+
|
|
|
+import math
|
|
|
+
|
|
|
+import torch
|
|
|
+import torch.nn.functional as F
|
|
|
+
|
|
|
+import triton
|
|
|
+import triton.language as tl
|
|
|
+
|
|
|
+
|
|
|
+def layer_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
|
|
|
+ dtype = x.dtype
|
|
|
+ if upcast:
|
|
|
+ weight = weight.float()
|
|
|
+ bias = bias.float() if bias is not None else None
|
|
|
+ if upcast:
|
|
|
+ x = x.float()
|
|
|
+ residual = residual.float() if residual is not None else residual
|
|
|
+ if residual is not None:
|
|
|
+ x = (x + residual).to(x.dtype)
|
|
|
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(dtype)
|
|
|
+ return out if residual is None else (out, x)
|
|
|
+
|
|
|
+
|
|
|
+def rms_norm_ref(x, weight, bias, residual=None, eps=1e-6, upcast=False):
|
|
|
+ dtype = x.dtype
|
|
|
+ if upcast:
|
|
|
+ weight = weight.float()
|
|
|
+ bias = bias.float() if bias is not None else None
|
|
|
+ if upcast:
|
|
|
+ x = x.float()
|
|
|
+ residual = residual.float() if residual is not None else residual
|
|
|
+ if residual is not None:
|
|
|
+ x = (x + residual).to(x.dtype)
|
|
|
+ rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
|
|
+ out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
|
|
+ out = out.to(dtype)
|
|
|
+ return out if residual is None else (out, x)
|
|
|
+
|
|
|
+
|
|
|
+@triton.autotune(
|
|
|
+ configs=[
|
|
|
+ triton.Config({}, num_warps=1),
|
|
|
+ triton.Config({}, num_warps=2),
|
|
|
+ triton.Config({}, num_warps=4),
|
|
|
+ triton.Config({}, num_warps=8),
|
|
|
+ triton.Config({}, num_warps=16),
|
|
|
+ triton.Config({}, num_warps=32),
|
|
|
+ ],
|
|
|
+ key=["N", "HAS_RESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
|
|
+)
|
|
|
+# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
|
|
+# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
|
|
|
+@triton.jit
|
|
|
+def _layer_norm_fwd_1pass_kernel(
|
|
|
+ X, # pointer to the input
|
|
|
+ Y, # pointer to the output
|
|
|
+ W, # pointer to the weights
|
|
|
+ B, # pointer to the biases
|
|
|
+ RESIDUAL, # pointer to the residual
|
|
|
+ RESIDUAL_OUT, # pointer to the residual
|
|
|
+ Mean, # pointer to the mean
|
|
|
+ Rstd, # pointer to the 1/std
|
|
|
+ stride_x_row, # how much to increase the pointer when moving by 1 row
|
|
|
+ stride_y_row,
|
|
|
+ stride_res_row,
|
|
|
+ stride_res_out_row,
|
|
|
+ N, # number of columns in X
|
|
|
+ eps, # epsilon to avoid division by zero
|
|
|
+ IS_RMS_NORM: tl.constexpr,
|
|
|
+ BLOCK_N: tl.constexpr,
|
|
|
+ HAS_RESIDUAL: tl.constexpr,
|
|
|
+ HAS_BIAS: tl.constexpr,
|
|
|
+):
|
|
|
+ # Map the program id to the row of X and Y it should compute.
|
|
|
+ row = tl.program_id(0)
|
|
|
+ X += row * stride_x_row
|
|
|
+ Y += row * stride_y_row
|
|
|
+ if HAS_RESIDUAL:
|
|
|
+ RESIDUAL += row * stride_res_row
|
|
|
+ RESIDUAL_OUT += row * stride_res_out_row
|
|
|
+ # Compute mean and variance
|
|
|
+ cols = tl.arange(0, BLOCK_N)
|
|
|
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
|
|
+ if HAS_RESIDUAL:
|
|
|
+ residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.).to(tl.float32)
|
|
|
+ x += residual
|
|
|
+ tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
|
|
|
+ if not IS_RMS_NORM:
|
|
|
+ mean = tl.sum(x, axis=0) / N
|
|
|
+ tl.store(Mean + row, mean)
|
|
|
+ xbar = tl.where(cols < N, x - mean, 0.)
|
|
|
+ var = tl.sum(xbar * xbar, axis=0) / N
|
|
|
+ else:
|
|
|
+ xbar = tl.where(cols < N, x, 0.)
|
|
|
+ var = tl.sum(xbar * xbar, axis=0) / N
|
|
|
+ rstd = 1 / tl.sqrt(var + eps)
|
|
|
+ tl.store(Rstd + row, rstd)
|
|
|
+ # Normalize and apply linear transformation
|
|
|
+ mask = cols < N
|
|
|
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
|
|
|
+ if HAS_BIAS:
|
|
|
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
|
|
|
+ x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
|
|
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
|
|
|
+ # Write output
|
|
|
+ tl.store(Y + cols, y, mask=mask)
|
|
|
+
|
|
|
+
|
|
|
+def _layer_norm_fwd(x, weight, bias, eps, residual=None, is_rms_norm=False):
|
|
|
+ M, N = x.shape
|
|
|
+ assert x.stride(-1) == 1
|
|
|
+ if residual is not None:
|
|
|
+ assert residual.stride(-1) == 1
|
|
|
+ assert residual.shape == (M, N)
|
|
|
+ assert weight.shape == (N,)
|
|
|
+ assert weight.stride(-1) == 1
|
|
|
+ if bias is not None:
|
|
|
+ assert bias.stride(-1) == 1
|
|
|
+ assert bias.shape == (N,)
|
|
|
+ # allocate output
|
|
|
+ y = torch.empty_like(x)
|
|
|
+ assert y.stride(-1) == 1
|
|
|
+ if residual is not None:
|
|
|
+ residual_out = torch.empty_like(residual)
|
|
|
+ assert residual_out.stride(-1) == 1
|
|
|
+ else:
|
|
|
+ residual_out = None
|
|
|
+ mean = torch.empty((M, ), dtype=torch.float32, device='cuda') if not is_rms_norm else None
|
|
|
+ rstd = torch.empty((M, ), dtype=torch.float32, device='cuda')
|
|
|
+ # Less than 64KB per feature: enqueue fused kernel
|
|
|
+ MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
|
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
|
|
+ if N > BLOCK_N:
|
|
|
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
|
+ # heuristics for number of warps
|
|
|
+ with torch.cuda.device(x.device.index):
|
|
|
+ _layer_norm_fwd_1pass_kernel[(M,)](x, y, weight, bias, residual, residual_out,
|
|
|
+ mean, rstd,
|
|
|
+ x.stride(0), y.stride(0),
|
|
|
+ residual.stride(0) if residual is not None else 0,
|
|
|
+ residual_out.stride(0) if residual is not None else 0,
|
|
|
+ N, eps,
|
|
|
+ is_rms_norm,
|
|
|
+ BLOCK_N,
|
|
|
+ residual is not None,
|
|
|
+ bias is not None,
|
|
|
+ )
|
|
|
+ return y, mean, rstd, residual_out
|
|
|
+
|
|
|
+
|
|
|
+@triton.autotune(
|
|
|
+ configs=[
|
|
|
+ triton.Config({}, num_warps=1),
|
|
|
+ triton.Config({}, num_warps=2),
|
|
|
+ triton.Config({}, num_warps=4),
|
|
|
+ triton.Config({}, num_warps=8),
|
|
|
+ triton.Config({}, num_warps=16),
|
|
|
+ triton.Config({}, num_warps=32),
|
|
|
+ ],
|
|
|
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS"],
|
|
|
+)
|
|
|
+# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
|
|
+# @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
|
|
|
+# @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
|
|
|
+@triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
|
|
|
+@triton.jit
|
|
|
+def _layer_norm_bwd_kernel(
|
|
|
+ X, # pointer to the input
|
|
|
+ W, # pointer to the weights
|
|
|
+ B, # pointer to the biases
|
|
|
+ Y, # pointer to the output to be recomputed
|
|
|
+ DY, # pointer to the output gradient
|
|
|
+ DX, # pointer to the input gradient
|
|
|
+ DW, # pointer to the partial sum of weights gradient
|
|
|
+ DB, # pointer to the partial sum of biases gradient
|
|
|
+ DRESIDUAL,
|
|
|
+ DRESIDUAL_IN,
|
|
|
+ Mean, # pointer to the mean
|
|
|
+ Rstd, # pointer to the 1/std
|
|
|
+ stride_x_row, # how much to increase the pointer when moving by 1 row
|
|
|
+ stride_y_row,
|
|
|
+ stride_dy_row,
|
|
|
+ stride_dx_row,
|
|
|
+ stride_dres_row,
|
|
|
+ stride_dres_in_row,
|
|
|
+ M, # number of rows in X
|
|
|
+ N, # number of columns in X
|
|
|
+ eps, # epsilon to avoid division by zero
|
|
|
+ rows_per_program,
|
|
|
+ IS_RMS_NORM: tl.constexpr,
|
|
|
+ BLOCK_N: tl.constexpr,
|
|
|
+ HAS_DRESIDUAL: tl.constexpr,
|
|
|
+ STORE_DRESIDUAL: tl.constexpr,
|
|
|
+ HAS_BIAS: tl.constexpr,
|
|
|
+ RECOMPUTE_OUTPUT: tl.constexpr,
|
|
|
+):
|
|
|
+ # Map the program id to the elements of X, DX, and DY it should compute.
|
|
|
+ row_block_id = tl.program_id(0)
|
|
|
+ row_start = row_block_id * rows_per_program
|
|
|
+ cols = tl.arange(0, BLOCK_N)
|
|
|
+ mask = cols < N
|
|
|
+ X += row_start * stride_x_row
|
|
|
+ if HAS_DRESIDUAL:
|
|
|
+ DRESIDUAL += row_start * stride_dres_row
|
|
|
+ if STORE_DRESIDUAL:
|
|
|
+ DRESIDUAL_IN += row_start * stride_dres_in_row
|
|
|
+ DY += row_start * stride_dy_row
|
|
|
+ DX += row_start * stride_dx_row
|
|
|
+ if RECOMPUTE_OUTPUT:
|
|
|
+ Y += row_start * stride_y_row
|
|
|
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
|
|
|
+ if RECOMPUTE_OUTPUT and HAS_BIAS:
|
|
|
+ b = tl.load(B + cols, mask=mask, other=0.).to(tl.float32)
|
|
|
+ dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
|
+ if HAS_BIAS:
|
|
|
+ db = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
|
|
+ row_end = min((row_block_id + 1) * rows_per_program, M)
|
|
|
+ for row in range(row_start, row_end):
|
|
|
+ # Load data to SRAM
|
|
|
+ x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
|
|
|
+ dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
|
|
|
+ if not IS_RMS_NORM:
|
|
|
+ mean = tl.load(Mean + row)
|
|
|
+ rstd = tl.load(Rstd + row)
|
|
|
+ # Compute dx
|
|
|
+ xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
|
|
+ xhat = tl.where(mask, xhat, 0.)
|
|
|
+ if RECOMPUTE_OUTPUT:
|
|
|
+ y = xhat * w + b if HAS_BIAS else xhat * w
|
|
|
+ tl.store(Y + cols, y, mask=mask)
|
|
|
+ wdy = w * dy
|
|
|
+ dw += dy * xhat
|
|
|
+ if HAS_BIAS:
|
|
|
+ db += dy
|
|
|
+ if not IS_RMS_NORM:
|
|
|
+ c1 = tl.sum(xhat * wdy, axis=0) / N
|
|
|
+ c2 = tl.sum(wdy, axis=0) / N
|
|
|
+ dx = (wdy - (xhat * c1 + c2)) * rstd
|
|
|
+ else:
|
|
|
+ c1 = tl.sum(xhat * wdy, axis=0) / N
|
|
|
+ dx = (wdy - xhat * c1) * rstd
|
|
|
+ if HAS_DRESIDUAL:
|
|
|
+ dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
|
|
|
+ dx += dres
|
|
|
+ # Write dx
|
|
|
+ if STORE_DRESIDUAL:
|
|
|
+ tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
|
|
|
+ tl.store(DX + cols, dx, mask=mask)
|
|
|
+
|
|
|
+ X += stride_x_row
|
|
|
+ if HAS_DRESIDUAL:
|
|
|
+ DRESIDUAL += stride_dres_row
|
|
|
+ if STORE_DRESIDUAL:
|
|
|
+ DRESIDUAL_IN += stride_dres_in_row
|
|
|
+ if RECOMPUTE_OUTPUT:
|
|
|
+ Y += stride_y_row
|
|
|
+ DY += stride_dy_row
|
|
|
+ DX += stride_dx_row
|
|
|
+ tl.store(DW + row_block_id * N + cols, dw, mask=mask)
|
|
|
+ if HAS_BIAS:
|
|
|
+ tl.store(DB + row_block_id * N + cols, db, mask=mask)
|
|
|
+
|
|
|
+
|
|
|
+def _layer_norm_bwd(dy, x, weight, bias, eps, mean, rstd, dresidual=None, is_rms_norm=False, x_dtype=None,
|
|
|
+ recompute_output=False):
|
|
|
+ M, N = x.shape
|
|
|
+ assert x.stride(-1) == 1
|
|
|
+ assert dy.stride(-1) == 1
|
|
|
+ assert dy.shape == (M, N)
|
|
|
+ if dresidual is not None:
|
|
|
+ assert dresidual.stride(-1) == 1
|
|
|
+ assert dresidual.shape == (M, N)
|
|
|
+ assert weight.shape == (N,)
|
|
|
+ assert weight.stride(-1) == 1
|
|
|
+ if bias is not None:
|
|
|
+ assert bias.stride(-1) == 1
|
|
|
+ assert bias.shape == (N,)
|
|
|
+ # allocate output
|
|
|
+ dx = torch.empty_like(x) if x_dtype is None else torch.empty(M, N, dtype=x_dtype, device=x.device)
|
|
|
+ dresidual_in = torch.empty_like(dresidual) if dresidual is not None and dx.dtype != dresidual.dtype else None
|
|
|
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
|
|
|
+
|
|
|
+ # Less than 64KB per feature: enqueue fused kernel
|
|
|
+ MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
|
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
|
|
|
+ if N > BLOCK_N:
|
|
|
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
|
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
|
|
|
+ _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
|
|
|
+ _db = torch.empty((sm_count, N), dtype=torch.float32, device=bias.device) if bias is not None else None
|
|
|
+ rows_per_program = math.ceil(M / sm_count)
|
|
|
+ grid = (sm_count,)
|
|
|
+ with torch.cuda.device(x.device.index):
|
|
|
+ _layer_norm_bwd_kernel[grid](x, weight, bias, y,
|
|
|
+ dy, dx, _dw, _db, dresidual, dresidual_in,
|
|
|
+ mean, rstd,
|
|
|
+ x.stride(0),
|
|
|
+ 0 if not recompute_output else y.stride(0),
|
|
|
+ dy.stride(0), dx.stride(0),
|
|
|
+ dresidual.stride(0) if dresidual is not None else 0,
|
|
|
+ dresidual_in.stride(0) if dresidual_in is not None else 0,
|
|
|
+ M, N, eps,
|
|
|
+ rows_per_program,
|
|
|
+ is_rms_norm,
|
|
|
+ BLOCK_N,
|
|
|
+ dresidual is not None,
|
|
|
+ dresidual_in is not None,
|
|
|
+ bias is not None)
|
|
|
+ dw = _dw.sum(0).to(weight.dtype)
|
|
|
+ db = _db.sum(0).to(bias.dtype) if bias is not None else None
|
|
|
+ # Don't need to compute dresidual_in separately in this case
|
|
|
+ if dresidual is not None and dx.dtype == dresidual.dtype:
|
|
|
+ dresidual_in = dx
|
|
|
+ return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)
|
|
|
+
|
|
|
+
|
|
|
+class LayerNormFn(torch.autograd.Function):
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def forward(ctx, x, weight, bias, residual=None, eps=1e-6, is_rms_norm=False):
|
|
|
+ x_shape_og = x.shape
|
|
|
+ # reshape input data into 2D tensor
|
|
|
+ x = x.reshape(-1, x.shape[-1])
|
|
|
+ if x.stride(-1) != 1:
|
|
|
+ x = x.contiguous()
|
|
|
+ if residual is not None:
|
|
|
+ assert residual.shape == x_shape_og
|
|
|
+ residual = residual.reshape(-1, residual.shape[-1])
|
|
|
+ if residual.stride(-1) != 1:
|
|
|
+ residual = residual.contiguous()
|
|
|
+ weight = weight.contiguous()
|
|
|
+ if bias is not None:
|
|
|
+ bias = bias.contiguous()
|
|
|
+ y, mean, rstd, *rest = _layer_norm_fwd(x, weight, bias, eps, residual, is_rms_norm)
|
|
|
+ if residual is not None:
|
|
|
+ residual_out = rest[0]
|
|
|
+ ctx.save_for_backward(x if residual is None else residual_out, weight, bias, mean, rstd)
|
|
|
+ ctx.x_shape_og = x_shape_og
|
|
|
+ ctx.eps = eps
|
|
|
+ ctx.is_rms_norm = is_rms_norm
|
|
|
+ ctx.has_residual = residual is not None
|
|
|
+ ctx.x_dtype = x.dtype
|
|
|
+ y = y.reshape(x_shape_og)
|
|
|
+ return y if residual is None else (y, residual_out.reshape(x_shape_og))
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def backward(ctx, dy, *args):
|
|
|
+ x, weight, bias, mean, rstd = ctx.saved_tensors
|
|
|
+ dy = dy.reshape(-1, dy.shape[-1])
|
|
|
+ if dy.stride(-1) != 1:
|
|
|
+ dy = dy.contiguous()
|
|
|
+ assert dy.shape == x.shape
|
|
|
+ if ctx.has_residual:
|
|
|
+ dresidual = args[0]
|
|
|
+ dresidual = dresidual.reshape(-1, dresidual.shape[-1])
|
|
|
+ if dresidual.stride(-1) != 1:
|
|
|
+ dresidual = dresidual.contiguous()
|
|
|
+ assert dresidual.shape == x.shape
|
|
|
+ else:
|
|
|
+ dresidual = None
|
|
|
+ dx, dw, db, dresidual_in = _layer_norm_bwd(dy, x, weight, bias, ctx.eps, mean, rstd, dresidual,
|
|
|
+ ctx.is_rms_norm, x_dtype=ctx.x_dtype)
|
|
|
+ return dx.reshape(ctx.x_shape_og), dw, db, dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None, None, None
|
|
|
+
|
|
|
+
|
|
|
+def layer_norm_fn(x, weight, bias, residual=None, eps=1e-6, is_rms_norm=False):
|
|
|
+ return LayerNormFn.apply(x, weight, bias, residual, eps, is_rms_norm)
|
|
|
+
|
|
|
+
|
|
|
+def rms_norm_fn(x, weight, bias, residual=None, eps=1e-6):
|
|
|
+ return LayerNormFn.apply(x, weight, bias, residual, eps, True)
|
|
|
+
|
|
|
+
|
|
|
+class RMSNorm(torch.nn.Module):
|
|
|
+ def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
|
|
|
+ factory_kwargs = {"device": device, "dtype": dtype}
|
|
|
+ super().__init__()
|
|
|
+ self.eps = eps
|
|
|
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
|
|
+ self.register_parameter("bias", None)
|
|
|
+ self.reset_parameters()
|
|
|
+
|
|
|
+ def reset_parameters(self):
|
|
|
+ torch.nn.init.ones_(self.weight)
|
|
|
+
|
|
|
+ def forward(self, x, residual=None):
|
|
|
+ return layer_norm_fn(x, self.weight, self.bias, residual=residual, eps=self.eps, is_rms_norm=True)
|