1
0

layer_norm.py 35 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112
  1. # Copyright (c) 2024, Tri Dao.
  2. # Implement dropout + residual + layer_norm / rms_norm.
  3. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
  4. # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
  5. # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
  6. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
  7. import math
  8. import torch
  9. import torch.nn.functional as F
  10. from torch.cuda.amp import custom_fwd, custom_bwd
  11. import triton
  12. import triton.language as tl
  13. def layer_norm_ref(
  14. x,
  15. weight,
  16. bias,
  17. residual=None,
  18. x1=None,
  19. weight1=None,
  20. bias1=None,
  21. eps=1e-6,
  22. dropout_p=0.0,
  23. rowscale=None,
  24. prenorm=False,
  25. dropout_mask=None,
  26. dropout_mask1=None,
  27. upcast=False,
  28. ):
  29. dtype = x.dtype
  30. if upcast:
  31. x = x.float()
  32. weight = weight.float()
  33. bias = bias.float() if bias is not None else None
  34. residual = residual.float() if residual is not None else residual
  35. x1 = x1.float() if x1 is not None else None
  36. weight1 = weight1.float() if weight1 is not None else None
  37. bias1 = bias1.float() if bias1 is not None else None
  38. if x1 is not None:
  39. assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
  40. if rowscale is not None:
  41. x = x * rowscale[..., None]
  42. if dropout_p > 0.0:
  43. if dropout_mask is not None:
  44. x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
  45. else:
  46. x = F.dropout(x, p=dropout_p)
  47. if x1 is not None:
  48. if dropout_mask1 is not None:
  49. x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
  50. else:
  51. x1 = F.dropout(x1, p=dropout_p)
  52. if x1 is not None:
  53. x = x + x1
  54. if residual is not None:
  55. x = (x + residual).to(x.dtype)
  56. out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
  57. dtype
  58. )
  59. if weight1 is None:
  60. return out if not prenorm else (out, x)
  61. else:
  62. out1 = F.layer_norm(
  63. x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
  64. ).to(dtype)
  65. return (out, out1) if not prenorm else (out, out1, x)
  66. def rms_norm_ref(
  67. x,
  68. weight,
  69. bias,
  70. residual=None,
  71. x1=None,
  72. weight1=None,
  73. bias1=None,
  74. eps=1e-6,
  75. dropout_p=0.0,
  76. rowscale=None,
  77. prenorm=False,
  78. dropout_mask=None,
  79. dropout_mask1=None,
  80. upcast=False,
  81. ):
  82. dtype = x.dtype
  83. if upcast:
  84. x = x.float()
  85. weight = weight.float()
  86. bias = bias.float() if bias is not None else None
  87. residual = residual.float() if residual is not None else residual
  88. x1 = x1.float() if x1 is not None else None
  89. weight1 = weight1.float() if weight1 is not None else None
  90. bias1 = bias1.float() if bias1 is not None else None
  91. if x1 is not None:
  92. assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
  93. if rowscale is not None:
  94. x = x * rowscale[..., None]
  95. if dropout_p > 0.0:
  96. if dropout_mask is not None:
  97. x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
  98. else:
  99. x = F.dropout(x, p=dropout_p)
  100. if x1 is not None:
  101. if dropout_mask1 is not None:
  102. x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
  103. else:
  104. x1 = F.dropout(x1, p=dropout_p)
  105. if x1 is not None:
  106. x = x + x1
  107. if residual is not None:
  108. x = (x + residual).to(x.dtype)
  109. rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
  110. out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
  111. if weight1 is None:
  112. return out if not prenorm else (out, x)
  113. else:
  114. out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
  115. dtype
  116. )
  117. return (out, out1) if not prenorm else (out, out1, x)
  118. @triton.autotune(
  119. configs=[
  120. triton.Config({}, num_warps=1),
  121. triton.Config({}, num_warps=2),
  122. triton.Config({}, num_warps=4),
  123. triton.Config({}, num_warps=8),
  124. triton.Config({}, num_warps=16),
  125. triton.Config({}, num_warps=32),
  126. ],
  127. key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
  128. )
  129. # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
  130. # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
  131. @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
  132. @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
  133. @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
  134. @triton.jit
  135. def _layer_norm_fwd_1pass_kernel(
  136. X, # pointer to the input
  137. Y, # pointer to the output
  138. W, # pointer to the weights
  139. B, # pointer to the biases
  140. RESIDUAL, # pointer to the residual
  141. X1,
  142. W1,
  143. B1,
  144. Y1,
  145. RESIDUAL_OUT, # pointer to the residual
  146. ROWSCALE,
  147. SEEDS, # Dropout seeds for each row
  148. DROPOUT_MASK,
  149. Mean, # pointer to the mean
  150. Rstd, # pointer to the 1/std
  151. stride_x_row, # how much to increase the pointer when moving by 1 row
  152. stride_y_row,
  153. stride_res_row,
  154. stride_res_out_row,
  155. stride_x1_row,
  156. stride_y1_row,
  157. M, # number of rows in X
  158. N, # number of columns in X
  159. eps, # epsilon to avoid division by zero
  160. dropout_p, # Dropout probability
  161. IS_RMS_NORM: tl.constexpr,
  162. BLOCK_N: tl.constexpr,
  163. HAS_RESIDUAL: tl.constexpr,
  164. STORE_RESIDUAL_OUT: tl.constexpr,
  165. HAS_BIAS: tl.constexpr,
  166. HAS_DROPOUT: tl.constexpr,
  167. STORE_DROPOUT_MASK: tl.constexpr,
  168. HAS_ROWSCALE: tl.constexpr,
  169. HAS_X1: tl.constexpr,
  170. HAS_W1: tl.constexpr,
  171. HAS_B1: tl.constexpr,
  172. ):
  173. # Map the program id to the row of X and Y it should compute.
  174. row = tl.program_id(0)
  175. X += row * stride_x_row
  176. Y += row * stride_y_row
  177. if HAS_RESIDUAL:
  178. RESIDUAL += row * stride_res_row
  179. if STORE_RESIDUAL_OUT:
  180. RESIDUAL_OUT += row * stride_res_out_row
  181. if HAS_X1:
  182. X1 += row * stride_x1_row
  183. if HAS_W1:
  184. Y1 += row * stride_y1_row
  185. # Compute mean and variance
  186. cols = tl.arange(0, BLOCK_N)
  187. x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
  188. if HAS_ROWSCALE:
  189. rowscale = tl.load(ROWSCALE + row).to(tl.float32)
  190. x *= rowscale
  191. if HAS_DROPOUT:
  192. # Compute dropout mask
  193. # 7 rounds is good enough, and reduces register pressure
  194. keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
  195. x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
  196. if STORE_DROPOUT_MASK:
  197. tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
  198. if HAS_X1:
  199. x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
  200. if HAS_ROWSCALE:
  201. rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
  202. x1 *= rowscale
  203. if HAS_DROPOUT:
  204. # Compute dropout mask
  205. # 7 rounds is good enough, and reduces register pressure
  206. keep_mask = (
  207. tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
  208. )
  209. x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
  210. if STORE_DROPOUT_MASK:
  211. tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
  212. x += x1
  213. if HAS_RESIDUAL:
  214. residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
  215. x += residual
  216. if STORE_RESIDUAL_OUT:
  217. tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
  218. if not IS_RMS_NORM:
  219. mean = tl.sum(x, axis=0) / N
  220. tl.store(Mean + row, mean)
  221. xbar = tl.where(cols < N, x - mean, 0.0)
  222. var = tl.sum(xbar * xbar, axis=0) / N
  223. else:
  224. xbar = tl.where(cols < N, x, 0.0)
  225. var = tl.sum(xbar * xbar, axis=0) / N
  226. rstd = 1 / tl.sqrt(var + eps)
  227. tl.store(Rstd + row, rstd)
  228. # Normalize and apply linear transformation
  229. mask = cols < N
  230. w = tl.load(W + cols, mask=mask).to(tl.float32)
  231. if HAS_BIAS:
  232. b = tl.load(B + cols, mask=mask).to(tl.float32)
  233. x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
  234. y = x_hat * w + b if HAS_BIAS else x_hat * w
  235. # Write output
  236. tl.store(Y + cols, y, mask=mask)
  237. if HAS_W1:
  238. w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
  239. if HAS_B1:
  240. b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
  241. y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
  242. tl.store(Y1 + cols, y1, mask=mask)
  243. def _layer_norm_fwd(
  244. x,
  245. weight,
  246. bias,
  247. eps,
  248. residual=None,
  249. x1=None,
  250. weight1=None,
  251. bias1=None,
  252. dropout_p=0.0,
  253. rowscale=None,
  254. out_dtype=None,
  255. residual_dtype=None,
  256. is_rms_norm=False,
  257. return_dropout_mask=False,
  258. out=None,
  259. residual_out=None
  260. ):
  261. if residual is not None:
  262. residual_dtype = residual.dtype
  263. M, N = x.shape
  264. assert x.stride(-1) == 1
  265. if residual is not None:
  266. assert residual.stride(-1) == 1
  267. assert residual.shape == (M, N)
  268. assert weight.shape == (N,)
  269. assert weight.stride(-1) == 1
  270. if bias is not None:
  271. assert bias.stride(-1) == 1
  272. assert bias.shape == (N,)
  273. if x1 is not None:
  274. assert x1.shape == x.shape
  275. assert rowscale is None
  276. assert x1.stride(-1) == 1
  277. if weight1 is not None:
  278. assert weight1.shape == (N,)
  279. assert weight1.stride(-1) == 1
  280. if bias1 is not None:
  281. assert bias1.shape == (N,)
  282. assert bias1.stride(-1) == 1
  283. if rowscale is not None:
  284. assert rowscale.is_contiguous()
  285. assert rowscale.shape == (M,)
  286. # allocate output
  287. if out is None:
  288. out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
  289. else:
  290. assert out.shape == x.shape
  291. assert out.stride(-1) == 1
  292. if weight1 is not None:
  293. y1 = torch.empty_like(out)
  294. assert y1.stride(-1) == 1
  295. else:
  296. y1 = None
  297. if (
  298. residual is not None
  299. or (residual_dtype is not None and residual_dtype != x.dtype)
  300. or dropout_p > 0.0
  301. or rowscale is not None
  302. or x1 is not None
  303. ):
  304. if residual_out is None:
  305. residual_out = torch.empty(
  306. M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
  307. )
  308. else:
  309. assert residual_out.shape == x.shape
  310. assert residual_out.stride(-1) == 1
  311. else:
  312. residual_out = None
  313. mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
  314. rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
  315. if dropout_p > 0.0:
  316. seeds = torch.randint(
  317. 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
  318. )
  319. else:
  320. seeds = None
  321. if return_dropout_mask and dropout_p > 0.0:
  322. dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
  323. else:
  324. dropout_mask = None
  325. # Less than 64KB per feature: enqueue fused kernel
  326. MAX_FUSED_SIZE = 65536 // x.element_size()
  327. BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
  328. if N > BLOCK_N:
  329. raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
  330. with torch.cuda.device(x.device.index):
  331. _layer_norm_fwd_1pass_kernel[(M,)](
  332. x,
  333. out,
  334. weight,
  335. bias,
  336. residual,
  337. x1,
  338. weight1,
  339. bias1,
  340. y1,
  341. residual_out,
  342. rowscale,
  343. seeds,
  344. dropout_mask,
  345. mean,
  346. rstd,
  347. x.stride(0),
  348. out.stride(0),
  349. residual.stride(0) if residual is not None else 0,
  350. residual_out.stride(0) if residual_out is not None else 0,
  351. x1.stride(0) if x1 is not None else 0,
  352. y1.stride(0) if y1 is not None else 0,
  353. M,
  354. N,
  355. eps,
  356. dropout_p,
  357. is_rms_norm,
  358. BLOCK_N,
  359. residual is not None,
  360. residual_out is not None,
  361. bias is not None,
  362. dropout_p > 0.0,
  363. dropout_mask is not None,
  364. rowscale is not None,
  365. )
  366. # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
  367. if dropout_mask is not None and x1 is not None:
  368. dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
  369. else:
  370. dropout_mask1 = None
  371. return (
  372. out,
  373. y1,
  374. mean,
  375. rstd,
  376. residual_out if residual_out is not None else x,
  377. seeds,
  378. dropout_mask,
  379. dropout_mask1,
  380. )
  381. @triton.autotune(
  382. configs=[
  383. triton.Config({}, num_warps=1),
  384. triton.Config({}, num_warps=2),
  385. triton.Config({}, num_warps=4),
  386. triton.Config({}, num_warps=8),
  387. triton.Config({}, num_warps=16),
  388. triton.Config({}, num_warps=32),
  389. ],
  390. key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
  391. )
  392. # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
  393. # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
  394. # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
  395. @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
  396. @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
  397. @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
  398. @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
  399. @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
  400. @triton.jit
  401. def _layer_norm_bwd_kernel(
  402. X, # pointer to the input
  403. W, # pointer to the weights
  404. B, # pointer to the biases
  405. Y, # pointer to the output to be recomputed
  406. DY, # pointer to the output gradient
  407. DX, # pointer to the input gradient
  408. DW, # pointer to the partial sum of weights gradient
  409. DB, # pointer to the partial sum of biases gradient
  410. DRESIDUAL,
  411. W1,
  412. DY1,
  413. DX1,
  414. DW1,
  415. DB1,
  416. DRESIDUAL_IN,
  417. ROWSCALE,
  418. SEEDS,
  419. Mean, # pointer to the mean
  420. Rstd, # pointer to the 1/std
  421. stride_x_row, # how much to increase the pointer when moving by 1 row
  422. stride_y_row,
  423. stride_dy_row,
  424. stride_dx_row,
  425. stride_dres_row,
  426. stride_dy1_row,
  427. stride_dx1_row,
  428. stride_dres_in_row,
  429. M, # number of rows in X
  430. N, # number of columns in X
  431. eps, # epsilon to avoid division by zero
  432. dropout_p,
  433. rows_per_program,
  434. IS_RMS_NORM: tl.constexpr,
  435. BLOCK_N: tl.constexpr,
  436. HAS_DRESIDUAL: tl.constexpr,
  437. STORE_DRESIDUAL: tl.constexpr,
  438. HAS_BIAS: tl.constexpr,
  439. HAS_DROPOUT: tl.constexpr,
  440. HAS_ROWSCALE: tl.constexpr,
  441. HAS_DY1: tl.constexpr,
  442. HAS_DX1: tl.constexpr,
  443. HAS_B1: tl.constexpr,
  444. RECOMPUTE_OUTPUT: tl.constexpr,
  445. ):
  446. # Map the program id to the elements of X, DX, and DY it should compute.
  447. row_block_id = tl.program_id(0)
  448. row_start = row_block_id * rows_per_program
  449. # Do not early exit if row_start >= M, because we need to write DW and DB
  450. cols = tl.arange(0, BLOCK_N)
  451. mask = cols < N
  452. X += row_start * stride_x_row
  453. if HAS_DRESIDUAL:
  454. DRESIDUAL += row_start * stride_dres_row
  455. if STORE_DRESIDUAL:
  456. DRESIDUAL_IN += row_start * stride_dres_in_row
  457. DY += row_start * stride_dy_row
  458. DX += row_start * stride_dx_row
  459. if HAS_DY1:
  460. DY1 += row_start * stride_dy1_row
  461. if HAS_DX1:
  462. DX1 += row_start * stride_dx1_row
  463. if RECOMPUTE_OUTPUT:
  464. Y += row_start * stride_y_row
  465. w = tl.load(W + cols, mask=mask).to(tl.float32)
  466. if RECOMPUTE_OUTPUT and HAS_BIAS:
  467. b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
  468. if HAS_DY1:
  469. w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
  470. dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
  471. if HAS_BIAS:
  472. db = tl.zeros((BLOCK_N,), dtype=tl.float32)
  473. if HAS_DY1:
  474. dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
  475. if HAS_B1:
  476. db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
  477. row_end = min((row_block_id + 1) * rows_per_program, M)
  478. for row in range(row_start, row_end):
  479. # Load data to SRAM
  480. x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
  481. dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
  482. if HAS_DY1:
  483. dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
  484. if not IS_RMS_NORM:
  485. mean = tl.load(Mean + row)
  486. rstd = tl.load(Rstd + row)
  487. # Compute dx
  488. xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
  489. xhat = tl.where(mask, xhat, 0.0)
  490. if RECOMPUTE_OUTPUT:
  491. y = xhat * w + b if HAS_BIAS else xhat * w
  492. tl.store(Y + cols, y, mask=mask)
  493. wdy = w * dy
  494. dw += dy * xhat
  495. if HAS_BIAS:
  496. db += dy
  497. if HAS_DY1:
  498. wdy += w1 * dy1
  499. dw1 += dy1 * xhat
  500. if HAS_B1:
  501. db1 += dy1
  502. if not IS_RMS_NORM:
  503. c1 = tl.sum(xhat * wdy, axis=0) / N
  504. c2 = tl.sum(wdy, axis=0) / N
  505. dx = (wdy - (xhat * c1 + c2)) * rstd
  506. else:
  507. c1 = tl.sum(xhat * wdy, axis=0) / N
  508. dx = (wdy - xhat * c1) * rstd
  509. if HAS_DRESIDUAL:
  510. dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
  511. dx += dres
  512. # Write dx
  513. if STORE_DRESIDUAL:
  514. tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
  515. if HAS_DX1:
  516. if HAS_DROPOUT:
  517. keep_mask = (
  518. tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
  519. )
  520. dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
  521. else:
  522. dx1 = dx
  523. tl.store(DX1 + cols, dx1, mask=mask)
  524. if HAS_DROPOUT:
  525. keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
  526. dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
  527. if HAS_ROWSCALE:
  528. rowscale = tl.load(ROWSCALE + row).to(tl.float32)
  529. dx *= rowscale
  530. tl.store(DX + cols, dx, mask=mask)
  531. X += stride_x_row
  532. if HAS_DRESIDUAL:
  533. DRESIDUAL += stride_dres_row
  534. if STORE_DRESIDUAL:
  535. DRESIDUAL_IN += stride_dres_in_row
  536. if RECOMPUTE_OUTPUT:
  537. Y += stride_y_row
  538. DY += stride_dy_row
  539. DX += stride_dx_row
  540. if HAS_DY1:
  541. DY1 += stride_dy1_row
  542. if HAS_DX1:
  543. DX1 += stride_dx1_row
  544. tl.store(DW + row_block_id * N + cols, dw, mask=mask)
  545. if HAS_BIAS:
  546. tl.store(DB + row_block_id * N + cols, db, mask=mask)
  547. if HAS_DY1:
  548. tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
  549. if HAS_B1:
  550. tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
  551. def _layer_norm_bwd(
  552. dy,
  553. x,
  554. weight,
  555. bias,
  556. eps,
  557. mean,
  558. rstd,
  559. dresidual=None,
  560. dy1=None,
  561. weight1=None,
  562. bias1=None,
  563. seeds=None,
  564. dropout_p=0.0,
  565. rowscale=None,
  566. has_residual=False,
  567. has_x1=False,
  568. is_rms_norm=False,
  569. x_dtype=None,
  570. recompute_output=False,
  571. ):
  572. M, N = x.shape
  573. assert x.stride(-1) == 1
  574. assert dy.stride(-1) == 1
  575. assert dy.shape == (M, N)
  576. if dresidual is not None:
  577. assert dresidual.stride(-1) == 1
  578. assert dresidual.shape == (M, N)
  579. assert weight.shape == (N,)
  580. assert weight.stride(-1) == 1
  581. if bias is not None:
  582. assert bias.stride(-1) == 1
  583. assert bias.shape == (N,)
  584. if dy1 is not None:
  585. assert weight1 is not None
  586. assert dy1.shape == dy.shape
  587. assert dy1.stride(-1) == 1
  588. if weight1 is not None:
  589. assert weight1.shape == (N,)
  590. assert weight1.stride(-1) == 1
  591. if bias1 is not None:
  592. assert bias1.shape == (N,)
  593. assert bias1.stride(-1) == 1
  594. if seeds is not None:
  595. assert seeds.is_contiguous()
  596. assert seeds.shape == (M if not has_x1 else M * 2,)
  597. if rowscale is not None:
  598. assert rowscale.is_contiguous()
  599. assert rowscale.shape == (M,)
  600. # allocate output
  601. dx = (
  602. torch.empty_like(x)
  603. if x_dtype is None
  604. else torch.empty(M, N, dtype=x_dtype, device=x.device)
  605. )
  606. dresidual_in = (
  607. torch.empty_like(x)
  608. if has_residual
  609. and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
  610. else None
  611. )
  612. dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
  613. y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
  614. if recompute_output:
  615. assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
  616. # Less than 64KB per feature: enqueue fused kernel
  617. MAX_FUSED_SIZE = 65536 // x.element_size()
  618. BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
  619. if N > BLOCK_N:
  620. raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
  621. sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
  622. _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
  623. _db = (
  624. torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
  625. if bias is not None
  626. else None
  627. )
  628. _dw1 = torch.empty_like(_dw) if weight1 is not None else None
  629. _db1 = torch.empty_like(_db) if bias1 is not None else None
  630. rows_per_program = math.ceil(M / sm_count)
  631. grid = (sm_count,)
  632. with torch.cuda.device(x.device.index):
  633. _layer_norm_bwd_kernel[grid](
  634. x,
  635. weight,
  636. bias,
  637. y,
  638. dy,
  639. dx,
  640. _dw,
  641. _db,
  642. dresidual,
  643. weight1,
  644. dy1,
  645. dx1,
  646. _dw1,
  647. _db1,
  648. dresidual_in,
  649. rowscale,
  650. seeds,
  651. mean,
  652. rstd,
  653. x.stride(0),
  654. 0 if not recompute_output else y.stride(0),
  655. dy.stride(0),
  656. dx.stride(0),
  657. dresidual.stride(0) if dresidual is not None else 0,
  658. dy1.stride(0) if dy1 is not None else 0,
  659. dx1.stride(0) if dx1 is not None else 0,
  660. dresidual_in.stride(0) if dresidual_in is not None else 0,
  661. M,
  662. N,
  663. eps,
  664. dropout_p,
  665. rows_per_program,
  666. is_rms_norm,
  667. BLOCK_N,
  668. dresidual is not None,
  669. dresidual_in is not None,
  670. bias is not None,
  671. dropout_p > 0.0,
  672. )
  673. dw = _dw.sum(0).to(weight.dtype)
  674. db = _db.sum(0).to(bias.dtype) if bias is not None else None
  675. dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
  676. db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
  677. # Don't need to compute dresidual_in separately in this case
  678. if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
  679. dresidual_in = dx
  680. if has_x1 and dropout_p == 0.0:
  681. dx1 = dx
  682. return (
  683. (dx, dw, db, dresidual_in, dx1, dw1, db1)
  684. if not recompute_output
  685. else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
  686. )
  687. class LayerNormFn(torch.autograd.Function):
  688. @staticmethod
  689. def forward(
  690. ctx,
  691. x,
  692. weight,
  693. bias,
  694. residual=None,
  695. x1=None,
  696. weight1=None,
  697. bias1=None,
  698. eps=1e-6,
  699. dropout_p=0.0,
  700. rowscale=None,
  701. prenorm=False,
  702. residual_in_fp32=False,
  703. is_rms_norm=False,
  704. return_dropout_mask=False,
  705. out=None,
  706. residual_out=None
  707. ):
  708. x_shape_og = x.shape
  709. # reshape input data into 2D tensor
  710. x = x.reshape(-1, x.shape[-1])
  711. if x.stride(-1) != 1:
  712. x = x.contiguous()
  713. if residual is not None:
  714. assert residual.shape == x_shape_og
  715. residual = residual.reshape(-1, residual.shape[-1])
  716. if residual.stride(-1) != 1:
  717. residual = residual.contiguous()
  718. if x1 is not None:
  719. assert x1.shape == x_shape_og
  720. assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
  721. x1 = x1.reshape(-1, x1.shape[-1])
  722. if x1.stride(-1) != 1:
  723. x1 = x1.contiguous()
  724. weight = weight.contiguous()
  725. if bias is not None:
  726. bias = bias.contiguous()
  727. if weight1 is not None:
  728. weight1 = weight1.contiguous()
  729. if bias1 is not None:
  730. bias1 = bias1.contiguous()
  731. if rowscale is not None:
  732. rowscale = rowscale.reshape(-1).contiguous()
  733. residual_dtype = (
  734. residual.dtype
  735. if residual is not None
  736. else (torch.float32 if residual_in_fp32 else None)
  737. )
  738. if out is not None:
  739. out = out.reshape(-1, out.shape[-1])
  740. if residual_out is not None:
  741. residual_out = residual_out.reshape(-1, residual_out.shape[-1])
  742. y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
  743. x,
  744. weight,
  745. bias,
  746. eps,
  747. residual,
  748. x1,
  749. weight1,
  750. bias1,
  751. dropout_p=dropout_p,
  752. rowscale=rowscale,
  753. residual_dtype=residual_dtype,
  754. is_rms_norm=is_rms_norm,
  755. return_dropout_mask=return_dropout_mask,
  756. out=out,
  757. residual_out=residual_out
  758. )
  759. ctx.save_for_backward(
  760. residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
  761. )
  762. ctx.x_shape_og = x_shape_og
  763. ctx.eps = eps
  764. ctx.dropout_p = dropout_p
  765. ctx.is_rms_norm = is_rms_norm
  766. ctx.has_residual = residual is not None
  767. ctx.has_x1 = x1 is not None
  768. ctx.prenorm = prenorm
  769. ctx.x_dtype = x.dtype
  770. y = y.reshape(x_shape_og)
  771. y1 = y1.reshape(x_shape_og) if y1 is not None else None
  772. residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
  773. dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
  774. dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
  775. if not return_dropout_mask:
  776. if weight1 is None:
  777. return y if not prenorm else (y, residual_out)
  778. else:
  779. return (y, y1) if not prenorm else (y, y1, residual_out)
  780. else:
  781. if weight1 is None:
  782. return (
  783. (y, dropout_mask, dropout_mask1)
  784. if not prenorm
  785. else (y, residual_out, dropout_mask, dropout_mask1)
  786. )
  787. else:
  788. return (
  789. (y, y1, dropout_mask, dropout_mask1)
  790. if not prenorm
  791. else (y, y1, residual_out, dropout_mask, dropout_mask1)
  792. )
  793. @staticmethod
  794. def backward(ctx, dy, *args):
  795. x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
  796. dy = dy.reshape(-1, dy.shape[-1])
  797. if dy.stride(-1) != 1:
  798. dy = dy.contiguous()
  799. assert dy.shape == x.shape
  800. if weight1 is not None:
  801. dy1, args = args[0], args[1:]
  802. dy1 = dy1.reshape(-1, dy1.shape[-1])
  803. if dy1.stride(-1) != 1:
  804. dy1 = dy1.contiguous()
  805. assert dy1.shape == x.shape
  806. else:
  807. dy1 = None
  808. if ctx.prenorm:
  809. dresidual = args[0]
  810. dresidual = dresidual.reshape(-1, dresidual.shape[-1])
  811. if dresidual.stride(-1) != 1:
  812. dresidual = dresidual.contiguous()
  813. assert dresidual.shape == x.shape
  814. else:
  815. dresidual = None
  816. dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
  817. dy,
  818. x,
  819. weight,
  820. bias,
  821. ctx.eps,
  822. mean,
  823. rstd,
  824. dresidual,
  825. dy1,
  826. weight1,
  827. bias1,
  828. seeds,
  829. ctx.dropout_p,
  830. rowscale,
  831. ctx.has_residual,
  832. ctx.has_x1,
  833. ctx.is_rms_norm,
  834. x_dtype=ctx.x_dtype,
  835. )
  836. return (
  837. dx.reshape(ctx.x_shape_og),
  838. dw,
  839. db,
  840. dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
  841. dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
  842. dw1,
  843. db1,
  844. None,
  845. None,
  846. None,
  847. None,
  848. None,
  849. None,
  850. None,
  851. None,
  852. None,
  853. )
  854. def layer_norm_fn(
  855. x,
  856. weight,
  857. bias,
  858. residual=None,
  859. x1=None,
  860. weight1=None,
  861. bias1=None,
  862. eps=1e-6,
  863. dropout_p=0.0,
  864. rowscale=None,
  865. prenorm=False,
  866. residual_in_fp32=False,
  867. is_rms_norm=False,
  868. return_dropout_mask=False,
  869. out=None,
  870. residual_out=None
  871. ):
  872. return LayerNormFn.apply(
  873. x,
  874. weight,
  875. bias,
  876. residual,
  877. x1,
  878. weight1,
  879. bias1,
  880. eps,
  881. dropout_p,
  882. rowscale,
  883. prenorm,
  884. residual_in_fp32,
  885. is_rms_norm,
  886. return_dropout_mask,
  887. out,
  888. residual_out
  889. )
  890. def rms_norm_fn(
  891. x,
  892. weight,
  893. bias,
  894. residual=None,
  895. x1=None,
  896. weight1=None,
  897. bias1=None,
  898. eps=1e-6,
  899. dropout_p=0.0,
  900. rowscale=None,
  901. prenorm=False,
  902. residual_in_fp32=False,
  903. return_dropout_mask=False,
  904. out=None,
  905. residual_out=None
  906. ):
  907. return LayerNormFn.apply(
  908. x,
  909. weight,
  910. bias,
  911. residual,
  912. x1,
  913. weight1,
  914. bias1,
  915. eps,
  916. dropout_p,
  917. rowscale,
  918. prenorm,
  919. residual_in_fp32,
  920. True,
  921. return_dropout_mask,
  922. out,
  923. residual_out
  924. )
  925. class RMSNorm(torch.nn.Module):
  926. def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
  927. factory_kwargs = {"device": device, "dtype": dtype}
  928. super().__init__()
  929. self.eps = eps
  930. if dropout_p > 0.0:
  931. self.drop = torch.nn.Dropout(dropout_p)
  932. else:
  933. self.drop = None
  934. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  935. self.register_parameter("bias", None)
  936. self.reset_parameters()
  937. def reset_parameters(self):
  938. torch.nn.init.ones_(self.weight)
  939. def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
  940. return rms_norm_fn(
  941. x,
  942. self.weight,
  943. self.bias,
  944. residual=residual,
  945. eps=self.eps,
  946. dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
  947. prenorm=prenorm,
  948. residual_in_fp32=residual_in_fp32,
  949. )
  950. class LayerNormLinearFn(torch.autograd.Function):
  951. @staticmethod
  952. @custom_fwd
  953. def forward(
  954. ctx,
  955. x,
  956. norm_weight,
  957. norm_bias,
  958. linear_weight,
  959. linear_bias,
  960. residual=None,
  961. eps=1e-6,
  962. prenorm=False,
  963. residual_in_fp32=False,
  964. is_rms_norm=False,
  965. ):
  966. x_shape_og = x.shape
  967. # reshape input data into 2D tensor
  968. x = x.reshape(-1, x.shape[-1])
  969. if x.stride(-1) != 1:
  970. x = x.contiguous()
  971. if residual is not None:
  972. assert residual.shape == x_shape_og
  973. residual = residual.reshape(-1, residual.shape[-1])
  974. if residual.stride(-1) != 1:
  975. residual = residual.contiguous()
  976. norm_weight = norm_weight.contiguous()
  977. if norm_bias is not None:
  978. norm_bias = norm_bias.contiguous()
  979. residual_dtype = (
  980. residual.dtype
  981. if residual is not None
  982. else (torch.float32 if residual_in_fp32 else None)
  983. )
  984. y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
  985. x,
  986. norm_weight,
  987. norm_bias,
  988. eps,
  989. residual,
  990. out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
  991. residual_dtype=residual_dtype,
  992. is_rms_norm=is_rms_norm,
  993. )
  994. y = y.reshape(x_shape_og)
  995. dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
  996. linear_weight = linear_weight.to(dtype)
  997. linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
  998. out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
  999. # We don't store y, will be recomputed in the backward pass to save memory
  1000. ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
  1001. ctx.x_shape_og = x_shape_og
  1002. ctx.eps = eps
  1003. ctx.is_rms_norm = is_rms_norm
  1004. ctx.has_residual = residual is not None
  1005. ctx.prenorm = prenorm
  1006. ctx.x_dtype = x.dtype
  1007. ctx.linear_bias_is_none = linear_bias is None
  1008. return out if not prenorm else (out, residual_out.reshape(x_shape_og))
  1009. @staticmethod
  1010. @custom_bwd
  1011. def backward(ctx, dout, *args):
  1012. x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
  1013. dout = dout.reshape(-1, dout.shape[-1])
  1014. dy = F.linear(dout, linear_weight.t())
  1015. dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
  1016. if dy.stride(-1) != 1:
  1017. dy = dy.contiguous()
  1018. assert dy.shape == x.shape
  1019. if ctx.prenorm:
  1020. dresidual = args[0]
  1021. dresidual = dresidual.reshape(-1, dresidual.shape[-1])
  1022. if dresidual.stride(-1) != 1:
  1023. dresidual = dresidual.contiguous()
  1024. assert dresidual.shape == x.shape
  1025. else:
  1026. dresidual = None
  1027. dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
  1028. dy,
  1029. x,
  1030. norm_weight,
  1031. norm_bias,
  1032. ctx.eps,
  1033. mean,
  1034. rstd,
  1035. dresidual=dresidual,
  1036. has_residual=ctx.has_residual,
  1037. is_rms_norm=ctx.is_rms_norm,
  1038. x_dtype=ctx.x_dtype,
  1039. recompute_output=True,
  1040. )
  1041. dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
  1042. return (
  1043. dx.reshape(ctx.x_shape_og),
  1044. dnorm_weight,
  1045. dnorm_bias,
  1046. dlinear_weight,
  1047. dlinear_bias,
  1048. dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
  1049. None,
  1050. None,
  1051. None,
  1052. None,
  1053. )
  1054. def layer_norm_linear_fn(
  1055. x,
  1056. norm_weight,
  1057. norm_bias,
  1058. linear_weight,
  1059. linear_bias,
  1060. residual=None,
  1061. eps=1e-6,
  1062. prenorm=False,
  1063. residual_in_fp32=False,
  1064. is_rms_norm=False,
  1065. ):
  1066. return LayerNormLinearFn.apply(
  1067. x,
  1068. norm_weight,
  1069. norm_bias,
  1070. linear_weight,
  1071. linear_bias,
  1072. residual,
  1073. eps,
  1074. prenorm,
  1075. residual_in_fp32,
  1076. is_rms_norm,
  1077. )