12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523 |
- /*
- * Copyright (c) 2024 by PygmalionAI team.
- * Copyright (c) 2024 by FlashInfer team.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef APHRODITE_SAMPLING_CUH_
- #define APHRODITE_SAMPLING_CUH_
- #include <cub/block/block_adjacent_difference.cuh>
- #include <cub/block/block_reduce.cuh>
- #include <cub/block/block_scan.cuh>
- #include <numeric>
- #include "math.cuh"
- #include "utils.cuh"
- #include "vec_dtypes.cuh"
- namespace aphrodite {
- namespace sampling {
- using namespace cub;
- #define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
- if (deterministic) { \
- constexpr bool DETERMINISTIC = true; \
- __VA_ARGS__ \
- } else { \
- constexpr bool DETERMINISTIC = false; \
- __VA_ARGS__ \
- }
- constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
- constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
- #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100)
- #define APHRODITE_CUB_SUBTRACTLEFT_DEFINED
- #endif
- template <typename T>
- struct Pair {
- T value;
- int count;
- __device__ Pair operator+(const Pair& other) const {
- return {value + other.value, count + other.count};
- }
- __device__ Pair& operator+=(const Pair& other) {
- value += other.value;
- count += other.count;
- return *this;
- }
- };
- struct BoolDiffOp {
- __device__ __forceinline__ bool operator()(const bool& lhs,
- const bool& rhs) const {
- return lhs != rhs;
- }
- };
- template <typename T, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
- BlockReduceAlgorithm REDUCE_ALGORITHM>
- struct SamplingTempStorage {
- union {
- T deterministic_scan[BLOCK_THREADS / 32];
- typename BlockScan<T, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
- typename BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
- reduce;
- typename BlockReduce<Pair<T>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
- reduce_pair;
- typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
- } block_prim;
- struct {
- int32_t sampled_id;
- union {
- T value;
- Pair<T> pair;
- T max_p;
- } block_aggregate;
- } data;
- };
- /*!
- * \brief Deterministic inclusive scan implementation, use Belloch scan
- * algorithm. \note This implementation is slower than the cub::BlockScan, but
- * it is deterministic.
- */
- template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
- BlockScanAlgorithm SCAN_ALGORITHM,
- BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
- __device__ __forceinline__ void DeterministicInclusiveSum(
- const T* in_data, T* out_data,
- SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
- temp_storage) {
- T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
- T thread_data[VEC_SIZE];
- T thread_sum = 0;
- #pragma unroll
- for (uint32_t i = 0; i < VEC_SIZE; ++i) {
- thread_sum += in_data[i];
- thread_data[i] = thread_sum;
- }
- T thread_exclusive_prefix_sum = thread_sum;
- #pragma unroll
- for (uint32_t offset = 1; offset < 32; offset *= 2) {
- T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
- if ((threadIdx.x + 1) % (offset * 2) == 0) {
- thread_exclusive_prefix_sum += tmp;
- }
- }
- T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
- threadIdx.x | 0xffffffff);
- if (threadIdx.x % 32 == 31) {
- thread_exclusive_prefix_sum = 0;
- }
- #pragma unroll
- for (uint32_t offset = 16; offset >= 1; offset /= 2) {
- T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
- if ((threadIdx.x + 1) % (offset * 2) == 0) {
- thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum;
- }
- if ((threadIdx.x + 1) % (offset * 2) == offset) {
- thread_exclusive_prefix_sum = tmp;
- }
- }
- smem_prefix_sum[threadIdx.x / 32] = warp_sum;
- __syncthreads();
- if (threadIdx.x < 32) {
- T warp_exclusive_prefix_sum =
- (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0;
- #pragma unroll
- for (uint32_t offset = 1; offset < 32; offset *= 2) {
- T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
- if ((threadIdx.x + 1) % (offset * 2) == 0) {
- warp_exclusive_prefix_sum += tmp;
- }
- }
- if (threadIdx.x % 32 == 31) {
- warp_exclusive_prefix_sum = 0;
- }
- #pragma unroll
- for (uint32_t offset = 16; offset >= 1; offset /= 2) {
- T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
- if ((threadIdx.x + 1) % (offset * 2) == 0) {
- warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum;
- }
- if ((threadIdx.x + 1) % (offset * 2) == offset) {
- warp_exclusive_prefix_sum = tmp;
- }
- }
- if (threadIdx.x < BLOCK_THREADS / 32) {
- smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum;
- }
- }
- __syncthreads();
- #pragma unroll
- for (uint32_t i = 0; i < VEC_SIZE; ++i) {
- out_data[i] = smem_prefix_sum[threadIdx.x / 32] +
- thread_exclusive_prefix_sum + thread_data[i];
- }
- }
- template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
- BlockScanAlgorithm SCAN_ALGORITHM,
- BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T>
- __device__ __forceinline__ void DeviceSamplingFromProb(
- uint32_t i, uint32_t d, T threshold, T u, vec_t<T, VEC_SIZE> prob_vec,
- T& aggregate,
- SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
- temp_storage) {
- const uint32_t tx = threadIdx.x;
- T prob_greater_than_threshold[VEC_SIZE];
- T inclusive_cdf[VEC_SIZE];
- bool greater_than_u[VEC_SIZE], valid[VEC_SIZE];
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- prob_greater_than_threshold[j] =
- (prob_vec[j] > threshold) ? prob_vec[j] : T(0);
- valid[j] =
- prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
- }
- T aggregate_local = BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage->block_prim.reduce)
- .Sum<VEC_SIZE>(prob_greater_than_threshold);
- if (tx == 0) {
- temp_storage->data.block_aggregate.value = aggregate_local;
- }
- __syncthreads();
- aggregate_local = temp_storage->data.block_aggregate.value;
- if (aggregate + aggregate_local > u) {
- if constexpr (DETERMINISTIC) {
- DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM, T>(
- prob_greater_than_threshold, inclusive_cdf, temp_storage);
- } else {
- BlockScan<T, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
- .InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
- __syncthreads();
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- greater_than_u[j] = inclusive_cdf[j] + aggregate > u;
- }
- bool greater_than_u_diff[VEC_SIZE];
- #ifdef APHRODITE_CUB_SUBTRACTLEFT_DEFINED
- BlockAdjacentDifference<bool, BLOCK_THREADS>(
- temp_storage->block_prim.adj_diff)
- .SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff,
- BoolDiffOp());
- #else
- BlockAdjacentDifference<bool, BLOCK_THREADS>(
- temp_storage->block_prim.adj_diff)
- .FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(),
- 0);
- #endif
- __syncthreads();
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- if (greater_than_u_diff[j] && valid[j]) {
- if constexpr (DETERMINISTIC) {
- temp_storage->data.sampled_id =
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
- } else {
- // cub's block scan result might not be monotonic, so we need to find
- // the first element
- atomicMin(&(temp_storage->data.sampled_id),
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
- }
- }
- }
- __syncthreads();
- }
- aggregate += aggregate_local;
- }
- template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
- BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
- bool DETERMINISTIC, typename DType, typename IdType>
- __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples,
- IdType* output, IdType* row_indices,
- uint32_t d) {
- const uint32_t bx = blockIdx.x, tx = threadIdx.x;
- const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
- extern __shared__ __align__(
- alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
- auto& temp_storage =
- reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>&>(smem_sampling);
- temp_storage.data.sampled_id = d - 1;
- __syncthreads();
- vec_t<DType, VEC_SIZE> probs_vec;
- DType aggregate(0);
- float u = uniform_samples[bx];
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM, DETERMINISTIC, DType>(
- i, d, DType(0), u, probs_vec, aggregate, &temp_storage);
- if (float(aggregate) > u) {
- break;
- }
- }
- output[bx] = temp_storage.data.sampled_id;
- }
- template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
- BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
- bool DETERMINISTIC, typename DType, typename IdType>
- __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
- IdType* output, bool* success,
- IdType* top_k_arr,
- uint32_t top_k_val, uint32_t d,
- uint32_t max_top_k_rounds) {
- const uint32_t batch_size = gridDim.x;
- const uint32_t bx = blockIdx.x, tx = threadIdx.x;
- uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
- extern __shared__ __align__(
- alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
- auto& temp_storage =
- reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>&>(smem_sampling);
- vec_t<DType, VEC_SIZE> probs_vec;
- DType aggregate;
- DType q = DType(1);
- DType pivot = DType(0);
- IdType sampled_id;
- for (uint32_t round = 0; round < max_top_k_rounds; ++round) {
- temp_storage.data.sampled_id = d - 1;
- __syncthreads();
- DType u = uniform_samples[round * batch_size + bx] * q;
- aggregate = DType(0);
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
- }
- DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM, DETERMINISTIC, DType>(
- i, d, pivot, u, probs_vec, aggregate, &temp_storage);
- if (aggregate > u) {
- break;
- }
- }
- __syncthreads();
- sampled_id = temp_storage.data.sampled_id;
- pivot = max(pivot, probs[bx * d + sampled_id]);
- Pair<DType> aggregate_gt_pivot{DType(0), 0};
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
- }
- Pair<DType> probs_gt_pivot[VEC_SIZE];
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
- (probs_vec[j] > pivot &&
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
- }
- aggregate_gt_pivot +=
- BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce_pair)
- .Sum<VEC_SIZE>(probs_gt_pivot);
- if (tx == 0) {
- temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
- }
- __syncthreads();
- }
- q = temp_storage.data.block_aggregate.pair.value;
- if (temp_storage.data.block_aggregate.pair.count < k) {
- break;
- }
- }
- __syncthreads();
- if (tx == 0) {
- output[bx] = sampled_id;
- if (temp_storage.data.block_aggregate.pair.count >= k) {
- // failed to sample within MAX_TOP_P_ROUNDS
- if (success != nullptr) {
- success[bx] = false;
- }
- } else {
- if (success != nullptr) {
- success[bx] = true;
- }
- }
- }
- }
- template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
- BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
- bool DETERMINISTIC, typename DType, typename IdType>
- __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
- IdType* output, bool* success,
- IdType* row_indices,
- float* top_p_arr, float top_p_val,
- uint32_t d,
- uint32_t max_top_p_rounds) {
- const uint32_t batch_size = gridDim.x;
- const uint32_t bx = blockIdx.x, tx = threadIdx.x;
- float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[bx];
- const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
- extern __shared__ __align__(
- alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
- auto& temp_storage =
- reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>&>(smem_sampling);
- vec_t<DType, VEC_SIZE> probs_vec;
- DType aggregate;
- DType q = DType(1);
- DType pivot = DType(0);
- IdType sampled_id;
- for (uint32_t round = 0; round < max_top_p_rounds; ++round) {
- temp_storage.data.sampled_id = d - 1;
- __syncthreads();
- DType u = uniform_samples[round * batch_size + bx] * q;
- aggregate = DType(0);
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + row_idx * d +
- (i * BLOCK_THREADS + tx) * VEC_SIZE);
- }
- DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM, DETERMINISTIC, DType>(
- i, d, pivot, u, probs_vec, aggregate, &temp_storage);
- if (aggregate > u) {
- break;
- }
- }
- __syncthreads();
- sampled_id = temp_storage.data.sampled_id;
- pivot = max(pivot, probs[row_idx * d + sampled_id]);
- DType aggregate_gt_pivot = DType(0);
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + row_idx * d +
- (i * BLOCK_THREADS + tx) * VEC_SIZE);
- }
- DType probs_gt_pivot[VEC_SIZE];
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
- }
- aggregate_gt_pivot +=
- BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
- .Sum<VEC_SIZE>(probs_gt_pivot);
- if (tx == 0) {
- temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
- }
- __syncthreads();
- }
- q = temp_storage.data.block_aggregate.value;
- if (float(q) < top_p) {
- break;
- }
- }
- __syncthreads();
- if (tx == 0) {
- output[bx] = sampled_id;
- if (float(q) >= top_p) {
- // failed to sample within MAX_TOP_P_ROUNDS
- if (success != nullptr) {
- success[bx] = false;
- }
- } else {
- if (success != nullptr) {
- success[bx] = true;
- }
- }
- }
- }
- template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
- BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
- bool DETERMINISTIC, typename DType, typename IdType>
- __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
- DType* min_p_arr, IdType* output,
- bool* success, float min_p_val,
- uint32_t d,
- uint32_t max_min_p_rounds) {
- const uint32_t batch_size = gridDim.x;
- const uint32_t bx = blockIdx.x, tx = threadIdx.x;
- DType p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx];
- extern __shared__ __align__(
- alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
- auto& temp_storage =
- reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>&>(smem_sampling);
- vec_t<DType, VEC_SIZE> probs_vec;
- DType aggregate;
- DType q = DType(1);
- DType pivot = DType(0);
- DType max_p = 0;
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
- }
- DType probs_[VEC_SIZE];
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_[j] = probs_vec[j];
- }
- max_p = max(
- max_p, BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
- .Reduce<VEC_SIZE>(probs_, cub::Max()));
- __syncthreads();
- }
- if (tx == 0) {
- temp_storage.data.block_aggregate.max_p = max_p;
- }
- __syncthreads();
- DType scaled_p = temp_storage.data.block_aggregate.max_p * p;
- IdType sampled_id;
- for (uint32_t round = 0; round < max_min_p_rounds; ++round) {
- temp_storage.data.sampled_id = d - 1;
- __syncthreads();
- DType u = uniform_samples[round * batch_size + bx] * q;
- aggregate = DType(0);
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
- }
- DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM, DETERMINISTIC, DType>(
- i, d, pivot, u, probs_vec, aggregate, &temp_storage);
- if (aggregate > u) {
- break;
- }
- }
- __syncthreads();
- sampled_id = temp_storage.data.sampled_id;
- pivot = max(pivot, probs[bx * d + sampled_id]);
- if (pivot >= scaled_p) {
- break;
- }
- DType aggregate_gt_pivot = DType(0);
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
- }
- DType probs_gt_pivot[VEC_SIZE];
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
- }
- aggregate_gt_pivot +=
- BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
- .Sum<VEC_SIZE>(probs_gt_pivot);
- if (tx == 0) {
- temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
- }
- __syncthreads();
- }
- q = temp_storage.data.block_aggregate.value;
- }
- __syncthreads();
- if (tx == 0) {
- output[bx] = sampled_id;
- if (pivot < scaled_p) {
- // failed to sample within MAX_ROUNDS
- if (success != nullptr) {
- success[bx] = false;
- }
- } else {
- if (success != nullptr) {
- success[bx] = true;
- }
- }
- }
- }
- template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
- BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
- bool DETERMINISTIC, typename DType, typename IdType>
- __global__ void TopKTopPSamplingFromProbKernel(
- DType* probs, DType* uniform_samples, IdType* top_k_arr, DType* top_p_arr,
- IdType* output, bool* success, IdType top_k_val, DType top_p_val,
- uint32_t d, uint32_t max_rounds) {
- const uint32_t batch_size = gridDim.x;
- const uint32_t bx = blockIdx.x, tx = threadIdx.x;
- IdType k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
- DType p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
- extern __shared__ __align__(
- alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
- auto& temp_storage =
- reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM>&>(smem_sampling);
- vec_t<DType, VEC_SIZE> probs_vec;
- DType aggregate;
- DType q = DType(1);
- DType pivot = DType(0);
- IdType sampled_id;
- for (uint32_t round = 0; round < max_rounds; ++round) {
- temp_storage.data.sampled_id = d - 1;
- __syncthreads();
- DType u = uniform_samples[round * batch_size + bx] * q;
- aggregate = DType(0);
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
- }
- DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
- REDUCE_ALGORITHM, DETERMINISTIC, DType>(
- i, d, pivot, u, probs_vec, aggregate, &temp_storage);
- if (aggregate > u) {
- break;
- }
- }
- __syncthreads();
- sampled_id = temp_storage.data.sampled_id;
- pivot = max(pivot, probs[bx * d + sampled_id]);
- Pair<DType> aggregate_gt_pivot{DType(0), 0};
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
- }
- Pair<DType> probs_gt_pivot[VEC_SIZE];
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
- (probs_vec[j] > pivot &&
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
- }
- aggregate_gt_pivot +=
- BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce_pair)
- .Sum<VEC_SIZE>(probs_gt_pivot);
- if (tx == 0) {
- temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
- }
- __syncthreads();
- }
- q = temp_storage.data.block_aggregate.pair.value;
- if (temp_storage.data.block_aggregate.pair.count < k && float(q) < p) {
- break;
- }
- }
- __syncthreads();
- if (tx == 0) {
- output[bx] = sampled_id;
- if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) {
- // failed to sample within MAX_TOP_P_ROUNDS
- if (success != nullptr) {
- success[bx] = false;
- }
- } else {
- if (success != nullptr) {
- success[bx] = true;
- }
- }
- }
- }
- template <typename T, typename IdType>
- cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output,
- uint32_t batch_size, uint32_t d,
- bool deterministic, cudaStream_t stream = 0) {
- constexpr uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- IdType* row_indices_placeholder = nullptr;
- void* args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder,
- &d};
- const uint32_t smem_size =
- sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
- DISPATCH_ALIGNED_VEC_SIZE(
- vec_size, VEC_SIZE,
- {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
- auto kernel =
- SamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
- VEC_SIZE, DETERMINISTIC, T, IdType>;
- APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
- smem_size, stream));
- })});
- return cudaSuccess;
- }
- template <typename T, typename IdType>
- cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples,
- IdType* output, IdType* row_indices,
- uint32_t batch_size, uint32_t d,
- bool deterministic,
- cudaStream_t stream = 0) {
- constexpr uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- void* args[] = {&probs, &uniform_samples, &output, &row_indices, &d};
- const uint32_t smem_size =
- sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
- DISPATCH_ALIGNED_VEC_SIZE(
- vec_size, VEC_SIZE,
- {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
- auto kernel =
- SamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
- VEC_SIZE, DETERMINISTIC, T, IdType>;
- APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
- smem_size, stream));
- })});
- return cudaSuccess;
- }
- template <typename T, typename IdType>
- cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
- bool* success, T* top_k_arr,
- uint32_t batch_size, uint32_t top_k_val,
- uint32_t d, uint32_t max_top_k_rounds,
- bool deterministic, cudaStream_t stream = 0) {
- constexpr uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
- const uint32_t smem_size =
- sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- void* args[] = {&probs, &uniform_samples, &output, &success,
- &top_k_arr, &top_k_val, &d, &max_top_k_rounds};
- DISPATCH_ALIGNED_VEC_SIZE(
- vec_size, VEC_SIZE,
- {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
- auto kernel =
- TopKSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
- VEC_SIZE, DETERMINISTIC, T, IdType>;
- APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
- smem_size, stream));
- })});
- return cudaSuccess;
- }
- template <typename T, typename IdType>
- cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
- bool* success, T* top_p_arr,
- uint32_t batch_size, T top_p_val, uint32_t d,
- uint32_t max_top_p_rounds, bool deterministic,
- cudaStream_t stream = 0) {
- constexpr uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
- const uint32_t smem_size =
- sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- IdType* row_indices_placeholder = nullptr;
- void* args[] = {&probs,
- &uniform_samples,
- &output,
- &success,
- &row_indices_placeholder,
- &top_p_arr,
- &top_p_val,
- &d,
- &max_top_p_rounds};
- DISPATCH_ALIGNED_VEC_SIZE(
- vec_size, VEC_SIZE,
- {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
- auto kernel =
- TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
- VEC_SIZE, DETERMINISTIC, T, IdType>;
- APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
- smem_size, stream));
- })});
- return cudaSuccess;
- }
- template <typename T, typename IdType>
- cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr,
- IdType* output, bool* success,
- uint32_t batch_size, float min_p_val,
- uint32_t d, uint32_t max_rounds,
- bool deterministic, cudaStream_t stream = 0) {
- constexpr uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
- const uint32_t smem_size =
- sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- void* args[] = {&probs, &uniform_samples, &min_p_arr, &output,
- &success, &min_p_val, &d, &max_rounds};
- DISPATCH_ALIGNED_VEC_SIZE(
- vec_size, VEC_SIZE,
- {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
- auto kernel =
- MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
- VEC_SIZE, DETERMINISTIC, T, IdType>;
- APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
- smem_size, stream));
- })});
- return cudaSuccess;
- }
- template <typename T, typename IdType>
- cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples,
- IdType* top_k_arr, T* top_p_arr,
- IdType* output, bool* success,
- uint32_t batch_size, IdType top_k_val,
- T top_p_val, uint32_t d,
- uint32_t max_rounds, bool deterministic,
- cudaStream_t stream = 0) {
- constexpr uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
- const uint32_t smem_size =
- sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- void* args[] = {&probs, &uniform_samples, &top_k_arr, &top_p_arr,
- &output, &success, &top_k_val, &top_p_val,
- &d, &max_rounds};
- DISPATCH_ALIGNED_VEC_SIZE(
- vec_size, VEC_SIZE,
- {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
- auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO,
- REDUCE_ALGO, VEC_SIZE,
- DETERMINISTIC, T, IdType>;
- APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
- smem_size, stream));
- })});
- return cudaSuccess;
- }
- template <typename T, uint32_t BLOCK_THREADS,
- BlockReduceAlgorithm REDUCE_ALGORITHM>
- struct RenormTempStorage {
- union {
- typename BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
- reduce;
- typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
- reduce_int;
- typename BlockReduce<Pair<T>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
- reduce_pair;
- } block_prim;
- struct {
- T max_val;
- T min_val;
- union {
- T value;
- int count;
- Pair<T> pair;
- } block_aggregate;
- } data;
- };
- template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
- uint32_t VEC_SIZE, typename DType>
- __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob,
- DType* top_p_arr, float top_p_val,
- uint32_t d) {
- const uint32_t bx = blockIdx.x, tx = threadIdx.x;
- const uint32_t row_idx = bx;
- float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
- extern __shared__ __align__(
- alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
- uint8_t smem_renorm[];
- auto& temp_storage =
- reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
- smem_renorm);
- temp_storage.data.max_val = DType(0);
- vec_t<DType, VEC_SIZE> probs_vec;
- DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
- DType threadlocal_max_val = DType(0);
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_greater_than_pivot[j] = probs_vec[j];
- }
- threadlocal_max_val =
- max(threadlocal_max_val,
- BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
- __syncthreads();
- }
- if (tx == 0) {
- temp_storage.data.max_val = threadlocal_max_val;
- }
- __syncthreads();
- threadlocal_max_val = temp_storage.data.max_val;
- float low = 0, high = threadlocal_max_val;
- DType min_gt_low, max_le_high;
- DType sum_low(1);
- // f(x) = sum(probs[probs > x]), f(x) is non-increasing
- // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p
- // <= high} loop invariant:
- // - f(low) >= p, f(high) < p
- // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
- // stopping condition
- // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
- do {
- DType threadlocal_sum(0);
- float mid = (low + high) / 2;
- min_gt_low = high;
- max_le_high = low;
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_greater_than_pivot[j] =
- (probs_vec[j] > mid) ? probs_vec[j] : DType(0);
- if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
- min_gt_low = min(min_gt_low, probs_vec[j]);
- }
- if (probs_vec[j] <= high &&
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
- max_le_high = max(max_le_high, probs_vec[j]);
- }
- }
- threadlocal_sum += BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Sum<VEC_SIZE>(probs_greater_than_pivot);
- __syncthreads();
- }
- min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce(min_gt_low, cub::Min());
- __syncthreads();
- max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce(max_le_high, cub::Max());
- if (tx == 0) {
- temp_storage.data.block_aggregate.value = threadlocal_sum;
- temp_storage.data.min_val = min_gt_low;
- temp_storage.data.max_val = max_le_high;
- }
- __syncthreads();
- threadlocal_sum = temp_storage.data.block_aggregate.value;
- min_gt_low = temp_storage.data.min_val;
- max_le_high = temp_storage.data.max_val;
- if (threadlocal_sum >= p) {
- low = mid;
- sum_low = float(threadlocal_sum);
- } else {
- high = min(mid, max_le_high);
- }
- } while (min_gt_low != max_le_high);
- DType normalizer = math::ptx_rcp(max(sum_low, 1e-8));
- // normalize
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_vec[j] =
- (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0);
- }
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.store(renormed_prob + row_idx * d +
- i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
- }
- }
- }
- template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
- uint32_t VEC_SIZE, typename DType, typename IdType>
- __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits,
- IdType* top_k_arr, uint32_t top_k_val,
- uint32_t d) {
- const uint32_t bx = blockIdx.x, tx = threadIdx.x;
- const uint32_t row_idx = bx;
- uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
- float pivot = -std::numeric_limits<float>::infinity();
- vec_t<DType, VEC_SIZE> logits_vec;
- if (k < d) {
- extern __shared__ __align__(
- alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
- uint8_t smem_renorm[];
- auto& temp_storage =
- reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
- smem_renorm);
- DType logits_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
- DType threadlocal_max_val = DType(-std::numeric_limits<float>::infinity()),
- threadlocal_min_val = DType(std::numeric_limits<float>::infinity());
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- logits_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- logits_greater_than_pivot[j] = logits_vec[j];
- }
- threadlocal_max_val =
- max(threadlocal_max_val,
- BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Max()));
- __syncthreads();
- threadlocal_min_val =
- min(threadlocal_min_val,
- BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Min()));
- __syncthreads();
- }
- if (tx == 0) {
- temp_storage.data.max_val = threadlocal_max_val;
- temp_storage.data.min_val = threadlocal_min_val;
- }
- __syncthreads();
- threadlocal_max_val = temp_storage.data.max_val;
- threadlocal_min_val = temp_storage.data.min_val;
- float low = threadlocal_min_val - 1, high = threadlocal_max_val;
- DType min_gt_low, max_le_high;
- // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
- // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
- // p <= high} loop invariant:
- // - f(low) >= k, f(high) < k
- // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
- // stopping condition: min_gt_low == max_le_high
- // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
- do {
- int threadlocal_count_sum = 0;
- int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
- float mid = (low + high) / 2;
- min_gt_low = high;
- max_le_high = low;
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- logits_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_greater_than_pivot_count[j] =
- logits_vec[j] > mid &&
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
- if (logits_vec[j] > low &&
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
- min_gt_low = min(min_gt_low, logits_vec[j]);
- }
- if (logits_vec[j] <= high &&
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
- max_le_high = max(max_le_high, logits_vec[j]);
- }
- }
- threadlocal_count_sum +=
- BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce_int)
- .Sum<VEC_SIZE>(probs_greater_than_pivot_count);
- __syncthreads();
- }
- min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce(min_gt_low, cub::Min());
- __syncthreads();
- max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce(max_le_high, cub::Max());
- if (tx == 0) {
- temp_storage.data.block_aggregate.count = threadlocal_count_sum;
- temp_storage.data.min_val = min_gt_low;
- temp_storage.data.max_val = max_le_high;
- }
- __syncthreads();
- threadlocal_count_sum = temp_storage.data.block_aggregate.count;
- min_gt_low = temp_storage.data.min_val;
- max_le_high = temp_storage.data.max_val;
- if (threadlocal_count_sum >= k) {
- low = mid;
- } else {
- high = min(mid, max_le_high);
- }
- } while (min_gt_low != max_le_high);
- pivot = low;
- }
- // masking
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- logits_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- logits_vec[j] = (logits_vec[j] > pivot)
- ? logits_vec[j]
- : DType(-std::numeric_limits<float>::infinity());
- }
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- logits_vec.store(masked_logits + row_idx * d +
- i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
- }
- }
- }
- template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
- uint32_t VEC_SIZE, typename DType, typename IdType>
- __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob,
- IdType* top_k_arr, uint32_t top_k_val,
- uint32_t d) {
- const uint32_t bx = blockIdx.x, tx = threadIdx.x;
- const uint32_t row_idx = bx;
- uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
- float pivot = -std::numeric_limits<float>::infinity(), normalizer = 1;
- vec_t<DType, VEC_SIZE> probs_vec;
- if (k < d) {
- extern __shared__ __align__(
- alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
- uint8_t smem_renorm[];
- auto& temp_storage =
- reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
- smem_renorm);
- temp_storage.data.max_val = DType(0);
- DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
- DType threadlocal_max_val = DType(0);
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_greater_than_pivot[j] = probs_vec[j];
- }
- threadlocal_max_val =
- max(threadlocal_max_val,
- BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
- __syncthreads();
- }
- if (tx == 0) {
- temp_storage.data.max_val = threadlocal_max_val;
- }
- __syncthreads();
- threadlocal_max_val = temp_storage.data.max_val;
- float low = 0, high = threadlocal_max_val;
- DType min_gt_low, max_le_high;
- DType sum_low(1);
- // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
- // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
- // p <= high} loop invariant:
- // - f(low) >= k, f(high) < k
- // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
- // stopping condition: min_gt_low == max_le_high
- // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
- do {
- Pair<DType> threadlocal_sum{DType(0), 0};
- Pair<DType>
- probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
- float mid = (low + high) / 2;
- min_gt_low = high;
- max_le_high = low;
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_greater_than_pivot_pair[j] = {
- (probs_vec[j] > mid) ? probs_vec[j] : DType(0),
- (probs_vec[j] > mid &&
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
- if (probs_vec[j] > low &&
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
- min_gt_low = min(min_gt_low, probs_vec[j]);
- }
- if (probs_vec[j] <= high &&
- (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
- max_le_high = max(max_le_high, probs_vec[j]);
- }
- }
- threadlocal_sum +=
- BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce_pair)
- .Sum<VEC_SIZE>(probs_greater_than_pivot_pair);
- __syncthreads();
- }
- min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce(min_gt_low, cub::Min());
- __syncthreads();
- max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
- temp_storage.block_prim.reduce)
- .Reduce(max_le_high, cub::Max());
- if (tx == 0) {
- temp_storage.data.block_aggregate.pair = threadlocal_sum;
- temp_storage.data.min_val = min_gt_low;
- temp_storage.data.max_val = max_le_high;
- }
- __syncthreads();
- threadlocal_sum = temp_storage.data.block_aggregate.pair;
- min_gt_low = temp_storage.data.min_val;
- max_le_high = temp_storage.data.max_val;
- if (threadlocal_sum.count >= k) {
- low = mid;
- sum_low = float(threadlocal_sum.value);
- } else {
- high = min(mid, max_le_high);
- }
- } while (min_gt_low != max_le_high);
- normalizer = math::ptx_rcp(max(sum_low, 1e-8));
- pivot = low;
- }
- // normalize
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- probs_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
- tx * VEC_SIZE);
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- probs_vec[j] =
- (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : DType(0);
- }
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- probs_vec.store(renormed_prob + row_idx * d +
- i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
- }
- }
- }
- template <typename DType>
- cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
- uint32_t batch_size, float top_p_val, uint32_t d,
- cudaStream_t stream = 0) {
- const uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
- const uint32_t smem_size =
- sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d};
- DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
- auto kernel =
- TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>;
- APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- APHRODITE_CUDA_CALL(
- cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
- });
- return cudaSuccess;
- }
- template <typename DType, typename IdType>
- cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob,
- IdType* top_k_arr, uint32_t batch_size,
- uint32_t top_k_val, uint32_t d,
- cudaStream_t stream = 0) {
- const uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
- const uint32_t smem_size =
- sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d};
- DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
- auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE,
- DType, IdType>;
- APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- APHRODITE_CUDA_CALL(
- cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
- });
- return cudaSuccess;
- }
- template <typename DType, typename IdType>
- cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits,
- IdType* top_k_arr, uint32_t batch_size,
- uint32_t top_k_val, uint32_t d,
- cudaStream_t stream = 0) {
- const uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
- const uint32_t smem_size =
- sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d};
- DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
- auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE,
- DType, IdType>;
- APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- APHRODITE_CUDA_CALL(
- cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
- });
- return cudaSuccess;
- }
- template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
- BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
- typename DType, typename IdType>
- __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
- DType* uniform_samples, DType* target_probs,
- IdType* output_token_ids,
- IdType* output_accepted_token_num,
- IdType* output_emitted_token_num,
- uint32_t num_speculative_tokens, uint32_t d) {
- const uint32_t bx = blockIdx.x, tx = threadIdx.x;
- const uint32_t row_idx = bx;
- extern __shared__ __align__(
- alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
- uint8_t smem_sampling[];
- auto& temp_storage = reinterpret_cast<
- SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
- uint32_t pos = num_speculative_tokens;
- for (uint32_t i = 0; i < num_speculative_tokens; ++i) {
- IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i];
- float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id],
- p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id];
- DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i];
- if (u * p < q) {
- // accept the draft models output
- output_token_ids[row_idx * (num_speculative_tokens + 1) + i] = draft_id;
- } else {
- pos = i;
- break;
- }
- }
- uint32_t emitted_token_num = pos;
- uint32_t accepted_token_num = pos;
- for (uint32_t i = pos; i < num_speculative_tokens; ++i) {
- IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i];
- float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id],
- p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id];
- DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i];
- if (u * p < q) {
- ++accepted_token_num;
- }
- }
- if (tx == 0) {
- output_accepted_token_num[row_idx] += accepted_token_num;
- output_emitted_token_num[row_idx] += emitted_token_num;
- }
- // sample from relu(target_probs - draft_probs)
- DType sum_relu_q_minus_p(0);
- vec_t<DType, VEC_SIZE> q_vec, p_vec;
- DType relu_q_minus_p[VEC_SIZE];
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- q_vec.fill(DType(0));
- p_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- q_vec.load(target_probs + (row_idx * (num_speculative_tokens + 1) + pos) * d +
- i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
- if (pos != num_speculative_tokens) {
- // there is no draft_probs for the bonus token
- p_vec.load(draft_probs + (row_idx * num_speculative_tokens + pos) * d +
- i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
- }
- }
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0));
- }
- sum_relu_q_minus_p +=
- BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
- .Sum<VEC_SIZE>(relu_q_minus_p);
- __syncthreads();
- }
- if (tx == 0) {
- temp_storage.data.block_aggregate.value = sum_relu_q_minus_p;
- }
- // init the first rejected token to (d - 1)
- temp_storage.data.sampled_id = d - 1;
- __syncthreads();
- sum_relu_q_minus_p = temp_storage.data.block_aggregate.value;
- DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) +
- min(pos + 1, num_speculative_tokens)] *
- sum_relu_q_minus_p;
- DType aggregate_relu_q_minus_p(0);
- for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
- q_vec.fill(DType(0));
- p_vec.fill(DType(0));
- if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
- q_vec.load(target_probs + (row_idx * (num_speculative_tokens + 1) + pos) * d +
- i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
- if (pos != num_speculative_tokens) {
- // there is no draft_probs for the bonus token
- p_vec.load(draft_probs + (row_idx * num_speculative_tokens + pos) * d +
- i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
- }
- }
- vec_t<DType, VEC_SIZE> relu_q_minus_p_vec;
- #pragma unroll
- for (uint32_t j = 0; j < VEC_SIZE; ++j) {
- relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
- }
- DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
- DType>(i, d, DType(0), u, relu_q_minus_p_vec, aggregate_relu_q_minus_p,
- &temp_storage);
- if (aggregate_relu_q_minus_p > u) {
- break;
- }
- }
- __syncthreads();
- // set the first rejected token
- output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = temp_storage.data.sampled_id;
- // move to the next token
- pos++;
- // pad remaining tokens with -1
- for (; pos < num_speculative_tokens + 1; ++pos) {
- output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = -1;
- }
- }
- template <typename T, typename IdType>
- cudaError_t ParallelTopPSamplingFromProb(
- T* probs, T* uniform_samples, IdType* output, bool* success,
- IdType* row_indices, T* top_p_arr, uint32_t batch_size, uint32_t d,
- uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) {
- constexpr uint32_t BLOCK_THREADS = 1024;
- const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
- const uint32_t smem_size =
- sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
- dim3 nblks(batch_size);
- dim3 nthrs(BLOCK_THREADS);
- T top_p_placeholder = 0;
- void* args[] = {
- &probs, &uniform_samples, &output, &success, &row_indices,
- &top_p_arr, &top_p_placeholder, &d, &max_top_p_rounds};
- DISPATCH_ALIGNED_VEC_SIZE(
- vec_size, VEC_SIZE,
- {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
- auto kernel =
- TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
- VEC_SIZE, DETERMINISTIC, T, IdType>;
- APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
- kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
- APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
- smem_size, stream));
- })});
- return cudaSuccess;
- }
- } // namespace sampling
- } // namespace aphrodite
- #endif // APHRODITE_SAMPLING_CUH_
|