浏览代码

feat: add cuda sampling kernels for top_k and top_p

AlpinDale 4 月之前
父节点
当前提交
22422d962b

+ 2 - 1
CMakeLists.txt

@@ -218,7 +218,8 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
     "kernels/quantization/gguf/gguf_kernel.cu"
     "kernels/quantization/gguf/gguf_kernel.cu"
     "kernels/quantization/gptq_marlin/awq_marlin_repack.cu"
     "kernels/quantization/gptq_marlin/awq_marlin_repack.cu"
     "kernels/quantization/fp8/fp8_marlin.cu"
     "kernels/quantization/fp8/fp8_marlin.cu"
-    "kernels/all_reduce/custom_all_reduce.cu")
+    "kernels/all_reduce/custom_all_reduce.cu"
+    "kernels/sampling/sampling.cu")
 
 
   # Add CUTLASS and GPTQ Marlin kernels if not MSVC
   # Add CUTLASS and GPTQ Marlin kernels if not MSVC
   if(NOT MSVC)
   if(NOT MSVC)

+ 117 - 1
aphrodite/_custom_ops.py

@@ -1,6 +1,6 @@
 import contextlib
 import contextlib
 import functools
 import functools
-from typing import List, Optional, Tuple, Type
+from typing import List, Optional, Tuple, Type, Union
 
 
 import torch
 import torch
 from loguru import logger
 from loguru import logger
@@ -632,6 +632,122 @@ def register_graph_buffers(fa: int, handles: List[str],
     torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
     torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
 
 
 
 
+# Sampling Kernels
+def sampling_from_probs(probs: torch.Tensor,
+                        uniform_samplers: torch.Tensor,
+                        deterministic: bool = True,
+                        check_nan: bool = False) -> torch.Tensor:
+    if check_nan and torch.any(torch.isnan(probs)):
+        raise ValueError("NaN detected in probs")
+    return torch.ops._C.sampling_from_probs(probs, uniform_samplers,
+                                            deterministic)
+
+def _to_tensor_scalar_tuple(x):
+    if isinstance(x, torch.Tensor):
+        return (x, 0)
+    else:
+        return (None, x)
+def top_p_sampling_from_probs(
+        probs: torch.Tensor,
+        uniform_samples: torch.Tensor,
+        top_p: Union[torch.Tensor, float],
+        deterministic: bool = True,
+        check_nan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
+    if check_nan and torch.any(torch.isnan(probs)):
+        raise ValueError("NaN detected in probs")
+    return torch.ops._C.top_p_sampling_from_probs(
+        probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic)
+
+def top_k_sampling_from_probs(
+        probs: torch.Tensor,
+        uniform_samples: torch.Tensor,
+        top_k: Union[torch.Tensor, int],
+        deterministic: bool = True,
+        check_nan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
+    if check_nan and torch.any(torch.isnan(probs)):
+        raise ValueError("NaN detected in probs")
+    return torch.ops._C.top_k_sampling_from_probs(
+        probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), deterministic)
+
+def min_p_sampling_from_probs(
+        probs: torch.Tensor,
+        uniform_samples: torch.Tensor,
+        min_p: Union[torch.Tensor, float],
+        deterministic: bool = True,
+        check_nan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
+    if check_nan and torch.any(torch.isnan(probs)):
+        raise ValueError("NaN detected in probs")
+    return torch.ops._C.min_p_sampling_from_probs(
+        probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic)
+
+def top_k_mask_logits(
+    logits: torch.Tensor,
+    top_k: Union[torch.Tensor, int],
+) -> torch.Tensor:
+    return torch.ops._C.top_k_mask_logits(logits,
+                                          *_to_tensor_scalar_tuple(top_k))
+
+def top_p_renorm_prob(
+    probs: torch.Tensor,
+    top_p: Union[torch.Tensor, float],
+) -> torch.Tensor:
+    return torch.ops._C.top_p_renorm_prob(probs,
+                                          *_to_tensor_scalar_tuple(top_p))
+
+def top_k_renorm_prob(
+    probs: torch.Tensor,
+    top_k: Union[torch.Tensor, int],
+) -> torch.Tensor:
+    return torch.ops._C.top_k_renorm_prob(probs,
+                                          *_to_tensor_scalar_tuple(top_k))
+
+def top_k_top_p_sampling_from_logits(
+    probs: torch.Tensor,
+    uniform_samples: torch.Tensor,
+    top_k: Union[torch.Tensor, int],
+    top_p: Union[torch.Tensor, float],
+    filter_apply_order: str = "top_k_first",
+    deterministic: bool = True,
+    check_nan: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    if filter_apply_order == "top_k_first":
+        masked_logits = top_k_mask_logits(probs, top_k)
+        probs = torch.softmax(masked_logits, dim=-1)
+        return top_p_sampling_from_probs(probs, uniform_samples, top_p,
+                                         deterministic, check_nan)
+    elif filter_apply_order == "joint":
+        probs = torch.softmax(probs, dim=-1)
+        if check_nan and torch.any(torch.isnan(probs)):
+            raise ValueError("NaN detected in probs")
+        return torch.ops._C.top_k_top_p_sampling_from_logits(
+            probs, uniform_samples, *_to_tensor_scalar_tuple(top_k),
+            *_to_tensor_scalar_tuple(top_p), deterministic)
+    else:
+        raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
+
+def top_k_top_p_sampling_from_probs(
+    probs: torch.Tensor,
+    uniform_samples: torch.Tensor,
+    top_k: Union[torch.Tensor, int],
+    top_p: Union[torch.Tensor, float],
+    filter_apply_order: str = "top_k_first",
+    deterministic: bool = True,
+    check_nan: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    if filter_apply_order == "top_k_first":
+        renorm_probs = top_k_renorm_prob(probs, top_k)
+        return top_p_sampling_from_probs(renorm_probs, uniform_samples, top_p,
+                                         deterministic, check_nan)
+    elif filter_apply_order == "joint":
+        if check_nan and torch.any(torch.isnan(probs)):
+            raise ValueError("NaN detected in probs")
+        return torch.ops._C.top_k_top_p_sampling_from_probs(
+            probs, uniform_samples, *_to_tensor_scalar_tuple(top_k),
+            *_to_tensor_scalar_tuple(top_p), deterministic)
+    else:
+        raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
+
+
 # TODO: remove this later
 # TODO: remove this later
 names_and_values = globals()
 names_and_values = globals()
 names_and_values_to_update = {}
 names_and_values_to_update = {}

+ 76 - 22
aphrodite/modeling/layers/sampler.py

@@ -1,11 +1,14 @@
 """A layer that samples the next tokens from the model's outputs."""
 """A layer that samples the next tokens from the model's outputs."""
 import itertools
 import itertools
+import os
+import warnings
 from math import inf
 from math import inf
 from typing import Dict, List, Optional, Tuple
 from typing import Dict, List, Optional, Tuple
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
+import aphrodite._custom_ops as ops
 from aphrodite.common.sampling_params import SamplingType
 from aphrodite.common.sampling_params import SamplingType
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
                                        PromptLogprobs, SampleLogprobs,
                                        PromptLogprobs, SampleLogprobs,
@@ -27,6 +30,11 @@ SampleResultType = List[Tuple[List[int], List[int]]]
 # that this temperature well-uses the fp16 space after the logits are offset.
 # that this temperature well-uses the fp16 space after the logits are offset.
 _TEMPERATURE_MINIMUM = 2e-5
 _TEMPERATURE_MINIMUM = 2e-5
 
 
+# If enabled, we switch to a more performant implementation
+# of top-k and top-p
+APHRODITE_USE_SAMPLING_KERNELS = bool(int(
+    os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0")))
+
 
 
 class Sampler(nn.Module):
 class Sampler(nn.Module):
     """Samples the next tokens from the model's outputs.
     """Samples the next tokens from the model's outputs.
@@ -155,7 +163,7 @@ class Sampler(nn.Module):
         if do_nsigmas:
         if do_nsigmas:
             logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas)
             logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas)
 
 
-        if do_top_p_top_k:
+        if do_top_p_top_k and not APHRODITE_USE_SAMPLING_KERNELS:
             logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
             logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
                                         sampling_tensors.top_ks)
                                         sampling_tensors.top_ks)
 
 
@@ -816,14 +824,7 @@ def _multinomial(
     seq_groups: Optional[List[SequenceGroupToSample]] = None,
     seq_groups: Optional[List[SequenceGroupToSample]] = None,
 ) -> torch.Tensor:
 ) -> torch.Tensor:
     if num_samples > 1:
     if num_samples > 1:
-        # This is equivalent to torch.repeat_interleaved (which also
-        # forces a GPU<->CPU sync).
-        # This allows us to do sampling with replacement by creating
-        # num_samples copies of each row in the tensor, and then
-        # batch sampling the resulting tensor.
-        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
-                                         probs.shape[1]).contiguous().view(
-                                             -1, probs.shape[1])
+        probs = probs.repeat_interleave(num_samples, dim=0)
     q = torch.empty_like(probs)
     q = torch.empty_like(probs)
     if seq_groups is None:
     if seq_groups is None:
         q.exponential_()
         q.exponential_()
@@ -831,17 +832,57 @@ def _multinomial(
         sample_idx = 0
         sample_idx = 0
         for seq_group in seq_groups:
         for seq_group in seq_groups:
             seq_ids = seq_group.seq_ids
             seq_ids = seq_group.seq_ids
-            next_sample_idx = sample_idx + len(seq_ids) * num_samples
-            q[sample_idx:next_sample_idx].exponential_(
-                generator=seq_group.generator)
-            sample_idx = next_sample_idx
+            stride = len(seq_ids) * num_samples
+            assert seq_group.generator is not None
+            q[sample_idx:sample_idx +
+              stride].exponential_(generator=seq_group.generator)
+            sample_idx += stride
     return probs.div_(q).argmax(dim=1).view(-1, num_samples)
     return probs.div_(q).argmax(dim=1).view(-1, num_samples)
 
 
 
 
+def _top_k_top_p_multinomial_with_kernels(
+        probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
+        num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
+    max_top_k_round = 32
+    if num_samples > 1:
+        probs = probs.repeat_interleave(num_samples, dim=0)
+        top_ks = top_ks.repeat_interleave(num_samples)
+        top_ps = top_ps.repeat_interleave(num_samples)
+    batch_size = probs.shape[0]
+    uniform_samples = torch.empty((max_top_k_round, batch_size),
+                                  device=probs.device)
+    if seq_groups is None:
+        uniform_samples.uniform_()
+    else:
+        sample_idx = 0
+        for seq_group in seq_groups:
+            seq_ids = seq_group.seq_ids
+            stride = len(seq_ids) * num_samples
+            assert seq_group.generator is not None
+            uniform_samples[:, sample_idx:sample_idx +
+                            stride].uniform_(generator=seq_group.generator)
+            sample_idx += stride
+    batch_next_token_ids, success = ops.top_k_top_p_sampling_from_probs(
+        probs,
+        uniform_samples,
+        top_ks,
+        top_ps,
+    )
+    if not success.all():
+        warnings.warn("CUDA rejection sampling failed, fallback.",
+                      stacklevel=1)
+        probs = ops.top_k_renorm_prob(probs, top_ks)
+        probs = ops.top_p_renorm_prob(probs, top_ps)
+        batch_next_token_ids = ops.sampling_from_probs(
+            probs, uniform_samples[0])
+    return batch_next_token_ids.view(-1, num_samples)
+
+
 def _sample_with_torch(
 def _sample_with_torch(
     probs: torch.Tensor,
     probs: torch.Tensor,
     logprobs: torch.Tensor,
     logprobs: torch.Tensor,
     sampling_metadata: SamplingMetadata,
     sampling_metadata: SamplingMetadata,
+    sampling_tensors: SamplingTensors,
     include_gpu_probs_tensor: bool,
     include_gpu_probs_tensor: bool,
     modify_greedy_probs: bool,
     modify_greedy_probs: bool,
 ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
 ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
@@ -897,17 +938,29 @@ def _sample_with_torch(
                     sampling_params = seq_group.sampling_params
                     sampling_params = seq_group.sampling_params
                     max_best_of_in_batch = max(max_best_of_in_batch,
                     max_best_of_in_batch = max(max_best_of_in_batch,
                                                sampling_params.best_of)
                                                sampling_params.best_of)
-            seeded_args = {} if sampling_type == SamplingType.RANDOM else {
-                "seq_groups": seq_groups,
-            }
 
 
-            multinomial_samples[sampling_type] = _multinomial(
-                probs[long_sample_indices], max_best_of_in_batch,
-                **seeded_args)
-            if include_gpu_probs_tensor:
+            seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
+                              seq_groups)
+            if APHRODITE_USE_SAMPLING_KERNELS is not None:
+                multinomial_samples[
+                    sampling_type] = _top_k_top_p_multinomial_with_kernels(
+                        probs[long_sample_indices],
+                        sampling_tensors.top_ks[long_sample_indices],
+                        sampling_tensors.top_ps[long_sample_indices],
+                        max_best_of_in_batch,
+                        seq_groups_arg,
+                    )
+            else:
+                multinomial_samples[sampling_type] = _multinomial(
+                    probs[long_sample_indices],
+                    max_best_of_in_batch,
+                    seq_groups=seq_groups_arg)
+
+            if sampled_token_ids_tensor is not None:
                 # Store sampled tokens in output tensor.
                 # Store sampled tokens in output tensor.
-                sampled_token_ids_tensor[
-                    long_sample_indices] = multinomial_samples[sampling_type]
+                sampled_token_ids_tensor[long_sample_indices] = \
+                    multinomial_samples[sampling_type].to(torch.long)
+
         elif sampling_type == SamplingType.BEAM:
         elif sampling_type == SamplingType.BEAM:
             beam_search_logprobs = logprobs[sample_indices]
             beam_search_logprobs = logprobs[sample_indices]
         else:
         else:
@@ -1035,6 +1088,7 @@ def _sample(
         probs,
         probs,
         logprobs,
         logprobs,
         sampling_metadata,
         sampling_metadata,
+        sampling_tensors,
         include_gpu_probs_tensor=include_gpu_probs_tensor,
         include_gpu_probs_tensor=include_gpu_probs_tensor,
         modify_greedy_probs=modify_greedy_probs,
         modify_greedy_probs=modify_greedy_probs,
     )
     )

+ 32 - 0
kernels/ops.h

@@ -102,4 +102,36 @@ at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
                              const c10::optional<at::Tensor>& initial_states_,
                              const c10::optional<at::Tensor>& initial_states_,
                              const c10::optional<at::Tensor>& final_states_out_,
                              const c10::optional<at::Tensor>& final_states_out_,
                              bool silu_activation);
                              bool silu_activation);
+
+// Sampling kernels
+torch::Tensor sampling_from_probs(torch::Tensor probs,
+                                  torch::Tensor uniform_samples,
+                                  bool deterministic);
+std::vector<torch::Tensor> top_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
+    bool deterministic);
+std::vector<torch::Tensor> top_k_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
+    bool deterministic);
+std::vector<torch::Tensor> min_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_min_p_arr, double min_p_val,
+    bool deterministic);
+std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
+    std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
+    bool deterministic);
+torch::Tensor top_p_renorm_prob(torch::Tensor probs,
+                                std::optional<torch::Tensor> maybe_top_p_arr,
+                                double top_p_val);
+torch::Tensor top_k_renorm_prob(torch::Tensor probs,
+                                std::optional<torch::Tensor> maybe_top_k_arr,
+                                int64_t top_k_val);
+torch::Tensor top_k_mask_logits(torch::Tensor logits,
+                                std::optional<torch::Tensor> maybe_top_k_arr,
+                                int64_t top_k_val);
+
 #endif
 #endif

+ 159 - 0
kernels/sampling/math.cuh

@@ -0,0 +1,159 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2023 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef APHRODITE_MATH_CUH_
+#define APHRODITE_MATH_CUH_
+
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+namespace aphrodite {
+namespace math {
+
+// log2(e)
+constexpr float log2e = 1.44269504088896340736f;
+
+__forceinline__ __device__ half2 uint32_as_half2(uint32_t x) {
+  return *(half2*)&x;
+}
+
+__forceinline__ __device__ uint32_t half2_as_uint32(half2 x) {
+  return *(uint32_t*)&x;
+}
+
+/*!
+ * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x
+ * \param x input
+ */
+__forceinline__ __device__ float ptx_exp2(float x) {
+  float y;
+  asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x)
+ * \param x input
+ */
+__forceinline__ __device__ float ptx_log2(float x) {
+  float y;
+  asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX ex2.approx.f16x2 instruction, which computes 2^x
+ * \param x input
+ */
+__forceinline__ __device__ half2 ptx_exp2(half2 x) {
+  uint32_t y_u32;
+  uint32_t x_u32 = half2_as_uint32(x);
+  asm volatile("ex2.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
+  return uint32_as_half2(y_u32);
+}
+
+/*!
+ * \brief Wrapper of PTX ex2.approx.f16 instruction, which computes 2^x
+ * \param x input
+ */
+__forceinline__ __device__ half ptx_exp2(half x) {
+  ushort y_u16;
+  asm volatile("ex2.approx.f16 %0, %1;"
+               : "=h"(y_u16)
+               : "h"(__half_as_ushort(x)));
+  return __ushort_as_half(y_u16);
+}
+
+/*!
+ * \brief Wrapper of PTX rcp.approx instruction, which computes 1/x
+ * \param x input
+ */
+__forceinline__ __device__ float ptx_rcp(float x) {
+  float y;
+  asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly
+ * shuffle between threads in a warp. \param x The value in the source lane
+ * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^
+ * delta]
+ */
+__forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) {
+  float y;
+  asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;"
+               : "=f"(y)
+               : "f"(x), "r"(lane_mask));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX shfl.sync.bfly instruction on half2, which performs a
+ * butterfly shuffle between threads in a warp. \param x The value in the source
+ * lane \param lane_mask The mask to perform thread index xor with: y[i] <- x[i
+ * ^ lane_mask]
+ */
+__forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) {
+  return __shfl_xor_sync(0xffffffff, x, lane_mask);
+}
+
+/*!
+ * \brief Wrapper of PTX rsqrt approximation instruction, which computes
+ * 1/sqrt(x) \param x input
+ */
+__forceinline__ __device__ float rsqrt(float x) {
+  float y;
+  asm volatile("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x)
+ * \param x input
+ */
+__forceinline__ __device__ float tanh(float x) {
+  float y;
+  asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x)
+ * \param x input
+ */
+__forceinline__ __device__ half2 tanh(half2 x) {
+  uint32_t y_u32;
+  uint32_t x_u32 = half2_as_uint32(x);
+  asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
+  return uint32_as_half2(y_u32);
+}
+
+/*!
+ * \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x)
+ * \param x input
+ */
+__forceinline__ __device__ half tanh(half x) {
+  ushort y_u16;
+  asm volatile("tanh.approx.f16 %0, %1;"
+               : "=h"(y_u16)
+               : "h"(__half_as_ushort(x)));
+  return __ushort_as_half(y_u16);
+}
+
+}  // namespace math
+}  // namespace aphrodite
+#endif  // APHRODITE_MATH_CUH_

+ 391 - 0
kernels/sampling/sampling.cu

@@ -0,0 +1,391 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2024 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <c10/cuda/CUDAStream.h>
+
+#include "sampling.cuh"
+#include "../ops.h"
+#include "utils.cuh"
+
+// Check utils
+#define CUDA_CHECK(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
+
+#define CHECK_CONTIGUOUS(x) \
+  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+
+#define CHECK_INPUT(x) \
+  CUDA_CHECK(x);       \
+  CHECK_CONTIGUOUS(x)
+
+#define CHECK_EQ(a, b) \
+  TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
+
+#define CHECK_GE(a, b) \
+  TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
+
+#define CHECK_DIM(d, x) \
+  TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
+
+using namespace aphrodite;
+
+torch::Tensor sampling_from_probs(torch::Tensor probs,
+                                  torch::Tensor uniform_samples,
+                                  bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);            // probs: (batch_size, vocab_size)
+  CHECK_DIM(1, uniform_samples);  // uniform_samples: (batch_size)
+  CHECK_EQ(probs.size(0), uniform_samples.size(0));
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+
+  cudaError_t status = sampling::SamplingFromProb(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      static_cast<int*>(samples.data_ptr()), batch_size, vocab_size,
+      deterministic, torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "SamplingFromProbs failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+  return samples;
+}
+
+std::vector<torch::Tensor> top_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
+    bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);  // probs: (batch_size, vocab_size)
+  CHECK_DIM(
+      2, uniform_samples);  // uniform_samples: (max_top_p_rounds, batch_size)
+  CHECK_EQ(probs.size(0), uniform_samples.size(1));
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  unsigned int max_top_p_rounds = uniform_samples.size(0);
+  bool has_top_p_arr = maybe_top_p_arr.has_value();
+  auto top_p_arr = maybe_top_p_arr.value_or(
+      torch::empty({0}, torch::dtype(torch::kFloat32)));
+  if (has_top_p_arr) {
+    CHECK_INPUT(top_p_arr);
+    CHECK_DIM(1, top_p_arr);  // top_p_arr: (batch_size,)
+    CHECK_EQ(top_p_arr.size(0), batch_size);
+    CHECK_EQ(top_p_arr.device(), device);
+  }
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+  top_p_arr = top_p_arr.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+  auto success =
+      torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
+
+  cudaError_t status = sampling::TopPSamplingFromProb<float, int>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      static_cast<int*>(samples.data_ptr()),
+      static_cast<bool*>(success.data_ptr()),
+      has_top_p_arr ? static_cast<float*>(top_p_arr.data_ptr()) : nullptr,
+      batch_size, top_p_val, vocab_size, max_top_p_rounds, deterministic,
+      torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "TopPSamplingFromProbs failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+
+  return {samples, success};
+}
+
+std::vector<torch::Tensor> top_k_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
+    bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);  // probs: (batch_size, vocab_size)
+  CHECK_DIM(
+      2, uniform_samples);  // uniform_samples: (max_top_k_rounds, batch_size)
+  CHECK_EQ(probs.size(0), uniform_samples.size(1));
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  unsigned int max_top_k_rounds = uniform_samples.size(0);
+  bool has_top_k_arr = maybe_top_k_arr.has_value();
+  auto top_k_arr =
+      maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
+  if (has_top_k_arr) {
+    CHECK_INPUT(top_k_arr);
+    CHECK_DIM(1, top_k_arr);  // top_k_arr: (batch_size,)
+    CHECK_EQ(top_k_arr.size(0), batch_size);
+    CHECK_EQ(top_k_arr.device(), device);
+  }
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+  top_k_arr = top_k_arr.to(torch::kInt32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+  auto success =
+      torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
+
+  cudaError_t status = sampling::TopKSamplingFromProb<float, int>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      static_cast<int*>(samples.data_ptr()),
+      static_cast<bool*>(success.data_ptr()),
+      has_top_k_arr ? static_cast<float*>(top_k_arr.data_ptr()) : nullptr,
+      batch_size, top_k_val, vocab_size, max_top_k_rounds, deterministic,
+      torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "TopKSamplingFromProbs failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+
+  return {samples, success};
+}
+
+std::vector<torch::Tensor> min_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_min_p_arr, double min_p_val,
+    bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);            // probs: (batch_size, vocab_size)
+  CHECK_DIM(2, uniform_samples);  // uniform_samples: (max_rounds, batch_size)
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  unsigned int max_rounds = uniform_samples.size(0);
+  CHECK_EQ(uniform_samples.size(1), batch_size);
+  bool has_min_p_arr = maybe_min_p_arr.has_value();
+  auto min_p_arr = maybe_min_p_arr.value_or(
+      torch::empty({0}, torch::dtype(torch::kFloat32)));
+  if (has_min_p_arr) {
+    CHECK_INPUT(min_p_arr);
+    CHECK_DIM(1, min_p_arr);  // min_p_arr: (batch_size,)
+    CHECK_EQ(min_p_arr.size(0), batch_size);
+    CHECK_EQ(min_p_arr.device(), device);
+  }
+  min_p_arr = min_p_arr.to(torch::kFloat32);
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+  auto success =
+      torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
+
+  cudaError_t status = sampling::MinPSamplingFromProb<float, int>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      has_min_p_arr ? static_cast<float*>(min_p_arr.data_ptr()) : nullptr,
+      static_cast<int*>(samples.data_ptr()),
+      static_cast<bool*>(success.data_ptr()), batch_size, min_p_val, vocab_size,
+      max_rounds, deterministic, torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "MinPSamplingFromProb failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+
+  return {samples, success};
+}
+
+std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
+    std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
+    bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);            // probs: (batch_size, vocab_size)
+  CHECK_DIM(2, uniform_samples);  // uniform_samples: (max_rounds, batch_size)
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  unsigned int max_rounds = uniform_samples.size(0);
+  CHECK_EQ(uniform_samples.size(1), batch_size);
+  bool has_top_k_arr = maybe_top_k_arr.has_value();
+  auto top_k_arr =
+      maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
+  if (has_top_k_arr) {
+    CHECK_INPUT(top_k_arr);
+    CHECK_DIM(1, top_k_arr);  // top_k_arr: (batch_size,)
+    CHECK_EQ(top_k_arr.size(0), batch_size);
+    CHECK_EQ(top_k_arr.device(), device);
+  }
+  top_k_arr = top_k_arr.to(torch::kInt32);
+  bool has_top_p_arr = maybe_top_p_arr.has_value();
+  auto top_p_arr = maybe_top_p_arr.value_or(
+      torch::empty({0}, torch::dtype(torch::kFloat32)));
+  if (has_top_p_arr) {
+    CHECK_INPUT(top_p_arr);
+    CHECK_DIM(1, top_p_arr);  // top_p_arr: (batch_size,)
+    CHECK_EQ(top_p_arr.size(0), batch_size);
+    CHECK_EQ(top_p_arr.device(), device);
+  }
+  top_p_arr = top_p_arr.to(torch::kFloat32);
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+  auto success =
+      torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
+
+  cudaError_t status = sampling::TopKTopPSamplingFromProb<float, int>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      has_top_k_arr ? static_cast<int*>(top_k_arr.data_ptr()) : nullptr,
+      has_top_p_arr ? static_cast<float*>(top_p_arr.data_ptr()) : nullptr,
+      static_cast<int*>(samples.data_ptr()),
+      static_cast<bool*>(success.data_ptr()), batch_size, top_k_val, top_p_val,
+      vocab_size, max_rounds, deterministic, torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "TopKTopPSamplingFromProbs failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+
+  return {samples, success};
+}
+
+torch::Tensor top_p_renorm_prob(torch::Tensor probs,
+                                std::optional<torch::Tensor> maybe_top_p_arr,
+                                double top_p_val) {
+  CHECK_INPUT(probs);
+  auto device = probs.device();
+  CHECK_DIM(2, probs);  // probs: (batch_size, vocab_size)
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  bool has_top_p_arr = maybe_top_p_arr.has_value();
+  auto top_p_arr = maybe_top_p_arr.value_or(
+      torch::empty({0}, torch::dtype(torch::kFloat32)));
+  if (has_top_p_arr) {
+    CHECK_INPUT(top_p_arr);
+    CHECK_DIM(1, top_p_arr);  // top_p_arr: (batch_size,)
+    CHECK_EQ(top_p_arr.size(0), batch_size);
+    CHECK_EQ(top_p_arr.device(), device);
+  }
+  top_p_arr = top_p_arr.to(torch::kFloat32);
+  probs = probs.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto renorm_probs = torch::empty(
+      {batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
+
+  cudaError_t status = sampling::TopPRenormProb<float>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(renorm_probs.data_ptr()),
+      has_top_p_arr ? static_cast<float*>(top_p_arr.data_ptr()) : nullptr,
+      batch_size, top_p_val, vocab_size, torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "TopPRenormProb failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+  return renorm_probs;
+}
+
+torch::Tensor top_k_renorm_prob(torch::Tensor probs,
+                                std::optional<torch::Tensor> maybe_top_k_arr,
+                                int64_t top_k_val) {
+  CHECK_INPUT(probs);
+  auto device = probs.device();
+  CHECK_DIM(2, probs);  // probs: (batch_size, vocab_size)
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  bool has_top_k_arr = maybe_top_k_arr.has_value();
+  auto top_k_arr =
+      maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
+  if (has_top_k_arr) {
+    CHECK_INPUT(top_k_arr);
+    CHECK_DIM(1, top_k_arr);  // top_k_arr: (batch_size,)
+    CHECK_EQ(top_k_arr.size(0), batch_size);
+    CHECK_EQ(top_k_arr.device(), device);
+  }
+  top_k_arr = top_k_arr.to(torch::kInt32);
+  probs = probs.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto renorm_probs = torch::empty(
+      {batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
+
+  cudaError_t status = sampling::TopKRenormProb<float>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(renorm_probs.data_ptr()),
+      has_top_k_arr ? static_cast<int*>(top_k_arr.data_ptr()) : nullptr,
+      batch_size, top_k_val, vocab_size, torch_current_stream);
+
+  TORCH_CHECK(status == cudaSuccess,
+              "TopKRenormProb failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+  return renorm_probs;
+}
+
+torch::Tensor top_k_mask_logits(torch::Tensor logits,
+                                std::optional<torch::Tensor> maybe_top_k_arr,
+                                int64_t top_k_val) {
+  CHECK_INPUT(logits);
+  auto device = logits.device();
+  CHECK_DIM(2, logits);  // logits: (batch_size, vocab_size)
+  unsigned int batch_size = logits.size(0);
+  unsigned int vocab_size = logits.size(1);
+  bool has_top_k_arr = maybe_top_k_arr.has_value();
+  auto top_k_arr =
+      maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
+  if (has_top_k_arr) {
+    CHECK_INPUT(top_k_arr);
+    CHECK_DIM(1, top_k_arr);  // top_k_arr: (batch_size,)
+    CHECK_EQ(top_k_arr.size(0), batch_size);
+    CHECK_EQ(top_k_arr.device(), device);
+  }
+  top_k_arr = top_k_arr.to(torch::kInt32);
+  logits = logits.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto mask_logits = torch::empty({batch_size, vocab_size},
+                                  torch::dtype(torch::kFloat32).device(device));
+
+  cudaError_t status = sampling::TopKMaskLogits<float>(
+      static_cast<float*>(logits.data_ptr()),
+      static_cast<float*>(mask_logits.data_ptr()),
+      has_top_k_arr ? static_cast<int*>(top_k_arr.data_ptr()) : nullptr,
+      batch_size, top_k_val, vocab_size, torch_current_stream);
+
+  TORCH_CHECK(status == cudaSuccess,
+              "TopKMaskLogits failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+  return mask_logits;
+}

+ 1398 - 0
kernels/sampling/sampling.cuh

@@ -0,0 +1,1398 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2024 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef APHRODITE_SAMPLING_CUH_
+#define APHRODITE_SAMPLING_CUH_
+
+#include <cub/block/block_adjacent_difference.cuh>
+#include <cub/block/block_reduce.cuh>
+#include <cub/block/block_scan.cuh>
+#include <numeric>
+
+#include "math.cuh"
+#include "utils.cuh"
+#include "vec_dtypes.cuh"
+
+namespace aphrodite {
+
+namespace sampling {
+
+using namespace cub;
+
+#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
+  if (deterministic) {                                            \
+    constexpr bool DETERMINISTIC = true;                          \
+    __VA_ARGS__                                                   \
+  } else {                                                        \
+    constexpr bool DETERMINISTIC = false;                         \
+    __VA_ARGS__                                                   \
+  }
+
+constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
+constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
+
+#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100)
+  #define APHRODITE_CUB_SUBTRACTLEFT_DEFINED
+#endif
+
+template <typename T>
+struct Pair {
+  T value;
+  int count;
+
+  __device__ Pair operator+(const Pair& other) const {
+    return {value + other.value, count + other.count};
+  }
+  __device__ Pair& operator+=(const Pair& other) {
+    value += other.value;
+    count += other.count;
+    return *this;
+  }
+};
+
+struct BoolDiffOp {
+  __device__ __forceinline__ bool operator()(const bool& lhs,
+                                             const bool& rhs) const {
+    return lhs != rhs;
+  }
+};
+
+template <typename T, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM>
+struct SamplingTempStorage {
+  union {
+    T deterministic_scan[BLOCK_THREADS / 32];
+    typename BlockScan<T, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
+    typename BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce;
+    typename BlockReduce<Pair<T>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce_pair;
+    typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
+  } block_prim;
+  struct {
+    int32_t sampled_id;
+    union {
+      T value;
+      Pair<T> pair;
+      T max_p;
+    } block_aggregate;
+  } data;
+};
+
+/*!
+ * \brief Deterministic inclusive scan implementation, use Belloch scan
+ * algorithm. \note This implementation is slower than the cub::BlockScan, but
+ * it is deterministic.
+ */
+template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
+          BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
+__device__ __forceinline__ void DeterministicInclusiveSum(
+    const T* in_data, T* out_data,
+    SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
+        temp_storage) {
+  T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
+  T thread_data[VEC_SIZE];
+  T thread_sum = 0;
+#pragma unroll
+  for (uint32_t i = 0; i < VEC_SIZE; ++i) {
+    thread_sum += in_data[i];
+    thread_data[i] = thread_sum;
+  }
+
+  T thread_exclusive_prefix_sum = thread_sum;
+
+#pragma unroll
+  for (uint32_t offset = 1; offset < 32; offset *= 2) {
+    T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
+    if ((threadIdx.x + 1) % (offset * 2) == 0) {
+      thread_exclusive_prefix_sum += tmp;
+    }
+  }
+
+  T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
+                           threadIdx.x | 0xffffffff);
+  if (threadIdx.x % 32 == 31) {
+    thread_exclusive_prefix_sum = 0;
+  }
+
+#pragma unroll
+  for (uint32_t offset = 16; offset >= 1; offset /= 2) {
+    T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
+    if ((threadIdx.x + 1) % (offset * 2) == 0) {
+      thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum;
+    }
+    if ((threadIdx.x + 1) % (offset * 2) == offset) {
+      thread_exclusive_prefix_sum = tmp;
+    }
+  }
+
+  smem_prefix_sum[threadIdx.x / 32] = warp_sum;
+  __syncthreads();
+
+  if (threadIdx.x < 32) {
+    T warp_exclusive_prefix_sum =
+        (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0;
+
+#pragma unroll
+    for (uint32_t offset = 1; offset < 32; offset *= 2) {
+      T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
+      if ((threadIdx.x + 1) % (offset * 2) == 0) {
+        warp_exclusive_prefix_sum += tmp;
+      }
+    }
+
+    if (threadIdx.x % 32 == 31) {
+      warp_exclusive_prefix_sum = 0;
+    }
+
+#pragma unroll
+    for (uint32_t offset = 16; offset >= 1; offset /= 2) {
+      T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
+      if ((threadIdx.x + 1) % (offset * 2) == 0) {
+        warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum;
+      }
+      if ((threadIdx.x + 1) % (offset * 2) == offset) {
+        warp_exclusive_prefix_sum = tmp;
+      }
+    }
+    if (threadIdx.x < BLOCK_THREADS / 32) {
+      smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum;
+    }
+  }
+  __syncthreads();
+
+#pragma unroll
+  for (uint32_t i = 0; i < VEC_SIZE; ++i) {
+    out_data[i] = smem_prefix_sum[threadIdx.x / 32] +
+                  thread_exclusive_prefix_sum + thread_data[i];
+  }
+}
+
+template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
+          BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T>
+__device__ __forceinline__ void DeviceSamplingFromProb(
+    uint32_t i, uint32_t d, T threshold, T u, vec_t<T, VEC_SIZE> prob_vec,
+    T& aggregate,
+    SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
+        temp_storage) {
+  const uint32_t tx = threadIdx.x;
+  T prob_greater_than_threshold[VEC_SIZE];
+  T inclusive_cdf[VEC_SIZE];
+  bool greater_than_u[VEC_SIZE], valid[VEC_SIZE];
+#pragma unroll
+  for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+    prob_greater_than_threshold[j] =
+        (prob_vec[j] > threshold) ? prob_vec[j] : T(0);
+    valid[j] =
+        prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
+  }
+  T aggregate_local = BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                          temp_storage->block_prim.reduce)
+                          .Sum<VEC_SIZE>(prob_greater_than_threshold);
+  if (tx == 0) {
+    temp_storage->data.block_aggregate.value = aggregate_local;
+  }
+  __syncthreads();
+  aggregate_local = temp_storage->data.block_aggregate.value;
+
+  if (aggregate + aggregate_local > u) {
+    if constexpr (DETERMINISTIC) {
+      DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                                REDUCE_ALGORITHM, T>(
+          prob_greater_than_threshold, inclusive_cdf, temp_storage);
+    } else {
+      BlockScan<T, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
+          .InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
+
+      __syncthreads();
+    }
+
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      greater_than_u[j] = inclusive_cdf[j] + aggregate > u;
+    }
+
+    bool greater_than_u_diff[VEC_SIZE];
+#ifdef APHRODITE_CUB_SUBTRACTLEFT_DEFINED
+    BlockAdjacentDifference<bool, BLOCK_THREADS>(
+        temp_storage->block_prim.adj_diff)
+        .SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff,
+                                BoolDiffOp());
+#else
+    BlockAdjacentDifference<bool, BLOCK_THREADS>(
+        temp_storage->block_prim.adj_diff)
+        .FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(),
+                             0);
+#endif
+    __syncthreads();
+
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      if (greater_than_u_diff[j] && valid[j]) {
+        if constexpr (DETERMINISTIC) {
+          temp_storage->data.sampled_id =
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
+        } else {
+          // cub's block scan result might not be monotonic, so we need to find
+          // the first element
+          atomicMin(&(temp_storage->data.sampled_id),
+                    (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
+        }
+      }
+    }
+    __syncthreads();
+  }
+  aggregate += aggregate_local;
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples,
+                                       IdType* output, IdType* row_indices,
+                                       uint32_t d) {
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+  temp_storage.data.sampled_id = d - 1;
+  __syncthreads();
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate(0);
+  float u = uniform_samples[bx];
+
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                     tx * VEC_SIZE);
+    }
+
+    DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                           REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+        i, d, DType(0), u, probs_vec, aggregate, &temp_storage);
+    if (float(aggregate) > u) {
+      break;
+    }
+  }
+  output[bx] = temp_storage.data.sampled_id;
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
+                                           IdType* output, bool* success,
+                                           IdType* top_k_arr,
+                                           uint32_t top_k_val, uint32_t d,
+                                           uint32_t max_top_k_rounds) {
+  const uint32_t batch_size = gridDim.x;
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate;
+  DType q = DType(1);
+  DType pivot = DType(0);
+  IdType sampled_id;
+  for (uint32_t round = 0; round < max_top_k_rounds; ++round) {
+    temp_storage.data.sampled_id = d - 1;
+    __syncthreads();
+    DType u = uniform_samples[round * batch_size + bx] * q;
+    aggregate = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                             REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+          i, d, pivot, u, probs_vec, aggregate, &temp_storage);
+      if (aggregate > u) {
+        break;
+      }
+    }
+    __syncthreads();
+    sampled_id = temp_storage.data.sampled_id;
+    pivot = max(pivot, probs[bx * d + sampled_id]);
+
+    Pair<DType> aggregate_gt_pivot{DType(0), 0};
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      Pair<DType> probs_gt_pivot[VEC_SIZE];
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
+                             (probs_vec[j] > pivot &&
+                              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
+      }
+
+      aggregate_gt_pivot +=
+          BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
+              temp_storage.block_prim.reduce_pair)
+              .Sum<VEC_SIZE>(probs_gt_pivot);
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
+      }
+      __syncthreads();
+    }
+    q = temp_storage.data.block_aggregate.pair.value;
+    if (temp_storage.data.block_aggregate.pair.count < k) {
+      break;
+    }
+  }
+  __syncthreads();
+  if (tx == 0) {
+    output[bx] = sampled_id;
+    if (temp_storage.data.block_aggregate.pair.count >= k) {
+      // failed to sample within MAX_TOP_P_ROUNDS
+      if (success != nullptr) {
+        success[bx] = false;
+      }
+    } else {
+      if (success != nullptr) {
+        success[bx] = true;
+      }
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
+                                           IdType* output, bool* success,
+                                           IdType* row_indices,
+                                           float* top_p_arr, float top_p_val,
+                                           uint32_t d,
+                                           uint32_t max_top_p_rounds) {
+  const uint32_t batch_size = gridDim.x;
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[bx];
+
+  const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate;
+  DType q = DType(1);
+  DType pivot = DType(0);
+  IdType sampled_id;
+  for (uint32_t round = 0; round < max_top_p_rounds; ++round) {
+    temp_storage.data.sampled_id = d - 1;
+    __syncthreads();
+    DType u = uniform_samples[round * batch_size + bx] * q;
+    aggregate = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + row_idx * d +
+                       (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                             REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+          i, d, pivot, u, probs_vec, aggregate, &temp_storage);
+      if (aggregate > u) {
+        break;
+      }
+    }
+    __syncthreads();
+    sampled_id = temp_storage.data.sampled_id;
+    pivot = max(pivot, probs[row_idx * d + sampled_id]);
+
+    DType aggregate_gt_pivot = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + row_idx * d +
+                       (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DType probs_gt_pivot[VEC_SIZE];
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
+      }
+
+      aggregate_gt_pivot +=
+          BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
+              .Sum<VEC_SIZE>(probs_gt_pivot);
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
+      }
+      __syncthreads();
+    }
+    q = temp_storage.data.block_aggregate.value;
+    if (float(q) < top_p) {
+      break;
+    }
+  }
+  __syncthreads();
+  if (tx == 0) {
+    output[bx] = sampled_id;
+    if (float(q) >= top_p) {
+      // failed to sample within MAX_TOP_P_ROUNDS
+      if (success != nullptr) {
+        success[bx] = false;
+      }
+    } else {
+      if (success != nullptr) {
+        success[bx] = true;
+      }
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
+                                           DType* min_p_arr, IdType* output,
+                                           bool* success, float min_p_val,
+                                           uint32_t d,
+                                           uint32_t max_min_p_rounds) {
+  const uint32_t batch_size = gridDim.x;
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  DType p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate;
+  DType q = DType(1);
+  DType pivot = DType(0);
+
+  DType max_p = 0;
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+    }
+    DType probs_[VEC_SIZE];
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      probs_[j] = probs_vec[j];
+    }
+    max_p = max(
+        max_p, BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
+                   .Reduce<VEC_SIZE>(probs_, cub::Max()));
+    __syncthreads();
+  }
+  if (tx == 0) {
+    temp_storage.data.block_aggregate.max_p = max_p;
+  }
+  __syncthreads();
+  DType scaled_p = temp_storage.data.block_aggregate.max_p * p;
+
+  IdType sampled_id;
+  for (uint32_t round = 0; round < max_min_p_rounds; ++round) {
+    temp_storage.data.sampled_id = d - 1;
+    __syncthreads();
+    DType u = uniform_samples[round * batch_size + bx] * q;
+    aggregate = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                             REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+          i, d, pivot, u, probs_vec, aggregate, &temp_storage);
+      if (aggregate > u) {
+        break;
+      }
+    }
+    __syncthreads();
+    sampled_id = temp_storage.data.sampled_id;
+    pivot = max(pivot, probs[bx * d + sampled_id]);
+    if (pivot >= scaled_p) {
+      break;
+    }
+
+    DType aggregate_gt_pivot = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DType probs_gt_pivot[VEC_SIZE];
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
+      }
+
+      aggregate_gt_pivot +=
+          BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
+              .Sum<VEC_SIZE>(probs_gt_pivot);
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
+      }
+      __syncthreads();
+    }
+    q = temp_storage.data.block_aggregate.value;
+  }
+  __syncthreads();
+  if (tx == 0) {
+    output[bx] = sampled_id;
+    if (pivot < scaled_p) {
+      // failed to sample within MAX_ROUNDS
+      if (success != nullptr) {
+        success[bx] = false;
+      }
+    } else {
+      if (success != nullptr) {
+        success[bx] = true;
+      }
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void TopKTopPSamplingFromProbKernel(
+    DType* probs, DType* uniform_samples, IdType* top_k_arr, DType* top_p_arr,
+    IdType* output, bool* success, IdType top_k_val, DType top_p_val,
+    uint32_t d, uint32_t max_rounds) {
+  const uint32_t batch_size = gridDim.x;
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  IdType k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
+  DType p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate;
+  DType q = DType(1);
+  DType pivot = DType(0);
+  IdType sampled_id;
+  for (uint32_t round = 0; round < max_rounds; ++round) {
+    temp_storage.data.sampled_id = d - 1;
+    __syncthreads();
+    DType u = uniform_samples[round * batch_size + bx] * q;
+    aggregate = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                             REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+          i, d, pivot, u, probs_vec, aggregate, &temp_storage);
+      if (aggregate > u) {
+        break;
+      }
+    }
+    __syncthreads();
+    sampled_id = temp_storage.data.sampled_id;
+    pivot = max(pivot, probs[bx * d + sampled_id]);
+
+    Pair<DType> aggregate_gt_pivot{DType(0), 0};
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      Pair<DType> probs_gt_pivot[VEC_SIZE];
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
+                             (probs_vec[j] > pivot &&
+                              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
+      }
+
+      aggregate_gt_pivot +=
+          BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
+              temp_storage.block_prim.reduce_pair)
+              .Sum<VEC_SIZE>(probs_gt_pivot);
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
+      }
+      __syncthreads();
+    }
+    q = temp_storage.data.block_aggregate.pair.value;
+    if (temp_storage.data.block_aggregate.pair.count < k && float(q) < p) {
+      break;
+    }
+  }
+  __syncthreads();
+  if (tx == 0) {
+    output[bx] = sampled_id;
+    if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) {
+      // failed to sample within MAX_TOP_P_ROUNDS
+      if (success != nullptr) {
+        success[bx] = false;
+      }
+    } else {
+      if (success != nullptr) {
+        success[bx] = true;
+      }
+    }
+  }
+}
+
+template <typename T, typename IdType>
+cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output,
+                             uint32_t batch_size, uint32_t d,
+                             bool deterministic, cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  IdType* row_indices_placeholder = nullptr;
+  void* args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder,
+                  &d};
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            SamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                   VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples,
+                                     IdType* output, IdType* row_indices,
+                                     uint32_t batch_size, uint32_t d,
+                                     bool deterministic,
+                                     cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs, &uniform_samples, &output, &row_indices, &d};
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            SamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                   VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
+                                 bool* success, T* top_k_arr,
+                                 uint32_t batch_size, uint32_t top_k_val,
+                                 uint32_t d, uint32_t max_top_k_rounds,
+                                 bool deterministic, cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs,     &uniform_samples, &output, &success,
+                  &top_k_arr, &top_k_val,       &d,      &max_top_k_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            TopKSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                       VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
+                                 bool* success, T* top_p_arr,
+                                 uint32_t batch_size, T top_p_val, uint32_t d,
+                                 uint32_t max_top_p_rounds, bool deterministic,
+                                 cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  IdType* row_indices_placeholder = nullptr;
+  void* args[] = {&probs,
+                  &uniform_samples,
+                  &output,
+                  &success,
+                  &row_indices_placeholder,
+                  &top_p_arr,
+                  &top_p_val,
+                  &d,
+                  &max_top_p_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                       VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr,
+                                 IdType* output, bool* success,
+                                 uint32_t batch_size, float min_p_val,
+                                 uint32_t d, uint32_t max_rounds,
+                                 bool deterministic, cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs,   &uniform_samples, &min_p_arr, &output,
+                  &success, &min_p_val,       &d,         &max_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                       VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples,
+                                     IdType* top_k_arr, T* top_p_arr,
+                                     IdType* output, bool* success,
+                                     uint32_t batch_size, IdType top_k_val,
+                                     T top_p_val, uint32_t d,
+                                     uint32_t max_rounds, bool deterministic,
+                                     cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs,  &uniform_samples, &top_k_arr, &top_p_arr,
+                  &output, &success,         &top_k_val, &top_p_val,
+                  &d,      &max_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO,
+                                                     REDUCE_ALGO, VEC_SIZE,
+                                                     DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, uint32_t BLOCK_THREADS,
+          BlockReduceAlgorithm REDUCE_ALGORITHM>
+struct RenormTempStorage {
+  union {
+    typename BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce;
+    typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce_int;
+    typename BlockReduce<Pair<T>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce_pair;
+  } block_prim;
+  struct {
+    T max_val;
+    T min_val;
+    union {
+      T value;
+      int count;
+      Pair<T> pair;
+    } block_aggregate;
+  } data;
+};
+
+template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
+          uint32_t VEC_SIZE, typename DType>
+__global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob,
+                                     DType* top_p_arr, float top_p_val,
+                                     uint32_t d) {
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  const uint32_t row_idx = bx;
+  float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
+
+  extern __shared__ __align__(
+      alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
+      uint8_t smem_renorm[];
+  auto& temp_storage =
+      reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
+          smem_renorm);
+  temp_storage.data.max_val = DType(0);
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType probs_greater_than_pivot[VEC_SIZE];  // pivot initialized to 0
+
+  DType threadlocal_max_val = DType(0);
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                     tx * VEC_SIZE);
+    }
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      probs_greater_than_pivot[j] = probs_vec[j];
+    }
+    threadlocal_max_val =
+        max(threadlocal_max_val,
+            BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                temp_storage.block_prim.reduce)
+                .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
+    __syncthreads();
+  }
+  if (tx == 0) {
+    temp_storage.data.max_val = threadlocal_max_val;
+  }
+  __syncthreads();
+  threadlocal_max_val = temp_storage.data.max_val;
+
+  float low = 0, high = threadlocal_max_val;
+  DType min_gt_low, max_le_high;
+  DType sum_low(1);
+  // f(x) = sum(probs[probs > x]), f(x) is non-increasing
+  // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p
+  // <= high} loop invariant:
+  // - f(low) >= p, f(high) < p
+  // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
+  // stopping condition
+  // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
+  do {
+    DType threadlocal_sum(0);
+    float mid = (low + high) / 2;
+    min_gt_low = high;
+    max_le_high = low;
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                       tx * VEC_SIZE);
+      }
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_greater_than_pivot[j] =
+            (probs_vec[j] > mid) ? probs_vec[j] : DType(0);
+        if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+          min_gt_low = min(min_gt_low, probs_vec[j]);
+        }
+        if (probs_vec[j] <= high &&
+            (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+          max_le_high = max(max_le_high, probs_vec[j]);
+        }
+      }
+      threadlocal_sum += BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                             temp_storage.block_prim.reduce)
+                             .Sum<VEC_SIZE>(probs_greater_than_pivot);
+      __syncthreads();
+    }
+    min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                     temp_storage.block_prim.reduce)
+                     .Reduce(min_gt_low, cub::Min());
+    __syncthreads();
+    max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                      temp_storage.block_prim.reduce)
+                      .Reduce(max_le_high, cub::Max());
+    if (tx == 0) {
+      temp_storage.data.block_aggregate.value = threadlocal_sum;
+      temp_storage.data.min_val = min_gt_low;
+      temp_storage.data.max_val = max_le_high;
+    }
+    __syncthreads();
+    threadlocal_sum = temp_storage.data.block_aggregate.value;
+    min_gt_low = temp_storage.data.min_val;
+    max_le_high = temp_storage.data.max_val;
+    if (threadlocal_sum >= p) {
+      low = mid;
+      sum_low = float(threadlocal_sum);
+    } else {
+      high = min(mid, max_le_high);
+    }
+  } while (min_gt_low != max_le_high);
+
+  DType normalizer = math::ptx_rcp(max(sum_low, 1e-8));
+
+  // normalize
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                     tx * VEC_SIZE);
+    }
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      probs_vec[j] =
+          (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0);
+    }
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.store(renormed_prob + row_idx * d +
+                      i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
+          uint32_t VEC_SIZE, typename DType, typename IdType>
+__global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits,
+                                     IdType* top_k_arr, uint32_t top_k_val,
+                                     uint32_t d) {
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  const uint32_t row_idx = bx;
+  uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
+  float pivot = -std::numeric_limits<float>::infinity();
+  vec_t<DType, VEC_SIZE> logits_vec;
+  if (k < d) {
+    extern __shared__ __align__(
+        alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
+        uint8_t smem_renorm[];
+    auto& temp_storage =
+        reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
+            smem_renorm);
+    DType logits_greater_than_pivot[VEC_SIZE];  // pivot initialized to 0
+
+    DType threadlocal_max_val = DType(-std::numeric_limits<float>::infinity()),
+          threadlocal_min_val = DType(std::numeric_limits<float>::infinity());
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      logits_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                        tx * VEC_SIZE);
+      }
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        logits_greater_than_pivot[j] = logits_vec[j];
+      }
+      threadlocal_max_val =
+          max(threadlocal_max_val,
+              BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                  temp_storage.block_prim.reduce)
+                  .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Max()));
+      __syncthreads();
+      threadlocal_min_val =
+          min(threadlocal_min_val,
+              BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                  temp_storage.block_prim.reduce)
+                  .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Min()));
+      __syncthreads();
+    }
+    if (tx == 0) {
+      temp_storage.data.max_val = threadlocal_max_val;
+      temp_storage.data.min_val = threadlocal_min_val;
+    }
+    __syncthreads();
+    threadlocal_max_val = temp_storage.data.max_val;
+    threadlocal_min_val = temp_storage.data.min_val;
+
+    float low = threadlocal_min_val - 1, high = threadlocal_max_val;
+    DType min_gt_low, max_le_high;
+    // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
+    // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
+    // p <= high} loop invariant:
+    // - f(low) >= k, f(high) < k
+    // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
+    // stopping condition: min_gt_low == max_le_high
+    // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
+    do {
+      int threadlocal_count_sum = 0;
+      int probs_greater_than_pivot_count[VEC_SIZE];  // pivot initialized to 0
+      float mid = (low + high) / 2;
+      min_gt_low = high;
+      max_le_high = low;
+      for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+        logits_vec.fill(DType(0));
+        if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+          logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                          tx * VEC_SIZE);
+        }
+#pragma unroll
+        for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+          probs_greater_than_pivot_count[j] =
+              logits_vec[j] > mid &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
+          if (logits_vec[j] > low &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+            min_gt_low = min(min_gt_low, logits_vec[j]);
+          }
+          if (logits_vec[j] <= high &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+            max_le_high = max(max_le_high, logits_vec[j]);
+          }
+        }
+        threadlocal_count_sum +=
+            BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                temp_storage.block_prim.reduce_int)
+                .Sum<VEC_SIZE>(probs_greater_than_pivot_count);
+        __syncthreads();
+      }
+      min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                       temp_storage.block_prim.reduce)
+                       .Reduce(min_gt_low, cub::Min());
+      __syncthreads();
+      max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                        temp_storage.block_prim.reduce)
+                        .Reduce(max_le_high, cub::Max());
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.count = threadlocal_count_sum;
+        temp_storage.data.min_val = min_gt_low;
+        temp_storage.data.max_val = max_le_high;
+      }
+      __syncthreads();
+      threadlocal_count_sum = temp_storage.data.block_aggregate.count;
+      min_gt_low = temp_storage.data.min_val;
+      max_le_high = temp_storage.data.max_val;
+      if (threadlocal_count_sum >= k) {
+        low = mid;
+      } else {
+        high = min(mid, max_le_high);
+      }
+    } while (min_gt_low != max_le_high);
+    pivot = low;
+  }
+
+  // masking
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    logits_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                      tx * VEC_SIZE);
+    }
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      logits_vec[j] = (logits_vec[j] > pivot)
+                          ? logits_vec[j]
+                          : DType(-std::numeric_limits<float>::infinity());
+    }
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      logits_vec.store(masked_logits + row_idx * d +
+                       i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
+          uint32_t VEC_SIZE, typename DType, typename IdType>
+__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob,
+                                     IdType* top_k_arr, uint32_t top_k_val,
+                                     uint32_t d) {
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  const uint32_t row_idx = bx;
+  uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
+  float pivot = -std::numeric_limits<float>::infinity(), normalizer = 1;
+  vec_t<DType, VEC_SIZE> probs_vec;
+  if (k < d) {
+    extern __shared__ __align__(
+        alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
+        uint8_t smem_renorm[];
+    auto& temp_storage =
+        reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
+            smem_renorm);
+    temp_storage.data.max_val = DType(0);
+    DType probs_greater_than_pivot[VEC_SIZE];  // pivot initialized to 0
+
+    DType threadlocal_max_val = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                       tx * VEC_SIZE);
+      }
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_greater_than_pivot[j] = probs_vec[j];
+      }
+      threadlocal_max_val =
+          max(threadlocal_max_val,
+              BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                  temp_storage.block_prim.reduce)
+                  .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
+      __syncthreads();
+    }
+    if (tx == 0) {
+      temp_storage.data.max_val = threadlocal_max_val;
+    }
+    __syncthreads();
+    threadlocal_max_val = temp_storage.data.max_val;
+
+    float low = 0, high = threadlocal_max_val;
+    DType min_gt_low, max_le_high;
+    DType sum_low(1);
+    // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
+    // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
+    // p <= high} loop invariant:
+    // - f(low) >= k, f(high) < k
+    // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
+    // stopping condition: min_gt_low == max_le_high
+    // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
+    do {
+      Pair<DType> threadlocal_sum{DType(0), 0};
+      Pair<DType>
+          probs_greater_than_pivot_pair[VEC_SIZE];  // pivot initialized to 0
+      float mid = (low + high) / 2;
+      min_gt_low = high;
+      max_le_high = low;
+      for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+        probs_vec.fill(DType(0));
+        if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+          probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                         tx * VEC_SIZE);
+        }
+#pragma unroll
+        for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+          probs_greater_than_pivot_pair[j] = {
+              (probs_vec[j] > mid) ? probs_vec[j] : DType(0),
+              (probs_vec[j] > mid &&
+               (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
+          if (probs_vec[j] > low &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+            min_gt_low = min(min_gt_low, probs_vec[j]);
+          }
+          if (probs_vec[j] <= high &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+            max_le_high = max(max_le_high, probs_vec[j]);
+          }
+        }
+        threadlocal_sum +=
+            BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                temp_storage.block_prim.reduce_pair)
+                .Sum<VEC_SIZE>(probs_greater_than_pivot_pair);
+        __syncthreads();
+      }
+      min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                       temp_storage.block_prim.reduce)
+                       .Reduce(min_gt_low, cub::Min());
+      __syncthreads();
+      max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                        temp_storage.block_prim.reduce)
+                        .Reduce(max_le_high, cub::Max());
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.pair = threadlocal_sum;
+        temp_storage.data.min_val = min_gt_low;
+        temp_storage.data.max_val = max_le_high;
+      }
+      __syncthreads();
+      threadlocal_sum = temp_storage.data.block_aggregate.pair;
+      min_gt_low = temp_storage.data.min_val;
+      max_le_high = temp_storage.data.max_val;
+      if (threadlocal_sum.count >= k) {
+        low = mid;
+        sum_low = float(threadlocal_sum.value);
+      } else {
+        high = min(mid, max_le_high);
+      }
+    } while (min_gt_low != max_le_high);
+
+    normalizer = math::ptx_rcp(max(sum_low, 1e-8));
+    pivot = low;
+  }
+
+  // normalize
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                     tx * VEC_SIZE);
+    }
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      probs_vec[j] =
+          (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : DType(0);
+    }
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.store(renormed_prob + row_idx * d +
+                      i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
+    }
+  }
+}
+
+template <typename DType>
+cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
+                           uint32_t batch_size, float top_p_val, uint32_t d,
+                           cudaStream_t stream = 0) {
+  const uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
+
+  const uint32_t smem_size =
+      sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d};
+  DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
+    auto kernel =
+        TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>;
+    APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+    APHRODITE_CUDA_CALL(
+        cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
+  });
+  return cudaSuccess;
+}
+
+template <typename DType, typename IdType>
+cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob,
+                           IdType* top_k_arr, uint32_t batch_size,
+                           uint32_t top_k_val, uint32_t d,
+                           cudaStream_t stream = 0) {
+  const uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
+
+  const uint32_t smem_size =
+      sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d};
+  DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
+    auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE,
+                                       DType, IdType>;
+    APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+    APHRODITE_CUDA_CALL(
+        cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
+  });
+  return cudaSuccess;
+}
+
+template <typename DType, typename IdType>
+cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits,
+                           IdType* top_k_arr, uint32_t batch_size,
+                           uint32_t top_k_val, uint32_t d,
+                           cudaStream_t stream = 0) {
+  const uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
+
+  const uint32_t smem_size =
+      sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d};
+  DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
+    auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE,
+                                       DType, IdType>;
+    APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+    APHRODITE_CUDA_CALL(
+        cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
+  });
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t ParallelTopPSamplingFromProb(
+    T* probs, T* uniform_samples, IdType* output, bool* success,
+    IdType* row_indices, T* top_p_arr, uint32_t batch_size, uint32_t d,
+    uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  T top_p_placeholder = 0;
+  void* args[] = {
+      &probs,     &uniform_samples,   &output, &success,         &row_indices,
+      &top_p_arr, &top_p_placeholder, &d,      &max_top_p_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                       VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+}  // namespace sampling
+
+}  // namespace aphrodite
+
+#endif  // APHRODITE_SAMPLING_CUH_

+ 273 - 0
kernels/sampling/utils.cuh

@@ -0,0 +1,273 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2023 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef APHRODITE_UTILS_CUH_
+#define APHRODITE_UTILS_CUH_
+#include <cuda_runtime.h>
+
+#include <iostream>
+#include <sstream>
+#include <stdexcept>
+#include <vector>
+#include <torch/all.h>
+
+#define STR_HELPER(x) #x
+#define STR(x) STR_HELPER(x)
+
+// macro to turn off fp16 qk reduction to reduce binary
+#ifndef APHRODITE_ALWAYS_DISALLOW_FP16_QK_REDUCTION
+  #define APHRODITE_ALWAYS_DISALLOW_FP16_QK_REDUCTION 0
+#endif
+
+#ifndef NDEBUG
+  #define APHRODITE_CUDA_CALL(func, ...)                                  \
+    {                                                                     \
+      cudaError_t e = (func);                                             \
+      if (e != cudaSuccess) {                                             \
+        std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
+                  << ") " << __FILE__ << ": line " << __LINE__            \
+                  << " at function " << STR(func) << std::endl;           \
+        return e;                                                         \
+      }                                                                   \
+    }
+#else
+  #define APHRODITE_CUDA_CALL(func, ...) \
+    {                                    \
+      cudaError_t e = (func);            \
+      if (e != cudaSuccess) {            \
+        return e;                        \
+      }                                  \
+    }
+#endif
+
+#define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction,           \
+                                         ALLOW_FP16_QK_REDUCTION, ...)      \
+  if (allow_fp16_qk_reduction) {                                            \
+    throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \
+  } else {                                                                  \
+    constexpr bool ALLOW_FP16_QK_REDUCTION = false;                         \
+    __VA_ARGS__                                                             \
+  }
+
+#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \
+  if (num_frags_x == 1) {                                   \
+    constexpr size_t NUM_FRAGS_X = 1;                       \
+    __VA_ARGS__                                             \
+  } else if (num_frags_x == 2) {                            \
+    constexpr size_t NUM_FRAGS_X = 2;                       \
+    __VA_ARGS__                                             \
+  } else {                                                  \
+    std::ostringstream err_msg;                             \
+    err_msg << "Unsupported num_frags_x: " << num_frags_x;  \
+    throw std::invalid_argument(err_msg.str());             \
+  }
+
+#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \
+  if (max_frags_z >= 8) {                                   \
+    constexpr size_t NUM_FRAGS_Z = 8;                       \
+    __VA_ARGS__                                             \
+  } else if (max_frags_z >= 4) {                            \
+    constexpr size_t NUM_FRAGS_Z = 4;                       \
+    __VA_ARGS__                                             \
+  } else if (max_frags_z >= 2) {                            \
+    constexpr size_t NUM_FRAGS_Z = 2;                       \
+    __VA_ARGS__                                             \
+  } else if (max_frags_z >= 1) {                            \
+    constexpr size_t NUM_FRAGS_Z = 1;                       \
+    __VA_ARGS__                                             \
+  } else {                                                  \
+    std::ostringstream err_msg;                             \
+    err_msg << "Unsupported max_frags_z: " << max_frags_z;  \
+    throw std::invalid_argument(err_msg.str());             \
+  }
+
+#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
+  if (group_size == 1) {                                     \
+    constexpr size_t GROUP_SIZE = 1;                         \
+    __VA_ARGS__                                              \
+  } else if (group_size == 2) {                              \
+    constexpr size_t GROUP_SIZE = 2;                         \
+    __VA_ARGS__                                              \
+  } else if (group_size == 4) {                              \
+    constexpr size_t GROUP_SIZE = 4;                         \
+    __VA_ARGS__                                              \
+  } else if (group_size == 8) {                              \
+    constexpr size_t GROUP_SIZE = 8;                         \
+    __VA_ARGS__                                              \
+  } else {                                                   \
+    std::ostringstream err_msg;                              \
+    err_msg << "Unsupported group_size: " << group_size;     \
+    throw std::invalid_argument(err_msg.str());              \
+  }
+
+#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...)         \
+  switch (mask_mode) {                                        \
+    case MaskMode::kNone: {                                   \
+      constexpr MaskMode MASK_MODE = MaskMode::kNone;         \
+      __VA_ARGS__                                             \
+      break;                                                  \
+    }                                                         \
+    case MaskMode::kCausal: {                                 \
+      constexpr MaskMode MASK_MODE = MaskMode::kCausal;       \
+      __VA_ARGS__                                             \
+      break;                                                  \
+    }                                                         \
+    case MaskMode::kCustom: {                                 \
+      constexpr MaskMode MASK_MODE = MaskMode::kCustom;       \
+      __VA_ARGS__                                             \
+      break;                                                  \
+    }                                                         \
+    default: {                                                \
+      std::ostringstream err_msg;                             \
+      err_msg << "Unsupported mask_mode: " << int(mask_mode); \
+      throw std::invalid_argument(err_msg.str());             \
+    }                                                         \
+  }
+
+#define DISPATCH_LOGITS_POST_HOOK(logits_soft_cap, LOGITS_POST_HOOK, ...) \
+  if (logits_soft_cap > 0.f) {                                            \
+    constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kSoftCap; \
+    __VA_ARGS__                                                           \
+  } else if (logits_soft_cap == 0.f) {                                    \
+    constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kNone;    \
+    __VA_ARGS__                                                           \
+  } else {                                                                \
+    std::ostringstream err_msg;                                           \
+    err_msg << "Invalid logits_soft_cap (should be >= 0): "               \
+            << logits_soft_cap;                                           \
+    throw std::invalid_argument(err_msg.str());                           \
+  }
+
+#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...)     \
+  switch (head_dim) {                                  \
+    case 64: {                                         \
+      constexpr size_t HEAD_DIM = 64;                  \
+      __VA_ARGS__                                      \
+      break;                                           \
+    }                                                  \
+    case 128: {                                        \
+      constexpr size_t HEAD_DIM = 128;                 \
+      __VA_ARGS__                                      \
+      break;                                           \
+    }                                                  \
+    case 256: {                                        \
+      constexpr size_t HEAD_DIM = 256;                 \
+      __VA_ARGS__                                      \
+      break;                                           \
+    }                                                  \
+    default: {                                         \
+      std::ostringstream err_msg;                      \
+      err_msg << "Unsupported head_dim: " << head_dim; \
+      throw std::invalid_argument(err_msg.str());      \
+    }                                                  \
+  }
+
+#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \
+  switch (pos_encoding_mode) {                                                \
+    case PosEncodingMode::kNone: {                                            \
+      constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone;   \
+      __VA_ARGS__                                                             \
+      break;                                                                  \
+    }                                                                         \
+    case PosEncodingMode::kRoPELlama: {                                       \
+      constexpr PosEncodingMode POS_ENCODING_MODE =                           \
+          PosEncodingMode::kRoPELlama;                                        \
+      __VA_ARGS__                                                             \
+      break;                                                                  \
+    }                                                                         \
+    case PosEncodingMode::kALiBi: {                                           \
+      constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi;  \
+      __VA_ARGS__                                                             \
+      break;                                                                  \
+    }                                                                         \
+    default: {                                                                \
+      std::ostringstream err_msg;                                             \
+      err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
+      throw std::invalid_argument(err_msg.str());                             \
+    }                                                                         \
+  }
+
+#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
+  switch (aligned_vec_size) {                                              \
+    case 16: {                                                             \
+      constexpr size_t ALIGNED_VEC_SIZE = 16;                              \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    case 8: {                                                              \
+      constexpr size_t ALIGNED_VEC_SIZE = 8;                               \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    case 4: {                                                              \
+      constexpr size_t ALIGNED_VEC_SIZE = 4;                               \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    case 2: {                                                              \
+      constexpr size_t ALIGNED_VEC_SIZE = 2;                               \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    case 1: {                                                              \
+      constexpr size_t ALIGNED_VEC_SIZE = 1;                               \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    default: {                                                             \
+      std::ostringstream err_msg;                                          \
+      err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size;     \
+      throw std::invalid_argument(err_msg.str());                          \
+    }                                                                      \
+  }
+
+namespace aphrodite {
+
+template <typename T1, typename T2>
+__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {
+  return (x + y - 1) / y;
+}
+
+template <typename T>
+inline void DebugPrintCUDAArray(T* device_ptr, size_t size,
+                                std::string prefix = "") {
+  std::vector<T> host_array(size);
+  std::cout << prefix;
+  cudaMemcpy(host_array.data(), device_ptr, size * sizeof(T),
+             cudaMemcpyDeviceToHost);
+  for (size_t i = 0; i < size; ++i) {
+    std::cout << host_array[i] << " ";
+  }
+  std::cout << std::endl;
+}
+
+/*!
+ * \brief Return x - y if x > y, otherwise return 0.
+ */
+__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
+                                                           uint32_t y) {
+  return (x > y) ? x - y : 0U;
+}
+
+__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) {
+  uint32_t tmp = a;
+  a = b;
+  b = tmp;
+}
+
+}  // namespace aphrodite
+
+#endif  // APHRODITE_UTILS_CUH_

+ 1501 - 0
kernels/sampling/vec_dtypes.cuh

@@ -0,0 +1,1501 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2023 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef VEC_DTYPES_CUH_
+#define VEC_DTYPES_CUH_
+
+#include <cuda_bf16.h>
+#include <cuda_fp16.h>
+#include <cuda_fp8.h>
+#include <cuda_runtime.h>
+
+#include <type_traits>
+
+namespace aphrodite {
+
+#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900))
+  #define APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+#endif
+
+#define APHRODITE_INLINE inline __attribute__((always_inline)) __device__
+
+/******************* vec_t type cast *******************/
+
+template <typename dst_t, typename src_t>
+struct vec_cast {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(dst_t* dst, const src_t* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size; ++i) {
+      dst[i] = (dst_t)src[i];
+    }
+  }
+};
+
+template <>
+struct vec_cast<float, half> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(float* dst, const half* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = (float)src[0];
+    } else {
+#pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        ((float2*)dst)[i] = __half22float2(((half2*)src)[i]);
+      }
+    }
+  }
+};
+
+template <>
+struct vec_cast<half, float> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(half* dst, const float* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = __float2half(src[0]);
+    } else {
+#pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]);
+      }
+    }
+  }
+};
+
+template <typename T>
+constexpr APHRODITE_INLINE int get_exponent_bits() {
+  if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
+    return 4;
+  } else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
+    return 5;
+  } else if constexpr (std::is_same<T, half>::value) {
+    return 5;
+  } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
+    return 8;
+  }
+}
+
+template <typename T>
+constexpr APHRODITE_INLINE int get_mantissa_bits() {
+  if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
+    return 3;
+  } else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
+    return 2;
+  } else if constexpr (std::is_same<T, half>::value) {
+    return 11;
+  } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
+    return 7;
+  }
+}
+
+/*!
+ * \brief Fallback to software fast dequant implementation if hardware
+ * dequantization is not available. \note Inspired by Marlin's fast
+ * dequantization, but here we don't have to permute weights order. \ref
+ * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120
+ */
+template <typename fp8_dtype, typename fp16_dtype>
+__device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) {
+  uint32_t q = *input;
+  if constexpr (std::is_same<fp8_dtype, __nv_fp8_e5m2>::value &&
+                std::is_same<fp16_dtype, half>::value) {
+    output->x = __byte_perm(0U, q, 0x5140);
+    output->y = __byte_perm(0U, q, 0x7362);
+  } else {
+    constexpr int FP8_EXPONENT = get_exponent_bits<fp8_dtype>();
+    constexpr int FP8_MANTISSA = get_mantissa_bits<fp8_dtype>();
+    constexpr int FP16_EXPONENT = get_exponent_bits<fp16_dtype>();
+
+    constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
+    // Calculate MASK for extracting mantissa and exponent
+    constexpr int MASK1 = 0x80000000;
+    constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
+    constexpr int MASK3 = MASK2 & 0x7fffffff;
+    constexpr int MASK = MASK3 | (MASK3 >> 16);
+    q = __byte_perm(q, q, 0x1302);
+
+    // Extract and shift FP8 values to FP16 format
+    uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
+    uint32_t Out2 =
+        ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
+
+    constexpr int BIAS_OFFSET =
+        (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
+    // Construct and apply exponent bias
+    if constexpr (std::is_same<fp16_dtype, half>::value) {
+      const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
+
+      // Convert to half2 and apply bias
+      *(half2*)&(output->x) =
+          __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
+      *(half2*)&(output->y) =
+          __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
+    } else {
+      constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
+      const nv_bfloat162 bias_reg =
+          __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
+      // Convert to bfloat162 and apply bias
+      *(nv_bfloat162*)&(output->x) =
+          __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
+      *(nv_bfloat162*)&(output->y) =
+          __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
+    }
+  }
+}
+
+template <>
+struct vec_cast<nv_bfloat16, __nv_fp8_e4m3> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(nv_bfloat16* dst,
+                                    const __nv_fp8_e4m3* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = nv_bfloat16(src[0]);
+    } else if constexpr (vec_size == 2) {
+      dst[0] = nv_bfloat16(src[0]);
+      dst[1] = nv_bfloat16(src[1]);
+    } else {
+      static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
+#pragma unroll
+      for (uint32_t i = 0; i < vec_size / 4; ++i) {
+        fast_dequant_f8f16x4<__nv_fp8_e4m3, nv_bfloat16>((uint32_t*)&src[i * 4],
+                                                         (uint2*)&dst[i * 4]);
+      }
+    }
+  }
+};
+
+template <>
+struct vec_cast<nv_bfloat16, __nv_fp8_e5m2> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(nv_bfloat16* dst,
+                                    const __nv_fp8_e5m2* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = nv_bfloat16(src[0]);
+    } else if constexpr (vec_size == 2) {
+      dst[0] = nv_bfloat16(src[0]);
+      dst[1] = nv_bfloat16(src[1]);
+    } else {
+      static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
+#pragma unroll
+      for (uint32_t i = 0; i < vec_size / 4; ++i) {
+        fast_dequant_f8f16x4<__nv_fp8_e5m2, nv_bfloat16>((uint32_t*)&src[i * 4],
+                                                         (uint2*)&dst[i * 4]);
+      }
+    }
+  }
+};
+
+template <>
+struct vec_cast<__nv_fp8_e4m3, half> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(__nv_fp8_e4m3* dst, const half* src) {
+#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+    if constexpr (vec_size == 1) {
+      dst[0] = __nv_fp8_e4m3(src[0]);
+    } else {
+  #pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        uint16_t y;
+        uint32_t x = *(uint32_t*)&src[i * 2];
+        asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;"
+                     : "=h"(y)
+                     : "r"(x));
+        *(uint16_t*)&dst[i * 2] = y;
+      }
+    }
+#else
+  #pragma unroll
+    for (size_t i = 0; i < vec_size; ++i) {
+      dst[i] = __nv_fp8_e4m3(src[i]);
+    }
+#endif  // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+  }
+};
+
+template <>
+struct vec_cast<__nv_fp8_e5m2, half> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(__nv_fp8_e5m2* dst, const half* src) {
+#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+    if constexpr (vec_size == 1) {
+      dst[0] = __nv_fp8_e5m2(src[0]);
+    } else {
+  #pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        uint16_t y;
+        uint32_t x = *(uint32_t*)&src[i * 2];
+        asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;"
+                     : "=h"(y)
+                     : "r"(x));
+        *(uint16_t*)&dst[i * 2] = y;
+      }
+    }
+#else
+  #pragma unroll
+    for (size_t i = 0; i < vec_size; ++i) {
+      dst[i] = __nv_fp8_e5m2(src[i]);
+    }
+#endif  // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+  }
+};
+
+template <>
+struct vec_cast<half, __nv_fp8_e4m3> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e4m3* src) {
+#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+    if constexpr (vec_size == 1) {
+      dst[0] = half(src[0]);
+    } else {
+  #pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        uint32_t y;
+        uint16_t x = *(uint16_t*)&src[i * 2];
+        asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(y) : "h"(x));
+        *(uint32_t*)&dst[i * 2] = y;
+      }
+    }
+#else
+    if constexpr (vec_size == 1) {
+      dst[0] = half(src[0]);
+    } else if constexpr (vec_size == 2) {
+      dst[0] = half(src[0]);
+      dst[1] = half(src[1]);
+    } else {
+      static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
+  #pragma unroll
+      for (uint32_t i = 0; i < vec_size / 4; ++i) {
+        fast_dequant_f8f16x4<__nv_fp8_e4m3, half>((uint32_t*)&src[i * 4],
+                                                  (uint2*)&dst[i * 4]);
+      }
+    }
+#endif  // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+  }
+};
+
+template <>
+struct vec_cast<half, __nv_fp8_e5m2> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e5m2* src) {
+#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+    if constexpr (vec_size == 1) {
+      dst[0] = half(src[0]);
+    } else {
+  #pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        uint32_t y;
+        uint16_t x = *(uint16_t*)&src[i * 2];
+        asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;" : "=r"(y) : "h"(x));
+        *(uint32_t*)&dst[i * 2] = y;
+      }
+    }
+#else
+    if constexpr (vec_size == 1) {
+      dst[0] = half(src[0]);
+    } else if constexpr (vec_size == 2) {
+      dst[0] = half(src[0]);
+      dst[1] = half(src[1]);
+    } else {
+      static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
+  #pragma unroll
+      for (uint32_t i = 0; i < vec_size / 4; ++i) {
+        fast_dequant_f8f16x4<__nv_fp8_e5m2, half>((uint32_t*)&src[i * 4],
+                                                  (uint2*)&dst[i * 4]);
+      }
+    }
+#endif  // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+  }
+};
+
+template <>
+struct vec_cast<float, nv_bfloat16> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(float* dst, const nv_bfloat16* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = (float)src[0];
+    } else {
+#pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]);
+      }
+    }
+  }
+};
+
+template <>
+struct vec_cast<nv_bfloat16, float> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(nv_bfloat16* dst, const float* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = nv_bfloat16(src[0]);
+    } else {
+#pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]);
+      }
+    }
+  }
+};
+
+template <typename float_t, size_t vec_size>
+struct vec_t {
+  APHRODITE_INLINE float_t& operator[](size_t i);
+  APHRODITE_INLINE const float_t& operator[](size_t i) const;
+  APHRODITE_INLINE void fill(float_t val);
+  APHRODITE_INLINE void load(const float_t* ptr);
+  APHRODITE_INLINE void store(float_t* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src);
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr);
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const;
+  APHRODITE_INLINE static void memcpy(float_t* dst, const float_t* src);
+  APHRODITE_INLINE float_t* ptr();
+};
+
+template <typename src_float_t, typename tgt_float_t, size_t vec_size>
+APHRODITE_INLINE void cast_from_impl(vec_t<tgt_float_t, vec_size>& dst,
+                                     const vec_t<src_float_t, vec_size>& src) {
+  vec_cast<tgt_float_t, src_float_t>::cast<vec_size>(
+      dst.ptr(), const_cast<vec_t<src_float_t, vec_size>*>(&src)->ptr());
+}
+
+template <typename src_float_t, typename tgt_float_t, size_t vec_size>
+APHRODITE_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
+                                     const src_float_t* src_ptr) {
+  if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
+    dst.load(src_ptr);
+  } else {
+    vec_t<src_float_t, vec_size> tmp;
+    tmp.load(src_ptr);
+    dst.cast_from(tmp);
+  }
+}
+
+template <typename src_float_t, typename tgt_float_t, size_t vec_size>
+APHRODITE_INLINE void cast_store_impl(tgt_float_t* dst_ptr,
+                                      const vec_t<src_float_t, vec_size>& src) {
+  if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
+    src.store(dst_ptr);
+  } else {
+    vec_t<tgt_float_t, vec_size> tmp;
+    tmp.cast_from(src);
+    tmp.store(dst_ptr);
+  }
+}
+
+/******************* vec_t<__nv_fp8_e4m3> *******************/
+
+// __nv_fp8_e4m3 x 1
+template <>
+struct vec_t<__nv_fp8_e4m3, 1> {
+  __nv_fp8_e4m3 data;
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) {
+  data = val;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3* ptr) {
+  data = *ptr;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(__nv_fp8_e4m3* ptr) const {
+  *ptr = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(
+    __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
+  *dst = *src;
+}
+
+// __nv_fp8_e4m3 x 2
+template <>
+struct vec_t<__nv_fp8_e4m3, 2> {
+  __nv_fp8x2_e4m3 data;
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) {
+  data.__x =
+      (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3* ptr) {
+  data = *((__nv_fp8x2_e4m3*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(__nv_fp8_e4m3* ptr) const {
+  *((__nv_fp8x2_e4m3*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(
+    __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
+  *((__nv_fp8x2_e4m3*)dst) = *((__nv_fp8x2_e4m3*)src);
+}
+
+// __nv_fp8_e4m3 x 4
+
+template <>
+struct vec_t<__nv_fp8_e4m3, 4> {
+  __nv_fp8x4_e4m3 data;
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) {
+  data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+             (__nv_fp8x4_storage_t(val.__x) << 16) |
+             (__nv_fp8x4_storage_t(val.__x) << 8) |
+             __nv_fp8x4_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3* ptr) {
+  data = *((__nv_fp8x4_e4m3*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(__nv_fp8_e4m3* ptr) const {
+  *((__nv_fp8x4_e4m3*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(
+    __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
+  *((__nv_fp8x4_e4m3*)dst) = *((__nv_fp8x4_e4m3*)src);
+}
+
+// __nv_fp8_e4m3 x 8
+
+template <>
+struct vec_t<__nv_fp8_e4m3, 8> {
+  uint2 data;
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 8>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) {
+  ((__nv_fp8x4_e4m3*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 16) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 8) |
+                                       __nv_fp8x4_storage_t(val.__x);
+  ((__nv_fp8x4_e4m3*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 16) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 8) |
+                                       __nv_fp8x4_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3* ptr) {
+  data = *((uint2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(__nv_fp8_e4m3* ptr) const {
+  *((uint2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(
+    __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
+  *((uint2*)dst) = *((uint2*)src);
+}
+
+// __nv_fp8_e4m3 x 16 or more
+template <size_t vec_size>
+struct vec_t<__nv_fp8_e4m3, vec_size> {
+  uint4 data[vec_size / 16];
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)data)[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)data)[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((__nv_fp8x4_e4m3*)(&(data[i].x)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e4m3*)(&(data[i].y)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e4m3*)(&(data[i].z)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e4m3*)(&(data[i].w)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+    }
+  }
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      data[i] = ((uint4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((uint4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((uint4*)dst)[i] = ((uint4*)src)[i];
+    }
+  }
+};
+
+/******************* vec_t<__nv_fp8_e5m2> *******************/
+
+// __nv_fp8_e5m2 x 1
+template <>
+struct vec_t<__nv_fp8_e5m2, 1> {
+  __nv_fp8_e5m2 data;
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) {
+  data = val;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2* ptr) {
+  data = *ptr;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(__nv_fp8_e5m2* ptr) const {
+  *ptr = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(
+    __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
+  *dst = *src;
+}
+
+// __nv_fp8_e5m2 x 2
+template <>
+struct vec_t<__nv_fp8_e5m2, 2> {
+  __nv_fp8x2_e5m2 data;
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) {
+  data.__x =
+      (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2* ptr) {
+  data = *((__nv_fp8x2_e5m2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(__nv_fp8_e5m2* ptr) const {
+  *((__nv_fp8x2_e5m2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(
+    __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
+  *((__nv_fp8x2_e5m2*)dst) = *((__nv_fp8x2_e5m2*)src);
+}
+
+// __nv_fp8_e5m2 x 4
+
+template <>
+struct vec_t<__nv_fp8_e5m2, 4> {
+  __nv_fp8x4_e5m2 data;
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) {
+  data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+             (__nv_fp8x4_storage_t(val.__x) << 16) |
+             (__nv_fp8x4_storage_t(val.__x) << 8) |
+             __nv_fp8x4_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2* ptr) {
+  data = *((__nv_fp8x4_e5m2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(__nv_fp8_e5m2* ptr) const {
+  *((__nv_fp8x4_e5m2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(
+    __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
+  *((__nv_fp8x4_e5m2*)dst) = *((__nv_fp8x4_e5m2*)src);
+}
+
+// __nv_fp8_e5m2 x 8
+
+template <>
+struct vec_t<__nv_fp8_e5m2, 8> {
+  uint2 data;
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 8>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) {
+  ((__nv_fp8x4_e5m2*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 16) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 8) |
+                                       __nv_fp8x4_storage_t(val.__x);
+  ((__nv_fp8x4_e5m2*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 16) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 8) |
+                                       __nv_fp8x4_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2* ptr) {
+  data = *((uint2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(__nv_fp8_e5m2* ptr) const {
+  *((uint2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(
+    __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
+  *((uint2*)dst) = *((uint2*)src);
+}
+
+// __nv_fp8_e5m2 x 16 or more
+
+template <size_t vec_size>
+struct vec_t<__nv_fp8_e5m2, vec_size> {
+  uint4 data[vec_size / 16];
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)data)[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)data)[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((__nv_fp8x4_e5m2*)(&(data[i].x)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e5m2*)(&(data[i].y)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e5m2*)(&(data[i].z)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e5m2*)(&(data[i].w)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+    }
+  }
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      data[i] = ((uint4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((uint4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((uint4*)dst)[i] = ((uint4*)src)[i];
+    }
+  }
+};
+
+/******************* vec_t<half> *******************/
+
+// half x 1
+template <>
+struct vec_t<half, 1> {
+  half data;
+
+  APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
+  APHRODITE_INLINE const half& operator[](size_t i) const {
+    return ((const half*)(&data))[i];
+  }
+  APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
+  APHRODITE_INLINE void fill(half val);
+  APHRODITE_INLINE void load(const half* ptr);
+  APHRODITE_INLINE void store(half* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(half* dst, const half* src);
+};
+
+APHRODITE_INLINE void vec_t<half, 1>::fill(half val) { data = val; }
+
+APHRODITE_INLINE void vec_t<half, 1>::load(const half* ptr) { data = *ptr; }
+
+APHRODITE_INLINE void vec_t<half, 1>::store(half* ptr) const { *ptr = data; }
+
+APHRODITE_INLINE void vec_t<half, 1>::memcpy(half* dst, const half* src) {
+  *dst = *src;
+}
+
+// half x 2
+template <>
+struct vec_t<half, 2> {
+  half2 data;
+
+  APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
+  APHRODITE_INLINE const half& operator[](size_t i) const {
+    return ((const half*)(&data))[i];
+  }
+  APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
+  APHRODITE_INLINE void fill(half val);
+  APHRODITE_INLINE void load(const half* ptr);
+  APHRODITE_INLINE void store(half* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(half* dst, const half* src);
+};
+
+APHRODITE_INLINE void vec_t<half, 2>::fill(half val) {
+  data = make_half2(val, val);
+}
+
+APHRODITE_INLINE void vec_t<half, 2>::load(const half* ptr) {
+  data = *((half2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<half, 2>::store(half* ptr) const {
+  *((half2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<half, 2>::memcpy(half* dst, const half* src) {
+  *((half2*)dst) = *((half2*)src);
+}
+
+// half x 4
+
+template <>
+struct vec_t<half, 4> {
+  uint2 data;
+
+  APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
+  APHRODITE_INLINE const half& operator[](size_t i) const {
+    return ((const half*)(&data))[i];
+  }
+  APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
+  APHRODITE_INLINE void fill(half val);
+  APHRODITE_INLINE void load(const half* ptr);
+  APHRODITE_INLINE void store(half* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(half* dst, const half* src);
+};
+
+APHRODITE_INLINE void vec_t<half, 4>::fill(half val) {
+  *(half2*)(&data.x) = make_half2(val, val);
+  *(half2*)(&data.y) = make_half2(val, val);
+}
+
+APHRODITE_INLINE void vec_t<half, 4>::load(const half* ptr) {
+  data = *((uint2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<half, 4>::store(half* ptr) const {
+  *((uint2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<half, 4>::memcpy(half* dst, const half* src) {
+  *((uint2*)dst) = *((uint2*)src);
+}
+
+// half x 8 or more
+
+template <size_t vec_size>
+struct vec_t<half, vec_size> {
+  uint4 data[vec_size / 8];
+  APHRODITE_INLINE half& operator[](size_t i) { return ((half*)data)[i]; }
+  APHRODITE_INLINE const half& operator[](size_t i) const {
+    return ((const half*)data)[i];
+  }
+  APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
+  APHRODITE_INLINE void fill(half val) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      *(half2*)(&(data[i].x)) = make_half2(val, val);
+      *(half2*)(&(data[i].y)) = make_half2(val, val);
+      *(half2*)(&(data[i].z)) = make_half2(val, val);
+      *(half2*)(&(data[i].w)) = make_half2(val, val);
+    }
+  }
+  APHRODITE_INLINE void load(const half* ptr) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      data[i] = ((uint4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(half* ptr) const {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      ((uint4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(half* dst, const half* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      ((uint4*)dst)[i] = ((uint4*)src)[i];
+    }
+  }
+};
+
+/******************* vec_t<nv_bfloat16> *******************/
+
+// nv_bfloat16 x 1
+template <>
+struct vec_t<nv_bfloat16, 1> {
+  nv_bfloat16 data;
+  APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
+    return ((nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
+    return ((const nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE nv_bfloat16* ptr() {
+    return reinterpret_cast<nv_bfloat16*>(&data);
+  }
+  APHRODITE_INLINE void fill(nv_bfloat16 val);
+  APHRODITE_INLINE void load(const nv_bfloat16* ptr);
+  APHRODITE_INLINE void store(nv_bfloat16* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
+};
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::fill(nv_bfloat16 val) {
+  data = val;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16* ptr) {
+  data = *ptr;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16* ptr) const {
+  *ptr = data;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::memcpy(nv_bfloat16* dst,
+                                                    const nv_bfloat16* src) {
+  *dst = *src;
+}
+
+// nv_bfloat16 x 2
+template <>
+struct vec_t<nv_bfloat16, 2> {
+  nv_bfloat162 data;
+
+  APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
+    return ((nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
+    return ((const nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE nv_bfloat16* ptr() {
+    return reinterpret_cast<nv_bfloat16*>(&data);
+  }
+  APHRODITE_INLINE void fill(nv_bfloat16 val);
+  APHRODITE_INLINE void load(const nv_bfloat16* ptr);
+  APHRODITE_INLINE void store(nv_bfloat16* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
+};
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::fill(nv_bfloat16 val) {
+  data = make_bfloat162(val, val);
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16* ptr) {
+  data = *((nv_bfloat162*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16* ptr) const {
+  *((nv_bfloat162*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::memcpy(nv_bfloat16* dst,
+                                                    const nv_bfloat16* src) {
+  *((nv_bfloat162*)dst) = *((nv_bfloat162*)src);
+}
+
+// nv_bfloat16 x 4
+
+template <>
+struct vec_t<nv_bfloat16, 4> {
+  uint2 data;
+
+  APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
+    return ((nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
+    return ((const nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE nv_bfloat16* ptr() {
+    return reinterpret_cast<nv_bfloat16*>(&data);
+  }
+  APHRODITE_INLINE void fill(nv_bfloat16 val);
+  APHRODITE_INLINE void load(const nv_bfloat16* ptr);
+  APHRODITE_INLINE void store(nv_bfloat16* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
+};
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::fill(nv_bfloat16 val) {
+  *(nv_bfloat162*)(&data.x) = make_bfloat162(val, val);
+  *(nv_bfloat162*)(&data.y) = make_bfloat162(val, val);
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16* ptr) {
+  data = *((uint2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16* ptr) const {
+  *((uint2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::memcpy(nv_bfloat16* dst,
+                                                    const nv_bfloat16* src) {
+  *((uint2*)dst) = *((uint2*)src);
+}
+
+// nv_bfloat16 x 8 or more
+
+template <size_t vec_size>
+struct vec_t<nv_bfloat16, vec_size> {
+  uint4 data[vec_size / 8];
+
+  APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
+    return ((nv_bfloat16*)data)[i];
+  }
+  APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
+    return ((const nv_bfloat16*)data)[i];
+  }
+  APHRODITE_INLINE nv_bfloat16* ptr() {
+    return reinterpret_cast<nv_bfloat16*>(&data);
+  }
+  APHRODITE_INLINE void fill(nv_bfloat16 val) {
+#pragma unoll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      *(nv_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val);
+      *(nv_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val);
+      *(nv_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val);
+      *(nv_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val);
+    }
+  }
+  APHRODITE_INLINE void load(const nv_bfloat16* ptr) {
+#pragma unoll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      data[i] = ((uint4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(nv_bfloat16* ptr) const {
+#pragma unoll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      ((uint4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(nv_bfloat16* dst,
+                                      const nv_bfloat16* src) {
+#pragma unoll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      ((uint4*)dst)[i] = ((uint4*)src)[i];
+    }
+  }
+};
+
+/******************* vec_t<float> *******************/
+
+// float x 1
+
+template <>
+struct vec_t<float, 1> {
+  float data;
+
+  APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
+  APHRODITE_INLINE const float& operator[](size_t i) const {
+    return ((const float*)(&data))[i];
+  }
+  APHRODITE_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
+  APHRODITE_INLINE void fill(float val);
+  APHRODITE_INLINE void load(const float* ptr);
+  APHRODITE_INLINE void store(float* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(float* dst, const float* src);
+};
+
+APHRODITE_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
+
+APHRODITE_INLINE void vec_t<float, 1>::load(const float* ptr) { data = *ptr; }
+
+APHRODITE_INLINE void vec_t<float, 1>::store(float* ptr) const { *ptr = data; }
+
+APHRODITE_INLINE void vec_t<float, 1>::memcpy(float* dst, const float* src) {
+  *dst = *src;
+}
+
+// float x 2
+
+template <>
+struct vec_t<float, 2> {
+  float2 data;
+
+  APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
+  APHRODITE_INLINE const float& operator[](size_t i) const {
+    return ((const float*)(&data))[i];
+  }
+  APHRODITE_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
+  APHRODITE_INLINE void fill(float val);
+  APHRODITE_INLINE void load(const float* ptr);
+  APHRODITE_INLINE void store(float* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(float* dst, const float* src);
+};
+
+APHRODITE_INLINE void vec_t<float, 2>::fill(float val) {
+  data = make_float2(val, val);
+}
+
+APHRODITE_INLINE void vec_t<float, 2>::load(const float* ptr) {
+  data = *((float2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<float, 2>::store(float* ptr) const {
+  *((float2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<float, 2>::memcpy(float* dst, const float* src) {
+  *((float2*)dst) = *((float2*)src);
+}
+
+// float x 4 or more
+template <size_t vec_size>
+struct vec_t<float, vec_size> {
+  float4 data[vec_size / 4];
+
+  APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; }
+  APHRODITE_INLINE const float& operator[](size_t i) const {
+    return ((const float*)(data))[i];
+  }
+  APHRODITE_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
+  APHRODITE_INLINE void fill(float val) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 4; ++i) {
+      data[i] = make_float4(val, val, val, val);
+    }
+  }
+  APHRODITE_INLINE void load(const float* ptr) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 4; ++i) {
+      data[i] = ((float4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(float* ptr) const {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 4; ++i) {
+      ((float4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(float* dst, const float* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 4; ++i) {
+      ((float4*)dst)[i] = ((float4*)src)[i];
+    }
+  }
+};
+
+}  // namespace aphrodite
+
+#endif  // VEC_DTYPES_CUH_

+ 22 - 0
kernels/torch_bindings.cpp

@@ -207,6 +207,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA,
   ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA,
            &fp_eXmY_linear_forward_cuda);
            &fp_eXmY_linear_forward_cuda);
 
 
+  // Sampling Kernels
+  ops.def("sampling_from_probs", &sampling_from_probs);
+  ops.impl("sampling_from_probs", torch::kCUDA, &sampling_from_probs);
+  ops.def("top_k_sampling_from_probs", &top_k_sampling_from_probs);
+  ops.impl("top_k_sampling_from_probs", torch::kCUDA,
+           &top_k_sampling_from_probs);
+  ops.def("min_p_sampling_from_probs", &min_p_sampling_from_probs);
+  ops.impl("min_p_sampling_from_probs", torch::kCUDA,
+           &min_p_sampling_from_probs);
+  ops.def("top_p_sampling_from_probs", &top_p_sampling_from_probs);
+  ops.impl("top_p_sampling_from_probs", torch::kCUDA,
+           &top_p_sampling_from_probs);
+  ops.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs);
+  ops.impl("top_k_top_p_sampling_from_probs", torch::kCUDA,
+           &top_k_top_p_sampling_from_probs);
+  ops.def("top_k_renorm_prob", &top_k_renorm_prob);
+  ops.impl("top_k_renorm_prob", torch::kCUDA, &top_k_renorm_prob);
+  ops.def("top_p_renorm_prob", &top_p_renorm_prob);
+  ops.impl("top_p_renorm_prob", torch::kCUDA, &top_p_renorm_prob);
+  ops.def("top_k_mask_logits", &top_k_mask_logits);
+  ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
+
 #endif
 #endif
 
 
   // Quantized GEMM for GPTQ.
   // Quantized GEMM for GPTQ.

+ 9 - 2
tests/benchmarks/engine/throughput.py

@@ -75,6 +75,7 @@ def run_aphrodite(
     dtype: str,
     dtype: str,
     max_model_len: Optional[int],
     max_model_len: Optional[int],
     enforce_eager: bool,
     enforce_eager: bool,
+    max_seq_len_to_capture: int,
     kv_cache_dtype: str,
     kv_cache_dtype: str,
     quantization_param_path: Optional[str],
     quantization_param_path: Optional[str],
     device: str,
     device: str,
@@ -100,6 +101,7 @@ def run_aphrodite(
         max_model_len=max_model_len,
         max_model_len=max_model_len,
         gpu_memory_utilization=gpu_memory_utilization,
         gpu_memory_utilization=gpu_memory_utilization,
         enforce_eager=enforce_eager,
         enforce_eager=enforce_eager,
+        max_seq_len_to_capture=max_seq_len_to_capture,
         kv_cache_dtype=kv_cache_dtype,
         kv_cache_dtype=kv_cache_dtype,
         quantization_param_path=quantization_param_path,
         quantization_param_path=quantization_param_path,
         device=device,
         device=device,
@@ -233,8 +235,8 @@ def main(args: argparse.Namespace):
             args.quant_llm_fp_bits,
             args.quant_llm_fp_bits,
             args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
             args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
             args.trust_remote_code, args.dtype, args.max_model_len,
             args.trust_remote_code, args.dtype, args.max_model_len,
-            args.enforce_eager, args.kv_cache_dtype,
-            args.quantization_param_path, args.device,
+            args.enforce_eager, args.max_seq_len_to_capture,
+            args.kv_cache_dtype, args.quantization_param_path, args.device,
             args.enable_prefix_caching, args.enable_chunked_prefill,
             args.enable_prefix_caching, args.enable_chunked_prefill,
             args.max_num_batched_tokens, args.distributed_executor_backend,
             args.max_num_batched_tokens, args.distributed_executor_backend,
             args.gpu_memory_utilization, args.download_dir, args.load_format,
             args.gpu_memory_utilization, args.download_dir, args.load_format,
@@ -344,6 +346,11 @@ if __name__ == "__main__":
     parser.add_argument("--enforce-eager",
     parser.add_argument("--enforce-eager",
                         action="store_true",
                         action="store_true",
                         help="enforce eager execution")
                         help="enforce eager execution")
+    parser.add_argument("--max-seq-len-to-capture",
+                        type=int,
+                        default=None,
+                        help="The maximum sequence length to capture for "
+                        "CUDA graphs.")
     parser.add_argument(
     parser.add_argument(
         '--kv-cache-dtype',
         '--kv-cache-dtype',
         type=str,
         type=str,