Browse Source

fix: missing cache_config for dbrx

AlpinDale 7 months ago
parent
commit
b2cb5a92e9

+ 2 - 1
aphrodite/modeling/models/dbrx.py

@@ -246,11 +246,12 @@ class DbrxFusedNormAttention(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.d_model = config.d_model
-        self.attn = DbrxAttention(config, quant_config)
+        self.attn = DbrxAttention(config, cache_config, quant_config)
         self.norm_1 = nn.LayerNorm(self.d_model)
         self.norm_2 = nn.LayerNorm(self.d_model)
 

+ 2 - 1
aphrodite/quantization/compressed_tensors/schemes/utils.py

@@ -4,6 +4,7 @@ import torch
 
 from aphrodite._quant_C import quant_ops
 
+
 # cutlass
 def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
                          scale_a: torch.Tensor, scale_b: torch.Tensor,
@@ -44,4 +45,4 @@ def scaled_int8_quant(
                                device=input.device,
                                dtype=torch.float32)
     quant_ops.dynamic_scaled_int8_quant(output, input, input_scales)
-    return output, input_scales
+    return output, input_scales

+ 25 - 53
kernels/cpu/pybind.cpp

@@ -7,66 +7,38 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
   pybind11::module ops = m.def_submodule("ops", "Aphrodite custom operators");
 
   // Attention ops
-  ops.def(
-    "paged_attention_v1",
-    &paged_attention_v1,
-    "Compute the attention between an input query and the cached keys/values using PagedAttention.");
-  ops.def(
-    "paged_attention_v2",
-    &paged_attention_v2,
-    "PagedAttention V2.");
+  ops.def("paged_attention_v1", &paged_attention_v1,
+          "Compute the attention between an input query and the cached "
+          "keys/values using PagedAttention.");
+  ops.def("paged_attention_v2", &paged_attention_v2, "PagedAttention V2.");
 
   // Activation ops
-  ops.def(
-    "silu_and_mul",
-    &silu_and_mul,
-    "Activation function used in SwiGLU.");
-  ops.def(
-    "gelu_and_mul",
-    &gelu_and_mul,
-    "Activation function used in GeGLU with `none` approximation.");
-  ops.def(
-    "gelu_tanh_and_mul",
-    &gelu_tanh_and_mul,
-    "Activation function used in GeGLU with `tanh` approximation.");
-  ops.def(
-    "gelu_new",
-    &gelu_new,
-    "GELU implementation used in GPT-2.");
-  ops.def(
-    "gelu_fast",
-    &gelu_fast,
-    "Approximate GELU implementation.");
+  ops.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU.");
+  ops.def("gelu_and_mul", &gelu_and_mul,
+          "Activation function used in GeGLU with `none` approximation.");
+  ops.def("gelu_tanh_and_mul", &gelu_tanh_and_mul,
+          "Activation function used in GeGLU with `tanh` approximation.");
+  ops.def("gelu_new", &gelu_new, "GELU implementation used in GPT-2.");
+  ops.def("gelu_fast", &gelu_fast, "Approximate GELU implementation.");
 
   // Layernorm
-  ops.def(
-    "rms_norm",
-    &rms_norm,
-    "Apply Root Mean Square (RMS) Normalization to the input tensor.");
+  ops.def("rms_norm", &rms_norm,
+          "Apply Root Mean Square (RMS) Normalization to the input tensor.");
 
-  ops.def(
-    "fused_add_rms_norm",
-    &fused_add_rms_norm,
-    "In-place fused Add and RMS Normalization");
+  ops.def("fused_add_rms_norm", &fused_add_rms_norm,
+          "In-place fused Add and RMS Normalization");
 
   // Rotary embedding
-  ops.def(
-    "rotary_embedding",
-    &rotary_embedding,
-    "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
+  ops.def("rotary_embedding", &rotary_embedding,
+          "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
 
   // Cache ops
-  pybind11::module cache_ops = m.def_submodule("cache_ops", "Aphrodite cache ops");
-  cache_ops.def(
-    "swap_blocks",
-    &swap_blocks,
-    "Swap in (out) the cache blocks from src to dst");
-  cache_ops.def(
-    "copy_blocks",
-    &copy_blocks,
-    "Copy the cache blocks from src to dst");
-  cache_ops.def(
-    "reshape_and_cache",
-    &reshape_and_cache,
-    "Reshape the key and value tensors and cache them");
+  pybind11::module cache_ops =
+      m.def_submodule("cache_ops", "Aphrodite cache ops");
+  cache_ops.def("swap_blocks", &swap_blocks,
+                "Swap in (out) the cache blocks from src to dst");
+  cache_ops.def("copy_blocks", &copy_blocks,
+                "Copy the cache blocks from src to dst");
+  cache_ops.def("reshape_and_cache", &reshape_and_cache,
+                "Reshape the key and value tensors and cache them");
 }