Przeglądaj źródła

feat: add cuda sampling kernels for top_k and top_p

AlpinDale 4 miesięcy temu
rodzic
commit
22422d962b

+ 2 - 1
CMakeLists.txt

@@ -218,7 +218,8 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
     "kernels/quantization/gguf/gguf_kernel.cu"
     "kernels/quantization/gptq_marlin/awq_marlin_repack.cu"
     "kernels/quantization/fp8/fp8_marlin.cu"
-    "kernels/all_reduce/custom_all_reduce.cu")
+    "kernels/all_reduce/custom_all_reduce.cu"
+    "kernels/sampling/sampling.cu")
 
   # Add CUTLASS and GPTQ Marlin kernels if not MSVC
   if(NOT MSVC)

+ 117 - 1
aphrodite/_custom_ops.py

@@ -1,6 +1,6 @@
 import contextlib
 import functools
-from typing import List, Optional, Tuple, Type
+from typing import List, Optional, Tuple, Type, Union
 
 import torch
 from loguru import logger
@@ -632,6 +632,122 @@ def register_graph_buffers(fa: int, handles: List[str],
     torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
 
 
+# Sampling Kernels
+def sampling_from_probs(probs: torch.Tensor,
+                        uniform_samplers: torch.Tensor,
+                        deterministic: bool = True,
+                        check_nan: bool = False) -> torch.Tensor:
+    if check_nan and torch.any(torch.isnan(probs)):
+        raise ValueError("NaN detected in probs")
+    return torch.ops._C.sampling_from_probs(probs, uniform_samplers,
+                                            deterministic)
+
+def _to_tensor_scalar_tuple(x):
+    if isinstance(x, torch.Tensor):
+        return (x, 0)
+    else:
+        return (None, x)
+def top_p_sampling_from_probs(
+        probs: torch.Tensor,
+        uniform_samples: torch.Tensor,
+        top_p: Union[torch.Tensor, float],
+        deterministic: bool = True,
+        check_nan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
+    if check_nan and torch.any(torch.isnan(probs)):
+        raise ValueError("NaN detected in probs")
+    return torch.ops._C.top_p_sampling_from_probs(
+        probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic)
+
+def top_k_sampling_from_probs(
+        probs: torch.Tensor,
+        uniform_samples: torch.Tensor,
+        top_k: Union[torch.Tensor, int],
+        deterministic: bool = True,
+        check_nan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
+    if check_nan and torch.any(torch.isnan(probs)):
+        raise ValueError("NaN detected in probs")
+    return torch.ops._C.top_k_sampling_from_probs(
+        probs, uniform_samples, *_to_tensor_scalar_tuple(top_k), deterministic)
+
+def min_p_sampling_from_probs(
+        probs: torch.Tensor,
+        uniform_samples: torch.Tensor,
+        min_p: Union[torch.Tensor, float],
+        deterministic: bool = True,
+        check_nan: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
+    if check_nan and torch.any(torch.isnan(probs)):
+        raise ValueError("NaN detected in probs")
+    return torch.ops._C.min_p_sampling_from_probs(
+        probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic)
+
+def top_k_mask_logits(
+    logits: torch.Tensor,
+    top_k: Union[torch.Tensor, int],
+) -> torch.Tensor:
+    return torch.ops._C.top_k_mask_logits(logits,
+                                          *_to_tensor_scalar_tuple(top_k))
+
+def top_p_renorm_prob(
+    probs: torch.Tensor,
+    top_p: Union[torch.Tensor, float],
+) -> torch.Tensor:
+    return torch.ops._C.top_p_renorm_prob(probs,
+                                          *_to_tensor_scalar_tuple(top_p))
+
+def top_k_renorm_prob(
+    probs: torch.Tensor,
+    top_k: Union[torch.Tensor, int],
+) -> torch.Tensor:
+    return torch.ops._C.top_k_renorm_prob(probs,
+                                          *_to_tensor_scalar_tuple(top_k))
+
+def top_k_top_p_sampling_from_logits(
+    probs: torch.Tensor,
+    uniform_samples: torch.Tensor,
+    top_k: Union[torch.Tensor, int],
+    top_p: Union[torch.Tensor, float],
+    filter_apply_order: str = "top_k_first",
+    deterministic: bool = True,
+    check_nan: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    if filter_apply_order == "top_k_first":
+        masked_logits = top_k_mask_logits(probs, top_k)
+        probs = torch.softmax(masked_logits, dim=-1)
+        return top_p_sampling_from_probs(probs, uniform_samples, top_p,
+                                         deterministic, check_nan)
+    elif filter_apply_order == "joint":
+        probs = torch.softmax(probs, dim=-1)
+        if check_nan and torch.any(torch.isnan(probs)):
+            raise ValueError("NaN detected in probs")
+        return torch.ops._C.top_k_top_p_sampling_from_logits(
+            probs, uniform_samples, *_to_tensor_scalar_tuple(top_k),
+            *_to_tensor_scalar_tuple(top_p), deterministic)
+    else:
+        raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
+
+def top_k_top_p_sampling_from_probs(
+    probs: torch.Tensor,
+    uniform_samples: torch.Tensor,
+    top_k: Union[torch.Tensor, int],
+    top_p: Union[torch.Tensor, float],
+    filter_apply_order: str = "top_k_first",
+    deterministic: bool = True,
+    check_nan: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    if filter_apply_order == "top_k_first":
+        renorm_probs = top_k_renorm_prob(probs, top_k)
+        return top_p_sampling_from_probs(renorm_probs, uniform_samples, top_p,
+                                         deterministic, check_nan)
+    elif filter_apply_order == "joint":
+        if check_nan and torch.any(torch.isnan(probs)):
+            raise ValueError("NaN detected in probs")
+        return torch.ops._C.top_k_top_p_sampling_from_probs(
+            probs, uniform_samples, *_to_tensor_scalar_tuple(top_k),
+            *_to_tensor_scalar_tuple(top_p), deterministic)
+    else:
+        raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}")
+
+
 # TODO: remove this later
 names_and_values = globals()
 names_and_values_to_update = {}

+ 76 - 22
aphrodite/modeling/layers/sampler.py

@@ -1,11 +1,14 @@
 """A layer that samples the next tokens from the model's outputs."""
 import itertools
+import os
+import warnings
 from math import inf
 from typing import Dict, List, Optional, Tuple
 
 import torch
 import torch.nn as nn
 
+import aphrodite._custom_ops as ops
 from aphrodite.common.sampling_params import SamplingType
 from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
                                        PromptLogprobs, SampleLogprobs,
@@ -27,6 +30,11 @@ SampleResultType = List[Tuple[List[int], List[int]]]
 # that this temperature well-uses the fp16 space after the logits are offset.
 _TEMPERATURE_MINIMUM = 2e-5
 
+# If enabled, we switch to a more performant implementation
+# of top-k and top-p
+APHRODITE_USE_SAMPLING_KERNELS = bool(int(
+    os.getenv("APHRODITE_USE_SAMPLING_KERNELS", "0")))
+
 
 class Sampler(nn.Module):
     """Samples the next tokens from the model's outputs.
@@ -155,7 +163,7 @@ class Sampler(nn.Module):
         if do_nsigmas:
             logits = _apply_top_nsigma(logits, sampling_tensors.nsigmas)
 
-        if do_top_p_top_k:
+        if do_top_p_top_k and not APHRODITE_USE_SAMPLING_KERNELS:
             logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
                                         sampling_tensors.top_ks)
 
@@ -816,14 +824,7 @@ def _multinomial(
     seq_groups: Optional[List[SequenceGroupToSample]] = None,
 ) -> torch.Tensor:
     if num_samples > 1:
-        # This is equivalent to torch.repeat_interleaved (which also
-        # forces a GPU<->CPU sync).
-        # This allows us to do sampling with replacement by creating
-        # num_samples copies of each row in the tensor, and then
-        # batch sampling the resulting tensor.
-        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
-                                         probs.shape[1]).contiguous().view(
-                                             -1, probs.shape[1])
+        probs = probs.repeat_interleave(num_samples, dim=0)
     q = torch.empty_like(probs)
     if seq_groups is None:
         q.exponential_()
@@ -831,17 +832,57 @@ def _multinomial(
         sample_idx = 0
         for seq_group in seq_groups:
             seq_ids = seq_group.seq_ids
-            next_sample_idx = sample_idx + len(seq_ids) * num_samples
-            q[sample_idx:next_sample_idx].exponential_(
-                generator=seq_group.generator)
-            sample_idx = next_sample_idx
+            stride = len(seq_ids) * num_samples
+            assert seq_group.generator is not None
+            q[sample_idx:sample_idx +
+              stride].exponential_(generator=seq_group.generator)
+            sample_idx += stride
     return probs.div_(q).argmax(dim=1).view(-1, num_samples)
 
 
+def _top_k_top_p_multinomial_with_kernels(
+        probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
+        num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
+    max_top_k_round = 32
+    if num_samples > 1:
+        probs = probs.repeat_interleave(num_samples, dim=0)
+        top_ks = top_ks.repeat_interleave(num_samples)
+        top_ps = top_ps.repeat_interleave(num_samples)
+    batch_size = probs.shape[0]
+    uniform_samples = torch.empty((max_top_k_round, batch_size),
+                                  device=probs.device)
+    if seq_groups is None:
+        uniform_samples.uniform_()
+    else:
+        sample_idx = 0
+        for seq_group in seq_groups:
+            seq_ids = seq_group.seq_ids
+            stride = len(seq_ids) * num_samples
+            assert seq_group.generator is not None
+            uniform_samples[:, sample_idx:sample_idx +
+                            stride].uniform_(generator=seq_group.generator)
+            sample_idx += stride
+    batch_next_token_ids, success = ops.top_k_top_p_sampling_from_probs(
+        probs,
+        uniform_samples,
+        top_ks,
+        top_ps,
+    )
+    if not success.all():
+        warnings.warn("CUDA rejection sampling failed, fallback.",
+                      stacklevel=1)
+        probs = ops.top_k_renorm_prob(probs, top_ks)
+        probs = ops.top_p_renorm_prob(probs, top_ps)
+        batch_next_token_ids = ops.sampling_from_probs(
+            probs, uniform_samples[0])
+    return batch_next_token_ids.view(-1, num_samples)
+
+
 def _sample_with_torch(
     probs: torch.Tensor,
     logprobs: torch.Tensor,
     sampling_metadata: SamplingMetadata,
+    sampling_tensors: SamplingTensors,
     include_gpu_probs_tensor: bool,
     modify_greedy_probs: bool,
 ) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
@@ -897,17 +938,29 @@ def _sample_with_torch(
                     sampling_params = seq_group.sampling_params
                     max_best_of_in_batch = max(max_best_of_in_batch,
                                                sampling_params.best_of)
-            seeded_args = {} if sampling_type == SamplingType.RANDOM else {
-                "seq_groups": seq_groups,
-            }
 
-            multinomial_samples[sampling_type] = _multinomial(
-                probs[long_sample_indices], max_best_of_in_batch,
-                **seeded_args)
-            if include_gpu_probs_tensor:
+            seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
+                              seq_groups)
+            if APHRODITE_USE_SAMPLING_KERNELS is not None:
+                multinomial_samples[
+                    sampling_type] = _top_k_top_p_multinomial_with_kernels(
+                        probs[long_sample_indices],
+                        sampling_tensors.top_ks[long_sample_indices],
+                        sampling_tensors.top_ps[long_sample_indices],
+                        max_best_of_in_batch,
+                        seq_groups_arg,
+                    )
+            else:
+                multinomial_samples[sampling_type] = _multinomial(
+                    probs[long_sample_indices],
+                    max_best_of_in_batch,
+                    seq_groups=seq_groups_arg)
+
+            if sampled_token_ids_tensor is not None:
                 # Store sampled tokens in output tensor.
-                sampled_token_ids_tensor[
-                    long_sample_indices] = multinomial_samples[sampling_type]
+                sampled_token_ids_tensor[long_sample_indices] = \
+                    multinomial_samples[sampling_type].to(torch.long)
+
         elif sampling_type == SamplingType.BEAM:
             beam_search_logprobs = logprobs[sample_indices]
         else:
@@ -1035,6 +1088,7 @@ def _sample(
         probs,
         logprobs,
         sampling_metadata,
+        sampling_tensors,
         include_gpu_probs_tensor=include_gpu_probs_tensor,
         modify_greedy_probs=modify_greedy_probs,
     )

+ 32 - 0
kernels/ops.h

@@ -102,4 +102,36 @@ at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
                              const c10::optional<at::Tensor>& initial_states_,
                              const c10::optional<at::Tensor>& final_states_out_,
                              bool silu_activation);
+
+// Sampling kernels
+torch::Tensor sampling_from_probs(torch::Tensor probs,
+                                  torch::Tensor uniform_samples,
+                                  bool deterministic);
+std::vector<torch::Tensor> top_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
+    bool deterministic);
+std::vector<torch::Tensor> top_k_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
+    bool deterministic);
+std::vector<torch::Tensor> min_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_min_p_arr, double min_p_val,
+    bool deterministic);
+std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
+    std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
+    bool deterministic);
+torch::Tensor top_p_renorm_prob(torch::Tensor probs,
+                                std::optional<torch::Tensor> maybe_top_p_arr,
+                                double top_p_val);
+torch::Tensor top_k_renorm_prob(torch::Tensor probs,
+                                std::optional<torch::Tensor> maybe_top_k_arr,
+                                int64_t top_k_val);
+torch::Tensor top_k_mask_logits(torch::Tensor logits,
+                                std::optional<torch::Tensor> maybe_top_k_arr,
+                                int64_t top_k_val);
+
 #endif

+ 159 - 0
kernels/sampling/math.cuh

@@ -0,0 +1,159 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2023 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef APHRODITE_MATH_CUH_
+#define APHRODITE_MATH_CUH_
+
+#include <cuda_fp16.h>
+#include <cuda_runtime.h>
+
+namespace aphrodite {
+namespace math {
+
+// log2(e)
+constexpr float log2e = 1.44269504088896340736f;
+
+__forceinline__ __device__ half2 uint32_as_half2(uint32_t x) {
+  return *(half2*)&x;
+}
+
+__forceinline__ __device__ uint32_t half2_as_uint32(half2 x) {
+  return *(uint32_t*)&x;
+}
+
+/*!
+ * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x
+ * \param x input
+ */
+__forceinline__ __device__ float ptx_exp2(float x) {
+  float y;
+  asm volatile("ex2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x)
+ * \param x input
+ */
+__forceinline__ __device__ float ptx_log2(float x) {
+  float y;
+  asm volatile("lg2.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX ex2.approx.f16x2 instruction, which computes 2^x
+ * \param x input
+ */
+__forceinline__ __device__ half2 ptx_exp2(half2 x) {
+  uint32_t y_u32;
+  uint32_t x_u32 = half2_as_uint32(x);
+  asm volatile("ex2.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
+  return uint32_as_half2(y_u32);
+}
+
+/*!
+ * \brief Wrapper of PTX ex2.approx.f16 instruction, which computes 2^x
+ * \param x input
+ */
+__forceinline__ __device__ half ptx_exp2(half x) {
+  ushort y_u16;
+  asm volatile("ex2.approx.f16 %0, %1;"
+               : "=h"(y_u16)
+               : "h"(__half_as_ushort(x)));
+  return __ushort_as_half(y_u16);
+}
+
+/*!
+ * \brief Wrapper of PTX rcp.approx instruction, which computes 1/x
+ * \param x input
+ */
+__forceinline__ __device__ float ptx_rcp(float x) {
+  float y;
+  asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly
+ * shuffle between threads in a warp. \param x The value in the source lane
+ * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^
+ * delta]
+ */
+__forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) {
+  float y;
+  asm volatile("shfl.sync.bfly.b32 %0, %1, %2, 0x1f, 0xffffffff;"
+               : "=f"(y)
+               : "f"(x), "r"(lane_mask));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX shfl.sync.bfly instruction on half2, which performs a
+ * butterfly shuffle between threads in a warp. \param x The value in the source
+ * lane \param lane_mask The mask to perform thread index xor with: y[i] <- x[i
+ * ^ lane_mask]
+ */
+__forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) {
+  return __shfl_xor_sync(0xffffffff, x, lane_mask);
+}
+
+/*!
+ * \brief Wrapper of PTX rsqrt approximation instruction, which computes
+ * 1/sqrt(x) \param x input
+ */
+__forceinline__ __device__ float rsqrt(float x) {
+  float y;
+  asm volatile("rsqrt.approx.ftz.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX tanh.approx.f32 instruction, which computes tanh(x)
+ * \param x input
+ */
+__forceinline__ __device__ float tanh(float x) {
+  float y;
+  asm volatile("tanh.approx.f32 %0, %1;" : "=f"(y) : "f"(x));
+  return y;
+}
+
+/*!
+ * \brief Wrapper of PTX tanh.approx.f16x2 instruction, which computes tanh(x)
+ * \param x input
+ */
+__forceinline__ __device__ half2 tanh(half2 x) {
+  uint32_t y_u32;
+  uint32_t x_u32 = half2_as_uint32(x);
+  asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(y_u32) : "r"(x_u32));
+  return uint32_as_half2(y_u32);
+}
+
+/*!
+ * \brief Wrapper of PTX tanh.approx.f16 instruction, which computes tanh(x)
+ * \param x input
+ */
+__forceinline__ __device__ half tanh(half x) {
+  ushort y_u16;
+  asm volatile("tanh.approx.f16 %0, %1;"
+               : "=h"(y_u16)
+               : "h"(__half_as_ushort(x)));
+  return __ushort_as_half(y_u16);
+}
+
+}  // namespace math
+}  // namespace aphrodite
+#endif  // APHRODITE_MATH_CUH_

+ 391 - 0
kernels/sampling/sampling.cu

@@ -0,0 +1,391 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2024 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#include <c10/cuda/CUDAStream.h>
+
+#include "sampling.cuh"
+#include "../ops.h"
+#include "utils.cuh"
+
+// Check utils
+#define CUDA_CHECK(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
+
+#define CHECK_CONTIGUOUS(x) \
+  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+
+#define CHECK_INPUT(x) \
+  CUDA_CHECK(x);       \
+  CHECK_CONTIGUOUS(x)
+
+#define CHECK_EQ(a, b) \
+  TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
+
+#define CHECK_GE(a, b) \
+  TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
+
+#define CHECK_DIM(d, x) \
+  TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
+
+using namespace aphrodite;
+
+torch::Tensor sampling_from_probs(torch::Tensor probs,
+                                  torch::Tensor uniform_samples,
+                                  bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);            // probs: (batch_size, vocab_size)
+  CHECK_DIM(1, uniform_samples);  // uniform_samples: (batch_size)
+  CHECK_EQ(probs.size(0), uniform_samples.size(0));
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+
+  cudaError_t status = sampling::SamplingFromProb(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      static_cast<int*>(samples.data_ptr()), batch_size, vocab_size,
+      deterministic, torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "SamplingFromProbs failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+  return samples;
+}
+
+std::vector<torch::Tensor> top_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
+    bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);  // probs: (batch_size, vocab_size)
+  CHECK_DIM(
+      2, uniform_samples);  // uniform_samples: (max_top_p_rounds, batch_size)
+  CHECK_EQ(probs.size(0), uniform_samples.size(1));
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  unsigned int max_top_p_rounds = uniform_samples.size(0);
+  bool has_top_p_arr = maybe_top_p_arr.has_value();
+  auto top_p_arr = maybe_top_p_arr.value_or(
+      torch::empty({0}, torch::dtype(torch::kFloat32)));
+  if (has_top_p_arr) {
+    CHECK_INPUT(top_p_arr);
+    CHECK_DIM(1, top_p_arr);  // top_p_arr: (batch_size,)
+    CHECK_EQ(top_p_arr.size(0), batch_size);
+    CHECK_EQ(top_p_arr.device(), device);
+  }
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+  top_p_arr = top_p_arr.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+  auto success =
+      torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
+
+  cudaError_t status = sampling::TopPSamplingFromProb<float, int>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      static_cast<int*>(samples.data_ptr()),
+      static_cast<bool*>(success.data_ptr()),
+      has_top_p_arr ? static_cast<float*>(top_p_arr.data_ptr()) : nullptr,
+      batch_size, top_p_val, vocab_size, max_top_p_rounds, deterministic,
+      torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "TopPSamplingFromProbs failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+
+  return {samples, success};
+}
+
+std::vector<torch::Tensor> top_k_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
+    bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);  // probs: (batch_size, vocab_size)
+  CHECK_DIM(
+      2, uniform_samples);  // uniform_samples: (max_top_k_rounds, batch_size)
+  CHECK_EQ(probs.size(0), uniform_samples.size(1));
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  unsigned int max_top_k_rounds = uniform_samples.size(0);
+  bool has_top_k_arr = maybe_top_k_arr.has_value();
+  auto top_k_arr =
+      maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
+  if (has_top_k_arr) {
+    CHECK_INPUT(top_k_arr);
+    CHECK_DIM(1, top_k_arr);  // top_k_arr: (batch_size,)
+    CHECK_EQ(top_k_arr.size(0), batch_size);
+    CHECK_EQ(top_k_arr.device(), device);
+  }
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+  top_k_arr = top_k_arr.to(torch::kInt32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+  auto success =
+      torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
+
+  cudaError_t status = sampling::TopKSamplingFromProb<float, int>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      static_cast<int*>(samples.data_ptr()),
+      static_cast<bool*>(success.data_ptr()),
+      has_top_k_arr ? static_cast<float*>(top_k_arr.data_ptr()) : nullptr,
+      batch_size, top_k_val, vocab_size, max_top_k_rounds, deterministic,
+      torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "TopKSamplingFromProbs failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+
+  return {samples, success};
+}
+
+std::vector<torch::Tensor> min_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_min_p_arr, double min_p_val,
+    bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);            // probs: (batch_size, vocab_size)
+  CHECK_DIM(2, uniform_samples);  // uniform_samples: (max_rounds, batch_size)
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  unsigned int max_rounds = uniform_samples.size(0);
+  CHECK_EQ(uniform_samples.size(1), batch_size);
+  bool has_min_p_arr = maybe_min_p_arr.has_value();
+  auto min_p_arr = maybe_min_p_arr.value_or(
+      torch::empty({0}, torch::dtype(torch::kFloat32)));
+  if (has_min_p_arr) {
+    CHECK_INPUT(min_p_arr);
+    CHECK_DIM(1, min_p_arr);  // min_p_arr: (batch_size,)
+    CHECK_EQ(min_p_arr.size(0), batch_size);
+    CHECK_EQ(min_p_arr.device(), device);
+  }
+  min_p_arr = min_p_arr.to(torch::kFloat32);
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+  auto success =
+      torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
+
+  cudaError_t status = sampling::MinPSamplingFromProb<float, int>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      has_min_p_arr ? static_cast<float*>(min_p_arr.data_ptr()) : nullptr,
+      static_cast<int*>(samples.data_ptr()),
+      static_cast<bool*>(success.data_ptr()), batch_size, min_p_val, vocab_size,
+      max_rounds, deterministic, torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "MinPSamplingFromProb failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+
+  return {samples, success};
+}
+
+std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
+    torch::Tensor probs, torch::Tensor uniform_samples,
+    std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
+    std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
+    bool deterministic) {
+  CHECK_INPUT(probs);
+  CHECK_INPUT(uniform_samples);
+  auto device = probs.device();
+  CHECK_EQ(uniform_samples.device(), device);
+  CHECK_DIM(2, probs);            // probs: (batch_size, vocab_size)
+  CHECK_DIM(2, uniform_samples);  // uniform_samples: (max_rounds, batch_size)
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  unsigned int max_rounds = uniform_samples.size(0);
+  CHECK_EQ(uniform_samples.size(1), batch_size);
+  bool has_top_k_arr = maybe_top_k_arr.has_value();
+  auto top_k_arr =
+      maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
+  if (has_top_k_arr) {
+    CHECK_INPUT(top_k_arr);
+    CHECK_DIM(1, top_k_arr);  // top_k_arr: (batch_size,)
+    CHECK_EQ(top_k_arr.size(0), batch_size);
+    CHECK_EQ(top_k_arr.device(), device);
+  }
+  top_k_arr = top_k_arr.to(torch::kInt32);
+  bool has_top_p_arr = maybe_top_p_arr.has_value();
+  auto top_p_arr = maybe_top_p_arr.value_or(
+      torch::empty({0}, torch::dtype(torch::kFloat32)));
+  if (has_top_p_arr) {
+    CHECK_INPUT(top_p_arr);
+    CHECK_DIM(1, top_p_arr);  // top_p_arr: (batch_size,)
+    CHECK_EQ(top_p_arr.size(0), batch_size);
+    CHECK_EQ(top_p_arr.device(), device);
+  }
+  top_p_arr = top_p_arr.to(torch::kFloat32);
+  probs = probs.to(torch::kFloat32);
+  uniform_samples = uniform_samples.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto samples =
+      torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
+  auto success =
+      torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
+
+  cudaError_t status = sampling::TopKTopPSamplingFromProb<float, int>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(uniform_samples.data_ptr()),
+      has_top_k_arr ? static_cast<int*>(top_k_arr.data_ptr()) : nullptr,
+      has_top_p_arr ? static_cast<float*>(top_p_arr.data_ptr()) : nullptr,
+      static_cast<int*>(samples.data_ptr()),
+      static_cast<bool*>(success.data_ptr()), batch_size, top_k_val, top_p_val,
+      vocab_size, max_rounds, deterministic, torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "TopKTopPSamplingFromProbs failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+
+  return {samples, success};
+}
+
+torch::Tensor top_p_renorm_prob(torch::Tensor probs,
+                                std::optional<torch::Tensor> maybe_top_p_arr,
+                                double top_p_val) {
+  CHECK_INPUT(probs);
+  auto device = probs.device();
+  CHECK_DIM(2, probs);  // probs: (batch_size, vocab_size)
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  bool has_top_p_arr = maybe_top_p_arr.has_value();
+  auto top_p_arr = maybe_top_p_arr.value_or(
+      torch::empty({0}, torch::dtype(torch::kFloat32)));
+  if (has_top_p_arr) {
+    CHECK_INPUT(top_p_arr);
+    CHECK_DIM(1, top_p_arr);  // top_p_arr: (batch_size,)
+    CHECK_EQ(top_p_arr.size(0), batch_size);
+    CHECK_EQ(top_p_arr.device(), device);
+  }
+  top_p_arr = top_p_arr.to(torch::kFloat32);
+  probs = probs.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto renorm_probs = torch::empty(
+      {batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
+
+  cudaError_t status = sampling::TopPRenormProb<float>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(renorm_probs.data_ptr()),
+      has_top_p_arr ? static_cast<float*>(top_p_arr.data_ptr()) : nullptr,
+      batch_size, top_p_val, vocab_size, torch_current_stream);
+  TORCH_CHECK(status == cudaSuccess,
+              "TopPRenormProb failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+  return renorm_probs;
+}
+
+torch::Tensor top_k_renorm_prob(torch::Tensor probs,
+                                std::optional<torch::Tensor> maybe_top_k_arr,
+                                int64_t top_k_val) {
+  CHECK_INPUT(probs);
+  auto device = probs.device();
+  CHECK_DIM(2, probs);  // probs: (batch_size, vocab_size)
+  unsigned int batch_size = probs.size(0);
+  unsigned int vocab_size = probs.size(1);
+  bool has_top_k_arr = maybe_top_k_arr.has_value();
+  auto top_k_arr =
+      maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
+  if (has_top_k_arr) {
+    CHECK_INPUT(top_k_arr);
+    CHECK_DIM(1, top_k_arr);  // top_k_arr: (batch_size,)
+    CHECK_EQ(top_k_arr.size(0), batch_size);
+    CHECK_EQ(top_k_arr.device(), device);
+  }
+  top_k_arr = top_k_arr.to(torch::kInt32);
+  probs = probs.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto renorm_probs = torch::empty(
+      {batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
+
+  cudaError_t status = sampling::TopKRenormProb<float>(
+      static_cast<float*>(probs.data_ptr()),
+      static_cast<float*>(renorm_probs.data_ptr()),
+      has_top_k_arr ? static_cast<int*>(top_k_arr.data_ptr()) : nullptr,
+      batch_size, top_k_val, vocab_size, torch_current_stream);
+
+  TORCH_CHECK(status == cudaSuccess,
+              "TopKRenormProb failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+  return renorm_probs;
+}
+
+torch::Tensor top_k_mask_logits(torch::Tensor logits,
+                                std::optional<torch::Tensor> maybe_top_k_arr,
+                                int64_t top_k_val) {
+  CHECK_INPUT(logits);
+  auto device = logits.device();
+  CHECK_DIM(2, logits);  // logits: (batch_size, vocab_size)
+  unsigned int batch_size = logits.size(0);
+  unsigned int vocab_size = logits.size(1);
+  bool has_top_k_arr = maybe_top_k_arr.has_value();
+  auto top_k_arr =
+      maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
+  if (has_top_k_arr) {
+    CHECK_INPUT(top_k_arr);
+    CHECK_DIM(1, top_k_arr);  // top_k_arr: (batch_size,)
+    CHECK_EQ(top_k_arr.size(0), batch_size);
+    CHECK_EQ(top_k_arr.device(), device);
+  }
+  top_k_arr = top_k_arr.to(torch::kInt32);
+  logits = logits.to(torch::kFloat32);
+
+  cudaStream_t torch_current_stream =
+      c10::cuda::getCurrentCUDAStream(device.index());
+  auto mask_logits = torch::empty({batch_size, vocab_size},
+                                  torch::dtype(torch::kFloat32).device(device));
+
+  cudaError_t status = sampling::TopKMaskLogits<float>(
+      static_cast<float*>(logits.data_ptr()),
+      static_cast<float*>(mask_logits.data_ptr()),
+      has_top_k_arr ? static_cast<int*>(top_k_arr.data_ptr()) : nullptr,
+      batch_size, top_k_val, vocab_size, torch_current_stream);
+
+  TORCH_CHECK(status == cudaSuccess,
+              "TopKMaskLogits failed with error code " +
+                  std::string(cudaGetErrorString(status)));
+  return mask_logits;
+}

+ 1398 - 0
kernels/sampling/sampling.cuh

@@ -0,0 +1,1398 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2024 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef APHRODITE_SAMPLING_CUH_
+#define APHRODITE_SAMPLING_CUH_
+
+#include <cub/block/block_adjacent_difference.cuh>
+#include <cub/block/block_reduce.cuh>
+#include <cub/block/block_scan.cuh>
+#include <numeric>
+
+#include "math.cuh"
+#include "utils.cuh"
+#include "vec_dtypes.cuh"
+
+namespace aphrodite {
+
+namespace sampling {
+
+using namespace cub;
+
+#define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
+  if (deterministic) {                                            \
+    constexpr bool DETERMINISTIC = true;                          \
+    __VA_ARGS__                                                   \
+  } else {                                                        \
+    constexpr bool DETERMINISTIC = false;                         \
+    __VA_ARGS__                                                   \
+  }
+
+constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
+constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
+
+#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100)
+  #define APHRODITE_CUB_SUBTRACTLEFT_DEFINED
+#endif
+
+template <typename T>
+struct Pair {
+  T value;
+  int count;
+
+  __device__ Pair operator+(const Pair& other) const {
+    return {value + other.value, count + other.count};
+  }
+  __device__ Pair& operator+=(const Pair& other) {
+    value += other.value;
+    count += other.count;
+    return *this;
+  }
+};
+
+struct BoolDiffOp {
+  __device__ __forceinline__ bool operator()(const bool& lhs,
+                                             const bool& rhs) const {
+    return lhs != rhs;
+  }
+};
+
+template <typename T, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM>
+struct SamplingTempStorage {
+  union {
+    T deterministic_scan[BLOCK_THREADS / 32];
+    typename BlockScan<T, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
+    typename BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce;
+    typename BlockReduce<Pair<T>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce_pair;
+    typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
+  } block_prim;
+  struct {
+    int32_t sampled_id;
+    union {
+      T value;
+      Pair<T> pair;
+      T max_p;
+    } block_aggregate;
+  } data;
+};
+
+/*!
+ * \brief Deterministic inclusive scan implementation, use Belloch scan
+ * algorithm. \note This implementation is slower than the cub::BlockScan, but
+ * it is deterministic.
+ */
+template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
+          BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
+__device__ __forceinline__ void DeterministicInclusiveSum(
+    const T* in_data, T* out_data,
+    SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
+        temp_storage) {
+  T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
+  T thread_data[VEC_SIZE];
+  T thread_sum = 0;
+#pragma unroll
+  for (uint32_t i = 0; i < VEC_SIZE; ++i) {
+    thread_sum += in_data[i];
+    thread_data[i] = thread_sum;
+  }
+
+  T thread_exclusive_prefix_sum = thread_sum;
+
+#pragma unroll
+  for (uint32_t offset = 1; offset < 32; offset *= 2) {
+    T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
+    if ((threadIdx.x + 1) % (offset * 2) == 0) {
+      thread_exclusive_prefix_sum += tmp;
+    }
+  }
+
+  T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
+                           threadIdx.x | 0xffffffff);
+  if (threadIdx.x % 32 == 31) {
+    thread_exclusive_prefix_sum = 0;
+  }
+
+#pragma unroll
+  for (uint32_t offset = 16; offset >= 1; offset /= 2) {
+    T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
+    if ((threadIdx.x + 1) % (offset * 2) == 0) {
+      thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum;
+    }
+    if ((threadIdx.x + 1) % (offset * 2) == offset) {
+      thread_exclusive_prefix_sum = tmp;
+    }
+  }
+
+  smem_prefix_sum[threadIdx.x / 32] = warp_sum;
+  __syncthreads();
+
+  if (threadIdx.x < 32) {
+    T warp_exclusive_prefix_sum =
+        (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0;
+
+#pragma unroll
+    for (uint32_t offset = 1; offset < 32; offset *= 2) {
+      T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
+      if ((threadIdx.x + 1) % (offset * 2) == 0) {
+        warp_exclusive_prefix_sum += tmp;
+      }
+    }
+
+    if (threadIdx.x % 32 == 31) {
+      warp_exclusive_prefix_sum = 0;
+    }
+
+#pragma unroll
+    for (uint32_t offset = 16; offset >= 1; offset /= 2) {
+      T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
+      if ((threadIdx.x + 1) % (offset * 2) == 0) {
+        warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum;
+      }
+      if ((threadIdx.x + 1) % (offset * 2) == offset) {
+        warp_exclusive_prefix_sum = tmp;
+      }
+    }
+    if (threadIdx.x < BLOCK_THREADS / 32) {
+      smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum;
+    }
+  }
+  __syncthreads();
+
+#pragma unroll
+  for (uint32_t i = 0; i < VEC_SIZE; ++i) {
+    out_data[i] = smem_prefix_sum[threadIdx.x / 32] +
+                  thread_exclusive_prefix_sum + thread_data[i];
+  }
+}
+
+template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
+          BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T>
+__device__ __forceinline__ void DeviceSamplingFromProb(
+    uint32_t i, uint32_t d, T threshold, T u, vec_t<T, VEC_SIZE> prob_vec,
+    T& aggregate,
+    SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
+        temp_storage) {
+  const uint32_t tx = threadIdx.x;
+  T prob_greater_than_threshold[VEC_SIZE];
+  T inclusive_cdf[VEC_SIZE];
+  bool greater_than_u[VEC_SIZE], valid[VEC_SIZE];
+#pragma unroll
+  for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+    prob_greater_than_threshold[j] =
+        (prob_vec[j] > threshold) ? prob_vec[j] : T(0);
+    valid[j] =
+        prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
+  }
+  T aggregate_local = BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                          temp_storage->block_prim.reduce)
+                          .Sum<VEC_SIZE>(prob_greater_than_threshold);
+  if (tx == 0) {
+    temp_storage->data.block_aggregate.value = aggregate_local;
+  }
+  __syncthreads();
+  aggregate_local = temp_storage->data.block_aggregate.value;
+
+  if (aggregate + aggregate_local > u) {
+    if constexpr (DETERMINISTIC) {
+      DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                                REDUCE_ALGORITHM, T>(
+          prob_greater_than_threshold, inclusive_cdf, temp_storage);
+    } else {
+      BlockScan<T, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
+          .InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
+
+      __syncthreads();
+    }
+
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      greater_than_u[j] = inclusive_cdf[j] + aggregate > u;
+    }
+
+    bool greater_than_u_diff[VEC_SIZE];
+#ifdef APHRODITE_CUB_SUBTRACTLEFT_DEFINED
+    BlockAdjacentDifference<bool, BLOCK_THREADS>(
+        temp_storage->block_prim.adj_diff)
+        .SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff,
+                                BoolDiffOp());
+#else
+    BlockAdjacentDifference<bool, BLOCK_THREADS>(
+        temp_storage->block_prim.adj_diff)
+        .FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(),
+                             0);
+#endif
+    __syncthreads();
+
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      if (greater_than_u_diff[j] && valid[j]) {
+        if constexpr (DETERMINISTIC) {
+          temp_storage->data.sampled_id =
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
+        } else {
+          // cub's block scan result might not be monotonic, so we need to find
+          // the first element
+          atomicMin(&(temp_storage->data.sampled_id),
+                    (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
+        }
+      }
+    }
+    __syncthreads();
+  }
+  aggregate += aggregate_local;
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples,
+                                       IdType* output, IdType* row_indices,
+                                       uint32_t d) {
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+  temp_storage.data.sampled_id = d - 1;
+  __syncthreads();
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate(0);
+  float u = uniform_samples[bx];
+
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                     tx * VEC_SIZE);
+    }
+
+    DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                           REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+        i, d, DType(0), u, probs_vec, aggregate, &temp_storage);
+    if (float(aggregate) > u) {
+      break;
+    }
+  }
+  output[bx] = temp_storage.data.sampled_id;
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
+                                           IdType* output, bool* success,
+                                           IdType* top_k_arr,
+                                           uint32_t top_k_val, uint32_t d,
+                                           uint32_t max_top_k_rounds) {
+  const uint32_t batch_size = gridDim.x;
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate;
+  DType q = DType(1);
+  DType pivot = DType(0);
+  IdType sampled_id;
+  for (uint32_t round = 0; round < max_top_k_rounds; ++round) {
+    temp_storage.data.sampled_id = d - 1;
+    __syncthreads();
+    DType u = uniform_samples[round * batch_size + bx] * q;
+    aggregate = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                             REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+          i, d, pivot, u, probs_vec, aggregate, &temp_storage);
+      if (aggregate > u) {
+        break;
+      }
+    }
+    __syncthreads();
+    sampled_id = temp_storage.data.sampled_id;
+    pivot = max(pivot, probs[bx * d + sampled_id]);
+
+    Pair<DType> aggregate_gt_pivot{DType(0), 0};
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      Pair<DType> probs_gt_pivot[VEC_SIZE];
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
+                             (probs_vec[j] > pivot &&
+                              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
+      }
+
+      aggregate_gt_pivot +=
+          BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
+              temp_storage.block_prim.reduce_pair)
+              .Sum<VEC_SIZE>(probs_gt_pivot);
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
+      }
+      __syncthreads();
+    }
+    q = temp_storage.data.block_aggregate.pair.value;
+    if (temp_storage.data.block_aggregate.pair.count < k) {
+      break;
+    }
+  }
+  __syncthreads();
+  if (tx == 0) {
+    output[bx] = sampled_id;
+    if (temp_storage.data.block_aggregate.pair.count >= k) {
+      // failed to sample within MAX_TOP_P_ROUNDS
+      if (success != nullptr) {
+        success[bx] = false;
+      }
+    } else {
+      if (success != nullptr) {
+        success[bx] = true;
+      }
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
+                                           IdType* output, bool* success,
+                                           IdType* row_indices,
+                                           float* top_p_arr, float top_p_val,
+                                           uint32_t d,
+                                           uint32_t max_top_p_rounds) {
+  const uint32_t batch_size = gridDim.x;
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[bx];
+
+  const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate;
+  DType q = DType(1);
+  DType pivot = DType(0);
+  IdType sampled_id;
+  for (uint32_t round = 0; round < max_top_p_rounds; ++round) {
+    temp_storage.data.sampled_id = d - 1;
+    __syncthreads();
+    DType u = uniform_samples[round * batch_size + bx] * q;
+    aggregate = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + row_idx * d +
+                       (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                             REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+          i, d, pivot, u, probs_vec, aggregate, &temp_storage);
+      if (aggregate > u) {
+        break;
+      }
+    }
+    __syncthreads();
+    sampled_id = temp_storage.data.sampled_id;
+    pivot = max(pivot, probs[row_idx * d + sampled_id]);
+
+    DType aggregate_gt_pivot = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + row_idx * d +
+                       (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DType probs_gt_pivot[VEC_SIZE];
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
+      }
+
+      aggregate_gt_pivot +=
+          BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
+              .Sum<VEC_SIZE>(probs_gt_pivot);
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
+      }
+      __syncthreads();
+    }
+    q = temp_storage.data.block_aggregate.value;
+    if (float(q) < top_p) {
+      break;
+    }
+  }
+  __syncthreads();
+  if (tx == 0) {
+    output[bx] = sampled_id;
+    if (float(q) >= top_p) {
+      // failed to sample within MAX_TOP_P_ROUNDS
+      if (success != nullptr) {
+        success[bx] = false;
+      }
+    } else {
+      if (success != nullptr) {
+        success[bx] = true;
+      }
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
+                                           DType* min_p_arr, IdType* output,
+                                           bool* success, float min_p_val,
+                                           uint32_t d,
+                                           uint32_t max_min_p_rounds) {
+  const uint32_t batch_size = gridDim.x;
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  DType p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate;
+  DType q = DType(1);
+  DType pivot = DType(0);
+
+  DType max_p = 0;
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+    }
+    DType probs_[VEC_SIZE];
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      probs_[j] = probs_vec[j];
+    }
+    max_p = max(
+        max_p, BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
+                   .Reduce<VEC_SIZE>(probs_, cub::Max()));
+    __syncthreads();
+  }
+  if (tx == 0) {
+    temp_storage.data.block_aggregate.max_p = max_p;
+  }
+  __syncthreads();
+  DType scaled_p = temp_storage.data.block_aggregate.max_p * p;
+
+  IdType sampled_id;
+  for (uint32_t round = 0; round < max_min_p_rounds; ++round) {
+    temp_storage.data.sampled_id = d - 1;
+    __syncthreads();
+    DType u = uniform_samples[round * batch_size + bx] * q;
+    aggregate = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                             REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+          i, d, pivot, u, probs_vec, aggregate, &temp_storage);
+      if (aggregate > u) {
+        break;
+      }
+    }
+    __syncthreads();
+    sampled_id = temp_storage.data.sampled_id;
+    pivot = max(pivot, probs[bx * d + sampled_id]);
+    if (pivot >= scaled_p) {
+      break;
+    }
+
+    DType aggregate_gt_pivot = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DType probs_gt_pivot[VEC_SIZE];
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
+      }
+
+      aggregate_gt_pivot +=
+          BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
+              .Sum<VEC_SIZE>(probs_gt_pivot);
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
+      }
+      __syncthreads();
+    }
+    q = temp_storage.data.block_aggregate.value;
+  }
+  __syncthreads();
+  if (tx == 0) {
+    output[bx] = sampled_id;
+    if (pivot < scaled_p) {
+      // failed to sample within MAX_ROUNDS
+      if (success != nullptr) {
+        success[bx] = false;
+      }
+    } else {
+      if (success != nullptr) {
+        success[bx] = true;
+      }
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
+          BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
+          bool DETERMINISTIC, typename DType, typename IdType>
+__global__ void TopKTopPSamplingFromProbKernel(
+    DType* probs, DType* uniform_samples, IdType* top_k_arr, DType* top_p_arr,
+    IdType* output, bool* success, IdType top_k_val, DType top_p_val,
+    uint32_t d, uint32_t max_rounds) {
+  const uint32_t batch_size = gridDim.x;
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  IdType k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
+  DType p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
+
+  extern __shared__ __align__(
+      alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                  REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
+  auto& temp_storage =
+      reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
+                                           REDUCE_ALGORITHM>&>(smem_sampling);
+
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType aggregate;
+  DType q = DType(1);
+  DType pivot = DType(0);
+  IdType sampled_id;
+  for (uint32_t round = 0; round < max_rounds; ++round) {
+    temp_storage.data.sampled_id = d - 1;
+    __syncthreads();
+    DType u = uniform_samples[round * batch_size + bx] * q;
+    aggregate = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
+                             REDUCE_ALGORITHM, DETERMINISTIC, DType>(
+          i, d, pivot, u, probs_vec, aggregate, &temp_storage);
+      if (aggregate > u) {
+        break;
+      }
+    }
+    __syncthreads();
+    sampled_id = temp_storage.data.sampled_id;
+    pivot = max(pivot, probs[bx * d + sampled_id]);
+
+    Pair<DType> aggregate_gt_pivot{DType(0), 0};
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
+      }
+
+      Pair<DType> probs_gt_pivot[VEC_SIZE];
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
+                             (probs_vec[j] > pivot &&
+                              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
+      }
+
+      aggregate_gt_pivot +=
+          BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
+              temp_storage.block_prim.reduce_pair)
+              .Sum<VEC_SIZE>(probs_gt_pivot);
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
+      }
+      __syncthreads();
+    }
+    q = temp_storage.data.block_aggregate.pair.value;
+    if (temp_storage.data.block_aggregate.pair.count < k && float(q) < p) {
+      break;
+    }
+  }
+  __syncthreads();
+  if (tx == 0) {
+    output[bx] = sampled_id;
+    if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) {
+      // failed to sample within MAX_TOP_P_ROUNDS
+      if (success != nullptr) {
+        success[bx] = false;
+      }
+    } else {
+      if (success != nullptr) {
+        success[bx] = true;
+      }
+    }
+  }
+}
+
+template <typename T, typename IdType>
+cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output,
+                             uint32_t batch_size, uint32_t d,
+                             bool deterministic, cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  IdType* row_indices_placeholder = nullptr;
+  void* args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder,
+                  &d};
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            SamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                   VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples,
+                                     IdType* output, IdType* row_indices,
+                                     uint32_t batch_size, uint32_t d,
+                                     bool deterministic,
+                                     cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs, &uniform_samples, &output, &row_indices, &d};
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            SamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                   VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
+                                 bool* success, T* top_k_arr,
+                                 uint32_t batch_size, uint32_t top_k_val,
+                                 uint32_t d, uint32_t max_top_k_rounds,
+                                 bool deterministic, cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs,     &uniform_samples, &output, &success,
+                  &top_k_arr, &top_k_val,       &d,      &max_top_k_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            TopKSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                       VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
+                                 bool* success, T* top_p_arr,
+                                 uint32_t batch_size, T top_p_val, uint32_t d,
+                                 uint32_t max_top_p_rounds, bool deterministic,
+                                 cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  IdType* row_indices_placeholder = nullptr;
+  void* args[] = {&probs,
+                  &uniform_samples,
+                  &output,
+                  &success,
+                  &row_indices_placeholder,
+                  &top_p_arr,
+                  &top_p_val,
+                  &d,
+                  &max_top_p_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                       VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr,
+                                 IdType* output, bool* success,
+                                 uint32_t batch_size, float min_p_val,
+                                 uint32_t d, uint32_t max_rounds,
+                                 bool deterministic, cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs,   &uniform_samples, &min_p_arr, &output,
+                  &success, &min_p_val,       &d,         &max_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                       VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples,
+                                     IdType* top_k_arr, T* top_p_arr,
+                                     IdType* output, bool* success,
+                                     uint32_t batch_size, IdType top_k_val,
+                                     T top_p_val, uint32_t d,
+                                     uint32_t max_rounds, bool deterministic,
+                                     cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs,  &uniform_samples, &top_k_arr, &top_p_arr,
+                  &output, &success,         &top_k_val, &top_p_val,
+                  &d,      &max_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO,
+                                                     REDUCE_ALGO, VEC_SIZE,
+                                                     DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+template <typename T, uint32_t BLOCK_THREADS,
+          BlockReduceAlgorithm REDUCE_ALGORITHM>
+struct RenormTempStorage {
+  union {
+    typename BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce;
+    typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce_int;
+    typename BlockReduce<Pair<T>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
+        reduce_pair;
+  } block_prim;
+  struct {
+    T max_val;
+    T min_val;
+    union {
+      T value;
+      int count;
+      Pair<T> pair;
+    } block_aggregate;
+  } data;
+};
+
+template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
+          uint32_t VEC_SIZE, typename DType>
+__global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob,
+                                     DType* top_p_arr, float top_p_val,
+                                     uint32_t d) {
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  const uint32_t row_idx = bx;
+  float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
+
+  extern __shared__ __align__(
+      alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
+      uint8_t smem_renorm[];
+  auto& temp_storage =
+      reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
+          smem_renorm);
+  temp_storage.data.max_val = DType(0);
+  vec_t<DType, VEC_SIZE> probs_vec;
+  DType probs_greater_than_pivot[VEC_SIZE];  // pivot initialized to 0
+
+  DType threadlocal_max_val = DType(0);
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                     tx * VEC_SIZE);
+    }
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      probs_greater_than_pivot[j] = probs_vec[j];
+    }
+    threadlocal_max_val =
+        max(threadlocal_max_val,
+            BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                temp_storage.block_prim.reduce)
+                .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
+    __syncthreads();
+  }
+  if (tx == 0) {
+    temp_storage.data.max_val = threadlocal_max_val;
+  }
+  __syncthreads();
+  threadlocal_max_val = temp_storage.data.max_val;
+
+  float low = 0, high = threadlocal_max_val;
+  DType min_gt_low, max_le_high;
+  DType sum_low(1);
+  // f(x) = sum(probs[probs > x]), f(x) is non-increasing
+  // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p
+  // <= high} loop invariant:
+  // - f(low) >= p, f(high) < p
+  // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
+  // stopping condition
+  // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
+  do {
+    DType threadlocal_sum(0);
+    float mid = (low + high) / 2;
+    min_gt_low = high;
+    max_le_high = low;
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                       tx * VEC_SIZE);
+      }
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_greater_than_pivot[j] =
+            (probs_vec[j] > mid) ? probs_vec[j] : DType(0);
+        if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+          min_gt_low = min(min_gt_low, probs_vec[j]);
+        }
+        if (probs_vec[j] <= high &&
+            (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+          max_le_high = max(max_le_high, probs_vec[j]);
+        }
+      }
+      threadlocal_sum += BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                             temp_storage.block_prim.reduce)
+                             .Sum<VEC_SIZE>(probs_greater_than_pivot);
+      __syncthreads();
+    }
+    min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                     temp_storage.block_prim.reduce)
+                     .Reduce(min_gt_low, cub::Min());
+    __syncthreads();
+    max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                      temp_storage.block_prim.reduce)
+                      .Reduce(max_le_high, cub::Max());
+    if (tx == 0) {
+      temp_storage.data.block_aggregate.value = threadlocal_sum;
+      temp_storage.data.min_val = min_gt_low;
+      temp_storage.data.max_val = max_le_high;
+    }
+    __syncthreads();
+    threadlocal_sum = temp_storage.data.block_aggregate.value;
+    min_gt_low = temp_storage.data.min_val;
+    max_le_high = temp_storage.data.max_val;
+    if (threadlocal_sum >= p) {
+      low = mid;
+      sum_low = float(threadlocal_sum);
+    } else {
+      high = min(mid, max_le_high);
+    }
+  } while (min_gt_low != max_le_high);
+
+  DType normalizer = math::ptx_rcp(max(sum_low, 1e-8));
+
+  // normalize
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                     tx * VEC_SIZE);
+    }
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      probs_vec[j] =
+          (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0);
+    }
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.store(renormed_prob + row_idx * d +
+                      i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
+          uint32_t VEC_SIZE, typename DType, typename IdType>
+__global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits,
+                                     IdType* top_k_arr, uint32_t top_k_val,
+                                     uint32_t d) {
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  const uint32_t row_idx = bx;
+  uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
+  float pivot = -std::numeric_limits<float>::infinity();
+  vec_t<DType, VEC_SIZE> logits_vec;
+  if (k < d) {
+    extern __shared__ __align__(
+        alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
+        uint8_t smem_renorm[];
+    auto& temp_storage =
+        reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
+            smem_renorm);
+    DType logits_greater_than_pivot[VEC_SIZE];  // pivot initialized to 0
+
+    DType threadlocal_max_val = DType(-std::numeric_limits<float>::infinity()),
+          threadlocal_min_val = DType(std::numeric_limits<float>::infinity());
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      logits_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                        tx * VEC_SIZE);
+      }
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        logits_greater_than_pivot[j] = logits_vec[j];
+      }
+      threadlocal_max_val =
+          max(threadlocal_max_val,
+              BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                  temp_storage.block_prim.reduce)
+                  .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Max()));
+      __syncthreads();
+      threadlocal_min_val =
+          min(threadlocal_min_val,
+              BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                  temp_storage.block_prim.reduce)
+                  .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Min()));
+      __syncthreads();
+    }
+    if (tx == 0) {
+      temp_storage.data.max_val = threadlocal_max_val;
+      temp_storage.data.min_val = threadlocal_min_val;
+    }
+    __syncthreads();
+    threadlocal_max_val = temp_storage.data.max_val;
+    threadlocal_min_val = temp_storage.data.min_val;
+
+    float low = threadlocal_min_val - 1, high = threadlocal_max_val;
+    DType min_gt_low, max_le_high;
+    // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
+    // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
+    // p <= high} loop invariant:
+    // - f(low) >= k, f(high) < k
+    // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
+    // stopping condition: min_gt_low == max_le_high
+    // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
+    do {
+      int threadlocal_count_sum = 0;
+      int probs_greater_than_pivot_count[VEC_SIZE];  // pivot initialized to 0
+      float mid = (low + high) / 2;
+      min_gt_low = high;
+      max_le_high = low;
+      for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+        logits_vec.fill(DType(0));
+        if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+          logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                          tx * VEC_SIZE);
+        }
+#pragma unroll
+        for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+          probs_greater_than_pivot_count[j] =
+              logits_vec[j] > mid &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
+          if (logits_vec[j] > low &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+            min_gt_low = min(min_gt_low, logits_vec[j]);
+          }
+          if (logits_vec[j] <= high &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+            max_le_high = max(max_le_high, logits_vec[j]);
+          }
+        }
+        threadlocal_count_sum +=
+            BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                temp_storage.block_prim.reduce_int)
+                .Sum<VEC_SIZE>(probs_greater_than_pivot_count);
+        __syncthreads();
+      }
+      min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                       temp_storage.block_prim.reduce)
+                       .Reduce(min_gt_low, cub::Min());
+      __syncthreads();
+      max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                        temp_storage.block_prim.reduce)
+                        .Reduce(max_le_high, cub::Max());
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.count = threadlocal_count_sum;
+        temp_storage.data.min_val = min_gt_low;
+        temp_storage.data.max_val = max_le_high;
+      }
+      __syncthreads();
+      threadlocal_count_sum = temp_storage.data.block_aggregate.count;
+      min_gt_low = temp_storage.data.min_val;
+      max_le_high = temp_storage.data.max_val;
+      if (threadlocal_count_sum >= k) {
+        low = mid;
+      } else {
+        high = min(mid, max_le_high);
+      }
+    } while (min_gt_low != max_le_high);
+    pivot = low;
+  }
+
+  // masking
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    logits_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                      tx * VEC_SIZE);
+    }
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      logits_vec[j] = (logits_vec[j] > pivot)
+                          ? logits_vec[j]
+                          : DType(-std::numeric_limits<float>::infinity());
+    }
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      logits_vec.store(masked_logits + row_idx * d +
+                       i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
+    }
+  }
+}
+
+template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
+          uint32_t VEC_SIZE, typename DType, typename IdType>
+__global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob,
+                                     IdType* top_k_arr, uint32_t top_k_val,
+                                     uint32_t d) {
+  const uint32_t bx = blockIdx.x, tx = threadIdx.x;
+  const uint32_t row_idx = bx;
+  uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
+  float pivot = -std::numeric_limits<float>::infinity(), normalizer = 1;
+  vec_t<DType, VEC_SIZE> probs_vec;
+  if (k < d) {
+    extern __shared__ __align__(
+        alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
+        uint8_t smem_renorm[];
+    auto& temp_storage =
+        reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
+            smem_renorm);
+    temp_storage.data.max_val = DType(0);
+    DType probs_greater_than_pivot[VEC_SIZE];  // pivot initialized to 0
+
+    DType threadlocal_max_val = DType(0);
+    for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+      probs_vec.fill(DType(0));
+      if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+        probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                       tx * VEC_SIZE);
+      }
+#pragma unroll
+      for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+        probs_greater_than_pivot[j] = probs_vec[j];
+      }
+      threadlocal_max_val =
+          max(threadlocal_max_val,
+              BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                  temp_storage.block_prim.reduce)
+                  .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
+      __syncthreads();
+    }
+    if (tx == 0) {
+      temp_storage.data.max_val = threadlocal_max_val;
+    }
+    __syncthreads();
+    threadlocal_max_val = temp_storage.data.max_val;
+
+    float low = 0, high = threadlocal_max_val;
+    DType min_gt_low, max_le_high;
+    DType sum_low(1);
+    // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
+    // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
+    // p <= high} loop invariant:
+    // - f(low) >= k, f(high) < k
+    // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
+    // stopping condition: min_gt_low == max_le_high
+    // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
+    do {
+      Pair<DType> threadlocal_sum{DType(0), 0};
+      Pair<DType>
+          probs_greater_than_pivot_pair[VEC_SIZE];  // pivot initialized to 0
+      float mid = (low + high) / 2;
+      min_gt_low = high;
+      max_le_high = low;
+      for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+        probs_vec.fill(DType(0));
+        if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+          probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                         tx * VEC_SIZE);
+        }
+#pragma unroll
+        for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+          probs_greater_than_pivot_pair[j] = {
+              (probs_vec[j] > mid) ? probs_vec[j] : DType(0),
+              (probs_vec[j] > mid &&
+               (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
+          if (probs_vec[j] > low &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+            min_gt_low = min(min_gt_low, probs_vec[j]);
+          }
+          if (probs_vec[j] <= high &&
+              (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
+            max_le_high = max(max_le_high, probs_vec[j]);
+          }
+        }
+        threadlocal_sum +=
+            BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                temp_storage.block_prim.reduce_pair)
+                .Sum<VEC_SIZE>(probs_greater_than_pivot_pair);
+        __syncthreads();
+      }
+      min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                       temp_storage.block_prim.reduce)
+                       .Reduce(min_gt_low, cub::Min());
+      __syncthreads();
+      max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
+                        temp_storage.block_prim.reduce)
+                        .Reduce(max_le_high, cub::Max());
+      if (tx == 0) {
+        temp_storage.data.block_aggregate.pair = threadlocal_sum;
+        temp_storage.data.min_val = min_gt_low;
+        temp_storage.data.max_val = max_le_high;
+      }
+      __syncthreads();
+      threadlocal_sum = temp_storage.data.block_aggregate.pair;
+      min_gt_low = temp_storage.data.min_val;
+      max_le_high = temp_storage.data.max_val;
+      if (threadlocal_sum.count >= k) {
+        low = mid;
+        sum_low = float(threadlocal_sum.value);
+      } else {
+        high = min(mid, max_le_high);
+      }
+    } while (min_gt_low != max_le_high);
+
+    normalizer = math::ptx_rcp(max(sum_low, 1e-8));
+    pivot = low;
+  }
+
+  // normalize
+  for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
+    probs_vec.fill(DType(0));
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
+                     tx * VEC_SIZE);
+    }
+#pragma unroll
+    for (uint32_t j = 0; j < VEC_SIZE; ++j) {
+      probs_vec[j] =
+          (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : DType(0);
+    }
+    if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
+      probs_vec.store(renormed_prob + row_idx * d +
+                      i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
+    }
+  }
+}
+
+template <typename DType>
+cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
+                           uint32_t batch_size, float top_p_val, uint32_t d,
+                           cudaStream_t stream = 0) {
+  const uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
+
+  const uint32_t smem_size =
+      sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d};
+  DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
+    auto kernel =
+        TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>;
+    APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+    APHRODITE_CUDA_CALL(
+        cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
+  });
+  return cudaSuccess;
+}
+
+template <typename DType, typename IdType>
+cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob,
+                           IdType* top_k_arr, uint32_t batch_size,
+                           uint32_t top_k_val, uint32_t d,
+                           cudaStream_t stream = 0) {
+  const uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
+
+  const uint32_t smem_size =
+      sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d};
+  DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
+    auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE,
+                                       DType, IdType>;
+    APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+    APHRODITE_CUDA_CALL(
+        cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
+  });
+  return cudaSuccess;
+}
+
+template <typename DType, typename IdType>
+cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits,
+                           IdType* top_k_arr, uint32_t batch_size,
+                           uint32_t top_k_val, uint32_t d,
+                           cudaStream_t stream = 0) {
+  const uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
+
+  const uint32_t smem_size =
+      sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d};
+  DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
+    auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE,
+                                       DType, IdType>;
+    APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+        kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+    APHRODITE_CUDA_CALL(
+        cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
+  });
+  return cudaSuccess;
+}
+
+template <typename T, typename IdType>
+cudaError_t ParallelTopPSamplingFromProb(
+    T* probs, T* uniform_samples, IdType* output, bool* success,
+    IdType* row_indices, T* top_p_arr, uint32_t batch_size, uint32_t d,
+    uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) {
+  constexpr uint32_t BLOCK_THREADS = 1024;
+  const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
+
+  const uint32_t smem_size =
+      sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
+  dim3 nblks(batch_size);
+  dim3 nthrs(BLOCK_THREADS);
+  T top_p_placeholder = 0;
+  void* args[] = {
+      &probs,     &uniform_samples,   &output, &success,         &row_indices,
+      &top_p_arr, &top_p_placeholder, &d,      &max_top_p_rounds};
+
+  DISPATCH_ALIGNED_VEC_SIZE(
+      vec_size, VEC_SIZE,
+      {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
+        auto kernel =
+            TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
+                                       VEC_SIZE, DETERMINISTIC, T, IdType>;
+        APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
+            kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
+        APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
+                                             smem_size, stream));
+      })});
+  return cudaSuccess;
+}
+
+}  // namespace sampling
+
+}  // namespace aphrodite
+
+#endif  // APHRODITE_SAMPLING_CUH_

+ 273 - 0
kernels/sampling/utils.cuh

@@ -0,0 +1,273 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2023 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef APHRODITE_UTILS_CUH_
+#define APHRODITE_UTILS_CUH_
+#include <cuda_runtime.h>
+
+#include <iostream>
+#include <sstream>
+#include <stdexcept>
+#include <vector>
+#include <torch/all.h>
+
+#define STR_HELPER(x) #x
+#define STR(x) STR_HELPER(x)
+
+// macro to turn off fp16 qk reduction to reduce binary
+#ifndef APHRODITE_ALWAYS_DISALLOW_FP16_QK_REDUCTION
+  #define APHRODITE_ALWAYS_DISALLOW_FP16_QK_REDUCTION 0
+#endif
+
+#ifndef NDEBUG
+  #define APHRODITE_CUDA_CALL(func, ...)                                  \
+    {                                                                     \
+      cudaError_t e = (func);                                             \
+      if (e != cudaSuccess) {                                             \
+        std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
+                  << ") " << __FILE__ << ": line " << __LINE__            \
+                  << " at function " << STR(func) << std::endl;           \
+        return e;                                                         \
+      }                                                                   \
+    }
+#else
+  #define APHRODITE_CUDA_CALL(func, ...) \
+    {                                    \
+      cudaError_t e = (func);            \
+      if (e != cudaSuccess) {            \
+        return e;                        \
+      }                                  \
+    }
+#endif
+
+#define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction,           \
+                                         ALLOW_FP16_QK_REDUCTION, ...)      \
+  if (allow_fp16_qk_reduction) {                                            \
+    throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \
+  } else {                                                                  \
+    constexpr bool ALLOW_FP16_QK_REDUCTION = false;                         \
+    __VA_ARGS__                                                             \
+  }
+
+#define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \
+  if (num_frags_x == 1) {                                   \
+    constexpr size_t NUM_FRAGS_X = 1;                       \
+    __VA_ARGS__                                             \
+  } else if (num_frags_x == 2) {                            \
+    constexpr size_t NUM_FRAGS_X = 2;                       \
+    __VA_ARGS__                                             \
+  } else {                                                  \
+    std::ostringstream err_msg;                             \
+    err_msg << "Unsupported num_frags_x: " << num_frags_x;  \
+    throw std::invalid_argument(err_msg.str());             \
+  }
+
+#define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \
+  if (max_frags_z >= 8) {                                   \
+    constexpr size_t NUM_FRAGS_Z = 8;                       \
+    __VA_ARGS__                                             \
+  } else if (max_frags_z >= 4) {                            \
+    constexpr size_t NUM_FRAGS_Z = 4;                       \
+    __VA_ARGS__                                             \
+  } else if (max_frags_z >= 2) {                            \
+    constexpr size_t NUM_FRAGS_Z = 2;                       \
+    __VA_ARGS__                                             \
+  } else if (max_frags_z >= 1) {                            \
+    constexpr size_t NUM_FRAGS_Z = 1;                       \
+    __VA_ARGS__                                             \
+  } else {                                                  \
+    std::ostringstream err_msg;                             \
+    err_msg << "Unsupported max_frags_z: " << max_frags_z;  \
+    throw std::invalid_argument(err_msg.str());             \
+  }
+
+#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
+  if (group_size == 1) {                                     \
+    constexpr size_t GROUP_SIZE = 1;                         \
+    __VA_ARGS__                                              \
+  } else if (group_size == 2) {                              \
+    constexpr size_t GROUP_SIZE = 2;                         \
+    __VA_ARGS__                                              \
+  } else if (group_size == 4) {                              \
+    constexpr size_t GROUP_SIZE = 4;                         \
+    __VA_ARGS__                                              \
+  } else if (group_size == 8) {                              \
+    constexpr size_t GROUP_SIZE = 8;                         \
+    __VA_ARGS__                                              \
+  } else {                                                   \
+    std::ostringstream err_msg;                              \
+    err_msg << "Unsupported group_size: " << group_size;     \
+    throw std::invalid_argument(err_msg.str());              \
+  }
+
+#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...)         \
+  switch (mask_mode) {                                        \
+    case MaskMode::kNone: {                                   \
+      constexpr MaskMode MASK_MODE = MaskMode::kNone;         \
+      __VA_ARGS__                                             \
+      break;                                                  \
+    }                                                         \
+    case MaskMode::kCausal: {                                 \
+      constexpr MaskMode MASK_MODE = MaskMode::kCausal;       \
+      __VA_ARGS__                                             \
+      break;                                                  \
+    }                                                         \
+    case MaskMode::kCustom: {                                 \
+      constexpr MaskMode MASK_MODE = MaskMode::kCustom;       \
+      __VA_ARGS__                                             \
+      break;                                                  \
+    }                                                         \
+    default: {                                                \
+      std::ostringstream err_msg;                             \
+      err_msg << "Unsupported mask_mode: " << int(mask_mode); \
+      throw std::invalid_argument(err_msg.str());             \
+    }                                                         \
+  }
+
+#define DISPATCH_LOGITS_POST_HOOK(logits_soft_cap, LOGITS_POST_HOOK, ...) \
+  if (logits_soft_cap > 0.f) {                                            \
+    constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kSoftCap; \
+    __VA_ARGS__                                                           \
+  } else if (logits_soft_cap == 0.f) {                                    \
+    constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kNone;    \
+    __VA_ARGS__                                                           \
+  } else {                                                                \
+    std::ostringstream err_msg;                                           \
+    err_msg << "Invalid logits_soft_cap (should be >= 0): "               \
+            << logits_soft_cap;                                           \
+    throw std::invalid_argument(err_msg.str());                           \
+  }
+
+#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...)     \
+  switch (head_dim) {                                  \
+    case 64: {                                         \
+      constexpr size_t HEAD_DIM = 64;                  \
+      __VA_ARGS__                                      \
+      break;                                           \
+    }                                                  \
+    case 128: {                                        \
+      constexpr size_t HEAD_DIM = 128;                 \
+      __VA_ARGS__                                      \
+      break;                                           \
+    }                                                  \
+    case 256: {                                        \
+      constexpr size_t HEAD_DIM = 256;                 \
+      __VA_ARGS__                                      \
+      break;                                           \
+    }                                                  \
+    default: {                                         \
+      std::ostringstream err_msg;                      \
+      err_msg << "Unsupported head_dim: " << head_dim; \
+      throw std::invalid_argument(err_msg.str());      \
+    }                                                  \
+  }
+
+#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \
+  switch (pos_encoding_mode) {                                                \
+    case PosEncodingMode::kNone: {                                            \
+      constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone;   \
+      __VA_ARGS__                                                             \
+      break;                                                                  \
+    }                                                                         \
+    case PosEncodingMode::kRoPELlama: {                                       \
+      constexpr PosEncodingMode POS_ENCODING_MODE =                           \
+          PosEncodingMode::kRoPELlama;                                        \
+      __VA_ARGS__                                                             \
+      break;                                                                  \
+    }                                                                         \
+    case PosEncodingMode::kALiBi: {                                           \
+      constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi;  \
+      __VA_ARGS__                                                             \
+      break;                                                                  \
+    }                                                                         \
+    default: {                                                                \
+      std::ostringstream err_msg;                                             \
+      err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
+      throw std::invalid_argument(err_msg.str());                             \
+    }                                                                         \
+  }
+
+#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
+  switch (aligned_vec_size) {                                              \
+    case 16: {                                                             \
+      constexpr size_t ALIGNED_VEC_SIZE = 16;                              \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    case 8: {                                                              \
+      constexpr size_t ALIGNED_VEC_SIZE = 8;                               \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    case 4: {                                                              \
+      constexpr size_t ALIGNED_VEC_SIZE = 4;                               \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    case 2: {                                                              \
+      constexpr size_t ALIGNED_VEC_SIZE = 2;                               \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    case 1: {                                                              \
+      constexpr size_t ALIGNED_VEC_SIZE = 1;                               \
+      __VA_ARGS__                                                          \
+      break;                                                               \
+    }                                                                      \
+    default: {                                                             \
+      std::ostringstream err_msg;                                          \
+      err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size;     \
+      throw std::invalid_argument(err_msg.str());                          \
+    }                                                                      \
+  }
+
+namespace aphrodite {
+
+template <typename T1, typename T2>
+__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {
+  return (x + y - 1) / y;
+}
+
+template <typename T>
+inline void DebugPrintCUDAArray(T* device_ptr, size_t size,
+                                std::string prefix = "") {
+  std::vector<T> host_array(size);
+  std::cout << prefix;
+  cudaMemcpy(host_array.data(), device_ptr, size * sizeof(T),
+             cudaMemcpyDeviceToHost);
+  for (size_t i = 0; i < size; ++i) {
+    std::cout << host_array[i] << " ";
+  }
+  std::cout << std::endl;
+}
+
+/*!
+ * \brief Return x - y if x > y, otherwise return 0.
+ */
+__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
+                                                           uint32_t y) {
+  return (x > y) ? x - y : 0U;
+}
+
+__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) {
+  uint32_t tmp = a;
+  a = b;
+  b = tmp;
+}
+
+}  // namespace aphrodite
+
+#endif  // APHRODITE_UTILS_CUH_

+ 1501 - 0
kernels/sampling/vec_dtypes.cuh

@@ -0,0 +1,1501 @@
+/*
+ * Copyright (c) 2024 by PygmalionAI team.
+ * Copyright (c) 2023 by FlashInfer team.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifndef VEC_DTYPES_CUH_
+#define VEC_DTYPES_CUH_
+
+#include <cuda_bf16.h>
+#include <cuda_fp16.h>
+#include <cuda_fp8.h>
+#include <cuda_runtime.h>
+
+#include <type_traits>
+
+namespace aphrodite {
+
+#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900))
+  #define APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+#endif
+
+#define APHRODITE_INLINE inline __attribute__((always_inline)) __device__
+
+/******************* vec_t type cast *******************/
+
+template <typename dst_t, typename src_t>
+struct vec_cast {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(dst_t* dst, const src_t* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size; ++i) {
+      dst[i] = (dst_t)src[i];
+    }
+  }
+};
+
+template <>
+struct vec_cast<float, half> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(float* dst, const half* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = (float)src[0];
+    } else {
+#pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        ((float2*)dst)[i] = __half22float2(((half2*)src)[i]);
+      }
+    }
+  }
+};
+
+template <>
+struct vec_cast<half, float> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(half* dst, const float* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = __float2half(src[0]);
+    } else {
+#pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]);
+      }
+    }
+  }
+};
+
+template <typename T>
+constexpr APHRODITE_INLINE int get_exponent_bits() {
+  if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
+    return 4;
+  } else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
+    return 5;
+  } else if constexpr (std::is_same<T, half>::value) {
+    return 5;
+  } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
+    return 8;
+  }
+}
+
+template <typename T>
+constexpr APHRODITE_INLINE int get_mantissa_bits() {
+  if constexpr (std::is_same<T, __nv_fp8_e4m3>::value) {
+    return 3;
+  } else if constexpr (std::is_same<T, __nv_fp8_e5m2>::value) {
+    return 2;
+  } else if constexpr (std::is_same<T, half>::value) {
+    return 11;
+  } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
+    return 7;
+  }
+}
+
+/*!
+ * \brief Fallback to software fast dequant implementation if hardware
+ * dequantization is not available. \note Inspired by Marlin's fast
+ * dequantization, but here we don't have to permute weights order. \ref
+ * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120
+ */
+template <typename fp8_dtype, typename fp16_dtype>
+__device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) {
+  uint32_t q = *input;
+  if constexpr (std::is_same<fp8_dtype, __nv_fp8_e5m2>::value &&
+                std::is_same<fp16_dtype, half>::value) {
+    output->x = __byte_perm(0U, q, 0x5140);
+    output->y = __byte_perm(0U, q, 0x7362);
+  } else {
+    constexpr int FP8_EXPONENT = get_exponent_bits<fp8_dtype>();
+    constexpr int FP8_MANTISSA = get_mantissa_bits<fp8_dtype>();
+    constexpr int FP16_EXPONENT = get_exponent_bits<fp16_dtype>();
+
+    constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
+    // Calculate MASK for extracting mantissa and exponent
+    constexpr int MASK1 = 0x80000000;
+    constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
+    constexpr int MASK3 = MASK2 & 0x7fffffff;
+    constexpr int MASK = MASK3 | (MASK3 >> 16);
+    q = __byte_perm(q, q, 0x1302);
+
+    // Extract and shift FP8 values to FP16 format
+    uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
+    uint32_t Out2 =
+        ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
+
+    constexpr int BIAS_OFFSET =
+        (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
+    // Construct and apply exponent bias
+    if constexpr (std::is_same<fp16_dtype, half>::value) {
+      const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
+
+      // Convert to half2 and apply bias
+      *(half2*)&(output->x) =
+          __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
+      *(half2*)&(output->y) =
+          __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
+    } else {
+      constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
+      const nv_bfloat162 bias_reg =
+          __float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
+      // Convert to bfloat162 and apply bias
+      *(nv_bfloat162*)&(output->x) =
+          __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
+      *(nv_bfloat162*)&(output->y) =
+          __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
+    }
+  }
+}
+
+template <>
+struct vec_cast<nv_bfloat16, __nv_fp8_e4m3> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(nv_bfloat16* dst,
+                                    const __nv_fp8_e4m3* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = nv_bfloat16(src[0]);
+    } else if constexpr (vec_size == 2) {
+      dst[0] = nv_bfloat16(src[0]);
+      dst[1] = nv_bfloat16(src[1]);
+    } else {
+      static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
+#pragma unroll
+      for (uint32_t i = 0; i < vec_size / 4; ++i) {
+        fast_dequant_f8f16x4<__nv_fp8_e4m3, nv_bfloat16>((uint32_t*)&src[i * 4],
+                                                         (uint2*)&dst[i * 4]);
+      }
+    }
+  }
+};
+
+template <>
+struct vec_cast<nv_bfloat16, __nv_fp8_e5m2> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(nv_bfloat16* dst,
+                                    const __nv_fp8_e5m2* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = nv_bfloat16(src[0]);
+    } else if constexpr (vec_size == 2) {
+      dst[0] = nv_bfloat16(src[0]);
+      dst[1] = nv_bfloat16(src[1]);
+    } else {
+      static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
+#pragma unroll
+      for (uint32_t i = 0; i < vec_size / 4; ++i) {
+        fast_dequant_f8f16x4<__nv_fp8_e5m2, nv_bfloat16>((uint32_t*)&src[i * 4],
+                                                         (uint2*)&dst[i * 4]);
+      }
+    }
+  }
+};
+
+template <>
+struct vec_cast<__nv_fp8_e4m3, half> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(__nv_fp8_e4m3* dst, const half* src) {
+#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+    if constexpr (vec_size == 1) {
+      dst[0] = __nv_fp8_e4m3(src[0]);
+    } else {
+  #pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        uint16_t y;
+        uint32_t x = *(uint32_t*)&src[i * 2];
+        asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;"
+                     : "=h"(y)
+                     : "r"(x));
+        *(uint16_t*)&dst[i * 2] = y;
+      }
+    }
+#else
+  #pragma unroll
+    for (size_t i = 0; i < vec_size; ++i) {
+      dst[i] = __nv_fp8_e4m3(src[i]);
+    }
+#endif  // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+  }
+};
+
+template <>
+struct vec_cast<__nv_fp8_e5m2, half> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(__nv_fp8_e5m2* dst, const half* src) {
+#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+    if constexpr (vec_size == 1) {
+      dst[0] = __nv_fp8_e5m2(src[0]);
+    } else {
+  #pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        uint16_t y;
+        uint32_t x = *(uint32_t*)&src[i * 2];
+        asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;"
+                     : "=h"(y)
+                     : "r"(x));
+        *(uint16_t*)&dst[i * 2] = y;
+      }
+    }
+#else
+  #pragma unroll
+    for (size_t i = 0; i < vec_size; ++i) {
+      dst[i] = __nv_fp8_e5m2(src[i]);
+    }
+#endif  // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+  }
+};
+
+template <>
+struct vec_cast<half, __nv_fp8_e4m3> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e4m3* src) {
+#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+    if constexpr (vec_size == 1) {
+      dst[0] = half(src[0]);
+    } else {
+  #pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        uint32_t y;
+        uint16_t x = *(uint16_t*)&src[i * 2];
+        asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(y) : "h"(x));
+        *(uint32_t*)&dst[i * 2] = y;
+      }
+    }
+#else
+    if constexpr (vec_size == 1) {
+      dst[0] = half(src[0]);
+    } else if constexpr (vec_size == 2) {
+      dst[0] = half(src[0]);
+      dst[1] = half(src[1]);
+    } else {
+      static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
+  #pragma unroll
+      for (uint32_t i = 0; i < vec_size / 4; ++i) {
+        fast_dequant_f8f16x4<__nv_fp8_e4m3, half>((uint32_t*)&src[i * 4],
+                                                  (uint2*)&dst[i * 4]);
+      }
+    }
+#endif  // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+  }
+};
+
+template <>
+struct vec_cast<half, __nv_fp8_e5m2> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(half* dst, const __nv_fp8_e5m2* src) {
+#ifdef APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+    if constexpr (vec_size == 1) {
+      dst[0] = half(src[0]);
+    } else {
+  #pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        uint32_t y;
+        uint16_t x = *(uint16_t*)&src[i * 2];
+        asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;" : "=r"(y) : "h"(x));
+        *(uint32_t*)&dst[i * 2] = y;
+      }
+    }
+#else
+    if constexpr (vec_size == 1) {
+      dst[0] = half(src[0]);
+    } else if constexpr (vec_size == 2) {
+      dst[0] = half(src[0]);
+      dst[1] = half(src[1]);
+    } else {
+      static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4");
+  #pragma unroll
+      for (uint32_t i = 0; i < vec_size / 4; ++i) {
+        fast_dequant_f8f16x4<__nv_fp8_e5m2, half>((uint32_t*)&src[i * 4],
+                                                  (uint2*)&dst[i * 4]);
+      }
+    }
+#endif  // APHRODITE_HARDWARE_FP8_CONVERSION_ENABLED
+  }
+};
+
+template <>
+struct vec_cast<float, nv_bfloat16> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(float* dst, const nv_bfloat16* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = (float)src[0];
+    } else {
+#pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]);
+      }
+    }
+  }
+};
+
+template <>
+struct vec_cast<nv_bfloat16, float> {
+  template <size_t vec_size>
+  APHRODITE_INLINE static void cast(nv_bfloat16* dst, const float* src) {
+    if constexpr (vec_size == 1) {
+      dst[0] = nv_bfloat16(src[0]);
+    } else {
+#pragma unroll
+      for (size_t i = 0; i < vec_size / 2; ++i) {
+        ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]);
+      }
+    }
+  }
+};
+
+template <typename float_t, size_t vec_size>
+struct vec_t {
+  APHRODITE_INLINE float_t& operator[](size_t i);
+  APHRODITE_INLINE const float_t& operator[](size_t i) const;
+  APHRODITE_INLINE void fill(float_t val);
+  APHRODITE_INLINE void load(const float_t* ptr);
+  APHRODITE_INLINE void store(float_t* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src);
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr);
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const;
+  APHRODITE_INLINE static void memcpy(float_t* dst, const float_t* src);
+  APHRODITE_INLINE float_t* ptr();
+};
+
+template <typename src_float_t, typename tgt_float_t, size_t vec_size>
+APHRODITE_INLINE void cast_from_impl(vec_t<tgt_float_t, vec_size>& dst,
+                                     const vec_t<src_float_t, vec_size>& src) {
+  vec_cast<tgt_float_t, src_float_t>::cast<vec_size>(
+      dst.ptr(), const_cast<vec_t<src_float_t, vec_size>*>(&src)->ptr());
+}
+
+template <typename src_float_t, typename tgt_float_t, size_t vec_size>
+APHRODITE_INLINE void cast_load_impl(vec_t<tgt_float_t, vec_size>& dst,
+                                     const src_float_t* src_ptr) {
+  if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
+    dst.load(src_ptr);
+  } else {
+    vec_t<src_float_t, vec_size> tmp;
+    tmp.load(src_ptr);
+    dst.cast_from(tmp);
+  }
+}
+
+template <typename src_float_t, typename tgt_float_t, size_t vec_size>
+APHRODITE_INLINE void cast_store_impl(tgt_float_t* dst_ptr,
+                                      const vec_t<src_float_t, vec_size>& src) {
+  if constexpr (std::is_same<src_float_t, tgt_float_t>::value) {
+    src.store(dst_ptr);
+  } else {
+    vec_t<tgt_float_t, vec_size> tmp;
+    tmp.cast_from(src);
+    tmp.store(dst_ptr);
+  }
+}
+
+/******************* vec_t<__nv_fp8_e4m3> *******************/
+
+// __nv_fp8_e4m3 x 1
+template <>
+struct vec_t<__nv_fp8_e4m3, 1> {
+  __nv_fp8_e4m3 data;
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) {
+  data = val;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3* ptr) {
+  data = *ptr;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(__nv_fp8_e4m3* ptr) const {
+  *ptr = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(
+    __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
+  *dst = *src;
+}
+
+// __nv_fp8_e4m3 x 2
+template <>
+struct vec_t<__nv_fp8_e4m3, 2> {
+  __nv_fp8x2_e4m3 data;
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) {
+  data.__x =
+      (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3* ptr) {
+  data = *((__nv_fp8x2_e4m3*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(__nv_fp8_e4m3* ptr) const {
+  *((__nv_fp8x2_e4m3*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(
+    __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
+  *((__nv_fp8x2_e4m3*)dst) = *((__nv_fp8x2_e4m3*)src);
+}
+
+// __nv_fp8_e4m3 x 4
+
+template <>
+struct vec_t<__nv_fp8_e4m3, 4> {
+  __nv_fp8x4_e4m3 data;
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) {
+  data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+             (__nv_fp8x4_storage_t(val.__x) << 16) |
+             (__nv_fp8x4_storage_t(val.__x) << 8) |
+             __nv_fp8x4_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3* ptr) {
+  data = *((__nv_fp8x4_e4m3*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(__nv_fp8_e4m3* ptr) const {
+  *((__nv_fp8x4_e4m3*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(
+    __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
+  *((__nv_fp8x4_e4m3*)dst) = *((__nv_fp8x4_e4m3*)src);
+}
+
+// __nv_fp8_e4m3 x 8
+
+template <>
+struct vec_t<__nv_fp8_e4m3, 8> {
+  uint2 data;
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 8>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) {
+  ((__nv_fp8x4_e4m3*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 16) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 8) |
+                                       __nv_fp8x4_storage_t(val.__x);
+  ((__nv_fp8x4_e4m3*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 16) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 8) |
+                                       __nv_fp8x4_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3* ptr) {
+  data = *((uint2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(__nv_fp8_e4m3* ptr) const {
+  *((uint2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(
+    __nv_fp8_e4m3* dst, const __nv_fp8_e4m3* src) {
+  *((uint2*)dst) = *((uint2*)src);
+}
+
+// __nv_fp8_e4m3 x 16 or more
+template <size_t vec_size>
+struct vec_t<__nv_fp8_e4m3, vec_size> {
+  uint4 data[vec_size / 16];
+
+  APHRODITE_INLINE __nv_fp8_e4m3& operator[](size_t i) {
+    return ((__nv_fp8_e4m3*)data)[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e4m3& operator[](size_t i) const {
+    return ((const __nv_fp8_e4m3*)data)[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e4m3* ptr() {
+    return reinterpret_cast<__nv_fp8_e4m3*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e4m3 val) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((__nv_fp8x4_e4m3*)(&(data[i].x)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e4m3*)(&(data[i].y)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e4m3*)(&(data[i].z)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e4m3*)(&(data[i].w)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+    }
+  }
+  APHRODITE_INLINE void load(const __nv_fp8_e4m3* ptr) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      data[i] = ((uint4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(__nv_fp8_e4m3* ptr) const {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((uint4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e4m3* dst,
+                                      const __nv_fp8_e4m3* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((uint4*)dst)[i] = ((uint4*)src)[i];
+    }
+  }
+};
+
+/******************* vec_t<__nv_fp8_e5m2> *******************/
+
+// __nv_fp8_e5m2 x 1
+template <>
+struct vec_t<__nv_fp8_e5m2, 1> {
+  __nv_fp8_e5m2 data;
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) {
+  data = val;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2* ptr) {
+  data = *ptr;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(__nv_fp8_e5m2* ptr) const {
+  *ptr = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(
+    __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
+  *dst = *src;
+}
+
+// __nv_fp8_e5m2 x 2
+template <>
+struct vec_t<__nv_fp8_e5m2, 2> {
+  __nv_fp8x2_e5m2 data;
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) {
+  data.__x =
+      (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2* ptr) {
+  data = *((__nv_fp8x2_e5m2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(__nv_fp8_e5m2* ptr) const {
+  *((__nv_fp8x2_e5m2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(
+    __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
+  *((__nv_fp8x2_e5m2*)dst) = *((__nv_fp8x2_e5m2*)src);
+}
+
+// __nv_fp8_e5m2 x 4
+
+template <>
+struct vec_t<__nv_fp8_e5m2, 4> {
+  __nv_fp8x4_e5m2 data;
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) {
+  data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+             (__nv_fp8x4_storage_t(val.__x) << 16) |
+             (__nv_fp8x4_storage_t(val.__x) << 8) |
+             __nv_fp8x4_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2* ptr) {
+  data = *((__nv_fp8x4_e5m2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(__nv_fp8_e5m2* ptr) const {
+  *((__nv_fp8x4_e5m2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(
+    __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
+  *((__nv_fp8x4_e5m2*)dst) = *((__nv_fp8x4_e5m2*)src);
+}
+
+// __nv_fp8_e5m2 x 8
+
+template <>
+struct vec_t<__nv_fp8_e5m2, 8> {
+  uint2 data;
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)(&data))[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val);
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr);
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 8>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src);
+};
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) {
+  ((__nv_fp8x4_e5m2*)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 16) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 8) |
+                                       __nv_fp8x4_storage_t(val.__x);
+  ((__nv_fp8x4_e5m2*)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 16) |
+                                       (__nv_fp8x4_storage_t(val.__x) << 8) |
+                                       __nv_fp8x4_storage_t(val.__x);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2* ptr) {
+  data = *((uint2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(__nv_fp8_e5m2* ptr) const {
+  *((uint2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(
+    __nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) {
+  *((uint2*)dst) = *((uint2*)src);
+}
+
+// __nv_fp8_e5m2 x 16 or more
+
+template <size_t vec_size>
+struct vec_t<__nv_fp8_e5m2, vec_size> {
+  uint4 data[vec_size / 16];
+
+  APHRODITE_INLINE __nv_fp8_e5m2& operator[](size_t i) {
+    return ((__nv_fp8_e5m2*)data)[i];
+  }
+  APHRODITE_INLINE const __nv_fp8_e5m2& operator[](size_t i) const {
+    return ((const __nv_fp8_e5m2*)data)[i];
+  }
+  APHRODITE_INLINE __nv_fp8_e5m2* ptr() {
+    return reinterpret_cast<__nv_fp8_e5m2*>(&data);
+  }
+  APHRODITE_INLINE void fill(__nv_fp8_e5m2 val) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((__nv_fp8x4_e5m2*)(&(data[i].x)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e5m2*)(&(data[i].y)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e5m2*)(&(data[i].z)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+      ((__nv_fp8x4_e5m2*)(&(data[i].w)))->__x =
+          (__nv_fp8x4_storage_t(val.__x) << 24) |
+          (__nv_fp8x4_storage_t(val.__x) << 16) |
+          (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x);
+    }
+  }
+  APHRODITE_INLINE void load(const __nv_fp8_e5m2* ptr) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      data[i] = ((uint4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(__nv_fp8_e5m2* ptr) const {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((uint4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(__nv_fp8_e5m2* dst,
+                                      const __nv_fp8_e5m2* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 16; ++i) {
+      ((uint4*)dst)[i] = ((uint4*)src)[i];
+    }
+  }
+};
+
+/******************* vec_t<half> *******************/
+
+// half x 1
+template <>
+struct vec_t<half, 1> {
+  half data;
+
+  APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
+  APHRODITE_INLINE const half& operator[](size_t i) const {
+    return ((const half*)(&data))[i];
+  }
+  APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
+  APHRODITE_INLINE void fill(half val);
+  APHRODITE_INLINE void load(const half* ptr);
+  APHRODITE_INLINE void store(half* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(half* dst, const half* src);
+};
+
+APHRODITE_INLINE void vec_t<half, 1>::fill(half val) { data = val; }
+
+APHRODITE_INLINE void vec_t<half, 1>::load(const half* ptr) { data = *ptr; }
+
+APHRODITE_INLINE void vec_t<half, 1>::store(half* ptr) const { *ptr = data; }
+
+APHRODITE_INLINE void vec_t<half, 1>::memcpy(half* dst, const half* src) {
+  *dst = *src;
+}
+
+// half x 2
+template <>
+struct vec_t<half, 2> {
+  half2 data;
+
+  APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
+  APHRODITE_INLINE const half& operator[](size_t i) const {
+    return ((const half*)(&data))[i];
+  }
+  APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
+  APHRODITE_INLINE void fill(half val);
+  APHRODITE_INLINE void load(const half* ptr);
+  APHRODITE_INLINE void store(half* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+
+  APHRODITE_INLINE static void memcpy(half* dst, const half* src);
+};
+
+APHRODITE_INLINE void vec_t<half, 2>::fill(half val) {
+  data = make_half2(val, val);
+}
+
+APHRODITE_INLINE void vec_t<half, 2>::load(const half* ptr) {
+  data = *((half2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<half, 2>::store(half* ptr) const {
+  *((half2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<half, 2>::memcpy(half* dst, const half* src) {
+  *((half2*)dst) = *((half2*)src);
+}
+
+// half x 4
+
+template <>
+struct vec_t<half, 4> {
+  uint2 data;
+
+  APHRODITE_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; }
+  APHRODITE_INLINE const half& operator[](size_t i) const {
+    return ((const half*)(&data))[i];
+  }
+  APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
+  APHRODITE_INLINE void fill(half val);
+  APHRODITE_INLINE void load(const half* ptr);
+  APHRODITE_INLINE void store(half* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(half* dst, const half* src);
+};
+
+APHRODITE_INLINE void vec_t<half, 4>::fill(half val) {
+  *(half2*)(&data.x) = make_half2(val, val);
+  *(half2*)(&data.y) = make_half2(val, val);
+}
+
+APHRODITE_INLINE void vec_t<half, 4>::load(const half* ptr) {
+  data = *((uint2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<half, 4>::store(half* ptr) const {
+  *((uint2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<half, 4>::memcpy(half* dst, const half* src) {
+  *((uint2*)dst) = *((uint2*)src);
+}
+
+// half x 8 or more
+
+template <size_t vec_size>
+struct vec_t<half, vec_size> {
+  uint4 data[vec_size / 8];
+  APHRODITE_INLINE half& operator[](size_t i) { return ((half*)data)[i]; }
+  APHRODITE_INLINE const half& operator[](size_t i) const {
+    return ((const half*)data)[i];
+  }
+  APHRODITE_INLINE half* ptr() { return reinterpret_cast<half*>(&data); }
+  APHRODITE_INLINE void fill(half val) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      *(half2*)(&(data[i].x)) = make_half2(val, val);
+      *(half2*)(&(data[i].y)) = make_half2(val, val);
+      *(half2*)(&(data[i].z)) = make_half2(val, val);
+      *(half2*)(&(data[i].w)) = make_half2(val, val);
+    }
+  }
+  APHRODITE_INLINE void load(const half* ptr) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      data[i] = ((uint4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(half* ptr) const {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      ((uint4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(half* dst, const half* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      ((uint4*)dst)[i] = ((uint4*)src)[i];
+    }
+  }
+};
+
+/******************* vec_t<nv_bfloat16> *******************/
+
+// nv_bfloat16 x 1
+template <>
+struct vec_t<nv_bfloat16, 1> {
+  nv_bfloat16 data;
+  APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
+    return ((nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
+    return ((const nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE nv_bfloat16* ptr() {
+    return reinterpret_cast<nv_bfloat16*>(&data);
+  }
+  APHRODITE_INLINE void fill(nv_bfloat16 val);
+  APHRODITE_INLINE void load(const nv_bfloat16* ptr);
+  APHRODITE_INLINE void store(nv_bfloat16* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
+};
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::fill(nv_bfloat16 val) {
+  data = val;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::load(const nv_bfloat16* ptr) {
+  data = *ptr;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::store(nv_bfloat16* ptr) const {
+  *ptr = data;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 1>::memcpy(nv_bfloat16* dst,
+                                                    const nv_bfloat16* src) {
+  *dst = *src;
+}
+
+// nv_bfloat16 x 2
+template <>
+struct vec_t<nv_bfloat16, 2> {
+  nv_bfloat162 data;
+
+  APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
+    return ((nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
+    return ((const nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE nv_bfloat16* ptr() {
+    return reinterpret_cast<nv_bfloat16*>(&data);
+  }
+  APHRODITE_INLINE void fill(nv_bfloat16 val);
+  APHRODITE_INLINE void load(const nv_bfloat16* ptr);
+  APHRODITE_INLINE void store(nv_bfloat16* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
+};
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::fill(nv_bfloat16 val) {
+  data = make_bfloat162(val, val);
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::load(const nv_bfloat16* ptr) {
+  data = *((nv_bfloat162*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::store(nv_bfloat16* ptr) const {
+  *((nv_bfloat162*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 2>::memcpy(nv_bfloat16* dst,
+                                                    const nv_bfloat16* src) {
+  *((nv_bfloat162*)dst) = *((nv_bfloat162*)src);
+}
+
+// nv_bfloat16 x 4
+
+template <>
+struct vec_t<nv_bfloat16, 4> {
+  uint2 data;
+
+  APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
+    return ((nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
+    return ((const nv_bfloat16*)(&data))[i];
+  }
+  APHRODITE_INLINE nv_bfloat16* ptr() {
+    return reinterpret_cast<nv_bfloat16*>(&data);
+  }
+  APHRODITE_INLINE void fill(nv_bfloat16 val);
+  APHRODITE_INLINE void load(const nv_bfloat16* ptr);
+  APHRODITE_INLINE void store(nv_bfloat16* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 4>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(nv_bfloat16* dst, const nv_bfloat16* src);
+};
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::fill(nv_bfloat16 val) {
+  *(nv_bfloat162*)(&data.x) = make_bfloat162(val, val);
+  *(nv_bfloat162*)(&data.y) = make_bfloat162(val, val);
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::load(const nv_bfloat16* ptr) {
+  data = *((uint2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::store(nv_bfloat16* ptr) const {
+  *((uint2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<nv_bfloat16, 4>::memcpy(nv_bfloat16* dst,
+                                                    const nv_bfloat16* src) {
+  *((uint2*)dst) = *((uint2*)src);
+}
+
+// nv_bfloat16 x 8 or more
+
+template <size_t vec_size>
+struct vec_t<nv_bfloat16, vec_size> {
+  uint4 data[vec_size / 8];
+
+  APHRODITE_INLINE nv_bfloat16& operator[](size_t i) {
+    return ((nv_bfloat16*)data)[i];
+  }
+  APHRODITE_INLINE const nv_bfloat16& operator[](size_t i) const {
+    return ((const nv_bfloat16*)data)[i];
+  }
+  APHRODITE_INLINE nv_bfloat16* ptr() {
+    return reinterpret_cast<nv_bfloat16*>(&data);
+  }
+  APHRODITE_INLINE void fill(nv_bfloat16 val) {
+#pragma unoll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      *(nv_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val);
+      *(nv_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val);
+      *(nv_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val);
+      *(nv_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val);
+    }
+  }
+  APHRODITE_INLINE void load(const nv_bfloat16* ptr) {
+#pragma unoll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      data[i] = ((uint4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(nv_bfloat16* ptr) const {
+#pragma unoll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      ((uint4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(nv_bfloat16* dst,
+                                      const nv_bfloat16* src) {
+#pragma unoll
+    for (size_t i = 0; i < vec_size / 8; ++i) {
+      ((uint4*)dst)[i] = ((uint4*)src)[i];
+    }
+  }
+};
+
+/******************* vec_t<float> *******************/
+
+// float x 1
+
+template <>
+struct vec_t<float, 1> {
+  float data;
+
+  APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
+  APHRODITE_INLINE const float& operator[](size_t i) const {
+    return ((const float*)(&data))[i];
+  }
+  APHRODITE_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
+  APHRODITE_INLINE void fill(float val);
+  APHRODITE_INLINE void load(const float* ptr);
+  APHRODITE_INLINE void store(float* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 1>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(float* dst, const float* src);
+};
+
+APHRODITE_INLINE void vec_t<float, 1>::fill(float val) { data = val; }
+
+APHRODITE_INLINE void vec_t<float, 1>::load(const float* ptr) { data = *ptr; }
+
+APHRODITE_INLINE void vec_t<float, 1>::store(float* ptr) const { *ptr = data; }
+
+APHRODITE_INLINE void vec_t<float, 1>::memcpy(float* dst, const float* src) {
+  *dst = *src;
+}
+
+// float x 2
+
+template <>
+struct vec_t<float, 2> {
+  float2 data;
+
+  APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; }
+  APHRODITE_INLINE const float& operator[](size_t i) const {
+    return ((const float*)(&data))[i];
+  }
+  APHRODITE_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
+  APHRODITE_INLINE void fill(float val);
+  APHRODITE_INLINE void load(const float* ptr);
+  APHRODITE_INLINE void store(float* ptr) const;
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, 2>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(float* dst, const float* src);
+};
+
+APHRODITE_INLINE void vec_t<float, 2>::fill(float val) {
+  data = make_float2(val, val);
+}
+
+APHRODITE_INLINE void vec_t<float, 2>::load(const float* ptr) {
+  data = *((float2*)ptr);
+}
+
+APHRODITE_INLINE void vec_t<float, 2>::store(float* ptr) const {
+  *((float2*)ptr) = data;
+}
+
+APHRODITE_INLINE void vec_t<float, 2>::memcpy(float* dst, const float* src) {
+  *((float2*)dst) = *((float2*)src);
+}
+
+// float x 4 or more
+template <size_t vec_size>
+struct vec_t<float, vec_size> {
+  float4 data[vec_size / 4];
+
+  APHRODITE_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; }
+  APHRODITE_INLINE const float& operator[](size_t i) const {
+    return ((const float*)(data))[i];
+  }
+  APHRODITE_INLINE float* ptr() { return reinterpret_cast<float*>(&data); }
+  APHRODITE_INLINE void fill(float val) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 4; ++i) {
+      data[i] = make_float4(val, val, val, val);
+    }
+  }
+  APHRODITE_INLINE void load(const float* ptr) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 4; ++i) {
+      data[i] = ((float4*)ptr)[i];
+    }
+  }
+  APHRODITE_INLINE void store(float* ptr) const {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 4; ++i) {
+      ((float4*)ptr)[i] = data[i];
+    }
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_from(const vec_t<T, vec_size>& src) {
+    cast_from_impl(*this, src);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_load(const T* ptr) {
+    cast_load_impl(*this, ptr);
+  }
+  template <typename T>
+  APHRODITE_INLINE void cast_store(T* ptr) const {
+    cast_store_impl(ptr, *this);
+  }
+  APHRODITE_INLINE static void memcpy(float* dst, const float* src) {
+#pragma unroll
+    for (size_t i = 0; i < vec_size / 4; ++i) {
+      ((float4*)dst)[i] = ((float4*)src)[i];
+    }
+  }
+};
+
+}  // namespace aphrodite
+
+#endif  // VEC_DTYPES_CUH_

+ 22 - 0
kernels/torch_bindings.cpp

@@ -207,6 +207,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   ops.impl("fp_eXmY_linear_forward_cuda", torch::kCUDA,
            &fp_eXmY_linear_forward_cuda);
 
+  // Sampling Kernels
+  ops.def("sampling_from_probs", &sampling_from_probs);
+  ops.impl("sampling_from_probs", torch::kCUDA, &sampling_from_probs);
+  ops.def("top_k_sampling_from_probs", &top_k_sampling_from_probs);
+  ops.impl("top_k_sampling_from_probs", torch::kCUDA,
+           &top_k_sampling_from_probs);
+  ops.def("min_p_sampling_from_probs", &min_p_sampling_from_probs);
+  ops.impl("min_p_sampling_from_probs", torch::kCUDA,
+           &min_p_sampling_from_probs);
+  ops.def("top_p_sampling_from_probs", &top_p_sampling_from_probs);
+  ops.impl("top_p_sampling_from_probs", torch::kCUDA,
+           &top_p_sampling_from_probs);
+  ops.def("top_k_top_p_sampling_from_probs", &top_k_top_p_sampling_from_probs);
+  ops.impl("top_k_top_p_sampling_from_probs", torch::kCUDA,
+           &top_k_top_p_sampling_from_probs);
+  ops.def("top_k_renorm_prob", &top_k_renorm_prob);
+  ops.impl("top_k_renorm_prob", torch::kCUDA, &top_k_renorm_prob);
+  ops.def("top_p_renorm_prob", &top_p_renorm_prob);
+  ops.impl("top_p_renorm_prob", torch::kCUDA, &top_p_renorm_prob);
+  ops.def("top_k_mask_logits", &top_k_mask_logits);
+  ops.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
+
 #endif
 
   // Quantized GEMM for GPTQ.

+ 9 - 2
tests/benchmarks/engine/throughput.py

@@ -75,6 +75,7 @@ def run_aphrodite(
     dtype: str,
     max_model_len: Optional[int],
     enforce_eager: bool,
+    max_seq_len_to_capture: int,
     kv_cache_dtype: str,
     quantization_param_path: Optional[str],
     device: str,
@@ -100,6 +101,7 @@ def run_aphrodite(
         max_model_len=max_model_len,
         gpu_memory_utilization=gpu_memory_utilization,
         enforce_eager=enforce_eager,
+        max_seq_len_to_capture=max_seq_len_to_capture,
         kv_cache_dtype=kv_cache_dtype,
         quantization_param_path=quantization_param_path,
         device=device,
@@ -233,8 +235,8 @@ def main(args: argparse.Namespace):
             args.quant_llm_fp_bits,
             args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
             args.trust_remote_code, args.dtype, args.max_model_len,
-            args.enforce_eager, args.kv_cache_dtype,
-            args.quantization_param_path, args.device,
+            args.enforce_eager, args.max_seq_len_to_capture,
+            args.kv_cache_dtype, args.quantization_param_path, args.device,
             args.enable_prefix_caching, args.enable_chunked_prefill,
             args.max_num_batched_tokens, args.distributed_executor_backend,
             args.gpu_memory_utilization, args.download_dir, args.load_format,
@@ -344,6 +346,11 @@ if __name__ == "__main__":
     parser.add_argument("--enforce-eager",
                         action="store_true",
                         help="enforce eager execution")
+    parser.add_argument("--max-seq-len-to-capture",
+                        type=int,
+                        default=None,
+                        help="The maximum sequence length to capture for "
+                        "CUDA graphs.")
     parser.add_argument(
         '--kv-cache-dtype',
         type=str,