Browse Source

Port mamba kernels to Aphrodite (#595)

* kernels

* fix interface

* clean up dockerfile
AlpinDale 7 months ago
parent
commit
f5d52320da

+ 2 - 0
CMakeLists.txt

@@ -165,6 +165,8 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
   FetchContent_MakeAvailable(cutlass)
 
   list(APPEND APHRODITE_EXT_SRC
+    "kernels/mamba/mamba_ssm/selective_scan_fwd.cu"
+    "kernels/mamba/causal_conv1d/causal_conv1d.cu"
     "kernels/quantization/aqlm/gemm_kernels.cu"
     "kernels/quantization/awq/gemm_kernels.cu"
     "kernels/quantization/quip/origin_order.cu"

+ 0 - 24
Dockerfile

@@ -40,10 +40,6 @@ RUN pip install packaging wheel
 RUN --mount=type=cache,target=/root/.cache/pip \
     python3 -m pip install -r requirements-cuda.txt
 
-COPY requirements-mamba.txt requirements-mamba.txt
-RUN python3 -m pip install packaging
-RUN python3 -m pip install -r requirements-mamba.txt
-
 # cuda arch list used by torch
 # can be useful for both `dev` and `test`
 # explicitly set the list to avoid issues with torch 2.2
@@ -101,22 +97,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
 
 #################### DEV IMAGE ####################
 
-#################### MAMBA Build IMAGE ####################
-FROM dev as mamba-builder
-# max jobs used for build
-ARG max_jobs=2
-ENV MAX_JOBS=${max_jobs}
-
-WORKDIR /usr/src/mamba
-
-COPY requirements-mamba.txt requirements-mamba.txt
-
-# Download the wheel or build it if a pre-compiled release doesn't exist
-RUN pip --verbose wheel -r requirements-mamba.txt \
-    --no-build-isolation --no-deps --no-cache-dir
-
-#################### MAMBA Build IMAGE ####################
-
 #################### Aphrodite installation IMAGE ####################
 # image with Aphrodite installed
 FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS aphrodite-base
@@ -137,10 +117,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/aphrodite-workspace
     --mount=type=cache,target=/root/.cache/pip \
     python3 -m pip install dist/*.whl --verbose
 
-RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
-    --mount=type=cache,target=/root/.cache/pip \
-    python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
-
 RUN --mount=type=cache,target=/root/.cache/pip \
     python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp310-cp310-linux_x86_64.whl
 #################### Aphrodite installation IMAGE ####################

+ 1 - 1
aphrodite/_custom_ops.py

@@ -374,7 +374,7 @@ def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
                       initial_states_: Optional[torch.Tensor],
                       final_states_out_: Optional[torch.Tensor],
                       silu_activation: bool) -> torch.Tensor:
-    return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_,
+    return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, None,
                                           initial_states_, final_states_out_,
                                           silu_activation)
 

+ 12 - 0
aphrodite/modeling/layers/mamba/__init__.py

@@ -0,0 +1,12 @@
+from aphrodite.modeling.layers.mamba.ops.causal_conv1d import (
+    causal_conv1d_fn, causal_conv1d_update)
+from aphrodite.modeling.layers.mamba.ops.mamba_ssm import (
+    selective_scan_fn, selective_state_update)
+
+
+__all__ = [
+    'causal_conv1d_fn',
+    'causal_conv1d_update',
+    'selective_scan_fn',
+    'selective_state_update',
+]

+ 13 - 6
aphrodite/modeling/layers/mamba/ops/causal_conv1d.py

@@ -3,7 +3,8 @@
 from typing import Optional
 
 import torch
-from causal_conv1d_cuda import causal_conv1d_fwd, causal_conv1d_update
+
+from aphrodite import _custom_ops as ops
 
 
 def causal_conv1d_fn(
@@ -58,12 +59,17 @@ def causal_conv1d_fn(
     else:
         final_states_out = None
 
-    out = causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
-                            final_states_out, activation in ["silu", "swish"])
+    out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
+                                final_states_out, activation
+                                in ["silu", "swish"])
     return (out, None) if not return_final_states else (out, final_states_out)
 
 
-def causal_conv1d_up(x, conv_state, weight, bias=None, activation=None):
+def causal_conv1d_update(x: torch.Tensor,
+                         conv_state: torch.Tensor,
+                         weight: torch.Tensor,
+                         bias: Optional[torch.Tensor] = None,
+                         activation: Optional[str] = None):
     """
     x: (batch, dim)
     conv_state: (batch, dim, width)
@@ -73,5 +79,6 @@ def causal_conv1d_up(x, conv_state, weight, bias=None, activation=None):
     """
     if activation not in [None, "silu", "swish"]:
         raise NotImplementedError("activation must be None, silu, or swish")
-    activation = activation in ["silu", "swish"]
-    return causal_conv1d_update(x, conv_state, weight, bias, activation)
+    activation_bool = activation in ["silu", "swish"]
+    return ops.causal_conv1d_update(x, conv_state, weight, bias,
+                                    activation_bool)

+ 4 - 2
aphrodite/modeling/layers/mamba/ops/mamba_ssm.py

@@ -1,3 +1,5 @@
+# Copyright (c) 2024, Tri Dao, Albert Gu.
+
 import torch
 import triton
 import triton.language as tl
@@ -331,14 +333,14 @@ def selective_scan_fn(u,
     ),
                     device=u.device,
                     dtype=torch.float32,
-                    requires_grad=u.requires_grad)
+                    requires_grad=False)
     x[:, :, 0, 0::2] = 1
     if prev_state is not None:
         x[:, :, 0, 1::2].copy_(prev_state)
     out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
                                            delta_softplus, position_indices, x)
     last_state = x[:, :, -1, 1::2]  # (batch, dim, dstate)
-    if z is not None:
+    if z is None:
         return out if not return_last_state else (out, last_state)
     else:
         out_z = rest[0]

+ 5 - 4
aphrodite/modeling/models/jamba.py

@@ -4,9 +4,6 @@ from dataclasses import dataclass
 from typing import Dict, Iterable, List, Optional, Tuple
 
 import torch
-from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
-from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
-from mamba_ssm.ops.triton.selective_state_update import selective_state_update
 from torch import nn
 from torch.nn.parameter import Parameter
 from transformers import JambaConfig
@@ -27,6 +24,10 @@ from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.mamba import (causal_conv1d_fn,
+                                             causal_conv1d_update,
+                                             selective_scan_fn,
+                                             selective_state_update)
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@@ -162,7 +163,7 @@ class JambaMambaMixer(nn.Module):
                     (self.conv_kernel_size - hidden_states.shape[-1], 0))
                 cache_params.conv_state.copy_(conv_states)
 
-            hidden_states = causal_conv1d_fn(
+            hidden_states, _ = causal_conv1d_fn(
                 hidden_states,
                 conv_weights,
                 self.conv1d.bias,

+ 1 - 1
kernels/mamba/causal_conv1d/causal_conv1d.cu

@@ -94,7 +94,7 @@ at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
                              const c10::optional<at::Tensor>& seq_idx_,
                              const c10::optional<at::Tensor>& seq_pos_idx_,
                              const c10::optional<at::Tensor>& initial_states_,
-                             c10::optional<at::Tensor>& final_states_out_,
+                             const c10::optional<at::Tensor>& final_states_out_,
                              bool silu_activation) {
   auto input_type = x.scalar_type();
   auto weight_type = weight.scalar_type();

+ 0 - 1
kernels/mamba/mamba_ssm/selective_scan_fwd.cu

@@ -410,7 +410,6 @@ void selective_scan_fwd_launch(SSMParamsBase& params, cudaStream_t stream) {
             constexpr int kSmemSize =
                 Ktraits::kSmemSize +
                 kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
-            // printf("smem_size = %d\n", kSmemSize);
             dim3 grid(params.batch, params.dim / kNRows);
             auto kernel = &selective_scan_fwd_kernel<Ktraits>;
             if (kSmemSize >= 48 * 1024) {

+ 14 - 0
kernels/ops.h

@@ -66,6 +66,20 @@ std::vector<torch::Tensor> selective_scan_fwd(
     const c10::optional<torch::Tensor>& index_,
     const c10::optional<torch::Tensor>& x);
 
+at::Tensor causal_conv1d_update(const at::Tensor& x,
+                                const at::Tensor& conv_state,
+                                const at::Tensor& weight,
+                                const c10::optional<at::Tensor>& bias_,
+                                bool silu_activation);
+
+at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
+                             const c10::optional<at::Tensor>& bias_,
+                             const c10::optional<at::Tensor>& seq_idx_,
+                             const c10::optional<at::Tensor>& seq_pos_idx_,
+                             const c10::optional<at::Tensor>& initial_states_,
+                             const c10::optional<at::Tensor>& final_states_out_,
+                             bool silu_activation);
+
 #ifndef USE_ROCM
 using fptr_t = int64_t;
 fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,

+ 23 - 4
kernels/torch_bindings.cpp

@@ -222,13 +222,32 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
            &dynamic_scaled_int8_quant);
 
+  // Mamba kernels
   ops.def(
       "selective_scan_fwd(Tensor! u, Tensor! delta,"
-      "                   Tensor! A, Tensor! B, Tensor C,"
-      "                   Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
-      "                   bool delta_softplus,"
-      "                   Tensor? index_, Tensor? x) -> Tensor[]");
+      "Tensor! A, Tensor! B, Tensor! C,"
+      "Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
+      "bool delta_softplus,"
+      "Tensor? index_, Tensor? x) -> Tensor[]");
   ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
+
+  ops.def(
+      "causal_conv1d_update(Tensor! x,"
+      "Tensor! conv_state,"
+      "Tensor! weight,"
+      "Tensor? bias_,"
+      "bool silu_activation) -> Tensor");
+  ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
+
+  ops.def(
+      "causal_conv1d_fwd(Tensor! x, Tensor! weight,"
+      "Tensor? bias_,"
+      "Tensor? seq_idx_,"
+      "Tensor? seq_pos_idx_,"
+      "Tensor? initial_states_,"
+      "Tensor? final_states_out_,"
+      "bool silu_activation) -> Tensor");
+  ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
 }
 
 TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

+ 0 - 2
requirements-mamba.txt

@@ -1,2 +0,0 @@
-causal-conv1d >= 1.2.1
-mamba-ssm >= 1.2.2