Kaynağa Gözat

kernel: add meta functions for ops to prevent graph breaks (#1019)

AlpinDale 2 ay önce
ebeveyn
işleme
a113309876

+ 200 - 0
aphrodite/_custom_ops.py

@@ -202,6 +202,21 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                                   b_g_idx, use_exllama, bit)
 
 
+# TODO: has to be a better way to do this
+try:
+    torch.ops._C.gptq_gemm  # noqa B018
+    @torch.library.register_fake("_C::gptq_gemm")
+    def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+                        b_gptq_qzeros: torch.Tensor,
+                        b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor,
+                        use_exllama: bool, bit: int) -> torch.Tensor:
+        return torch.empty((a.size(0), b_q_weight.size(1)),
+                           dtype=a.dtype,
+                           device=a.device)
+except Exception:
+    pass
+
+
 def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
                  bit: int) -> None:
     torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
@@ -231,6 +246,191 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                                             size_n, size_k)
 
 
+# TODO: has to be a better way to do this
+try:
+    torch.ops._C.gptq_marlin_24_gemm  # noqa B018
+    @torch.library.register_fake("_C::gptq_marlin_24_gemm")
+    def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+                                  b_meta: torch.Tensor, b_scales: torch.Tensor,
+                                  workspace: torch.Tensor,
+                                  b_q_type: ScalarType, size_m: int,
+                                  size_n: int, size_k: int) -> torch.Tensor:
+        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
+
+    @torch.library.register_fake("_C::gptq_marlin_gemm")
+    def _gptq_marlin_gemm_fake(a: torch.Tensor,
+                               b_q_weight: torch.Tensor,
+                               b_scales: torch.Tensor,
+                               b_zeros: torch.Tensor,
+                               g_idx: torch.Tensor,
+                               perm: torch.Tensor,
+                               workspace: torch.Tensor,
+                               b_q_type: ScalarType,
+                               size_m: int,
+                               size_n: int,
+                               size_k: int,
+                               is_k_full: bool,
+                               has_zp: bool = False,
+                               use_fp32_reduce: bool = False) -> torch.Tensor:
+        return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
+
+    @torch.library.register_fake("_C::ggml_dequantize")
+    def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
+                              n: int) -> torch.Tensor:
+        return torch.empty((m, n), dtype=torch.float16, device=W.device)
+
+    @torch.library.register_fake("_C::ggml_mul_mat_vec_a8")
+    def _ggml_mul_mat_vec_a8_fake(
+        W: torch.Tensor,
+        X: torch.Tensor,
+        quant_type: int,
+        row: int,
+    ) -> torch.Tensor:
+        return torch.empty((1, row), dtype=torch.float16, device=W.device)
+
+    @torch.library.register_fake("_C::ggml_mul_mat_a8")
+    def _ggml_mul_mat_a8_fake(
+        W: torch.Tensor,
+        X: torch.Tensor,
+        quant_type: int,
+        row: int,
+    ) -> torch.Tensor:
+        batch = X.size(0)
+        return torch.empty((batch, row), dtype=torch.float16, device=W.device)
+
+    @torch.library.register_fake("_C::marlin_qqq_gemm")
+    def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+                              s_tok: torch.Tensor, s_ch: torch.Tensor,
+                              s_group: torch.Tensor, workspace: torch.Tensor,
+                              size_m: int, size_n: int,
+                              size_k: int) -> torch.Tensor:
+        return torch.empty((size_m, size_n),
+                           dtype=torch.float16,
+                           device=a.device)
+
+    @torch.library.register_fake("_C::marlin_gemm")
+    def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+                          b_scales: torch.Tensor, workspace: torch.Tensor,
+                          size_m: int, size_n: int,
+                          size_k: int) -> torch.Tensor:
+        return torch.empty((size_m, size_n),
+                           dtype=torch.float16,
+                           device=a.device)
+
+    @torch.library.register_fake("_C::awq_dequantize")
+    def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
+                             zeros: torch.Tensor, split_k_iters: int, thx: int,
+                             thy: int) -> torch.Tensor:
+        in_c = qweight.size(0)
+        qout_c = qweight.size(1)
+        out_c = qout_c * 8
+        return torch.empty((in_c, out_c),
+                           dtype=scales.dtype,
+                           device=scales.device)
+
+    @torch.library.register_fake("_C::awq_gemm")
+    def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
+                       qzeros: torch.Tensor, scales: torch.Tensor,
+                       split_k_iters: int) -> torch.Tensor:
+        num_in_feats = input.size(0)
+        return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
+                           dtype=input.dtype,
+                           device=input.device).sum(0)
+
+    @torch.library.register_fake("_C::aqlm_gemm")
+    def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor,
+                        codebooks: torch.Tensor, scales: torch.Tensor,
+                        codebook_partition_sizes: List[int],
+                        bias: Optional[torch.Tensor]) -> torch.Tensor:
+        out_features = codes.size(0) * codebooks.size(2)
+        flat_input = input.reshape((-1, input.size(-1)))
+        flat_output = torch.empty((flat_input.size(0), out_features),
+                                  dtype=input.dtype,
+                                  device=input.device)
+        output_sizes = list(input.shape)
+        output_sizes.pop()
+        output_sizes.append(-1)
+        return flat_output.reshape(tuple(output_sizes))
+
+    @torch.library.register_fake("_C::aqlm_dequant")
+    def _aqlm_dequant_fake(
+            codes: torch.Tensor, codebooks: torch.Tensor,
+            codebook_partition_sizes: List[int]) -> torch.Tensor:
+        in_features = codes.size(1) * 8
+        out_features = codes.size(0)
+        return torch.empty((out_features, in_features),
+                           dtype=codebooks.dtype,
+                           device=codebooks.device)
+
+    @torch.library.register_fake("_C::fp8_marlin_gemm")
+    def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
+                              b_scales: torch.Tensor, workspace: torch.Tensor,
+                              num_bits: int, size_m: int, size_n: int,
+                              size_k: int) -> torch.Tensor:
+        return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
+
+    @torch.library.register_fake("_C::machete_gemm")
+    def machete_gemm_fake(
+        a: torch.Tensor,
+        b_q: torch.
+        Tensor,  # Should be the tensor returned by machete_prepack_B
+        b_type: ScalarType,
+        b_scales: Optional[torch.Tensor] = None,
+        b_zeros: Optional[torch.Tensor] = None,
+        b_group_size: Optional[int] = None,
+        c: Optional[torch.Tensor] = None,
+        alpha: Optional[float] = None,
+        beta: Optional[float] = None,
+        schedule: Optional[str] = None,
+    ) -> torch.Tensor:
+        m = a.size(0)
+        n = b_q.size(1)
+        return torch.empty((m, n), device=a.device, dtype=a.dtype)
+
+    @torch.library.register_fake("_C::machete_prepack_B")
+    def machete_prepack_B_fake(b_q_weight: torch.Tensor,
+                               b_type: ScalarType) -> torch.Tensor:
+        return torch.empty_like(b_q_weight)
+
+    @torch.library.register_fake("_C::causal_conv1d_fwd")
+    def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
+                               bias_: Optional[torch.Tensor],
+                               seq_idx_: Optional[torch.Tensor],
+                               initial_states_: Optional[torch.Tensor],
+                               final_states_out_: Optional[torch.Tensor],
+                               silu_activation: bool) -> torch.Tensor:
+        return torch.empty_like(x)
+
+    @torch.library.register_fake("_C::causal_conv1d_update")
+    def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
+                                  weight: torch.Tensor,
+                                  bias_: Optional[torch.Tensor],
+                                  silu_activation: bool) -> torch.Tensor:
+        return torch.empty_like(x)
+
+    @torch.library.register_fake("_C::selective_scan_fwd")
+    def selective_scan_fwd_fake(
+            u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor,
+            B: torch.Tensor, C: torch.Tensor, D_: Optional[torch.Tensor],
+            z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor],
+            delta_softplus: bool, index_: Optional[torch.Tensor],
+            x: Optional[torch.Tensor]) -> List[torch.Tensor]:
+        a = torch.empty_like(u)
+        if x is not None:
+            b = x
+        else:
+            b = torch.empty((u.size(0), u.size(1), A.size(1)),
+                            dtype=u.dtype,
+                            device=u.device)
+        if z_ is not None:
+            c = torch.empty_like(z_)
+            return [a, b, c]
+        else:
+            return [a, b]
+except Exception:
+    pass
+
+
 # fp8 marlin
 def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
                     b_scales: torch.Tensor, workspace: torch.Tensor,

+ 7 - 1
aphrodite/common/envs.py

@@ -58,6 +58,7 @@ if TYPE_CHECKING:
     APHRODITE_RPC_GET_DATA_TIMEOUT_MS: int = 5000
     APHRODITE_FORCE_SINGLE_USER_PREFIX_CACHE: bool = False
     APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE: int = 0
+    APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE: int = 0
     APHRODITE_USE_TRITON_AWQ: bool = False
     APHRODITE_DYNAMO_USE_CUSTOM_DISPATCHER: bool = False
 
@@ -202,6 +203,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
     (os.environ.get("APHRODITE_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
      ("true", "1")),
 
+    # Internal flag to enable Dynamo fullgraph capture
+    "APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE":
+    lambda: bool(
+        os.environ.get("APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
+
     # local rank of the process in the distributed setting, used to determine
     # the GPU device id
     "LOCAL_RANK":
@@ -261,7 +267,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
     "APHRODITE_ATTENTION_BACKEND":
     lambda: os.getenv("APHRODITE_ATTENTION_BACKEND", None),
 
-    # If set, aphrodite will use flashinfer sampler
+    # If set, aphrodite will use custom sampling kernels
     "APHRODITE_USE_SAMPLING_KERNELS":
     lambda: bool(int(os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0"))),
 

+ 1 - 1
aphrodite/modeling/models/mamba_cache.py

@@ -120,7 +120,7 @@ class MambaCacheManager:
                                   indices_for_current_run: List[int]):
         # move out all of the occupied but currently not running blocks
         # outside of the first n blocks
-        destination_indices = set([range(batch_size)])
+        destination_indices = range(batch_size)
         max_possible_batch_size = self.mamba_cache[0].shape[1]
         for destination_index in destination_indices:
             if destination_index in self._get_all_occupied_indices() and  \

+ 8 - 3
aphrodite/worker/model_runner.py

@@ -78,6 +78,10 @@ _NUM_WARMUP_ITERS = 2
 TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU")
 
 
+# For now, bump up cache limits for recompilations during CUDA graph warmups.
+torch._dynamo.config.cache_size_limit = 128
+torch._dynamo.config.accumulated_cache_size_limit = 128
+
 @dataclass(frozen=True)
 class ModelInputForGPU(ModelRunnerInputBase):
     """
@@ -1068,9 +1072,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         if envs.APHRODITE_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo:
             logger.info("Compiling the model using torch.compile...")
             start_time = time.time()
-            self.model = torch.compile(self.model,
-                                       fullgraph=True,
-                                       backend="eager")
+            self.model = torch.compile(
+                self.model,
+                fullgraph=envs.APHRODITE_TEST_DYNAMO_FULLGRAPH_CAPTURE,
+                backend="eager")
             end_time = time.time()
             logger.info(
                 f"Model compiled in {end_time - start_time:.2f} seconds.")

+ 3 - 3
aphrodite/worker/worker.py

@@ -28,11 +28,10 @@ from aphrodite.platforms import current_platform
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.worker.cache_engine import CacheEngine
 from aphrodite.worker.embedding_model_runner import EmbeddingModelRunner
-from aphrodite.worker.enc_dec_model_runner import (
-    EncoderDecoderModelRunner)
+from aphrodite.worker.enc_dec_model_runner import EncoderDecoderModelRunner
 from aphrodite.worker.model_runner import GPUModelRunnerBase, ModelRunner
 from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
-                                                WorkerInput)
+                                          WorkerInput)
 
 
 class Worker(LocalOrDistributedWorkerBase):
@@ -141,6 +140,7 @@ class Worker(LocalOrDistributedWorkerBase):
             torch.cuda.set_device(self.device)
 
             _check_if_gpu_supports_dtype(self.model_config.dtype)
+            gc.collect()
             torch.cuda.empty_cache()
             self.init_gpu_memory = torch.cuda.mem_get_info()[0]
         else:

+ 4 - 4
kernels/cpu/torch_bindings.cpp

@@ -30,8 +30,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   // PagedAttention V2.
   ops.def(
       "paged_attention_v2("
-      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
-      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
+      "    Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
+      "    Tensor! tmp_out, Tensor query, Tensor key_cache,"
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
       "    int max_seq_len, Tensor? alibi_slopes,"
@@ -119,8 +119,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
 
   // Copy the cache blocks from src to dst.
   cache_ops.def(
-      "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
-      "block_mapping) -> ()");
+      "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
+      "Tensor block_mapping) -> ()");
   cache_ops.impl("copy_blocks", torch::kCPU, &copy_blocks);
 
   // Reshape the key and value tensors and cache them.

+ 13 - 1
kernels/quantization/gptq_marlin/awq_marlin_repack.cu

@@ -266,4 +266,16 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
   return out;
 }
 
-#endif
+#endif
+
+torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
+                                     c10::SymInt size_k, c10::SymInt size_n,
+                                     int64_t num_bits) {
+  int const pack_factor = 32 / num_bits;
+  auto options = torch::TensorOptions()
+                     .dtype(b_q_weight.dtype())
+                     .device(b_q_weight.device());
+  return torch::empty_symint(
+      {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
+      options);
+}

+ 13 - 1
kernels/quantization/gptq_marlin/gptq_marlin_repack.cu

@@ -341,4 +341,16 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
   return out;
 }
 
-#endif
+#endif
+
+torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
+                                      torch::Tensor& perm, c10::SymInt size_k,
+                                      c10::SymInt size_n, int64_t num_bits) {
+  int const pack_factor = 32 / num_bits;
+  auto options = torch::TensorOptions()
+                     .dtype(b_q_weight.dtype())
+                     .device(b_q_weight.device());
+  return torch::empty_symint(
+      {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
+      options);
+}

+ 8 - 0
kernels/quantization/quant_ops.h

@@ -84,9 +84,17 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
                                  int64_t size_k, int64_t size_n,
                                  int64_t num_bits);
 
+torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
+                                      torch::Tensor& perm, c10::SymInt size_k,
+                                      c10::SymInt size_n, int64_t num_bits);
+
 torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
                                 int64_t size_n, int64_t num_bits);
 
+torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
+                                     c10::SymInt size_k, c10::SymInt size_n,
+                                     int64_t num_bits);
+
 torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
                               torch::Tensor& b_scales, torch::Tensor& workspace,
                               int64_t num_bits, int64_t size_m, int64_t size_n,

+ 147 - 53
kernels/torch_bindings.cpp

@@ -37,8 +37,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   // PagedAttention V2.
   ops.def(
       "paged_attention_v2("
-      "    Tensor! out, Tensor exp_sums, Tensor max_logits,"
-      "    Tensor tmp_out, Tensor query, Tensor key_cache,"
+      "    Tensor! out, Tensor! exp_sums, Tensor! max_logits,"
+      "    Tensor! tmp_out, Tensor query, Tensor key_cache,"
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
       "    int max_seq_len, Tensor? alibi_slopes,"
@@ -74,7 +74,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   ops.impl("gelu_quick", torch::kCUDA, &gelu_quick);
 
   // prepare_inputs advance_step
-  ops.def("advance_step", &advance_step);
+  ops.def(
+      "advance_step(int num_seqs, int num_queries, int block_size, "
+      "Tensor! input_tokens, Tensor sampled_token_ids, "
+      "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
+      "Tensor block_tables) -> ()");
   ops.impl("advance_step", torch::kCUDA, &advance_step);
 
   // Layernorm
@@ -111,60 +115,108 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   // Quantization ops
 #ifndef USE_ROCM
   // Quantized GEMM for AQLM.
-  ops.def("aqlm_gemm", &aqlm_gemm);
+  ops.def(
+      "aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, "
+      "Tensor scales, int[] codebook_partition_sizes, Tensor? bias) "
+      "-> Tensor");
   ops.impl("aqlm_gemm", torch::kCUDA, &aqlm_gemm);
 
   // Decompression method for AQLM.
-  ops.def("aqlm_dequant", &aqlm_dequant);
+  ops.def(
+      "aqlm_dequant(Tensor codes, Tensor codebooks, "
+      "int[] codebook_partition_sizes) -> Tensor");
   ops.impl("aqlm_dequant", torch::kCUDA, &aqlm_dequant);
 
   // Quantized GEMM for AWQ.
-  ops.def("awq_gemm", &awq_gemm);
+  ops.def(
+      "awq_gemm(Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, "
+      "Tensor _zeros, int split_k_iters) -> Tensor");
   ops.impl("awq_gemm", torch::kCUDA, &awq_gemm);
 
   // Dequantization for AWQ.
-  ops.def("awq_dequantize", &awq_dequantize);
+  ops.def(
+      "awq_dequantize(Tensor _kernel, Tensor _scaling_factors, "
+      "Tensor _zeros, int split_k_iters, int thx, int thy) -> Tensor");
   ops.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
 
   // Dequantization for GGML.
-  ops.def("ggml_dequantize", &ggml_dequantize);
+  ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
   ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
 
   // mmvq kernel for GGML.
-  ops.def("ggml_mul_mat_vec_a8", &ggml_mul_mat_vec_a8);
+  ops.def(
+      "ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, int row) "
+      "-> Tensor");
   ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);
 
   // mmq kernel for GGML.
-  ops.def("ggml_mul_mat_a8", &ggml_mul_mat_a8);
+  ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, int row) -> Tensor");
   ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);
 
+  // Note about marlin kernel 'workspace' arguments:
+  // Technically these should be mutable since they are modified by the kernel.
+  // But since they are set back to zero once the kernel is finished we can
+  // hand wave and say that they have no net effect.
+  //
+  // The reason to mark 'workspace' as immutable is so that they don't interfere
+  // with using ScalarType arguments in the ops. If they are marked as mutable,
+  // pytorch throws an assert in
+  // 'torch._higher_order_ops._register_effectful_op' that prevents these
+  // kernels from being torch.compile'd.
+  // See the following document for more info on custom types and ops that use
+  // custom types:
+  // https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
   // Marlin (Dense) Optimized Quantized GEMM for GPTQ.
-  ops.def("marlin_gemm", &marlin_gemm);
+  ops.def(
+      "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
+      "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
   ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
 
   // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
-  ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
+  ops.def(
+      "gptq_marlin_24_gemm(Tensor a, Tensor b_q_weight, Tensor b_meta, "
+      "Tensor b_scales, Tensor workspace, "
+      "__torch__.torch.classes._core_C.ScalarType b_q_type, "
+      "int size_m, int size_n, int size_k) -> Tensor");
   ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
 
   // gptq_marlin Optimized Quantized GEMM for GPTQ.
-  ops.def("gptq_marlin_gemm", &gptq_marlin_gemm);
+  ops.def(
+      "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
+      "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, "
+      "__torch__.torch.classes._core_C.ScalarType b_q_type, "
+      "int size_m, int size_n, int size_k, bool is_k_full, "
+      "bool has_zp, bool use_fp32_reduce, bool is_zp_float) -> Tensor");
   ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
 
   // gptq_marlin repack from GPTQ.
-  ops.def("gptq_marlin_repack", &gptq_marlin_repack);
+  ops.def(
+      "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
+      "SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
   ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
+  ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
 
   // awq_marlin repack from AWQ.
-  ops.def("awq_marlin_repack", &awq_marlin_repack);
+  ops.def(
+      "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
+      "SymInt size_n, int num_bits) -> Tensor");
   ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
+  ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
 
   // fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
-  ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
+  ops.def(
+      "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
+      "Tensor! workspace, int num_bits, int size_m, int size_n, "
+      "int size_k) -> Tensor");
   ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
 
   #ifndef _WIN32
   // marlin_qqq_gemm for QQQ.
-  ops.def("marlin_qqq_gemm", &marlin_qqq_gemm);
+  ops.def(
+      "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, "
+      "Tensor s_tok, Tensor s_ch, Tensor s_group, "
+      "Tensor! workspace, int size_m, int size_n, "
+      "int size_k) -> Tensor");
   ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
 
   // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
@@ -177,9 +229,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
 
   // Check if cutlass scaled_mm is supported for CUDA devices of the given
   // capability
-  ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
-  ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA,
-           &cutlass_scaled_mm_supports_fp8);
+  ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool");
+  ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8);
 
   // CUTLASS w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
   // quantization.
@@ -211,11 +262,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   #endif
 
   // QuIP# GEMV
-  ops.def("quip_gemv", &e8p_mm_origorder);
+  ops.def("quip_gemv(Tensor A, Tensor B, Tensor CB) -> Tensor",
+          &e8p_mm_origorder);
   ops.impl("quip_gemv", torch::kCUDA, &e8p_mm_origorder);
 
   // QuIP# Decompress
-  ops.def("quip_decompress", &decompress_e8p_origorder);
+  ops.def("quip_decompress(Tensor YIs, Tensor CB, Tensor Y) -> ()",
+          &decompress_e8p_origorder);
   ops.impl("quip_decompress", torch::kCUDA, &decompress_e8p_origorder);
 
   // fp6_llm
@@ -227,31 +280,73 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
            &fp_eXmY_linear_forward_cuda);
 
   // Sampling Kernels
-  ops.def("sampling_from_probs", &sampling_from_probs);
+  ops.def(
+      "sampling_from_probs(Tensor probs, Tensor uniform_samples, bool "
+      "deterministic) -> Tensor",
+      &sampling_from_probs);
   ops.impl("sampling_from_probs", torch::kCUDA, &sampling_from_probs);
-  ops.def("top_k_sampling_from_probs", &top_k_sampling_from_probs);
+
+  ops.def(
+      "top_k_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
+      "                          Tensor? maybe_top_k_arr, int top_k_val,"
+      "                          bool deterministic) -> Tensor[]",
+      &top_k_sampling_from_probs);
   ops.impl("top_k_sampling_from_probs", torch::kCUDA,
            &top_k_sampling_from_probs);
-  ops.def("min_p_sampling_from_probs", &min_p_sampling_from_probs);
+
+  ops.def(
+      "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
+      "                          Tensor? maybe_min_p_arr, float min_p_val,"
+      "                          bool deterministic) -> Tensor[]",
+      &min_p_sampling_from_probs);
   ops.impl("min_p_sampling_from_probs", torch::kCUDA,
            &min_p_sampling_from_probs);
-  ops.def("top_p_sampling_from_probs", &top_p_sampling_from_probs);
+
+  ops.def(
+      "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
+      "                          Tensor? maybe_top_p_arr, float top_p_val,"
+      "                          bool deterministic) -> Tensor[]",
+      &top_p_sampling_from_probs);
   ops.impl("top_p_sampling_from_probs", torch::kCUDA,
            &top_p_sampling_from_probs);
-  ops.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs);
+
+  ops.def(
+      "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples,"
+      "                          Tensor? maybe_top_k_arr, float top_k_val,"
+      "                          Tensor? maybe_top_p_arr, float top_p_val,"
+      "                          bool deterministic) -> Tensor[]",
+      &top_k_top_p_sampling_from_probs);
   ops.impl("top_k_top_p_sampling_from_probs", torch::kCUDA,
            &top_k_top_p_sampling_from_probs);
-  ops.def("top_k_renorm_prob", &top_k_renorm_prob);
+
+  ops.def(
+      "top_k_renorm_prob(Tensor probs, Tensor? maybe_top_k_arr, int top_k_val) "
+      "-> Tensor",
+      &top_k_renorm_prob);
   ops.impl("top_k_renorm_prob", torch::kCUDA, &top_k_renorm_prob);
-  ops.def("top_p_renorm_prob", &top_p_renorm_prob);
+
+  ops.def(
+      "top_p_renorm_prob(Tensor probs, Tensor? maybe_top_p_arr, float "
+      "top_p_val) "
+      "-> Tensor",
+      &top_p_renorm_prob);
   ops.impl("top_p_renorm_prob", torch::kCUDA, &top_p_renorm_prob);
-  ops.def("top_k_mask_logits", &top_k_mask_logits);
+
+  ops.def(
+      "top_k_mask_logits(Tensor logits, Tensor? maybe_top_k_arr, int "
+      "top_k_val) -> Tensor",
+      &top_k_mask_logits);
   ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
 
 #endif
 
   // Quantized GEMM for GPTQ.
-  ops.def("gptq_gemm", &gptq_gemm);
+  // Note: even though the C++ inferred schema is correct for this op, it seems
+  // to prevent the meta function registry.
+  ops.def(
+      "gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
+      "Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
+      "-> Tensor");
   ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
 
   // Post processing for GPTQ.
@@ -277,8 +372,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
 
   // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
   ops.def(
-      "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
-      "scale, Tensor? scale_ub) -> "
+      "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
+      "Tensor! scale, Tensor? scale_ub) -> "
       "()");
   ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
            &dynamic_per_token_scaled_fp8_quant);
@@ -321,7 +416,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "Tensor! A, Tensor! B, Tensor! C,"
       "Tensor? D_, Tensor? z_, Tensor? delta_bias_,"
       "bool delta_softplus,"
-      "Tensor? index_, Tensor? x) -> Tensor[]");
+      "Tensor? index_, Tensor(a! -> *)? x) -> Tensor(a)[]");
   ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
 
   ops.def(
@@ -353,8 +448,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
 
   // Copy the cache blocks from src to dst.
   cache_ops.def(
-      "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor "
-      "block_mapping) -> ()");
+      "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, "
+      "Tensor block_mapping) -> ()");
   cache_ops.impl("copy_blocks", torch::kCUDA, &copy_blocks);
 
   // Reshape the key and value tensors and cache them.
@@ -379,8 +474,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
 
   // Convert the key and value cache to fp8 data type.
   cache_ops.def(
-      "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, str "
-      "kv_cache_dtype) -> ()");
+      "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
+      "str kv_cache_dtype) -> ()");
   cache_ops.impl("convert_fp8", torch::kCUDA, &convert_fp8);
 }
 
@@ -388,24 +483,27 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
   // Cuda utils
 
   // Gets the specified device attribute.
-  cuda_utils.def("get_device_attribute", &get_device_attribute);
-  cuda_utils.impl("get_device_attribute", torch::kCUDA, &get_device_attribute);
+  cuda_utils.def("get_device_attribute(int attribute, int device_id) -> int");
+  cuda_utils.impl("get_device_attribute", &get_device_attribute);
 
   // Gets the maximum shared memory per block device attribute.
-  cuda_utils.def("get_max_shared_memory_per_block_device_attribute",
-                 &get_max_shared_memory_per_block_device_attribute);
+  cuda_utils.def(
+      "get_max_shared_memory_per_block_device_attribute(int device_id) -> int");
   cuda_utils.impl("get_max_shared_memory_per_block_device_attribute",
-                  torch::kCUDA,
                   &get_max_shared_memory_per_block_device_attribute);
 }
 
 #ifndef USE_ROCM
 TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
   // Custom all-reduce kernels
-  custom_ar.def("init_custom_ar", &init_custom_ar);
+  custom_ar.def(
+      "init_custom_ar(Tensor meta, Tensor rank_data, "
+      "str[] handles, int[] offsets, int rank, "
+      "bool full_nvlink) -> int");
   custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
-
-  custom_ar.def("should_custom_ar", &should_custom_ar);
+  custom_ar.def(
+      "should_custom_ar(Tensor inp, int max_size, int world_size, "
+      "bool full_nvlink) -> bool");
   custom_ar.impl("should_custom_ar", torch::kCUDA, &should_custom_ar);
 
   custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
@@ -417,21 +515,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
   custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
 
   custom_ar.def("dispose", &dispose);
-  custom_ar.impl("dispose", torch::kCPU, &dispose);
 
   custom_ar.def("meta_size", &meta_size);
-  custom_ar.impl("meta_size", torch::kCPU, &meta_size);
 
-  custom_ar.def("register_buffer", &register_buffer);
+  custom_ar.def(
+      "register_buffer(int fa, Tensor t, str[] handles, "
+      "int[] offsets) -> ()");
   custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
 
   custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
-  custom_ar.impl("get_graph_buffer_ipc_meta", torch::kCPU,
-                 &get_graph_buffer_ipc_meta);
 
   custom_ar.def("register_graph_buffers", &register_graph_buffers);
-  custom_ar.impl("register_graph_buffers", torch::kCPU,
-                 &register_graph_buffers);
 }
 #endif
 

+ 18 - 3
tests/kernels/test_activation.py

@@ -4,7 +4,9 @@ import pytest
 import torch
 
 from aphrodite.modeling.layers.activation import (FastGELU, GeluAndMul,
-                                                  NewGELU, SiluAndMul)
+                                                  NewGELU, QuickGELU,
+                                                  SiluAndMul)
+from tests.kernels.utils import opcheck
 
 from .allclose_default import get_default_atol, get_default_rtol
 
@@ -39,10 +41,13 @@ def test_act_and_mul(
     x = torch.randn(num_tokens, 2 * d, dtype=dtype)
     if activation == "silu":
         layer = SiluAndMul()
+        fn = torch.ops._C.silu_and_mul
     elif activation == "gelu":
         layer = GeluAndMul(approximate="none")
+        fn = torch.ops._C.gelu_and_mul
     elif activation == "gelu_tanh":
         layer = GeluAndMul(approximate="tanh")
+        fn = torch.ops._C.gelu_tanh_and_mul
     out = layer(x)
     ref_out = layer.forward_native(x)
     # The SiLU and GELU implementations are equivalent to the native PyTorch
@@ -50,7 +55,14 @@ def test_act_and_mul(
     torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
 
 
-@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
+    d = x.shape[-1] // 2
+    output_shape = (x.shape[:-1] + (d, ))
+    out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
+    opcheck(fn, (out, x))
+
+@pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
+                                        (NewGELU, torch.ops._C.gelu_new),
+                                        (QuickGELU, torch.ops._C.gelu_quick)])
 @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 @pytest.mark.parametrize("d", D)
 @pytest.mark.parametrize("dtype", DTYPES)
@@ -70,10 +82,13 @@ def test_activation(
         torch.cuda.manual_seed(seed)
     torch.set_default_device(device)
     x = torch.randn(num_tokens, d, dtype=dtype)
-    layer = activation()
+    layer = activation[0]()
+    fn = activation[1]
     out = layer(x)
     ref_out = layer.forward_native(x)
     torch.testing.assert_close(out,
                                ref_out,
                                atol=get_default_atol(out),
                                rtol=get_default_rtol(out))
+    out = torch.empty_like(x)
+    opcheck(fn, (out, x))

+ 14 - 0
tests/kernels/test_attention.py

@@ -8,6 +8,7 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
 
 from aphrodite import _custom_ops as ops
 from aphrodite.common.utils import get_max_shared_memory_bytes, is_hip
+from tests.kernels.utils import opcheck
 
 from .allclose_default import get_default_atol, get_default_rtol
 
@@ -198,6 +199,12 @@ def test_paged_attention(
             k_scale,
             v_scale,
         )
+        opcheck(torch.ops._C.paged_attention_v1,
+                (output, query, key_cache, value_cache, num_kv_heads, scale,
+                 block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
+                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
+                cond=(head_size == HEAD_SIZES[0]))
+
     elif version == "v2":
         num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
         assert PARTITION_SIZE % block_size == 0
@@ -230,6 +237,13 @@ def test_paged_attention(
             k_scale,
             v_scale,
         )
+        opcheck(torch.ops._C.paged_attention_v2,
+                (output, exp_sums, max_logits, tmp_output, query, key_cache,
+                 value_cache, num_kv_heads, scale, block_tables, seq_lens,
+                 block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
+                 k_scale, v_scale, 0, 0, 0, 64, 0),
+                cond=(head_size == HEAD_SIZES[0]))
+
     else:
         raise AssertionError(f"Unknown version: {version}")
 

+ 20 - 1
tests/kernels/test_cache.py

@@ -5,6 +5,7 @@ import pytest
 import torch
 
 from aphrodite import _custom_ops as ops
+from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
 
 COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')]
 DTYPES = [torch.half, torch.bfloat16, torch.float]
@@ -87,6 +88,10 @@ def test_copy_blocks(
     block_mapping_tensor = torch.tensor(block_mapping,
                                         dtype=torch.int64,
                                         device=device).view(-1, 2)
+    opcheck(torch.ops._C_cache_ops.copy_blocks,
+            (key_caches, value_caches, block_mapping_tensor),
+            test_utils=DEFAULT_OPCHECK_TEST_UTILS,
+            cond=(head_size == HEAD_SIZES[0]))
     ops.copy_blocks(key_caches, value_caches, block_mapping_tensor)
 
     # Run the reference implementation.
@@ -162,6 +167,10 @@ def test_reshape_and_cache(
     k_scale = v_scale = 1.0
 
     # Call the reshape_and_cache kernel.
+    opcheck(torch.ops._C_cache_ops.reshape_and_cache,
+            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
+             k_scale, v_scale),
+            cond=(head_size == HEAD_SIZES[0]))
     ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
                           kv_cache_dtype, k_scale, v_scale)
 
@@ -269,6 +278,10 @@ def test_reshape_and_cache_flash(
     k_scale = v_scale = 1.0
 
     # Call the reshape_and_cache kernel.
+    opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
+            (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
+             k_scale, v_scale),
+            cond=(head_size == HEAD_SIZES[0]))
     ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
                                 slot_mapping, kv_cache_dtype, k_scale, v_scale)
 
@@ -364,8 +377,14 @@ def test_swap_blocks(
 
     src_key_caches_clone = src_key_caches[0].clone()
     src_value_caches_clone = src_value_caches[0].clone()
-
     # Call the swap_blocks kernel.
+    do_opcheck = (head_size == HEAD_SIZES[0])
+    opcheck(torch.ops._C_cache_ops.swap_blocks,
+            (src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
+            cond=do_opcheck)
+    opcheck(torch.ops._C_cache_ops.swap_blocks,
+            (src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
+            cond=do_opcheck)
     ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
                     block_mapping_tensor)
     ops.swap_blocks(src_value_caches[0], dist_value_caches[0],

+ 11 - 0
tests/kernels/test_cutlass.py

@@ -9,6 +9,7 @@ import torch
 
 from aphrodite import _custom_ops as ops
 from aphrodite.platforms import current_platform
+from tests.kernels.utils import opcheck
 
 CUDA_DEVICES = [
     f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
@@ -108,6 +109,8 @@ def cutlass_int8_gemm_helper(m: int,
 
     torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
 
+    opcheck(torch.ops._C.cutlass_scaled_mm,
+            (out, a, b, scale_a, scale_b, bias))
 
 @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
 @pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
@@ -341,6 +344,14 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
     torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
     torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
 
+    if azp_per_token:
+        opcheck(torch.ops._C.cutlass_scaled_mm_azp,
+                (out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
+                 func_bias))
+    else:
+        opcheck(torch.ops._C.cutlass_scaled_mm_azp,
+                (out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
+                 func_bias))
 
 # Test working with a subset of A and B
 def test_cutlass_subset():

+ 14 - 0
tests/kernels/test_int8_quant.py

@@ -3,6 +3,7 @@ import torch
 
 from aphrodite._custom_ops import scaled_int8_quant
 from tests.kernels.quant_utils import ref_dynamic_per_token_quant
+from tests.kernels.utils import opcheck
 
 DTYPES = [torch.half, torch.bfloat16, torch.float]
 HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
@@ -12,6 +13,16 @@ SEEDS = [0]
 SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
 
 
+def opcheck_int8_quant(output, input, scale=None):
+    if scale is not None:
+        opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale))
+    else:
+        scale = torch.empty((input.numel() // input.shape[-1], 1),
+                            device=input.device,
+                            dtype=torch.float32)
+        opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale))
+
+
 @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
 @pytest.mark.parametrize("dtype", DTYPES)
@@ -34,6 +45,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
         ops_out, ref_out, atol=1,
         rtol=0.0)  # big atol to account for rounding errors
 
+    opcheck_int8_quant(ops_out, x)
 
 @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
 @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@@ -58,3 +70,5 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
     torch.testing.assert_close(
         out1, out2, atol=1,
         rtol=0.0)  # big atol to account for rounding errors
+
+    opcheck_int8_quant(out2, x, scale)

+ 8 - 0
tests/kernels/test_layernorm.py

@@ -2,6 +2,7 @@ import pytest
 import torch
 
 from aphrodite.modeling.layers.layernorm import RMSNorm
+from tests.kernels.utils import opcheck
 
 DTYPES = [torch.half, torch.bfloat16, torch.float]
 NUM_TOKENS = [7, 83, 4096]  # Arbitrary values for testing
@@ -52,3 +53,10 @@ def test_rms_norm(
         torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
     else:
         torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
+
+    if residual is not None:
+        opcheck(torch.ops._C.fused_add_rms_norm,
+                (x, residual, layer.weight.data, layer.variance_epsilon))
+    else:
+        opcheck(torch.ops._C.rms_norm,
+                (out, x, layer.weight.data, layer.variance_epsilon))

+ 5 - 0
tests/kernels/test_machete_gemm.py

@@ -13,6 +13,7 @@ from aphrodite.platforms import current_platform
 from aphrodite.quantization.utils.quant_utils import (pack_rows,
                                                       quantize_weights)
 from aphrodite.scalar_type import ScalarType, scalar_types
+from tests.kernels.utils import opcheck
 
 CUDA_DEVICES = [
     f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
@@ -77,6 +78,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
     w_q = w_q.t().contiguous().t()  # convert to col major
     w_q_machete = ops.machete_prepack_B(w_q, wtype)
 
+    opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
     return w_ref, w_q_machete, w_s, w_zp
 
 
@@ -148,6 +150,9 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
             schedule=schedule,
         )
 
+        opcheck(torch.ops._C.machete_gemm,
+                (a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
+                    w_zp, w_s), group_size, None, None, None, schedule))
         # Relax atol as our reduction dim becomes larger (more rounding error)
         # Relax atol when we have zeropoints since the way machete applies
         #  zeropoints (after scales) causes noise around 0

+ 27 - 26
tests/kernels/test_marlin_gemm.py

@@ -29,6 +29,7 @@ from aphrodite.quantization.utils.quant_utils import (awq_pack, gptq_pack,
                                                       gptq_quantize_weights,
                                                       quantize_weights,
                                                       sort_weights)
+from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
 from tests.quantization.utils import is_quant_method_supported
 
 ACT_ORDER_OPTS = [False, True]
@@ -75,12 +76,9 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
                             act_order, mnk_factors):
     m_factor, n_factor, k_factor = mnk_factors
 
-    size_m = m_factor
     size_k = k_chunk * k_factor
     size_n = n_chunk * n_factor
 
-    print(f"MNK = {size_m} {size_n} {size_k}")
-
     # Filter act_order
     if act_order:
         if group_size == -1:
@@ -114,6 +112,8 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
     marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                   weight_perm)
 
+    opcheck(torch.ops._C.gptq_marlin_repack,
+            (q_w_gptq, sort_indices, size_k, size_n, quant_type.size_bits))
     # Run Marlin repack GPU kernel
     marlin_q_w_2 = ops.gptq_marlin_repack(
         q_w_gptq,
@@ -139,12 +139,9 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
                            mnk_factors):
     m_factor, n_factor, k_factor = mnk_factors
 
-    size_m = m_factor
     size_k = k_chunk * k_factor
     size_n = n_chunk * n_factor
 
-    print(f"MNK = {size_m} {size_n} {size_k}")
-
     # Normalize group_size
     if group_size == -1:
         group_size = size_k
@@ -167,6 +164,8 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
     marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, quant_type.size_bits,
                                   weight_perm)
 
+    opcheck(torch.ops._C.awq_marlin_repack,
+            (q_w_awq, size_k, size_n, quant_type.size_bits))
     # Run Marlin repack GPU kernel
     marlin_q_w_2 = ops.awq_marlin_repack(
         q_w_awq,
@@ -206,9 +205,6 @@ def test_gptq_marlin_gemm(
     size_k = k_chunk * k_factor
     size_n = n_chunk * n_factor
 
-    print(f"MNK = {size_m} {size_n} {size_k}")
-    print(f"groupsize = {group_size}")
-
     if act_order:
         if group_size == -1:
             return
@@ -226,6 +222,13 @@ def test_gptq_marlin_gemm(
     workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                                 GPTQ_MARLIN_MAX_PARALLEL)
 
+    opcheck(
+        torch.ops._C.gptq_marlin_gemm,
+        (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
+         workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
+         a_input.shape[1], is_k_full, False, use_fp32_reduce),
+        test_utils=DEFAULT_OPCHECK_TEST_UTILS)
+
     output = ops.gptq_marlin_gemm(
         a_input,
         marlin_q_w,
@@ -247,7 +250,6 @@ def test_gptq_marlin_gemm(
     torch.cuda.synchronize()
 
     max_diff = compute_max_diff(output, output_ref)
-    print("max_diff = {}".format(max_diff))
 
     assert max_diff < 0.04
 
@@ -267,9 +269,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
     size_k = k_chunk * k_factor
     size_n = n_chunk * n_factor
 
-    print(f"MNK = {size_m} {size_n} {size_k}")
-    print(f"groupsize = {group_size}")
-
     a_input = rand_data((size_m, size_k))
     b_weight = rand_data((size_k, size_n))
 
@@ -281,6 +280,12 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
 
     output_ref = torch.matmul(a_input, w_24_ref)
 
+    opcheck(torch.ops._C.gptq_marlin_24_gemm,
+            (a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
+             workspace_24.scratch, quant_type, a_input.shape[0],
+             b_weight.shape[1], a_input.shape[1]),
+            test_utils=DEFAULT_OPCHECK_TEST_UTILS)
+
     output = ops.gptq_marlin_24_gemm(
         a_input,
         marlin_24_q_w_comp,
@@ -296,7 +301,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
     torch.cuda.synchronize()
 
     max_diff = compute_max_diff(output, output_ref)
-    print("max_diff = {}".format(max_diff))
 
     assert max_diff < 0.04
 
@@ -323,9 +327,6 @@ def test_fp8_marlin_gemm(
     size_k = k_chunk * k_factor
     size_n = n_chunk * n_factor
 
-    print(f"MNK = {size_m} {size_n} {size_k}")
-    print(f"groupsize = {group_size}")
-
     a_input = rand_data((size_m, size_k), dtype=dtype)
     b_weight = rand_data((size_k, size_n), dtype=dtype)
 
@@ -355,6 +356,10 @@ def test_fp8_marlin_gemm(
     workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
                                 GPTQ_MARLIN_MAX_PARALLEL)
 
+    opcheck(torch.ops._C.fp8_marlin_gemm,
+            (a_input, marlin_qweight, marlin_scales, workspace.scratch,
+             num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1]))
+
     output = ops.fp8_marlin_gemm(
         a=a_input,
         b_q_weight=marlin_qweight,
@@ -370,7 +375,6 @@ def test_fp8_marlin_gemm(
     torch.cuda.synchronize()
 
     max_diff = compute_max_diff(output, output_ref)
-    print("max_diff = {}".format(max_diff))
 
     assert max_diff < 0.04
 
@@ -398,9 +402,6 @@ def test_awq_marlin_gemm(
     size_k = k_chunk * k_factor
     size_n = n_chunk * n_factor
 
-    print(f"MNK = {size_m} {size_n} {size_k}")
-    print(f"groupsize = {group_size}")
-
     a_input = rand_data((size_m, size_k))
     b_weight = rand_data((size_k, size_n))
 
@@ -436,7 +437,6 @@ def test_awq_marlin_gemm(
     torch.cuda.synchronize()
 
     max_diff = compute_max_diff(output, output_ref)
-    print("max_diff = {}".format(max_diff))
 
     assert max_diff < 0.04
 
@@ -462,9 +462,6 @@ def test_marlin_qqq_gemm(
     size_k = k_chunk * k_factor
     size_n = n_chunk * n_factor
 
-    print(f"MNK = {size_m} {size_n} {size_k}")
-    print(f"groupsize = {group_size}")
-
     a_input = rand_data((size_m, size_k))
     b_weight = rand_data((size_k, size_n))
 
@@ -481,6 +478,11 @@ def test_marlin_qqq_gemm(
     workspace = MarlinWorkspace(size_n, MARLIN_QQQ_MIN_THREAD_N,
                                 MARLIN_QQQ_MAX_PARALLEL)
 
+    opcheck(torch.ops._C.marlin_qqq_gemm,
+            (q_a, marlin_qqq_q_w, s_a, marlin_qqq_s_channel,
+             marlin_qqq_s_group, workspace.scratch, a_input.shape[0],
+             b_weight.shape[1], a_input.shape[1]))
+
     output = ops.marlin_qqq_gemm(
         q_a,
         marlin_qqq_q_w,
@@ -497,6 +499,5 @@ def test_marlin_qqq_gemm(
     torch.cuda.synchronize()
 
     max_diff = compute_max_diff(output, output_ref)
-    print("max_diff = {}".format(max_diff))
 
     assert max_diff < 0.04

+ 32 - 1
tests/kernels/utils.py

@@ -3,7 +3,8 @@
 import itertools
 import random
 from numbers import Number
-from typing import Any, List, NamedTuple, Optional, Tuple, Union
+from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
+                    Union)
 
 import pytest
 import torch
@@ -14,6 +15,20 @@ from aphrodite.attention.backends.xformers import XFormersBackend
 from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
                                     make_tensor_with_pad)
 
+# For now, disable "test_aot_dispatch_dynamic" since there are some
+# bugs related to this test in PyTorch 2.4.
+DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
+    "test_schema",
+    "test_autograd_registration",
+    "test_faketensor",
+)
+ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
+    "test_schema",
+    "test_autograd_registration",
+    "test_faketensor",
+    "test_aot_dispatch_dynamic",
+)
+
 
 class QKVInputs(NamedTuple):
     '''
@@ -927,3 +942,19 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
     ideal_output = test_params.packed_qkvo.ideal_output
     torch.testing.assert_close(ideal_output,
                                output_under_test.view_as(ideal_output))
+
+
+def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
+                      torch._library.custom_ops.CustomOpDef],
+            args: Tuple[Any, ...],
+            kwargs: Optional[Dict[str, Any]] = None,
+            *,
+            test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
+            raise_exception: bool = True,
+            cond: bool = True) -> Dict[str, str]:
+    return torch.library.opcheck(
+        op,
+        args,
+        kwargs,
+        test_utils=test_utils,
+        raise_exception=raise_exception) if cond else {}