123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086 |
- # Copyright (c) 2024, Tri Dao.
- # Implement dropout + residual + layer_norm / rms_norm.
- # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
- # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
- # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
- # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
- import math
- import torch
- import torch.nn.functional as F
- from torch.cuda.amp import custom_fwd, custom_bwd
- import triton
- import triton.language as tl
- def layer_norm_ref(
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- dropout_mask=None,
- dropout_mask1=None,
- upcast=False,
- ):
- dtype = x.dtype
- if upcast:
- x = x.float()
- weight = weight.float()
- bias = bias.float() if bias is not None else None
- residual = residual.float() if residual is not None else residual
- x1 = x1.float() if x1 is not None else None
- weight1 = weight1.float() if weight1 is not None else None
- bias1 = bias1.float() if bias1 is not None else None
- if x1 is not None:
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
- if rowscale is not None:
- x = x * rowscale[..., None]
- if dropout_p > 0.0:
- if dropout_mask is not None:
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
- else:
- x = F.dropout(x, p=dropout_p)
- if x1 is not None:
- if dropout_mask1 is not None:
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
- else:
- x1 = F.dropout(x1, p=dropout_p)
- if x1 is not None:
- x = x + x1
- if residual is not None:
- x = (x + residual).to(x.dtype)
- out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
- dtype
- )
- if weight1 is None:
- return out if not prenorm else (out, x)
- else:
- out1 = F.layer_norm(
- x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
- ).to(dtype)
- return (out, out1) if not prenorm else (out, out1, x)
- def rms_norm_ref(
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- dropout_mask=None,
- dropout_mask1=None,
- upcast=False,
- ):
- dtype = x.dtype
- if upcast:
- x = x.float()
- weight = weight.float()
- bias = bias.float() if bias is not None else None
- residual = residual.float() if residual is not None else residual
- x1 = x1.float() if x1 is not None else None
- weight1 = weight1.float() if weight1 is not None else None
- bias1 = bias1.float() if bias1 is not None else None
- if x1 is not None:
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
- if rowscale is not None:
- x = x * rowscale[..., None]
- if dropout_p > 0.0:
- if dropout_mask is not None:
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
- else:
- x = F.dropout(x, p=dropout_p)
- if x1 is not None:
- if dropout_mask1 is not None:
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
- else:
- x1 = F.dropout(x1, p=dropout_p)
- if x1 is not None:
- x = x + x1
- if residual is not None:
- x = (x + residual).to(x.dtype)
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
- if weight1 is None:
- return out if not prenorm else (out, x)
- else:
- out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
- dtype
- )
- return (out, out1) if not prenorm else (out, out1, x)
- @triton.autotune(
- configs=[
- triton.Config({}, num_warps=1),
- triton.Config({}, num_warps=2),
- triton.Config({}, num_warps=4),
- triton.Config({}, num_warps=8),
- triton.Config({}, num_warps=16),
- triton.Config({}, num_warps=32),
- ],
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
- )
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
- # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
- @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
- @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
- @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
- @triton.jit
- def _layer_norm_fwd_1pass_kernel(
- X, # pointer to the input
- Y, # pointer to the output
- W, # pointer to the weights
- B, # pointer to the biases
- RESIDUAL, # pointer to the residual
- X1,
- W1,
- B1,
- Y1,
- RESIDUAL_OUT, # pointer to the residual
- ROWSCALE,
- SEEDS, # Dropout seeds for each row
- DROPOUT_MASK,
- Mean, # pointer to the mean
- Rstd, # pointer to the 1/std
- stride_x_row, # how much to increase the pointer when moving by 1 row
- stride_y_row,
- stride_res_row,
- stride_res_out_row,
- stride_x1_row,
- stride_y1_row,
- M, # number of rows in X
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- dropout_p, # Dropout probability
- IS_RMS_NORM: tl.constexpr,
- BLOCK_N: tl.constexpr,
- HAS_RESIDUAL: tl.constexpr,
- STORE_RESIDUAL_OUT: tl.constexpr,
- HAS_BIAS: tl.constexpr,
- HAS_DROPOUT: tl.constexpr,
- STORE_DROPOUT_MASK: tl.constexpr,
- HAS_ROWSCALE: tl.constexpr,
- HAS_X1: tl.constexpr,
- HAS_W1: tl.constexpr,
- HAS_B1: tl.constexpr,
- ):
- # Map the program id to the row of X and Y it should compute.
- row = tl.program_id(0)
- X += row * stride_x_row
- Y += row * stride_y_row
- if HAS_RESIDUAL:
- RESIDUAL += row * stride_res_row
- if STORE_RESIDUAL_OUT:
- RESIDUAL_OUT += row * stride_res_out_row
- if HAS_X1:
- X1 += row * stride_x1_row
- if HAS_W1:
- Y1 += row * stride_y1_row
- # Compute mean and variance
- cols = tl.arange(0, BLOCK_N)
- x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
- if HAS_ROWSCALE:
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
- x *= rowscale
- if HAS_DROPOUT:
- # Compute dropout mask
- # 7 rounds is good enough, and reduces register pressure
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
- x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
- if STORE_DROPOUT_MASK:
- tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
- if HAS_X1:
- x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
- if HAS_ROWSCALE:
- rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
- x1 *= rowscale
- if HAS_DROPOUT:
- # Compute dropout mask
- # 7 rounds is good enough, and reduces register pressure
- keep_mask = (
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
- )
- x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
- if STORE_DROPOUT_MASK:
- tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
- x += x1
- if HAS_RESIDUAL:
- residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
- x += residual
- if STORE_RESIDUAL_OUT:
- tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
- if not IS_RMS_NORM:
- mean = tl.sum(x, axis=0) / N
- tl.store(Mean + row, mean)
- xbar = tl.where(cols < N, x - mean, 0.0)
- var = tl.sum(xbar * xbar, axis=0) / N
- else:
- xbar = tl.where(cols < N, x, 0.0)
- var = tl.sum(xbar * xbar, axis=0) / N
- rstd = 1 / tl.sqrt(var + eps)
- tl.store(Rstd + row, rstd)
- # Normalize and apply linear transformation
- mask = cols < N
- w = tl.load(W + cols, mask=mask).to(tl.float32)
- if HAS_BIAS:
- b = tl.load(B + cols, mask=mask).to(tl.float32)
- x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
- y = x_hat * w + b if HAS_BIAS else x_hat * w
- # Write output
- tl.store(Y + cols, y, mask=mask)
- if HAS_W1:
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
- if HAS_B1:
- b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
- y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
- tl.store(Y1 + cols, y1, mask=mask)
- def _layer_norm_fwd(
- x,
- weight,
- bias,
- eps,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- dropout_p=0.0,
- rowscale=None,
- out_dtype=None,
- residual_dtype=None,
- is_rms_norm=False,
- return_dropout_mask=False,
- ):
- if residual is not None:
- residual_dtype = residual.dtype
- M, N = x.shape
- assert x.stride(-1) == 1
- if residual is not None:
- assert residual.stride(-1) == 1
- assert residual.shape == (M, N)
- assert weight.shape == (N,)
- assert weight.stride(-1) == 1
- if bias is not None:
- assert bias.stride(-1) == 1
- assert bias.shape == (N,)
- if x1 is not None:
- assert x1.shape == x.shape
- assert rowscale is None
- assert x1.stride(-1) == 1
- if weight1 is not None:
- assert weight1.shape == (N,)
- assert weight1.stride(-1) == 1
- if bias1 is not None:
- assert bias1.shape == (N,)
- assert bias1.stride(-1) == 1
- if rowscale is not None:
- assert rowscale.is_contiguous()
- assert rowscale.shape == (M,)
- # allocate output
- y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
- assert y.stride(-1) == 1
- if weight1 is not None:
- y1 = torch.empty_like(y)
- assert y1.stride(-1) == 1
- else:
- y1 = None
- if (
- residual is not None
- or (residual_dtype is not None and residual_dtype != x.dtype)
- or dropout_p > 0.0
- or rowscale is not None
- or x1 is not None
- ):
- residual_out = torch.empty(
- M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
- )
- assert residual_out.stride(-1) == 1
- else:
- residual_out = None
- mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
- rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
- if dropout_p > 0.0:
- seeds = torch.randint(
- 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
- )
- else:
- seeds = None
- if return_dropout_mask and dropout_p > 0.0:
- dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
- else:
- dropout_mask = None
- # Less than 64KB per feature: enqueue fused kernel
- MAX_FUSED_SIZE = 65536 // x.element_size()
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
- if N > BLOCK_N:
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
- with torch.cuda.device(x.device.index):
- _layer_norm_fwd_1pass_kernel[(M,)](
- x,
- y,
- weight,
- bias,
- residual,
- x1,
- weight1,
- bias1,
- y1,
- residual_out,
- rowscale,
- seeds,
- dropout_mask,
- mean,
- rstd,
- x.stride(0),
- y.stride(0),
- residual.stride(0) if residual is not None else 0,
- residual_out.stride(0) if residual_out is not None else 0,
- x1.stride(0) if x1 is not None else 0,
- y1.stride(0) if y1 is not None else 0,
- M,
- N,
- eps,
- dropout_p,
- is_rms_norm,
- BLOCK_N,
- residual is not None,
- residual_out is not None,
- bias is not None,
- dropout_p > 0.0,
- dropout_mask is not None,
- rowscale is not None,
- )
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
- if dropout_mask is not None and x1 is not None:
- dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
- else:
- dropout_mask1 = None
- return (
- y,
- y1,
- mean,
- rstd,
- residual_out if residual_out is not None else x,
- seeds,
- dropout_mask,
- dropout_mask1,
- )
- @triton.autotune(
- configs=[
- triton.Config({}, num_warps=1),
- triton.Config({}, num_warps=2),
- triton.Config({}, num_warps=4),
- triton.Config({}, num_warps=8),
- triton.Config({}, num_warps=16),
- triton.Config({}, num_warps=32),
- ],
- key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
- )
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
- # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
- # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
- @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
- @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
- @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
- @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
- @triton.jit
- def _layer_norm_bwd_kernel(
- X, # pointer to the input
- W, # pointer to the weights
- B, # pointer to the biases
- Y, # pointer to the output to be recomputed
- DY, # pointer to the output gradient
- DX, # pointer to the input gradient
- DW, # pointer to the partial sum of weights gradient
- DB, # pointer to the partial sum of biases gradient
- DRESIDUAL,
- W1,
- DY1,
- DX1,
- DW1,
- DB1,
- DRESIDUAL_IN,
- ROWSCALE,
- SEEDS,
- Mean, # pointer to the mean
- Rstd, # pointer to the 1/std
- stride_x_row, # how much to increase the pointer when moving by 1 row
- stride_y_row,
- stride_dy_row,
- stride_dx_row,
- stride_dres_row,
- stride_dy1_row,
- stride_dx1_row,
- stride_dres_in_row,
- M, # number of rows in X
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- dropout_p,
- rows_per_program,
- IS_RMS_NORM: tl.constexpr,
- BLOCK_N: tl.constexpr,
- HAS_DRESIDUAL: tl.constexpr,
- STORE_DRESIDUAL: tl.constexpr,
- HAS_BIAS: tl.constexpr,
- HAS_DROPOUT: tl.constexpr,
- HAS_ROWSCALE: tl.constexpr,
- HAS_DY1: tl.constexpr,
- HAS_DX1: tl.constexpr,
- HAS_B1: tl.constexpr,
- RECOMPUTE_OUTPUT: tl.constexpr,
- ):
- # Map the program id to the elements of X, DX, and DY it should compute.
- row_block_id = tl.program_id(0)
- row_start = row_block_id * rows_per_program
- # Do not early exit if row_start >= M, because we need to write DW and DB
- cols = tl.arange(0, BLOCK_N)
- mask = cols < N
- X += row_start * stride_x_row
- if HAS_DRESIDUAL:
- DRESIDUAL += row_start * stride_dres_row
- if STORE_DRESIDUAL:
- DRESIDUAL_IN += row_start * stride_dres_in_row
- DY += row_start * stride_dy_row
- DX += row_start * stride_dx_row
- if HAS_DY1:
- DY1 += row_start * stride_dy1_row
- if HAS_DX1:
- DX1 += row_start * stride_dx1_row
- if RECOMPUTE_OUTPUT:
- Y += row_start * stride_y_row
- w = tl.load(W + cols, mask=mask).to(tl.float32)
- if RECOMPUTE_OUTPUT and HAS_BIAS:
- b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
- if HAS_DY1:
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
- dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
- if HAS_BIAS:
- db = tl.zeros((BLOCK_N,), dtype=tl.float32)
- if HAS_DY1:
- dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
- if HAS_B1:
- db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
- row_end = min((row_block_id + 1) * rows_per_program, M)
- for row in range(row_start, row_end):
- # Load data to SRAM
- x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
- dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
- if HAS_DY1:
- dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
- if not IS_RMS_NORM:
- mean = tl.load(Mean + row)
- rstd = tl.load(Rstd + row)
- # Compute dx
- xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
- xhat = tl.where(mask, xhat, 0.0)
- if RECOMPUTE_OUTPUT:
- y = xhat * w + b if HAS_BIAS else xhat * w
- tl.store(Y + cols, y, mask=mask)
- wdy = w * dy
- dw += dy * xhat
- if HAS_BIAS:
- db += dy
- if HAS_DY1:
- wdy += w1 * dy1
- dw1 += dy1 * xhat
- if HAS_B1:
- db1 += dy1
- if not IS_RMS_NORM:
- c1 = tl.sum(xhat * wdy, axis=0) / N
- c2 = tl.sum(wdy, axis=0) / N
- dx = (wdy - (xhat * c1 + c2)) * rstd
- else:
- c1 = tl.sum(xhat * wdy, axis=0) / N
- dx = (wdy - xhat * c1) * rstd
- if HAS_DRESIDUAL:
- dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
- dx += dres
- # Write dx
- if STORE_DRESIDUAL:
- tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
- if HAS_DX1:
- if HAS_DROPOUT:
- keep_mask = (
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
- )
- dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
- else:
- dx1 = dx
- tl.store(DX1 + cols, dx1, mask=mask)
- if HAS_DROPOUT:
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
- dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
- if HAS_ROWSCALE:
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
- dx *= rowscale
- tl.store(DX + cols, dx, mask=mask)
- X += stride_x_row
- if HAS_DRESIDUAL:
- DRESIDUAL += stride_dres_row
- if STORE_DRESIDUAL:
- DRESIDUAL_IN += stride_dres_in_row
- if RECOMPUTE_OUTPUT:
- Y += stride_y_row
- DY += stride_dy_row
- DX += stride_dx_row
- if HAS_DY1:
- DY1 += stride_dy1_row
- if HAS_DX1:
- DX1 += stride_dx1_row
- tl.store(DW + row_block_id * N + cols, dw, mask=mask)
- if HAS_BIAS:
- tl.store(DB + row_block_id * N + cols, db, mask=mask)
- if HAS_DY1:
- tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
- if HAS_B1:
- tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
- def _layer_norm_bwd(
- dy,
- x,
- weight,
- bias,
- eps,
- mean,
- rstd,
- dresidual=None,
- dy1=None,
- weight1=None,
- bias1=None,
- seeds=None,
- dropout_p=0.0,
- rowscale=None,
- has_residual=False,
- has_x1=False,
- is_rms_norm=False,
- x_dtype=None,
- recompute_output=False,
- ):
- M, N = x.shape
- assert x.stride(-1) == 1
- assert dy.stride(-1) == 1
- assert dy.shape == (M, N)
- if dresidual is not None:
- assert dresidual.stride(-1) == 1
- assert dresidual.shape == (M, N)
- assert weight.shape == (N,)
- assert weight.stride(-1) == 1
- if bias is not None:
- assert bias.stride(-1) == 1
- assert bias.shape == (N,)
- if dy1 is not None:
- assert weight1 is not None
- assert dy1.shape == dy.shape
- assert dy1.stride(-1) == 1
- if weight1 is not None:
- assert weight1.shape == (N,)
- assert weight1.stride(-1) == 1
- if bias1 is not None:
- assert bias1.shape == (N,)
- assert bias1.stride(-1) == 1
- if seeds is not None:
- assert seeds.is_contiguous()
- assert seeds.shape == (M if not has_x1 else M * 2,)
- if rowscale is not None:
- assert rowscale.is_contiguous()
- assert rowscale.shape == (M,)
- # allocate output
- dx = (
- torch.empty_like(x)
- if x_dtype is None
- else torch.empty(M, N, dtype=x_dtype, device=x.device)
- )
- dresidual_in = (
- torch.empty_like(x)
- if has_residual
- and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
- else None
- )
- dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
- y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
- if recompute_output:
- assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
- # Less than 64KB per feature: enqueue fused kernel
- MAX_FUSED_SIZE = 65536 // x.element_size()
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
- if N > BLOCK_N:
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
- _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
- _db = (
- torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
- if bias is not None
- else None
- )
- _dw1 = torch.empty_like(_dw) if weight1 is not None else None
- _db1 = torch.empty_like(_db) if bias1 is not None else None
- rows_per_program = math.ceil(M / sm_count)
- grid = (sm_count,)
- with torch.cuda.device(x.device.index):
- _layer_norm_bwd_kernel[grid](
- x,
- weight,
- bias,
- y,
- dy,
- dx,
- _dw,
- _db,
- dresidual,
- weight1,
- dy1,
- dx1,
- _dw1,
- _db1,
- dresidual_in,
- rowscale,
- seeds,
- mean,
- rstd,
- x.stride(0),
- 0 if not recompute_output else y.stride(0),
- dy.stride(0),
- dx.stride(0),
- dresidual.stride(0) if dresidual is not None else 0,
- dy1.stride(0) if dy1 is not None else 0,
- dx1.stride(0) if dx1 is not None else 0,
- dresidual_in.stride(0) if dresidual_in is not None else 0,
- M,
- N,
- eps,
- dropout_p,
- rows_per_program,
- is_rms_norm,
- BLOCK_N,
- dresidual is not None,
- dresidual_in is not None,
- bias is not None,
- dropout_p > 0.0,
- )
- dw = _dw.sum(0).to(weight.dtype)
- db = _db.sum(0).to(bias.dtype) if bias is not None else None
- dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
- db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
- # Don't need to compute dresidual_in separately in this case
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
- dresidual_in = dx
- if has_x1 and dropout_p == 0.0:
- dx1 = dx
- return (
- (dx, dw, db, dresidual_in, dx1, dw1, db1)
- if not recompute_output
- else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
- )
- class LayerNormFn(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- residual_in_fp32=False,
- is_rms_norm=False,
- return_dropout_mask=False,
- ):
- x_shape_og = x.shape
- # reshape input data into 2D tensor
- x = x.reshape(-1, x.shape[-1])
- if x.stride(-1) != 1:
- x = x.contiguous()
- if residual is not None:
- assert residual.shape == x_shape_og
- residual = residual.reshape(-1, residual.shape[-1])
- if residual.stride(-1) != 1:
- residual = residual.contiguous()
- if x1 is not None:
- assert x1.shape == x_shape_og
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
- x1 = x1.reshape(-1, x1.shape[-1])
- if x1.stride(-1) != 1:
- x1 = x1.contiguous()
- weight = weight.contiguous()
- if bias is not None:
- bias = bias.contiguous()
- if weight1 is not None:
- weight1 = weight1.contiguous()
- if bias1 is not None:
- bias1 = bias1.contiguous()
- if rowscale is not None:
- rowscale = rowscale.reshape(-1).contiguous()
- residual_dtype = (
- residual.dtype
- if residual is not None
- else (torch.float32 if residual_in_fp32 else None)
- )
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
- x,
- weight,
- bias,
- eps,
- residual,
- x1,
- weight1,
- bias1,
- dropout_p=dropout_p,
- rowscale=rowscale,
- residual_dtype=residual_dtype,
- is_rms_norm=is_rms_norm,
- return_dropout_mask=return_dropout_mask,
- )
- ctx.save_for_backward(
- residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
- )
- ctx.x_shape_og = x_shape_og
- ctx.eps = eps
- ctx.dropout_p = dropout_p
- ctx.is_rms_norm = is_rms_norm
- ctx.has_residual = residual is not None
- ctx.has_x1 = x1 is not None
- ctx.prenorm = prenorm
- ctx.x_dtype = x.dtype
- y = y.reshape(x_shape_og)
- y1 = y1.reshape(x_shape_og) if y1 is not None else None
- residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
- dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
- dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
- if not return_dropout_mask:
- if weight1 is None:
- return y if not prenorm else (y, residual_out)
- else:
- return (y, y1) if not prenorm else (y, y1, residual_out)
- else:
- if weight1 is None:
- return (
- (y, dropout_mask, dropout_mask1)
- if not prenorm
- else (y, residual_out, dropout_mask, dropout_mask1)
- )
- else:
- return (
- (y, y1, dropout_mask, dropout_mask1)
- if not prenorm
- else (y, y1, residual_out, dropout_mask, dropout_mask1)
- )
- @staticmethod
- def backward(ctx, dy, *args):
- x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
- dy = dy.reshape(-1, dy.shape[-1])
- if dy.stride(-1) != 1:
- dy = dy.contiguous()
- assert dy.shape == x.shape
- if weight1 is not None:
- dy1, args = args[0], args[1:]
- dy1 = dy1.reshape(-1, dy1.shape[-1])
- if dy1.stride(-1) != 1:
- dy1 = dy1.contiguous()
- assert dy1.shape == x.shape
- else:
- dy1 = None
- if ctx.prenorm:
- dresidual = args[0]
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
- if dresidual.stride(-1) != 1:
- dresidual = dresidual.contiguous()
- assert dresidual.shape == x.shape
- else:
- dresidual = None
- dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
- dy,
- x,
- weight,
- bias,
- ctx.eps,
- mean,
- rstd,
- dresidual,
- dy1,
- weight1,
- bias1,
- seeds,
- ctx.dropout_p,
- rowscale,
- ctx.has_residual,
- ctx.has_x1,
- ctx.is_rms_norm,
- x_dtype=ctx.x_dtype,
- )
- return (
- dx.reshape(ctx.x_shape_og),
- dw,
- db,
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
- dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
- dw1,
- db1,
- None,
- None,
- None,
- None,
- None,
- None,
- None,
- )
- def layer_norm_fn(
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- residual_in_fp32=False,
- is_rms_norm=False,
- return_dropout_mask=False,
- ):
- return LayerNormFn.apply(
- x,
- weight,
- bias,
- residual,
- x1,
- weight1,
- bias1,
- eps,
- dropout_p,
- rowscale,
- prenorm,
- residual_in_fp32,
- is_rms_norm,
- return_dropout_mask,
- )
- def rms_norm_fn(
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- residual_in_fp32=False,
- return_dropout_mask=False,
- ):
- return LayerNormFn.apply(
- x,
- weight,
- bias,
- residual,
- x1,
- weight1,
- bias1,
- eps,
- dropout_p,
- rowscale,
- prenorm,
- residual_in_fp32,
- True,
- return_dropout_mask,
- )
- class RMSNorm(torch.nn.Module):
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.eps = eps
- if dropout_p > 0.0:
- self.drop = torch.nn.Dropout(dropout_p)
- else:
- self.drop = None
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
- self.register_parameter("bias", None)
- self.reset_parameters()
- def reset_parameters(self):
- torch.nn.init.ones_(self.weight)
- def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
- return rms_norm_fn(
- x,
- self.weight,
- self.bias,
- residual=residual,
- eps=self.eps,
- dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
- prenorm=prenorm,
- residual_in_fp32=residual_in_fp32,
- )
- class LayerNormLinearFn(torch.autograd.Function):
- @staticmethod
- @custom_fwd
- def forward(
- ctx,
- x,
- norm_weight,
- norm_bias,
- linear_weight,
- linear_bias,
- residual=None,
- eps=1e-6,
- prenorm=False,
- residual_in_fp32=False,
- is_rms_norm=False,
- ):
- x_shape_og = x.shape
- # reshape input data into 2D tensor
- x = x.reshape(-1, x.shape[-1])
- if x.stride(-1) != 1:
- x = x.contiguous()
- if residual is not None:
- assert residual.shape == x_shape_og
- residual = residual.reshape(-1, residual.shape[-1])
- if residual.stride(-1) != 1:
- residual = residual.contiguous()
- norm_weight = norm_weight.contiguous()
- if norm_bias is not None:
- norm_bias = norm_bias.contiguous()
- residual_dtype = (
- residual.dtype
- if residual is not None
- else (torch.float32 if residual_in_fp32 else None)
- )
- y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
- x,
- norm_weight,
- norm_bias,
- eps,
- residual,
- out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
- residual_dtype=residual_dtype,
- is_rms_norm=is_rms_norm,
- )
- y = y.reshape(x_shape_og)
- dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
- linear_weight = linear_weight.to(dtype)
- linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
- out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
- # We don't store y, will be recomputed in the backward pass to save memory
- ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
- ctx.x_shape_og = x_shape_og
- ctx.eps = eps
- ctx.is_rms_norm = is_rms_norm
- ctx.has_residual = residual is not None
- ctx.prenorm = prenorm
- ctx.x_dtype = x.dtype
- ctx.linear_bias_is_none = linear_bias is None
- return out if not prenorm else (out, residual_out.reshape(x_shape_og))
- @staticmethod
- @custom_bwd
- def backward(ctx, dout, *args):
- x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
- dout = dout.reshape(-1, dout.shape[-1])
- dy = F.linear(dout, linear_weight.t())
- dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
- if dy.stride(-1) != 1:
- dy = dy.contiguous()
- assert dy.shape == x.shape
- if ctx.prenorm:
- dresidual = args[0]
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
- if dresidual.stride(-1) != 1:
- dresidual = dresidual.contiguous()
- assert dresidual.shape == x.shape
- else:
- dresidual = None
- dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
- dy,
- x,
- norm_weight,
- norm_bias,
- ctx.eps,
- mean,
- rstd,
- dresidual=dresidual,
- has_residual=ctx.has_residual,
- is_rms_norm=ctx.is_rms_norm,
- x_dtype=ctx.x_dtype,
- recompute_output=True,
- )
- dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
- return (
- dx.reshape(ctx.x_shape_og),
- dnorm_weight,
- dnorm_bias,
- dlinear_weight,
- dlinear_bias,
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
- None,
- None,
- None,
- None,
- )
- def layer_norm_linear_fn(
- x,
- norm_weight,
- norm_bias,
- linear_weight,
- linear_bias,
- residual=None,
- eps=1e-6,
- prenorm=False,
- residual_in_fp32=False,
- is_rms_norm=False,
- ):
- return LayerNormLinearFn.apply(
- x,
- norm_weight,
- norm_bias,
- linear_weight,
- linear_bias,
- residual,
- eps,
- prenorm,
- residual_in_fp32,
- is_rms_norm,
- )
|