Browse Source

(1/N) Triton Backend: integrate Triton layernorm kernels (#1125)

* (1/N) Triton Backend: integrate Triton layernorm kernels

* add tests
AlpinDale 1 month ago
parent
commit
0c17153073

+ 5 - 0
aphrodite/common/envs.py

@@ -60,6 +60,7 @@ if TYPE_CHECKING:
     APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE: int = 0
     APHRODITE_USE_TRITON_AWQ: bool = False
     APHRODITE_DYNAMO_USE_CUSTOM_DISPATCHER: bool = False
+    APHRODITE_USE_TRITON_LAYERNORM: bool = False
 
 
 def get_default_cache_root():
@@ -401,6 +402,10 @@ environment_variables: Dict[str, Callable[[], Any]] = {
     # If set, Aphrodite will use Triton implementations of AWQ.
     "APHRODITE_USE_TRITON_AWQ":
     lambda: bool(int(os.getenv("APHRODITE_USE_TRITON_AWQ", "0"))),
+
+    # If set, Aphrodite will use Triton implementations of layernorm.
+    "APHRODITE_USE_TRITON_LAYERNORM":
+    lambda: bool(int(os.getenv("APHRODITE_USE_TRITON_LAYERNORM", "0"))),
 }
 
 # end-env-vars-definition

+ 6 - 0
aphrodite/modeling/_custom_op.py

@@ -1,5 +1,6 @@
 import torch.nn as nn
 
+import aphrodite.common.envs as envs
 from aphrodite.common.utils import is_cpu, is_hip, is_xpu
 from aphrodite.platforms import current_platform
 
@@ -47,6 +48,9 @@ class CustomOp(nn.Module):
         # NOTE: This is a placeholder for future extensions.
         return self.forward_native(*args, **kwargs)
 
+    def forward_triton(self, *args, **kwargs):
+        raise NotImplementedError
+
     def dispatch_forward(self):
         # NOTE: Here we assume that Aphrodite was built for only one
         # specific backend. Currently, we do not support dynamic dispatching.
@@ -58,5 +62,7 @@ class CustomOp(nn.Module):
             return self.forward_tpu
         elif is_xpu():
             return self.forward_xpu
+        elif envs.APHRODITE_USE_TRITON_LAYERNORM:
+            return self.forward_triton
         else:
             return self.forward_cuda

+ 18 - 0
aphrodite/modeling/layers/activation.py

@@ -45,6 +45,9 @@ class SiluAndMul(CustomOp):
         ops.silu_and_mul(out, x)
         return out
 
+    def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
+        return self.forward_cuda(x)
+
 
 class GeluAndMul(CustomOp):
     """An activation function for GeGLU.
@@ -90,6 +93,9 @@ class GeluAndMul(CustomOp):
             ops.gelu_tanh_and_mul(out, x)
         return out
 
+    def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
+        return self.forward_cuda(x)
+
     def extra_repr(self) -> str:
         return f'approximate={repr(self.approximate)}'
 
@@ -113,6 +119,9 @@ class NewGELU(CustomOp):
 
         return ops.gelu_new(x)
 
+    def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
+        return self.forward_cuda(x)
+
 
 class FastGELU(CustomOp):
 
@@ -132,6 +141,9 @@ class FastGELU(CustomOp):
 
         return ops.gelu_fast(x)
 
+    def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
+        return self.forward_cuda(x)
+
 
 class QuickGELU(CustomOp):
 
@@ -153,6 +165,9 @@ class QuickGELU(CustomOp):
         ops.gelu_quick(out, x)
         return out
 
+    def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
+        return self.forward_cuda(x)
+
 
 class ReLUSquaredActivation(CustomOp):
     """
@@ -166,6 +181,9 @@ class ReLUSquaredActivation(CustomOp):
     def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
         return self.forward_native(x)
 
+    def forward_triton(self, x: torch.Tensor) -> torch.Tensor:
+        return self.forward_cuda(x)
+
 
 class ScaledActivation(nn.Module):
     """An activation function with post-scale parameters.

+ 24 - 0
aphrodite/modeling/layers/layernorm.py

@@ -112,6 +112,18 @@ class RMSNorm(CustomOp):
             self.variance_epsilon,
         )
 
+    def forward_triton(
+            self,
+            x: torch.Tensor,
+            residual: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+        from aphrodite.modeling.layers.ops.layernorm import fast_rms_layernorm
+
+        if residual is not None:
+            x = x + residual
+            return fast_rms_layernorm(self, x, gemma=False), x
+        return fast_rms_layernorm(self, x, gemma=False)
+
     def extra_repr(self) -> str:
         s = f"hidden_size={self.weight.data.size(0)}"
         s += f", eps={self.variance_epsilon}"
@@ -177,3 +189,15 @@ class GemmaRMSNorm(CustomOp):
                 self.forward_static)
             self._is_compiled = True
         return self.forward_native(x, residual)
+
+    def forward_triton(
+        self,
+        x: torch.Tensor,
+        residual: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+        from aphrodite.modeling.layers.ops.layernorm import fast_rms_layernorm
+
+        if residual is not None:
+            x = x + residual
+            return fast_rms_layernorm(self, x, gemma=True), x
+        return fast_rms_layernorm(self, x, gemma=True)

+ 0 - 0
aphrodite/modeling/layers/ops/__init__.py


+ 167 - 0
aphrodite/modeling/layers/ops/layernorm.py

@@ -0,0 +1,167 @@
+# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved
+# Copyright 2024-present Andrej Karpathy & the llm.c team. All rights reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# The following code is adapted from:
+# https://github.com/unslothai/unsloth/blob/038e6d4c8d40207a87297ab3aaf787c19b1006d1/unsloth/kernels/rms_layernorm.py
+
+import torch
+import triton
+import triton.language as tl
+
+MAX_FUSED_SIZE : int = 65536
+next_power_of_2 = triton.next_power_of_2
+
+
+# Calculate the optimal block size and number of warps for the layernorm kernel
+# borrowed from https://github.com/unslothai/unsloth/blob/038e6d4c8d40207a87297ab3aaf787c19b1006d1/unsloth/kernels/utils.py#L49-L59
+def calculate_settings(n : int) -> tuple[int, int]:
+    BLOCK_SIZE : int = next_power_of_2(n)
+    if BLOCK_SIZE > MAX_FUSED_SIZE:
+        raise RuntimeError(
+            f"Cannot launch Triton kernel since n = {n} exceeds "
+            f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
+    num_warps : int = 4
+    if   BLOCK_SIZE >= 32768:
+        num_warps = 32
+    elif BLOCK_SIZE >=  8192:
+        num_warps = 16
+    elif BLOCK_SIZE >=  2048:
+        num_warps = 8
+    return BLOCK_SIZE, num_warps
+pass
+
+
+@triton.jit
+def _rms_layernorm_forward(
+    Y, Y_row_stride,
+    X, X_row_stride,
+    W, W_row_stride,
+    r, r_row_stride,
+    n_cols, eps,
+    BLOCK_SIZE : tl.constexpr
+):
+    """
+        Fast RMS Layernorm kernel
+        Inspiration from a Triton tutorial:
+        https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
+    """
+    row_idx = tl.program_id(0)
+    col_offsets = tl.arange(0, BLOCK_SIZE)
+    mask = col_offsets < n_cols
+
+    Y += row_idx * Y_row_stride
+    X += row_idx * X_row_stride
+    r += row_idx * r_row_stride
+
+    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+    W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)
+
+    row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
+    inv_var = tl.math.rsqrt(row_var + eps)
+    tl.store(r, inv_var)
+    normed = X_row * inv_var
+    normed = normed.to(W_row.dtype) # Exact copy from HF
+    output = normed * W_row
+    tl.store(Y + col_offsets, output, mask = mask)
+pass
+
+
+@triton.jit
+def _gemma_rms_layernorm_forward(
+    Y, Y_row_stride,
+    X, X_row_stride,
+    W, W_row_stride,
+    r, r_row_stride,
+    n_cols, eps,
+    BLOCK_SIZE : tl.constexpr,
+):
+    # Copies https://github.com/google-deepmind/gemma/blob/main/gemma/layers.py#L31
+    # and https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/rms_normalization.py#L33
+    # exactly. Essentially all in float32!
+    row_idx = tl.program_id(0)
+    col_offsets = tl.arange(0, BLOCK_SIZE)
+    mask = col_offsets < n_cols
+
+    Y += row_idx * Y_row_stride
+    X += row_idx * X_row_stride
+    r += row_idx * r_row_stride
+
+    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
+    W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)
+
+    row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
+    inv_var = tl.math.rsqrt(row_var + eps)
+    tl.store(r, inv_var)
+    normed = X_row * inv_var
+    output = normed * (W_row + 1.0)
+
+    tl.store(Y + col_offsets, output, mask = mask)
+pass
+
+
+class Fast_RMS_Layernorm(torch.autograd.Function):
+    @staticmethod
+    def forward(
+        ctx,
+        X : torch.Tensor,
+        W : torch.Tensor,
+        eps : float,
+        gemma : bool = False,
+    ):
+        shape = X.shape
+        dim : int = shape[-1]
+        X = X.view(-1, dim)
+        n_rows : int
+        n_cols : int
+        n_rows, n_cols = X.shape
+        BLOCK_SIZE : int
+        num_warps  : int
+        BLOCK_SIZE, num_warps = calculate_settings(n_cols)
+
+        Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = X.device)
+        r = torch.empty(n_rows, dtype = torch.float32, device = X.device)
+
+        fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward
+        with torch.cuda.device(X.device):
+            fx[(n_rows,)](
+                Y, Y.stride(0),
+                X, X.stride(0),
+                W, W.stride(0),
+                r, r.stride(0),
+                n_cols, eps,
+                BLOCK_SIZE = BLOCK_SIZE,
+                num_warps  = num_warps,
+            )
+        ctx.eps = eps
+        ctx.BLOCK_SIZE = BLOCK_SIZE
+        ctx.num_warps  = num_warps
+        ctx.GEMMA = gemma
+        ctx.save_for_backward(X, W, r)
+        return Y.view(*shape)
+    pass
+pass
+
+
+# [TODO] Unsure why RMS Layernorm is not torch.compiling properly
+@torch.compiler.disable
+def fast_rms_layernorm(layernorm, X : torch.Tensor, gemma : bool = False):
+    W : torch.Tensor = layernorm.weight
+    eps : float = layernorm.variance_epsilon if \
+        hasattr(layernorm, "variance_epsilon") \
+        else layernorm.eps
+    out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)
+    return out
+pass
+

+ 9 - 0
aphrodite/modeling/layers/rotary_embedding.py

@@ -195,6 +195,15 @@ class RotaryEmbedding(CustomOp):
                                  self.cos_sin_cache, self.is_neox_style)
         return query, key
 
+    def forward_triton(
+        self,
+        positions: torch.Tensor,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        offsets: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        return self.forward_cuda(positions, query, key, offsets)
+
     def extra_repr(self) -> str:
         s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
         s += f", max_position_embeddings={self.max_position_embeddings}"

+ 43 - 0
tests/kernels/test_layernorm.py

@@ -60,3 +60,46 @@ def test_rms_norm(
     else:
         opcheck(torch.ops._C.rms_norm,
                 (out, x, layer.weight.data, layer.variance_epsilon))
+
+
+@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
+@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
+@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("seed", SEEDS)
+@pytest.mark.parametrize("device", CUDA_DEVICES)
+@torch.inference_mode()
+def test_rms_norm_triton(
+    num_tokens: int,
+    hidden_size: int,
+    add_residual: bool,
+    dtype: torch.dtype,
+    seed: int,
+    device: str,
+) -> None:
+    """
+    Test RMSNorm's Triton kernel by comparing its output to the native CUDA
+    implementation.
+    """
+    torch.random.manual_seed(seed)
+    if torch.cuda.is_available():
+        torch.cuda.manual_seed(seed)
+    
+    # Explicitly move the layer to the selected device.
+    layer = RMSNorm(hidden_size).to(device=device, dtype=dtype)
+    layer.weight.data.normal_(mean=1.0, std=0.1)
+    
+    scale = 1 / (2 * hidden_size)
+    x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) * scale
+    residual = torch.randn_like(x) * scale if add_residual else None
+
+    ref_out = layer.forward_native(x, residual)
+    triton_out = layer.forward_triton(x, residual)
+    
+    if add_residual:
+        torch.testing.assert_close(triton_out[0], ref_out[0],
+                                   atol=1e-2, rtol=1e-2)
+        torch.testing.assert_close(triton_out[1], ref_out[1],
+                                   atol=1e-2, rtol=1e-2)
+    else:
+        torch.testing.assert_close(triton_out, ref_out, atol=1e-2, rtol=1e-2)