layer_norm.py 34 KB


  1. # Copyright (c) 2024, Tri Dao.
  2. # Implement dropout + residual + layer_norm / rms_norm.
  3. # Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
  4. # For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
  5. # This is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
  6. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
  7. import math
  8. import torch
  9. import torch.nn.functional as F
  10. from torch.cuda.amp import custom_fwd, custom_bwd
  11. import triton
  12. import triton.language as tl
  13. def layer_norm_ref(
  14. x,
  15. weight,
  16. bias,
  17. residual=None,
  18. x1=None,
  19. weight1=None,
  20. bias1=None,
  21. eps=1e-6,
  22. dropout_p=0.0,
  23. rowscale=None,
  24. prenorm=False,
  25. dropout_mask=None,
  26. dropout_mask1=None,
  27. upcast=False,
  28. ):
  29. dtype = x.dtype
  30. if upcast:
  31. x = x.float()
  32. weight = weight.float()
  33. bias = bias.float() if bias is not None else None
  34. residual = residual.float() if residual is not None else residual
  35. x1 = x1.float() if x1 is not None else None
  36. weight1 = weight1.float() if weight1 is not None else None
  37. bias1 = bias1.float() if bias1 is not None else None
  38. if x1 is not None:
  39. assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
  40. if rowscale is not None:
  41. x = x * rowscale[..., None]
  42. if dropout_p > 0.0:
  43. if dropout_mask is not None:
  44. x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
  45. else:
  46. x = F.dropout(x, p=dropout_p)
  47. if x1 is not None:
  48. if dropout_mask1 is not None:
  49. x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
  50. else:
  51. x1 = F.dropout(x1, p=dropout_p)
  52. if x1 is not None:
  53. x = x + x1
  54. if residual is not None:
  55. x = (x + residual).to(x.dtype)
  56. out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
  57. dtype
  58. )
  59. if weight1 is None:
  60. return out if not prenorm else (out, x)
  61. else:
  62. out1 = F.layer_norm(
  63. x.to(weight1.dtype), x.shape[-1:], weight=weight1, bias=bias1, eps=eps
  64. ).to(dtype)
  65. return (out, out1) if not prenorm else (out, out1, x)
  66. def rms_norm_ref(
  67. x,
  68. weight,
  69. bias,
  70. residual=None,
  71. x1=None,
  72. weight1=None,
  73. bias1=None,
  74. eps=1e-6,
  75. dropout_p=0.0,
  76. rowscale=None,
  77. prenorm=False,
  78. dropout_mask=None,
  79. dropout_mask1=None,
  80. upcast=False,
  81. ):
  82. dtype = x.dtype
  83. if upcast:
  84. x = x.float()
  85. weight = weight.float()
  86. bias = bias.float() if bias is not None else None
  87. residual = residual.float() if residual is not None else residual
  88. x1 = x1.float() if x1 is not None else None
  89. weight1 = weight1.float() if weight1 is not None else None
  90. bias1 = bias1.float() if bias1 is not None else None
  91. if x1 is not None:
  92. assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
  93. if rowscale is not None:
  94. x = x * rowscale[..., None]
  95. if dropout_p > 0.0:
  96. if dropout_mask is not None:
  97. x = x.masked_fill(~dropout_mask, 0.0) / (1.0 - dropout_p)
  98. else:
  99. x = F.dropout(x, p=dropout_p)
  100. if x1 is not None:
  101. if dropout_mask1 is not None:
  102. x1 = x1.masked_fill(~dropout_mask1, 0.0) / (1.0 - dropout_p)
  103. else:
  104. x1 = F.dropout(x1, p=dropout_p)
  105. if x1 is not None:
  106. x = x + x1
  107. if residual is not None:
  108. x = (x + residual).to(x.dtype)
  109. rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
  110. out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
  111. if weight1 is None:
  112. return out if not prenorm else (out, x)
  113. else:
  114. out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
  115. dtype
  116. )
  117. return (out, out1) if not prenorm else (out, out1, x)
  118. @triton.autotune(
  119. configs=[
  120. triton.Config({}, num_warps=1),
  121. triton.Config({}, num_warps=2),
  122. triton.Config({}, num_warps=4),
  123. triton.Config({}, num_warps=8),
  124. triton.Config({}, num_warps=16),
  125. triton.Config({}, num_warps=32),
  126. ],
  127. key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
  128. )
  129. # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
  130. # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
  131. @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
  132. @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
  133. @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
  134. @triton.jit
  135. def _layer_norm_fwd_1pass_kernel(
  136. X, # pointer to the input
  137. Y, # pointer to the output
  138. W, # pointer to the weights
  139. B, # pointer to the biases
  140. RESIDUAL, # pointer to the residual
  141. X1,
  142. W1,
  143. B1,
  144. Y1,
  145. RESIDUAL_OUT, # pointer to the residual
  146. ROWSCALE,
  147. SEEDS, # Dropout seeds for each row
  148. DROPOUT_MASK,
  149. Mean, # pointer to the mean
  150. Rstd, # pointer to the 1/std
  151. stride_x_row, # how much to increase the pointer when moving by 1 row
  152. stride_y_row,
  153. stride_res_row,
  154. stride_res_out_row,
  155. stride_x1_row,
  156. stride_y1_row,
  157. M, # number of rows in X
  158. N, # number of columns in X
  159. eps, # epsilon to avoid division by zero
  160. dropout_p, # Dropout probability
  161. IS_RMS_NORM: tl.constexpr,
  162. BLOCK_N: tl.constexpr,
  163. HAS_RESIDUAL: tl.constexpr,
  164. STORE_RESIDUAL_OUT: tl.constexpr,
  165. HAS_BIAS: tl.constexpr,
  166. HAS_DROPOUT: tl.constexpr,
  167. STORE_DROPOUT_MASK: tl.constexpr,
  168. HAS_ROWSCALE: tl.constexpr,
  169. HAS_X1: tl.constexpr,
  170. HAS_W1: tl.constexpr,
  171. HAS_B1: tl.constexpr,
  172. ):
  173. # Map the program id to the row of X and Y it should compute.
  174. row = tl.program_id(0)
  175. X += row * stride_x_row
  176. Y += row * stride_y_row
  177. if HAS_RESIDUAL:
  178. RESIDUAL += row * stride_res_row
  179. if STORE_RESIDUAL_OUT:
  180. RESIDUAL_OUT += row * stride_res_out_row
  181. if HAS_X1:
  182. X1 += row * stride_x1_row
  183. if HAS_W1:
  184. Y1 += row * stride_y1_row
  185. # Compute mean and variance
  186. cols = tl.arange(0, BLOCK_N)
  187. x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
  188. if HAS_ROWSCALE:
  189. rowscale = tl.load(ROWSCALE + row).to(tl.float32)
  190. x *= rowscale
  191. if HAS_DROPOUT:
  192. # Compute dropout mask
  193. # 7 rounds is good enough, and reduces register pressure
  194. keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
  195. x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
  196. if STORE_DROPOUT_MASK:
  197. tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
  198. if HAS_X1:
  199. x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
  200. if HAS_ROWSCALE:
  201. rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
  202. x1 *= rowscale
  203. if HAS_DROPOUT:
  204. # Compute dropout mask
  205. # 7 rounds is good enough, and reduces register pressure
  206. keep_mask = (
  207. tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
  208. )
  209. x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
  210. if STORE_DROPOUT_MASK:
  211. tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
  212. x += x1
  213. if HAS_RESIDUAL:
  214. residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
  215. x += residual
  216. if STORE_RESIDUAL_OUT:
  217. tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
  218. if not IS_RMS_NORM:
  219. mean = tl.sum(x, axis=0) / N
  220. tl.store(Mean + row, mean)
  221. xbar = tl.where(cols < N, x - mean, 0.0)
  222. var = tl.sum(xbar * xbar, axis=0) / N
  223. else:
  224. xbar = tl.where(cols < N, x, 0.0)
  225. var = tl.sum(xbar * xbar, axis=0) / N
  226. rstd = 1 / tl.sqrt(var + eps)
  227. tl.store(Rstd + row, rstd)
  228. # Normalize and apply linear transformation
  229. mask = cols < N
  230. w = tl.load(W + cols, mask=mask).to(tl.float32)
  231. if HAS_BIAS:
  232. b = tl.load(B + cols, mask=mask).to(tl.float32)
  233. x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
  234. y = x_hat * w + b if HAS_BIAS else x_hat * w
  235. # Write output
  236. tl.store(Y + cols, y, mask=mask)
  237. if HAS_W1:
  238. w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
  239. if HAS_B1:
  240. b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
  241. y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
  242. tl.store(Y1 + cols, y1, mask=mask)
  243. def _layer_norm_fwd(
  244. x,
  245. weight,
  246. bias,
  247. eps,
  248. residual=None,
  249. x1=None,
  250. weight1=None,
  251. bias1=None,
  252. dropout_p=0.0,
  253. rowscale=None,
  254. out_dtype=None,
  255. residual_dtype=None,
  256. is_rms_norm=False,
  257. return_dropout_mask=False,
  258. ):
  259. if residual is not None:
  260. residual_dtype = residual.dtype
  261. M, N = x.shape
  262. assert x.stride(-1) == 1
  263. if residual is not None:
  264. assert residual.stride(-1) == 1
  265. assert residual.shape == (M, N)
  266. assert weight.shape == (N,)
  267. assert weight.stride(-1) == 1
  268. if bias is not None:
  269. assert bias.stride(-1) == 1
  270. assert bias.shape == (N,)
  271. if x1 is not None:
  272. assert x1.shape == x.shape
  273. assert rowscale is None
  274. assert x1.stride(-1) == 1
  275. if weight1 is not None:
  276. assert weight1.shape == (N,)
  277. assert weight1.stride(-1) == 1
  278. if bias1 is not None:
  279. assert bias1.shape == (N,)
  280. assert bias1.stride(-1) == 1
  281. if rowscale is not None:
  282. assert rowscale.is_contiguous()
  283. assert rowscale.shape == (M,)
  284. # allocate output
  285. y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
  286. assert y.stride(-1) == 1
  287. if weight1 is not None:
  288. y1 = torch.empty_like(y)
  289. assert y1.stride(-1) == 1
  290. else:
  291. y1 = None
  292. if (
  293. residual is not None
  294. or (residual_dtype is not None and residual_dtype != x.dtype)
  295. or dropout_p > 0.0
  296. or rowscale is not None
  297. or x1 is not None
  298. ):
  299. residual_out = torch.empty(
  300. M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype
  301. )
  302. assert residual_out.stride(-1) == 1
  303. else:
  304. residual_out = None
  305. mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
  306. rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
  307. if dropout_p > 0.0:
  308. seeds = torch.randint(
  309. 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64
  310. )
  311. else:
  312. seeds = None
  313. if return_dropout_mask and dropout_p > 0.0:
  314. dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)
  315. else:
  316. dropout_mask = None
  317. # Less than 64KB per feature: enqueue fused kernel
  318. MAX_FUSED_SIZE = 65536 // x.element_size()
  319. BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
  320. if N > BLOCK_N:
  321. raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
  322. with torch.cuda.device(x.device.index):
  323. _layer_norm_fwd_1pass_kernel[(M,)](
  324. x,
  325. y,
  326. weight,
  327. bias,
  328. residual,
  329. x1,
  330. weight1,
  331. bias1,
  332. y1,
  333. residual_out,
  334. rowscale,
  335. seeds,
  336. dropout_mask,
  337. mean,
  338. rstd,
  339. x.stride(0),
  340. y.stride(0),
  341. residual.stride(0) if residual is not None else 0,
  342. residual_out.stride(0) if residual_out is not None else 0,
  343. x1.stride(0) if x1 is not None else 0,
  344. y1.stride(0) if y1 is not None else 0,
  345. M,
  346. N,
  347. eps,
  348. dropout_p,
  349. is_rms_norm,
  350. BLOCK_N,
  351. residual is not None,
  352. residual_out is not None,
  353. bias is not None,
  354. dropout_p > 0.0,
  355. dropout_mask is not None,
  356. rowscale is not None,
  357. )
  358. # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
  359. if dropout_mask is not None and x1 is not None:
  360. dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
  361. else:
  362. dropout_mask1 = None
  363. return (
  364. y,
  365. y1,
  366. mean,
  367. rstd,
  368. residual_out if residual_out is not None else x,
  369. seeds,
  370. dropout_mask,
  371. dropout_mask1,
  372. )
  373. @triton.autotune(
  374. configs=[
  375. triton.Config({}, num_warps=1),
  376. triton.Config({}, num_warps=2),
  377. triton.Config({}, num_warps=4),
  378. triton.Config({}, num_warps=8),
  379. triton.Config({}, num_warps=16),
  380. triton.Config({}, num_warps=32),
  381. ],
  382. key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
  383. )
  384. # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
  385. # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
  386. # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
  387. @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
  388. @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
  389. @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
  390. @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
  391. @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
  392. @triton.jit
  393. def _layer_norm_bwd_kernel(
  394. X, # pointer to the input
  395. W, # pointer to the weights
  396. B, # pointer to the biases
  397. Y, # pointer to the output to be recomputed
  398. DY, # pointer to the output gradient
  399. DX, # pointer to the input gradient
  400. DW, # pointer to the partial sum of weights gradient
  401. DB, # pointer to the partial sum of biases gradient
  402. DRESIDUAL,
  403. W1,
  404. DY1,
  405. DX1,
  406. DW1,
  407. DB1,
  408. DRESIDUAL_IN,
  409. ROWSCALE,
  410. SEEDS,
  411. Mean, # pointer to the mean
  412. Rstd, # pointer to the 1/std
  413. stride_x_row, # how much to increase the pointer when moving by 1 row
  414. stride_y_row,
  415. stride_dy_row,
  416. stride_dx_row,
  417. stride_dres_row,
  418. stride_dy1_row,
  419. stride_dx1_row,
  420. stride_dres_in_row,
  421. M, # number of rows in X
  422. N, # number of columns in X
  423. eps, # epsilon to avoid division by zero
  424. dropout_p,
  425. rows_per_program,
  426. IS_RMS_NORM: tl.constexpr,
  427. BLOCK_N: tl.constexpr,
  428. HAS_DRESIDUAL: tl.constexpr,
  429. STORE_DRESIDUAL: tl.constexpr,
  430. HAS_BIAS: tl.constexpr,
  431. HAS_DROPOUT: tl.constexpr,
  432. HAS_ROWSCALE: tl.constexpr,
  433. HAS_DY1: tl.constexpr,
  434. HAS_DX1: tl.constexpr,
  435. HAS_B1: tl.constexpr,
  436. RECOMPUTE_OUTPUT: tl.constexpr,
  437. ):
  438. # Map the program id to the elements of X, DX, and DY it should compute.
  439. row_block_id = tl.program_id(0)
  440. row_start = row_block_id * rows_per_program
  441. # Do not early exit if row_start >= M, because we need to write DW and DB
  442. cols = tl.arange(0, BLOCK_N)
  443. mask = cols < N
  444. X += row_start * stride_x_row
  445. if HAS_DRESIDUAL:
  446. DRESIDUAL += row_start * stride_dres_row
  447. if STORE_DRESIDUAL:
  448. DRESIDUAL_IN += row_start * stride_dres_in_row
  449. DY += row_start * stride_dy_row
  450. DX += row_start * stride_dx_row
  451. if HAS_DY1:
  452. DY1 += row_start * stride_dy1_row
  453. if HAS_DX1:
  454. DX1 += row_start * stride_dx1_row
  455. if RECOMPUTE_OUTPUT:
  456. Y += row_start * stride_y_row
  457. w = tl.load(W + cols, mask=mask).to(tl.float32)
  458. if RECOMPUTE_OUTPUT and HAS_BIAS:
  459. b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
  460. if HAS_DY1:
  461. w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
  462. dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
  463. if HAS_BIAS:
  464. db = tl.zeros((BLOCK_N,), dtype=tl.float32)
  465. if HAS_DY1:
  466. dw1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
  467. if HAS_B1:
  468. db1 = tl.zeros((BLOCK_N,), dtype=tl.float32)
  469. row_end = min((row_block_id + 1) * rows_per_program, M)
  470. for row in range(row_start, row_end):
  471. # Load data to SRAM
  472. x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)
  473. dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)
  474. if HAS_DY1:
  475. dy1 = tl.load(DY1 + cols, mask=mask, other=0).to(tl.float32)
  476. if not IS_RMS_NORM:
  477. mean = tl.load(Mean + row)
  478. rstd = tl.load(Rstd + row)
  479. # Compute dx
  480. xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
  481. xhat = tl.where(mask, xhat, 0.0)
  482. if RECOMPUTE_OUTPUT:
  483. y = xhat * w + b if HAS_BIAS else xhat * w
  484. tl.store(Y + cols, y, mask=mask)
  485. wdy = w * dy
  486. dw += dy * xhat
  487. if HAS_BIAS:
  488. db += dy
  489. if HAS_DY1:
  490. wdy += w1 * dy1
  491. dw1 += dy1 * xhat
  492. if HAS_B1:
  493. db1 += dy1
  494. if not IS_RMS_NORM:
  495. c1 = tl.sum(xhat * wdy, axis=0) / N
  496. c2 = tl.sum(wdy, axis=0) / N
  497. dx = (wdy - (xhat * c1 + c2)) * rstd
  498. else:
  499. c1 = tl.sum(xhat * wdy, axis=0) / N
  500. dx = (wdy - xhat * c1) * rstd
  501. if HAS_DRESIDUAL:
  502. dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)
  503. dx += dres
  504. # Write dx
  505. if STORE_DRESIDUAL:
  506. tl.store(DRESIDUAL_IN + cols, dx, mask=mask)
  507. if HAS_DX1:
  508. if HAS_DROPOUT:
  509. keep_mask = (
  510. tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
  511. )
  512. dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
  513. else:
  514. dx1 = dx
  515. tl.store(DX1 + cols, dx1, mask=mask)
  516. if HAS_DROPOUT:
  517. keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
  518. dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
  519. if HAS_ROWSCALE:
  520. rowscale = tl.load(ROWSCALE + row).to(tl.float32)
  521. dx *= rowscale
  522. tl.store(DX + cols, dx, mask=mask)
  523. X += stride_x_row
  524. if HAS_DRESIDUAL:
  525. DRESIDUAL += stride_dres_row
  526. if STORE_DRESIDUAL:
  527. DRESIDUAL_IN += stride_dres_in_row
  528. if RECOMPUTE_OUTPUT:
  529. Y += stride_y_row
  530. DY += stride_dy_row
  531. DX += stride_dx_row
  532. if HAS_DY1:
  533. DY1 += stride_dy1_row
  534. if HAS_DX1:
  535. DX1 += stride_dx1_row
  536. tl.store(DW + row_block_id * N + cols, dw, mask=mask)
  537. if HAS_BIAS:
  538. tl.store(DB + row_block_id * N + cols, db, mask=mask)
  539. if HAS_DY1:
  540. tl.store(DW1 + row_block_id * N + cols, dw1, mask=mask)
  541. if HAS_B1:
  542. tl.store(DB1 + row_block_id * N + cols, db1, mask=mask)
  543. def _layer_norm_bwd(
  544. dy,
  545. x,
  546. weight,
  547. bias,
  548. eps,
  549. mean,
  550. rstd,
  551. dresidual=None,
  552. dy1=None,
  553. weight1=None,
  554. bias1=None,
  555. seeds=None,
  556. dropout_p=0.0,
  557. rowscale=None,
  558. has_residual=False,
  559. has_x1=False,
  560. is_rms_norm=False,
  561. x_dtype=None,
  562. recompute_output=False,
  563. ):
  564. M, N = x.shape
  565. assert x.stride(-1) == 1
  566. assert dy.stride(-1) == 1
  567. assert dy.shape == (M, N)
  568. if dresidual is not None:
  569. assert dresidual.stride(-1) == 1
  570. assert dresidual.shape == (M, N)
  571. assert weight.shape == (N,)
  572. assert weight.stride(-1) == 1
  573. if bias is not None:
  574. assert bias.stride(-1) == 1
  575. assert bias.shape == (N,)
  576. if dy1 is not None:
  577. assert weight1 is not None
  578. assert dy1.shape == dy.shape
  579. assert dy1.stride(-1) == 1
  580. if weight1 is not None:
  581. assert weight1.shape == (N,)
  582. assert weight1.stride(-1) == 1
  583. if bias1 is not None:
  584. assert bias1.shape == (N,)
  585. assert bias1.stride(-1) == 1
  586. if seeds is not None:
  587. assert seeds.is_contiguous()
  588. assert seeds.shape == (M if not has_x1 else M * 2,)
  589. if rowscale is not None:
  590. assert rowscale.is_contiguous()
  591. assert rowscale.shape == (M,)
  592. # allocate output
  593. dx = (
  594. torch.empty_like(x)
  595. if x_dtype is None
  596. else torch.empty(M, N, dtype=x_dtype, device=x.device)
  597. )
  598. dresidual_in = (
  599. torch.empty_like(x)
  600. if has_residual
  601. and (dx.dtype != x.dtype or dropout_p > 0.0 or rowscale is not None or has_x1)
  602. else None
  603. )
  604. dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
  605. y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
  606. if recompute_output:
  607. assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
  608. # Less than 64KB per feature: enqueue fused kernel
  609. MAX_FUSED_SIZE = 65536 // x.element_size()
  610. BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
  611. if N > BLOCK_N:
  612. raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
  613. sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
  614. _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
  615. _db = (
  616. torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
  617. if bias is not None
  618. else None
  619. )
  620. _dw1 = torch.empty_like(_dw) if weight1 is not None else None
  621. _db1 = torch.empty_like(_db) if bias1 is not None else None
  622. rows_per_program = math.ceil(M / sm_count)
  623. grid = (sm_count,)
  624. with torch.cuda.device(x.device.index):
  625. _layer_norm_bwd_kernel[grid](
  626. x,
  627. weight,
  628. bias,
  629. y,
  630. dy,
  631. dx,
  632. _dw,
  633. _db,
  634. dresidual,
  635. weight1,
  636. dy1,
  637. dx1,
  638. _dw1,
  639. _db1,
  640. dresidual_in,
  641. rowscale,
  642. seeds,
  643. mean,
  644. rstd,
  645. x.stride(0),
  646. 0 if not recompute_output else y.stride(0),
  647. dy.stride(0),
  648. dx.stride(0),
  649. dresidual.stride(0) if dresidual is not None else 0,
  650. dy1.stride(0) if dy1 is not None else 0,
  651. dx1.stride(0) if dx1 is not None else 0,
  652. dresidual_in.stride(0) if dresidual_in is not None else 0,
  653. M,
  654. N,
  655. eps,
  656. dropout_p,
  657. rows_per_program,
  658. is_rms_norm,
  659. BLOCK_N,
  660. dresidual is not None,
  661. dresidual_in is not None,
  662. bias is not None,
  663. dropout_p > 0.0,
  664. )
  665. dw = _dw.sum(0).to(weight.dtype)
  666. db = _db.sum(0).to(bias.dtype) if bias is not None else None
  667. dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
  668. db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
  669. # Don't need to compute dresidual_in separately in this case
  670. if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
  671. dresidual_in = dx
  672. if has_x1 and dropout_p == 0.0:
  673. dx1 = dx
  674. return (
  675. (dx, dw, db, dresidual_in, dx1, dw1, db1)
  676. if not recompute_output
  677. else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
  678. )
  679. class LayerNormFn(torch.autograd.Function):
  680. @staticmethod
  681. def forward(
  682. ctx,
  683. x,
  684. weight,
  685. bias,
  686. residual=None,
  687. x1=None,
  688. weight1=None,
  689. bias1=None,
  690. eps=1e-6,
  691. dropout_p=0.0,
  692. rowscale=None,
  693. prenorm=False,
  694. residual_in_fp32=False,
  695. is_rms_norm=False,
  696. return_dropout_mask=False,
  697. ):
  698. x_shape_og = x.shape
  699. # reshape input data into 2D tensor
  700. x = x.reshape(-1, x.shape[-1])
  701. if x.stride(-1) != 1:
  702. x = x.contiguous()
  703. if residual is not None:
  704. assert residual.shape == x_shape_og
  705. residual = residual.reshape(-1, residual.shape[-1])
  706. if residual.stride(-1) != 1:
  707. residual = residual.contiguous()
  708. if x1 is not None:
  709. assert x1.shape == x_shape_og
  710. assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
  711. x1 = x1.reshape(-1, x1.shape[-1])
  712. if x1.stride(-1) != 1:
  713. x1 = x1.contiguous()
  714. weight = weight.contiguous()
  715. if bias is not None:
  716. bias = bias.contiguous()
  717. if weight1 is not None:
  718. weight1 = weight1.contiguous()
  719. if bias1 is not None:
  720. bias1 = bias1.contiguous()
  721. if rowscale is not None:
  722. rowscale = rowscale.reshape(-1).contiguous()
  723. residual_dtype = (
  724. residual.dtype
  725. if residual is not None
  726. else (torch.float32 if residual_in_fp32 else None)
  727. )
  728. y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
  729. x,
  730. weight,
  731. bias,
  732. eps,
  733. residual,
  734. x1,
  735. weight1,
  736. bias1,
  737. dropout_p=dropout_p,
  738. rowscale=rowscale,
  739. residual_dtype=residual_dtype,
  740. is_rms_norm=is_rms_norm,
  741. return_dropout_mask=return_dropout_mask,
  742. )
  743. ctx.save_for_backward(
  744. residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
  745. )
  746. ctx.x_shape_og = x_shape_og
  747. ctx.eps = eps
  748. ctx.dropout_p = dropout_p
  749. ctx.is_rms_norm = is_rms_norm
  750. ctx.has_residual = residual is not None
  751. ctx.has_x1 = x1 is not None
  752. ctx.prenorm = prenorm
  753. ctx.x_dtype = x.dtype
  754. y = y.reshape(x_shape_og)
  755. y1 = y1.reshape(x_shape_og) if y1 is not None else None
  756. residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
  757. dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
  758. dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
  759. if not return_dropout_mask:
  760. if weight1 is None:
  761. return y if not prenorm else (y, residual_out)
  762. else:
  763. return (y, y1) if not prenorm else (y, y1, residual_out)
  764. else:
  765. if weight1 is None:
  766. return (
  767. (y, dropout_mask, dropout_mask1)
  768. if not prenorm
  769. else (y, residual_out, dropout_mask, dropout_mask1)
  770. )
  771. else:
  772. return (
  773. (y, y1, dropout_mask, dropout_mask1)
  774. if not prenorm
  775. else (y, y1, residual_out, dropout_mask, dropout_mask1)
  776. )
  777. @staticmethod
  778. def backward(ctx, dy, *args):
  779. x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
  780. dy = dy.reshape(-1, dy.shape[-1])
  781. if dy.stride(-1) != 1:
  782. dy = dy.contiguous()
  783. assert dy.shape == x.shape
  784. if weight1 is not None:
  785. dy1, args = args[0], args[1:]
  786. dy1 = dy1.reshape(-1, dy1.shape[-1])
  787. if dy1.stride(-1) != 1:
  788. dy1 = dy1.contiguous()
  789. assert dy1.shape == x.shape
  790. else:
  791. dy1 = None
  792. if ctx.prenorm:
  793. dresidual = args[0]
  794. dresidual = dresidual.reshape(-1, dresidual.shape[-1])
  795. if dresidual.stride(-1) != 1:
  796. dresidual = dresidual.contiguous()
  797. assert dresidual.shape == x.shape
  798. else:
  799. dresidual = None
  800. dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
  801. dy,
  802. x,
  803. weight,
  804. bias,
  805. ctx.eps,
  806. mean,
  807. rstd,
  808. dresidual,
  809. dy1,
  810. weight1,
  811. bias1,
  812. seeds,
  813. ctx.dropout_p,
  814. rowscale,
  815. ctx.has_residual,
  816. ctx.has_x1,
  817. ctx.is_rms_norm,
  818. x_dtype=ctx.x_dtype,
  819. )
  820. return (
  821. dx.reshape(ctx.x_shape_og),
  822. dw,
  823. db,
  824. dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
  825. dx1.reshape(ctx.x_shape_og) if dx1 is not None else None,
  826. dw1,
  827. db1,
  828. None,
  829. None,
  830. None,
  831. None,
  832. None,
  833. None,
  834. None,
  835. )
  836. def layer_norm_fn(
  837. x,
  838. weight,
  839. bias,
  840. residual=None,
  841. x1=None,
  842. weight1=None,
  843. bias1=None,
  844. eps=1e-6,
  845. dropout_p=0.0,
  846. rowscale=None,
  847. prenorm=False,
  848. residual_in_fp32=False,
  849. is_rms_norm=False,
  850. return_dropout_mask=False,
  851. ):
  852. return LayerNormFn.apply(
  853. x,
  854. weight,
  855. bias,
  856. residual,
  857. x1,
  858. weight1,
  859. bias1,
  860. eps,
  861. dropout_p,
  862. rowscale,
  863. prenorm,
  864. residual_in_fp32,
  865. is_rms_norm,
  866. return_dropout_mask,
  867. )
  868. def rms_norm_fn(
  869. x,
  870. weight,
  871. bias,
  872. residual=None,
  873. x1=None,
  874. weight1=None,
  875. bias1=None,
  876. eps=1e-6,
  877. dropout_p=0.0,
  878. rowscale=None,
  879. prenorm=False,
  880. residual_in_fp32=False,
  881. return_dropout_mask=False,
  882. ):
  883. return LayerNormFn.apply(
  884. x,
  885. weight,
  886. bias,
  887. residual,
  888. x1,
  889. weight1,
  890. bias1,
  891. eps,
  892. dropout_p,
  893. rowscale,
  894. prenorm,
  895. residual_in_fp32,
  896. True,
  897. return_dropout_mask,
  898. )
  899. class RMSNorm(torch.nn.Module):
  900. def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
  901. factory_kwargs = {"device": device, "dtype": dtype}
  902. super().__init__()
  903. self.eps = eps
  904. if dropout_p > 0.0:
  905. self.drop = torch.nn.Dropout(dropout_p)
  906. else:
  907. self.drop = None
  908. self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
  909. self.register_parameter("bias", None)
  910. self.reset_parameters()
  911. def reset_parameters(self):
  912. torch.nn.init.ones_(self.weight)
  913. def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
  914. return rms_norm_fn(
  915. x,
  916. self.weight,
  917. self.bias,
  918. residual=residual,
  919. eps=self.eps,
  920. dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
  921. prenorm=prenorm,
  922. residual_in_fp32=residual_in_fp32,
  923. )
  924. class LayerNormLinearFn(torch.autograd.Function):
  925. @staticmethod
  926. @custom_fwd
  927. def forward(
  928. ctx,
  929. x,
  930. norm_weight,
  931. norm_bias,
  932. linear_weight,
  933. linear_bias,
  934. residual=None,
  935. eps=1e-6,
  936. prenorm=False,
  937. residual_in_fp32=False,
  938. is_rms_norm=False,
  939. ):
  940. x_shape_og = x.shape
  941. # reshape input data into 2D tensor
  942. x = x.reshape(-1, x.shape[-1])
  943. if x.stride(-1) != 1:
  944. x = x.contiguous()
  945. if residual is not None:
  946. assert residual.shape == x_shape_og
  947. residual = residual.reshape(-1, residual.shape[-1])
  948. if residual.stride(-1) != 1:
  949. residual = residual.contiguous()
  950. norm_weight = norm_weight.contiguous()
  951. if norm_bias is not None:
  952. norm_bias = norm_bias.contiguous()
  953. residual_dtype = (
  954. residual.dtype
  955. if residual is not None
  956. else (torch.float32 if residual_in_fp32 else None)
  957. )
  958. y, _, mean, rstd, residual_out, *rest = _layer_norm_fwd(
  959. x,
  960. norm_weight,
  961. norm_bias,
  962. eps,
  963. residual,
  964. out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
  965. residual_dtype=residual_dtype,
  966. is_rms_norm=is_rms_norm,
  967. )
  968. y = y.reshape(x_shape_og)
  969. dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
  970. linear_weight = linear_weight.to(dtype)
  971. linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
  972. out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
  973. # We don't store y, will be recomputed in the backward pass to save memory
  974. ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
  975. ctx.x_shape_og = x_shape_og
  976. ctx.eps = eps
  977. ctx.is_rms_norm = is_rms_norm
  978. ctx.has_residual = residual is not None
  979. ctx.prenorm = prenorm
  980. ctx.x_dtype = x.dtype
  981. ctx.linear_bias_is_none = linear_bias is None
  982. return out if not prenorm else (out, residual_out.reshape(x_shape_og))
  983. @staticmethod
  984. @custom_bwd
  985. def backward(ctx, dout, *args):
  986. x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
  987. dout = dout.reshape(-1, dout.shape[-1])
  988. dy = F.linear(dout, linear_weight.t())
  989. dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
  990. if dy.stride(-1) != 1:
  991. dy = dy.contiguous()
  992. assert dy.shape == x.shape
  993. if ctx.prenorm:
  994. dresidual = args[0]
  995. dresidual = dresidual.reshape(-1, dresidual.shape[-1])
  996. if dresidual.stride(-1) != 1:
  997. dresidual = dresidual.contiguous()
  998. assert dresidual.shape == x.shape
  999. else:
  1000. dresidual = None
  1001. dx, dnorm_weight, dnorm_bias, dresidual_in, _, _, _, y = _layer_norm_bwd(
  1002. dy,
  1003. x,
  1004. norm_weight,
  1005. norm_bias,
  1006. ctx.eps,
  1007. mean,
  1008. rstd,
  1009. dresidual=dresidual,
  1010. has_residual=ctx.has_residual,
  1011. is_rms_norm=ctx.is_rms_norm,
  1012. x_dtype=ctx.x_dtype,
  1013. recompute_output=True,
  1014. )
  1015. dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
  1016. return (
  1017. dx.reshape(ctx.x_shape_og),
  1018. dnorm_weight,
  1019. dnorm_bias,
  1020. dlinear_weight,
  1021. dlinear_bias,
  1022. dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
  1023. None,
  1024. None,
  1025. None,
  1026. None,
  1027. )
  1028. def layer_norm_linear_fn(
  1029. x,
  1030. norm_weight,
  1031. norm_bias,
  1032. linear_weight,
  1033. linear_bias,
  1034. residual=None,
  1035. eps=1e-6,
  1036. prenorm=False,
  1037. residual_in_fp32=False,
  1038. is_rms_norm=False,
  1039. ):
  1040. return LayerNormLinearFn.apply(
  1041. x,
  1042. norm_weight,
  1043. norm_bias,
  1044. linear_weight,
  1045. linear_bias,
  1046. residual,
  1047. eps,
  1048. prenorm,
  1049. residual_in_fp32,
  1050. is_rms_norm,
  1051. )