12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112 |
- # Copyright (c) 2024, Tri Dao.
- # Implement dropout + residual + layer_norm / rms_norm.
- # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
- # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
- # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
- # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
- import math
- import torch
- import torch.nn.functional as F
- from torch.cuda.amp import custom_fwd, custom_bwd
- import triton
- import triton.language as tl
- def layer_norm_ref(
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- dropout_mask=None,
- dropout_mask1=None,
- upcast=False,
- ):
- dtype = x.dtype
- if upcast:
- x = x.float()
- weight = weight.float()
- bias = bias.float() if bias is not None else None
- residual = residual.float() if residual is not None else residual
- x1 = x1.float() if x1 is not None else None
- weight1 = weight1.float() if weight1 is not None else None
- bias1 = bias1.float() if bias1 is not None else None
- if x1 is not None:
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
- if rowscale is not None:
- x = x * rowscale[..., None]
- if dropout_p > 0.0:
- if dropout_mask is not None:
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
- else:
- x = F.dropout(x, p=dropout_p)
- if x1 is not None:
- if dropout_mask1 is not None:
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
- else:
- x1 = F.dropout(x1, p=dropout_p)
- if x1 is not None:
- x = x + x1
- if residual is not None:
- x = (x + residual).to(x.dtype)
- out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
- dtype
- )
- if weight1 is None:
- return out if not prenorm else (out, x)
- else:
- out1 = F.layer_norm(
- x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
- ).to(dtype)
- return (out, out1) if not prenorm else (out, out1, x)
- def rms_norm_ref(
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- dropout_mask=None,
- dropout_mask1=None,
- upcast=False,
- ):
- dtype = x.dtype
- if upcast:
- x = x.float()
- weight = weight.float()
- bias = bias.float() if bias is not None else None
- residual = residual.float() if residual is not None else residual
- x1 = x1.float() if x1 is not None else None
- weight1 = weight1.float() if weight1 is not None else None
- bias1 = bias1.float() if bias1 is not None else None
- if x1 is not None:
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
- if rowscale is not None:
- x = x * rowscale[..., None]
- if dropout_p > 0.0:
- if dropout_mask is not None:
- x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
- else:
- x = F.dropout(x, p=dropout_p)
- if x1 is not None:
- if dropout_mask1 is not None:
- x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
- else:
- x1 = F.dropout(x1, p=dropout_p)
- if x1 is not None:
- x = x + x1
- if residual is not None:
- x = (x + residual).to(x.dtype)
- rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
- if weight1 is None:
- return out if not prenorm else (out, x)
- else:
- out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
- dtype
- )
- return (out, out1) if not prenorm else (out, out1, x)
- @triton.autotune(
- configs=[
- triton.Config({}, num_warps=1),
- triton.Config({}, num_warps=2),
- triton.Config({}, num_warps=4),
- triton.Config({}, num_warps=8),
- triton.Config({}, num_warps=16),
- triton.Config({}, num_warps=32),
- ],
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
- )
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
- # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
- @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
- @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
- @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
- @triton.jit
- def _layer_norm_fwd_1pass_kernel(
- X, # pointer to the input
- Y, # pointer to the output
- W, # pointer to the weights
- B, # pointer to the biases
- RESIDUAL, # pointer to the residual
- X1,
- W1,
- B1,
- Y1,
- RESIDUAL_OUT, # pointer to the residual
- ROWSCALE,
- SEEDS, # Dropout seeds for each row
- DROPOUT_MASK,
- Mean, # pointer to the mean
- Rstd, # pointer to the 1/std
- stride_x_row, # how much to increase the pointer when moving by 1 row
- stride_y_row,
- stride_res_row,
- stride_res_out_row,
- stride_x1_row,
- stride_y1_row,
- M, # number of rows in X
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- dropout_p, # Dropout probability
- IS_RMS_NORM: tl.constexpr,
- BLOCK_N: tl.constexpr,
- HAS_RESIDUAL: tl.constexpr,
- STORE_RESIDUAL_OUT: tl.constexpr,
- HAS_BIAS: tl.constexpr,
- HAS_DROPOUT: tl.constexpr,
- STORE_DROPOUT_MASK: tl.constexpr,
- HAS_ROWSCALE: tl.constexpr,
- HAS_X1: tl.constexpr,
- HAS_W1: tl.constexpr,
- HAS_B1: tl.constexpr,
- ):
- # Map the program id to the row of X and Y it should compute.
- row = tl.program_id(0)
- X += row * stride_x_row
- Y += row * stride_y_row
- if HAS_RESIDUAL:
- RESIDUAL += row * stride_res_row
- if STORE_RESIDUAL_OUT:
- RESIDUAL_OUT += row * stride_res_out_row
- if HAS_X1:
- X1 += row * stride_x1_row
- if HAS_W1:
- Y1 += row * stride_y1_row
- # Compute mean and variance
- cols = tl.arange(0, BLOCK_N)
- x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
- if HAS_ROWSCALE:
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
- x *= rowscale
- if HAS_DROPOUT:
- # Compute dropout mask
- # 7 rounds is good enough, and reduces register pressure
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
- x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
- if STORE_DROPOUT_MASK:
- tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
- if HAS_X1:
- x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
- if HAS_ROWSCALE:
- rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
- x1 *= rowscale
- if HAS_DROPOUT:
- # Compute dropout mask
- # 7 rounds is good enough, and reduces register pressure
- keep_mask = (
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
- )
- x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
- if STORE_DROPOUT_MASK:
- tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
- x += x1
- if HAS_RESIDUAL:
- residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
- x += residual
- if STORE_RESIDUAL_OUT:
- tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
- if not IS_RMS_NORM:
- mean = tl.sum(x, axis=0) / N
- tl.store(Mean + row, mean)
- xbar = tl.where(cols < N, x - mean, 0.0)
- var = tl.sum(xbar * xbar, axis=0) / N
- else:
- xbar = tl.where(cols < N, x, 0.0)
- var = tl.sum(xbar * xbar, axis=0) / N
- rstd = 1 / tl.sqrt(var + eps)
- tl.store(Rstd + row, rstd)
- # Normalize and apply linear transformation
- mask = cols < N
- w = tl.load(W + cols, mask=mask).to(tl.float32)
- if HAS_BIAS:
- b = tl.load(B + cols, mask=mask).to(tl.float32)
- x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
- y = x_hat * w + b if HAS_BIAS else x_hat * w
- # Write output
- tl.store(Y + cols, y, mask=mask)
- if HAS_W1:
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
- if HAS_B1:
- b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
- y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
- tl.store(Y1 + cols, y1, mask=mask)
- def _layer_norm_fwd(
- x,
- weight,
- bias,
- eps,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- dropout_p=0.0,
- rowscale=None,
- out_dtype=None,
- residual_dtype=None,
- is_rms_norm=False,
- return_dropout_mask=False,
- out=None,
- residual_out=None
- ):
- if residual is not None:
- residual_dtype = residual.dtype
- M, N = x.shape
- assert x.stride(-1) == 1
- if residual is not None:
- assert residual.stride(-1) == 1
- assert residual.shape == (M, N)
- assert weight.shape == (N,)
- assert weight.stride(-1) == 1
- if bias is not None:
- assert bias.stride(-1) == 1
- assert bias.shape == (N,)
- if x1 is not None:
- assert x1.shape == x.shape
- assert rowscale is None
- assert x1.stride(-1) == 1
- if weight1 is not None:
- assert weight1.shape == (N,)
- assert weight1.stride(-1) == 1
- if bias1 is not None:
- assert bias1.shape == (N,)
- assert bias1.stride(-1) == 1
- if rowscale is not None:
- assert rowscale.is_contiguous()
- assert rowscale.shape == (M,)
- # allocate output
- if out is None:
- out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
- else:
- assert out.shape == x.shape
- assert out.stride(-1) == 1
- if weight1 is not None:
- y1 = torch.empty_like(out)
- assert y1.stride(-1) == 1
- else:
- y1 = None
- if (
- residual is not None
- or (residual_dtype is not None and residual_dtype != x.dtype)
- or dropout_p > 0.0
- or rowscale is not None
- or x1 is not None
- ):
- if residual_out is None:
- residual_out = torch.empty(
- M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
- )
- else:
- assert residual_out.shape == x.shape
- assert residual_out.stride(-1) == 1
- else:
- residual_out = None
- mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
- rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
- if dropout_p > 0.0:
- seeds = torch.randint(
- 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
- )
- else:
- seeds = None
- if return_dropout_mask and dropout_p > 0.0:
- dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
- else:
- dropout_mask = None
- # Less than 64KB per feature: enqueue fused kernel
- MAX_FUSED_SIZE = 65536 // x.element_size()
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
- if N > BLOCK_N:
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
- with torch.cuda.device(x.device.index):
- _layer_norm_fwd_1pass_kernel[(M,)](
- x,
- out,
- weight,
- bias,
- residual,
- x1,
- weight1,
- bias1,
- y1,
- residual_out,
- rowscale,
- seeds,
- dropout_mask,
- mean,
- rstd,
- x.stride(0),
- out.stride(0),
- residual.stride(0) if residual is not None else 0,
- residual_out.stride(0) if residual_out is not None else 0,
- x1.stride(0) if x1 is not None else 0,
- y1.stride(0) if y1 is not None else 0,
- M,
- N,
- eps,
- dropout_p,
- is_rms_norm,
- BLOCK_N,
- residual is not None,
- residual_out is not None,
- bias is not None,
- dropout_p > 0.0,
- dropout_mask is not None,
- rowscale is not None,
- )
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
- if dropout_mask is not None and x1 is not None:
- dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
- else:
- dropout_mask1 = None
- return (
- out,
- y1,
- mean,
- rstd,
- residual_out if residual_out is not None else x,
- seeds,
- dropout_mask,
- dropout_mask1,
- )
- @triton.autotune(
- configs=[
- triton.Config({}, num_warps=1),
- triton.Config({}, num_warps=2),
- triton.Config({}, num_warps=4),
- triton.Config({}, num_warps=8),
- triton.Config({}, num_warps=16),
- triton.Config({}, num_warps=32),
- ],
- key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
- )
- # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
- # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
- # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
- @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
- @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
- @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
- @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
- @triton.jit
- def _layer_norm_bwd_kernel(
- X, # pointer to the input
- W, # pointer to the weights
- B, # pointer to the biases
- Y, # pointer to the output to be recomputed
- DY, # pointer to the output gradient
- DX, # pointer to the input gradient
- DW, # pointer to the partial sum of weights gradient
- DB, # pointer to the partial sum of biases gradient
- DRESIDUAL,
- W1,
- DY1,
- DX1,
- DW1,
- DB1,
- DRESIDUAL_IN,
- ROWSCALE,
- SEEDS,
- Mean, # pointer to the mean
- Rstd, # pointer to the 1/std
- stride_x_row, # how much to increase the pointer when moving by 1 row
- stride_y_row,
- stride_dy_row,
- stride_dx_row,
- stride_dres_row,
- stride_dy1_row,
- stride_dx1_row,
- stride_dres_in_row,
- M, # number of rows in X
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- dropout_p,
- rows_per_program,
- IS_RMS_NORM: tl.constexpr,
- BLOCK_N: tl.constexpr,
- HAS_DRESIDUAL: tl.constexpr,
- STORE_DRESIDUAL: tl.constexpr,
- HAS_BIAS: tl.constexpr,
- HAS_DROPOUT: tl.constexpr,
- HAS_ROWSCALE: tl.constexpr,
- HAS_DY1: tl.constexpr,
- HAS_DX1: tl.constexpr,
- HAS_B1: tl.constexpr,
- RECOMPUTE_OUTPUT: tl.constexpr,
- ):
- # Map the program id to the elements of X, DX, and DY it should compute.
- row_block_id = tl.program_id(0)
- row_start = row_block_id * rows_per_program
- # Do not early exit if row_start >= M, because we need to write DW and DB
- cols = tl.arange(0, BLOCK_N)
- mask = cols < N
- X += row_start * stride_x_row
- if HAS_DRESIDUAL:
- DRESIDUAL += row_start * stride_dres_row
- if STORE_DRESIDUAL:
- DRESIDUAL_IN += row_start * stride_dres_in_row
- DY += row_start * stride_dy_row
- DX += row_start * stride_dx_row
- if HAS_DY1:
- DY1 += row_start * stride_dy1_row
- if HAS_DX1:
- DX1 += row_start * stride_dx1_row
- if RECOMPUTE_OUTPUT:
- Y += row_start * stride_y_row
- w = tl.load(W + cols, mask=mask).to(tl.float32)
- if RECOMPUTE_OUTPUT and HAS_BIAS:
- b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
- if HAS_DY1:
- w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
- dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
- if HAS_BIAS:
- db = tl.zeros((BLOCK_N,), dtype=tl.float32)
- if HAS_DY1:
- dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
- if HAS_B1:
- db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
- row_end = min((row_block_id + 1) * rows_per_program, M)
- for row in range(row_start, row_end):
- # Load data to SRAM
- x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
- dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
- if HAS_DY1:
- dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
- if not IS_RMS_NORM:
- mean = tl.load(Mean + row)
- rstd = tl.load(Rstd + row)
- # Compute dx
- xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
- xhat = tl.where(mask, xhat, 0.0)
- if RECOMPUTE_OUTPUT:
- y = xhat * w + b if HAS_BIAS else xhat * w
- tl.store(Y + cols, y, mask=mask)
- wdy = w * dy
- dw += dy * xhat
- if HAS_BIAS:
- db += dy
- if HAS_DY1:
- wdy += w1 * dy1
- dw1 += dy1 * xhat
- if HAS_B1:
- db1 += dy1
- if not IS_RMS_NORM:
- c1 = tl.sum(xhat * wdy, axis=0) / N
- c2 = tl.sum(wdy, axis=0) / N
- dx = (wdy - (xhat * c1 + c2)) * rstd
- else:
- c1 = tl.sum(xhat * wdy, axis=0) / N
- dx = (wdy - xhat * c1) * rstd
- if HAS_DRESIDUAL:
- dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
- dx += dres
- # Write dx
- if STORE_DRESIDUAL:
- tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
- if HAS_DX1:
- if HAS_DROPOUT:
- keep_mask = (
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
- )
- dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
- else:
- dx1 = dx
- tl.store(DX1 + cols, dx1, mask=mask)
- if HAS_DROPOUT:
- keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
- dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
- if HAS_ROWSCALE:
- rowscale = tl.load(ROWSCALE + row).to(tl.float32)
- dx *= rowscale
- tl.store(DX + cols, dx, mask=mask)
- X += stride_x_row
- if HAS_DRESIDUAL:
- DRESIDUAL += stride_dres_row
- if STORE_DRESIDUAL:
- DRESIDUAL_IN += stride_dres_in_row
- if RECOMPUTE_OUTPUT:
- Y += stride_y_row
- DY += stride_dy_row
- DX += stride_dx_row
- if HAS_DY1:
- DY1 += stride_dy1_row
- if HAS_DX1:
- DX1 += stride_dx1_row
- tl.store(DW + row_block_id * N + cols, dw, mask=mask)
- if HAS_BIAS:
- tl.store(DB + row_block_id * N + cols, db, mask=mask)
- if HAS_DY1:
- tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
- if HAS_B1:
- tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
- def _layer_norm_bwd(
- dy,
- x,
- weight,
- bias,
- eps,
- mean,
- rstd,
- dresidual=None,
- dy1=None,
- weight1=None,
- bias1=None,
- seeds=None,
- dropout_p=0.0,
- rowscale=None,
- has_residual=False,
- has_x1=False,
- is_rms_norm=False,
- x_dtype=None,
- recompute_output=False,
- ):
- M, N = x.shape
- assert x.stride(-1) == 1
- assert dy.stride(-1) == 1
- assert dy.shape == (M, N)
- if dresidual is not None:
- assert dresidual.stride(-1) == 1
- assert dresidual.shape == (M, N)
- assert weight.shape == (N,)
- assert weight.stride(-1) == 1
- if bias is not None:
- assert bias.stride(-1) == 1
- assert bias.shape == (N,)
- if dy1 is not None:
- assert weight1 is not None
- assert dy1.shape == dy.shape
- assert dy1.stride(-1) == 1
- if weight1 is not None:
- assert weight1.shape == (N,)
- assert weight1.stride(-1) == 1
- if bias1 is not None:
- assert bias1.shape == (N,)
- assert bias1.stride(-1) == 1
- if seeds is not None:
- assert seeds.is_contiguous()
- assert seeds.shape == (M if not has_x1 else M * 2,)
- if rowscale is not None:
- assert rowscale.is_contiguous()
- assert rowscale.shape == (M,)
- # allocate output
- dx = (
- torch.empty_like(x)
- if x_dtype is None
- else torch.empty(M, N, dtype=x_dtype, device=x.device)
- )
- dresidual_in = (
- torch.empty_like(x)
- if has_residual
- and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
- else None
- )
- dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
- y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
- if recompute_output:
- assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
- # Less than 64KB per feature: enqueue fused kernel
- MAX_FUSED_SIZE = 65536 // x.element_size()
- BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
- if N > BLOCK_N:
- raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
- _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
- _db = (
- torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
- if bias is not None
- else None
- )
- _dw1 = torch.empty_like(_dw) if weight1 is not None else None
- _db1 = torch.empty_like(_db) if bias1 is not None else None
- rows_per_program = math.ceil(M / sm_count)
- grid = (sm_count,)
- with torch.cuda.device(x.device.index):
- _layer_norm_bwd_kernel[grid](
- x,
- weight,
- bias,
- y,
- dy,
- dx,
- _dw,
- _db,
- dresidual,
- weight1,
- dy1,
- dx1,
- _dw1,
- _db1,
- dresidual_in,
- rowscale,
- seeds,
- mean,
- rstd,
- x.stride(0),
- 0 if not recompute_output else y.stride(0),
- dy.stride(0),
- dx.stride(0),
- dresidual.stride(0) if dresidual is not None else 0,
- dy1.stride(0) if dy1 is not None else 0,
- dx1.stride(0) if dx1 is not None else 0,
- dresidual_in.stride(0) if dresidual_in is not None else 0,
- M,
- N,
- eps,
- dropout_p,
- rows_per_program,
- is_rms_norm,
- BLOCK_N,
- dresidual is not None,
- dresidual_in is not None,
- bias is not None,
- dropout_p > 0.0,
- )
- dw = _dw.sum(0).to(weight.dtype)
- db = _db.sum(0).to(bias.dtype) if bias is not None else None
- dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
- db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
- # Don't need to compute dresidual_in separately in this case
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
- dresidual_in = dx
- if has_x1 and dropout_p == 0.0:
- dx1 = dx
- return (
- (dx, dw, db, dresidual_in, dx1, dw1, db1)
- if not recompute_output
- else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
- )
- class LayerNormFn(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- residual_in_fp32=False,
- is_rms_norm=False,
- return_dropout_mask=False,
- out=None,
- residual_out=None
- ):
- x_shape_og = x.shape
- # reshape input data into 2D tensor
- x = x.reshape(-1, x.shape[-1])
- if x.stride(-1) != 1:
- x = x.contiguous()
- if residual is not None:
- assert residual.shape == x_shape_og
- residual = residual.reshape(-1, residual.shape[-1])
- if residual.stride(-1) != 1:
- residual = residual.contiguous()
- if x1 is not None:
- assert x1.shape == x_shape_og
- assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
- x1 = x1.reshape(-1, x1.shape[-1])
- if x1.stride(-1) != 1:
- x1 = x1.contiguous()
- weight = weight.contiguous()
- if bias is not None:
- bias = bias.contiguous()
- if weight1 is not None:
- weight1 = weight1.contiguous()
- if bias1 is not None:
- bias1 = bias1.contiguous()
- if rowscale is not None:
- rowscale = rowscale.reshape(-1).contiguous()
- residual_dtype = (
- residual.dtype
- if residual is not None
- else (torch.float32 if residual_in_fp32 else None)
- )
- if out is not None:
- out = out.reshape(-1, out.shape[-1])
- if residual_out is not None:
- residual_out = residual_out.reshape(-1, residual_out.shape[-1])
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
- x,
- weight,
- bias,
- eps,
- residual,
- x1,
- weight1,
- bias1,
- dropout_p=dropout_p,
- rowscale=rowscale,
- residual_dtype=residual_dtype,
- is_rms_norm=is_rms_norm,
- return_dropout_mask=return_dropout_mask,
- out=out,
- residual_out=residual_out
- )
- ctx.save_for_backward(
- residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
- )
- ctx.x_shape_og = x_shape_og
- ctx.eps = eps
- ctx.dropout_p = dropout_p
- ctx.is_rms_norm = is_rms_norm
- ctx.has_residual = residual is not None
- ctx.has_x1 = x1 is not None
- ctx.prenorm = prenorm
- ctx.x_dtype = x.dtype
- y = y.reshape(x_shape_og)
- y1 = y1.reshape(x_shape_og) if y1 is not None else None
- residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
- dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
- dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
- if not return_dropout_mask:
- if weight1 is None:
- return y if not prenorm else (y, residual_out)
- else:
- return (y, y1) if not prenorm else (y, y1, residual_out)
- else:
- if weight1 is None:
- return (
- (y, dropout_mask, dropout_mask1)
- if not prenorm
- else (y, residual_out, dropout_mask, dropout_mask1)
- )
- else:
- return (
- (y, y1, dropout_mask, dropout_mask1)
- if not prenorm
- else (y, y1, residual_out, dropout_mask, dropout_mask1)
- )
- @staticmethod
- def backward(ctx, dy, *args):
- x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
- dy = dy.reshape(-1, dy.shape[-1])
- if dy.stride(-1) != 1:
- dy = dy.contiguous()
- assert dy.shape == x.shape
- if weight1 is not None:
- dy1, args = args[0], args[1:]
- dy1 = dy1.reshape(-1, dy1.shape[-1])
- if dy1.stride(-1) != 1:
- dy1 = dy1.contiguous()
- assert dy1.shape == x.shape
- else:
- dy1 = None
- if ctx.prenorm:
- dresidual = args[0]
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
- if dresidual.stride(-1) != 1:
- dresidual = dresidual.contiguous()
- assert dresidual.shape == x.shape
- else:
- dresidual = None
- dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
- dy,
- x,
- weight,
- bias,
- ctx.eps,
- mean,
- rstd,
- dresidual,
- dy1,
- weight1,
- bias1,
- seeds,
- ctx.dropout_p,
- rowscale,
- ctx.has_residual,
- ctx.has_x1,
- ctx.is_rms_norm,
- x_dtype=ctx.x_dtype,
- )
- return (
- dx.reshape(ctx.x_shape_og),
- dw,
- db,
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
- dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
- dw1,
- db1,
- None,
- None,
- None,
- None,
- None,
- None,
- None,
- None,
- None,
- )
- def layer_norm_fn(
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- residual_in_fp32=False,
- is_rms_norm=False,
- return_dropout_mask=False,
- out=None,
- residual_out=None
- ):
- return LayerNormFn.apply(
- x,
- weight,
- bias,
- residual,
- x1,
- weight1,
- bias1,
- eps,
- dropout_p,
- rowscale,
- prenorm,
- residual_in_fp32,
- is_rms_norm,
- return_dropout_mask,
- out,
- residual_out
- )
- def rms_norm_fn(
- x,
- weight,
- bias,
- residual=None,
- x1=None,
- weight1=None,
- bias1=None,
- eps=1e-6,
- dropout_p=0.0,
- rowscale=None,
- prenorm=False,
- residual_in_fp32=False,
- return_dropout_mask=False,
- out=None,
- residual_out=None
- ):
- return LayerNormFn.apply(
- x,
- weight,
- bias,
- residual,
- x1,
- weight1,
- bias1,
- eps,
- dropout_p,
- rowscale,
- prenorm,
- residual_in_fp32,
- True,
- return_dropout_mask,
- out,
- residual_out
- )
- class RMSNorm(torch.nn.Module):
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- self.eps = eps
- if dropout_p > 0.0:
- self.drop = torch.nn.Dropout(dropout_p)
- else:
- self.drop = None
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
- self.register_parameter("bias", None)
- self.reset_parameters()
- def reset_parameters(self):
- torch.nn.init.ones_(self.weight)
- def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
- return rms_norm_fn(
- x,
- self.weight,
- self.bias,
- residual=residual,
- eps=self.eps,
- dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
- prenorm=prenorm,
- residual_in_fp32=residual_in_fp32,
- )
- class LayerNormLinearFn(torch.autograd.Function):
- @staticmethod
- @custom_fwd
- def forward(
- ctx,
- x,
- norm_weight,
- norm_bias,
- linear_weight,
- linear_bias,
- residual=None,
- eps=1e-6,
- prenorm=False,
- residual_in_fp32=False,
- is_rms_norm=False,
- ):
- x_shape_og = x.shape
- # reshape input data into 2D tensor
- x = x.reshape(-1, x.shape[-1])
- if x.stride(-1) != 1:
- x = x.contiguous()
- if residual is not None:
- assert residual.shape == x_shape_og
- residual = residual.reshape(-1, residual.shape[-1])
- if residual.stride(-1) != 1:
- residual = residual.contiguous()
- norm_weight = norm_weight.contiguous()
- if norm_bias is not None:
- norm_bias = norm_bias.contiguous()
- residual_dtype = (
- residual.dtype
- if residual is not None
- else (torch.float32 if residual_in_fp32 else None)
- )
- y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
- x,
- norm_weight,
- norm_bias,
- eps,
- residual,
- out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
- residual_dtype=residual_dtype,
- is_rms_norm=is_rms_norm,
- )
- y = y.reshape(x_shape_og)
- dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
- linear_weight = linear_weight.to(dtype)
- linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
- out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
- # We don't store y, will be recomputed in the backward pass to save memory
- ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
- ctx.x_shape_og = x_shape_og
- ctx.eps = eps
- ctx.is_rms_norm = is_rms_norm
- ctx.has_residual = residual is not None
- ctx.prenorm = prenorm
- ctx.x_dtype = x.dtype
- ctx.linear_bias_is_none = linear_bias is None
- return out if not prenorm else (out, residual_out.reshape(x_shape_og))
- @staticmethod
- @custom_bwd
- def backward(ctx, dout, *args):
- x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
- dout = dout.reshape(-1, dout.shape[-1])
- dy = F.linear(dout, linear_weight.t())
- dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
- if dy.stride(-1) != 1:
- dy = dy.contiguous()
- assert dy.shape == x.shape
- if ctx.prenorm:
- dresidual = args[0]
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
- if dresidual.stride(-1) != 1:
- dresidual = dresidual.contiguous()
- assert dresidual.shape == x.shape
- else:
- dresidual = None
- dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
- dy,
- x,
- norm_weight,
- norm_bias,
- ctx.eps,
- mean,
- rstd,
- dresidual=dresidual,
- has_residual=ctx.has_residual,
- is_rms_norm=ctx.is_rms_norm,
- x_dtype=ctx.x_dtype,
- recompute_output=True,
- )
- dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
- return (
- dx.reshape(ctx.x_shape_og),
- dnorm_weight,
- dnorm_bias,
- dlinear_weight,
- dlinear_bias,
- dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
- None,
- None,
- None,
- None,
- )
- def layer_norm_linear_fn(
- x,
- norm_weight,
- norm_bias,
- linear_weight,
- linear_bias,
- residual=None,
- eps=1e-6,
- prenorm=False,
- residual_in_fp32=False,
- is_rms_norm=False,
- ):
- return LayerNormLinearFn.apply(
- x,
- norm_weight,
- norm_bias,
- linear_weight,
- linear_bias,
- residual,
- eps,
- prenorm,
- residual_in_fp32,
- is_rms_norm,
- )
|