|
@@ -17,6 +17,7 @@ from aphrodite.common.sampling_params import SamplingType
|
|
|
from aphrodite.common.sequence import (CompletionSequenceGroupOutput, Logprob,
|
|
|
PromptLogprobs, SampleLogprobs,
|
|
|
SequenceOutput)
|
|
|
+from aphrodite.common.utils import is_cpu
|
|
|
from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
|
|
|
from aphrodite.triton_utils import HAS_TRITON
|
|
|
|
|
@@ -1464,7 +1465,7 @@ def _sample_with_torch(
|
|
|
|
|
|
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
|
|
|
seq_groups)
|
|
|
- if APHRODITE_USE_SAMPLING_KERNELS is not None:
|
|
|
+ if APHRODITE_USE_SAMPLING_KERNELS and not is_cpu():
|
|
|
multinomial_samples[
|
|
|
sampling_type] = _top_k_top_p_multinomial_with_kernels(
|
|
|
probs[long_sample_indices],
|