1
0

layernorm.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved
  2. # Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # The following code is adapted from:
  16. # https://github.com/unslothai/unsloth/blob/038e6d4c8d40207a87297ab3aaf787c19b1006d1/unsloth/kernels/rms_layernorm.py
  17. import torch
  18. import triton
  19. import triton.language as tl
  20. from aphrodite.modeling.layers.ops.utils import calculate_settings
  21. @triton.jit
  22. def _rms_layernorm_forward(
  23. Y, Y_row_stride,
  24. X, X_row_stride,
  25. W, W_row_stride,
  26. r, r_row_stride,
  27. n_cols, eps,
  28. BLOCK_SIZE : tl.constexpr
  29. ):
  30. """
  31. Fast RMS Layernorm kernel
  32. Inspiration from a Triton tutorial:
  33. https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
  34. """
  35. row_idx = tl.program_id(0)
  36. col_offsets = tl.arange(0, BLOCK_SIZE)
  37. mask = col_offsets < n_cols
  38. Y += row_idx * Y_row_stride
  39. X += row_idx * X_row_stride
  40. r += row_idx * r_row_stride
  41. X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
  42. W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
  43. row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
  44. inv_var = tl.math.rsqrt(row_var + eps)
  45. tl.store(r, inv_var)
  46. normed = X_row * inv_var
  47. normed = normed.to(W_row.dtype) # Exact copy from HF
  48. output = normed * W_row
  49. tl.store(Y + col_offsets, output, mask = mask)
  50. pass
  51. @triton.jit
  52. def _gemma_rms_layernorm_forward(
  53. Y, Y_row_stride,
  54. X, X_row_stride,
  55. W, W_row_stride,
  56. r, r_row_stride,
  57. n_cols, eps,
  58. BLOCK_SIZE : tl.constexpr,
  59. ):
  60. # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
  61. # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
  62. # exactly. Essentially all in float32!
  63. row_idx = tl.program_id(0)
  64. col_offsets = tl.arange(0, BLOCK_SIZE)
  65. mask = col_offsets < n_cols
  66. Y += row_idx * Y_row_stride
  67. X += row_idx * X_row_stride
  68. r += row_idx * r_row_stride
  69. X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
  70. W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
  71. row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
  72. inv_var = tl.math.rsqrt(row_var + eps)
  73. tl.store(r, inv_var)
  74. normed = X_row * inv_var
  75. output = normed * (W_row + 1.0)
  76. tl.store(Y + col_offsets, output, mask = mask)
  77. pass
  78. class Fast_RMS_Layernorm(torch.autograd.Function):
  79. @staticmethod
  80. def forward(
  81. ctx,
  82. X : torch.Tensor,
  83. W : torch.Tensor,
  84. eps : float,
  85. gemma : bool = False,
  86. ):
  87. shape = X.shape
  88. dim : int = shape[-1]
  89. X = X.view(-1, dim)
  90. n_rows : int
  91. n_cols : int
  92. n_rows, n_cols = X.shape
  93. BLOCK_SIZE : int
  94. num_warps : int
  95. BLOCK_SIZE, num_warps = calculate_settings(n_cols)
  96. Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = X.device)
  97. r = torch.empty(n_rows, dtype = torch.float32, device = X.device)
  98. fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
  99. with torch.cuda.device(X.device):
  100. fx[(n_rows,)](
  101. Y, Y.stride(0),
  102. X, X.stride(0),
  103. W, W.stride(0),
  104. r, r.stride(0),
  105. n_cols, eps,
  106. BLOCK_SIZE = BLOCK_SIZE,
  107. num_warps = num_warps,
  108. )
  109. ctx.eps = eps
  110. ctx.BLOCK_SIZE = BLOCK_SIZE
  111. ctx.num_warps = num_warps
  112. ctx.GEMMA = gemma
  113. ctx.save_for_backward(X, W, r)
  114. return Y.view(*shape)
  115. pass
  116. pass
  117. # [TODO] Unsure why RMS Layernorm is not torch.compiling properly
  118. @torch.compiler.disable
  119. def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False):
  120. W : torch.Tensor = layernorm.weight
  121. eps : float = layernorm.variance_epsilon if \
  122. hasattr(layernorm, "variance_epsilon") \
  123. else layernorm.eps
  124. out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
  125. return out
  126. pass