sampling.cuh 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523
  1. /*
  2. * Copyright (c) 2024 by PygmalionAI team.
  3. * Copyright (c) 2024 by FlashInfer team.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #ifndef APHRODITE_SAMPLING_CUH_
  18. #define APHRODITE_SAMPLING_CUH_
  19. #include <cub/block/block_adjacent_difference.cuh>
  20. #include <cub/block/block_reduce.cuh>
  21. #include <cub/block/block_scan.cuh>
  22. #include <numeric>
  23. #include "math.cuh"
  24. #include "utils.cuh"
  25. #include "vec_dtypes.cuh"
  26. namespace aphrodite {
  27. namespace sampling {
  28. using namespace cub;
  29. #define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \
  30. if (deterministic) { \
  31. constexpr bool DETERMINISTIC = true; \
  32. __VA_ARGS__ \
  33. } else { \
  34. constexpr bool DETERMINISTIC = false; \
  35. __VA_ARGS__ \
  36. }
  37. constexpr BlockScanAlgorithm SCAN_ALGO = BLOCK_SCAN_WARP_SCANS;
  38. constexpr BlockReduceAlgorithm REDUCE_ALGO = BLOCK_REDUCE_WARP_REDUCTIONS;
  39. #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120100)
  40. #define APHRODITE_CUB_SUBTRACTLEFT_DEFINED
  41. #endif
  42. template <typename T>
  43. struct Pair {
  44. T value;
  45. int count;
  46. __device__ Pair operator+(const Pair& other) const {
  47. return {value + other.value, count + other.count};
  48. }
  49. __device__ Pair& operator+=(const Pair& other) {
  50. value += other.value;
  51. count += other.count;
  52. return *this;
  53. }
  54. };
  55. struct BoolDiffOp {
  56. __device__ __forceinline__ bool operator()(const bool& lhs,
  57. const bool& rhs) const {
  58. return lhs != rhs;
  59. }
  60. };
  61. template <typename T, uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
  62. BlockReduceAlgorithm REDUCE_ALGORITHM>
  63. struct SamplingTempStorage {
  64. union {
  65. T deterministic_scan[BLOCK_THREADS / 32];
  66. typename BlockScan<T, BLOCK_THREADS, SCAN_ALGORITHM>::TempStorage scan;
  67. typename BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
  68. reduce;
  69. typename BlockReduce<Pair<T>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
  70. reduce_pair;
  71. typename BlockAdjacentDifference<bool, BLOCK_THREADS>::TempStorage adj_diff;
  72. } block_prim;
  73. struct {
  74. int32_t sampled_id;
  75. union {
  76. T value;
  77. Pair<T> pair;
  78. T max_p;
  79. } block_aggregate;
  80. } data;
  81. };
  82. /*!
  83. * \brief Deterministic inclusive scan implementation, use Belloch scan
  84. * algorithm. \note This implementation is slower than the cub::BlockScan, but
  85. * it is deterministic.
  86. */
  87. template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
  88. BlockScanAlgorithm SCAN_ALGORITHM,
  89. BlockReduceAlgorithm REDUCE_ALGORITHM, typename T>
  90. __device__ __forceinline__ void DeterministicInclusiveSum(
  91. const T* in_data, T* out_data,
  92. SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
  93. temp_storage) {
  94. T* smem_prefix_sum = temp_storage->block_prim.deterministic_scan;
  95. T thread_data[VEC_SIZE];
  96. T thread_sum = 0;
  97. #pragma unroll
  98. for (uint32_t i = 0; i < VEC_SIZE; ++i) {
  99. thread_sum += in_data[i];
  100. thread_data[i] = thread_sum;
  101. }
  102. T thread_exclusive_prefix_sum = thread_sum;
  103. #pragma unroll
  104. for (uint32_t offset = 1; offset < 32; offset *= 2) {
  105. T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
  106. if ((threadIdx.x + 1) % (offset * 2) == 0) {
  107. thread_exclusive_prefix_sum += tmp;
  108. }
  109. }
  110. T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum,
  111. threadIdx.x | 0xffffffff);
  112. if (threadIdx.x % 32 == 31) {
  113. thread_exclusive_prefix_sum = 0;
  114. }
  115. #pragma unroll
  116. for (uint32_t offset = 16; offset >= 1; offset /= 2) {
  117. T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset);
  118. if ((threadIdx.x + 1) % (offset * 2) == 0) {
  119. thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum;
  120. }
  121. if ((threadIdx.x + 1) % (offset * 2) == offset) {
  122. thread_exclusive_prefix_sum = tmp;
  123. }
  124. }
  125. smem_prefix_sum[threadIdx.x / 32] = warp_sum;
  126. __syncthreads();
  127. if (threadIdx.x < 32) {
  128. T warp_exclusive_prefix_sum =
  129. (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0;
  130. #pragma unroll
  131. for (uint32_t offset = 1; offset < 32; offset *= 2) {
  132. T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
  133. if ((threadIdx.x + 1) % (offset * 2) == 0) {
  134. warp_exclusive_prefix_sum += tmp;
  135. }
  136. }
  137. if (threadIdx.x % 32 == 31) {
  138. warp_exclusive_prefix_sum = 0;
  139. }
  140. #pragma unroll
  141. for (uint32_t offset = 16; offset >= 1; offset /= 2) {
  142. T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset);
  143. if ((threadIdx.x + 1) % (offset * 2) == 0) {
  144. warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum;
  145. }
  146. if ((threadIdx.x + 1) % (offset * 2) == offset) {
  147. warp_exclusive_prefix_sum = tmp;
  148. }
  149. }
  150. if (threadIdx.x < BLOCK_THREADS / 32) {
  151. smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum;
  152. }
  153. }
  154. __syncthreads();
  155. #pragma unroll
  156. for (uint32_t i = 0; i < VEC_SIZE; ++i) {
  157. out_data[i] = smem_prefix_sum[threadIdx.x / 32] +
  158. thread_exclusive_prefix_sum + thread_data[i];
  159. }
  160. }
  161. template <uint32_t VEC_SIZE, uint32_t BLOCK_THREADS,
  162. BlockScanAlgorithm SCAN_ALGORITHM,
  163. BlockReduceAlgorithm REDUCE_ALGORITHM, bool DETERMINISTIC, typename T>
  164. __device__ __forceinline__ void DeviceSamplingFromProb(
  165. uint32_t i, uint32_t d, T threshold, T u, vec_t<T, VEC_SIZE> prob_vec,
  166. T& aggregate,
  167. SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>*
  168. temp_storage) {
  169. const uint32_t tx = threadIdx.x;
  170. T prob_greater_than_threshold[VEC_SIZE];
  171. T inclusive_cdf[VEC_SIZE];
  172. bool greater_than_u[VEC_SIZE], valid[VEC_SIZE];
  173. #pragma unroll
  174. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  175. prob_greater_than_threshold[j] =
  176. (prob_vec[j] > threshold) ? prob_vec[j] : T(0);
  177. valid[j] =
  178. prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d;
  179. }
  180. T aggregate_local = BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>(
  181. temp_storage->block_prim.reduce)
  182. .Sum<VEC_SIZE>(prob_greater_than_threshold);
  183. if (tx == 0) {
  184. temp_storage->data.block_aggregate.value = aggregate_local;
  185. }
  186. __syncthreads();
  187. aggregate_local = temp_storage->data.block_aggregate.value;
  188. if (aggregate + aggregate_local > u) {
  189. if constexpr (DETERMINISTIC) {
  190. DeterministicInclusiveSum<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
  191. REDUCE_ALGORITHM, T>(
  192. prob_greater_than_threshold, inclusive_cdf, temp_storage);
  193. } else {
  194. BlockScan<T, BLOCK_THREADS, SCAN_ALGORITHM>(temp_storage->block_prim.scan)
  195. .InclusiveSum<VEC_SIZE>(prob_greater_than_threshold, inclusive_cdf);
  196. __syncthreads();
  197. }
  198. #pragma unroll
  199. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  200. greater_than_u[j] = inclusive_cdf[j] + aggregate > u;
  201. }
  202. bool greater_than_u_diff[VEC_SIZE];
  203. #ifdef APHRODITE_CUB_SUBTRACTLEFT_DEFINED
  204. BlockAdjacentDifference<bool, BLOCK_THREADS>(
  205. temp_storage->block_prim.adj_diff)
  206. .SubtractLeft<VEC_SIZE>(greater_than_u, greater_than_u_diff,
  207. BoolDiffOp());
  208. #else
  209. BlockAdjacentDifference<bool, BLOCK_THREADS>(
  210. temp_storage->block_prim.adj_diff)
  211. .FlagHeads<VEC_SIZE>(greater_than_u_diff, greater_than_u, BoolDiffOp(),
  212. 0);
  213. #endif
  214. __syncthreads();
  215. #pragma unroll
  216. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  217. if (greater_than_u_diff[j] && valid[j]) {
  218. if constexpr (DETERMINISTIC) {
  219. temp_storage->data.sampled_id =
  220. (i * BLOCK_THREADS + tx) * VEC_SIZE + j;
  221. } else {
  222. // cub's block scan result might not be monotonic, so we need to find
  223. // the first element
  224. atomicMin(&(temp_storage->data.sampled_id),
  225. (i * BLOCK_THREADS + tx) * VEC_SIZE + j);
  226. }
  227. }
  228. }
  229. __syncthreads();
  230. }
  231. aggregate += aggregate_local;
  232. }
  233. template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
  234. BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
  235. bool DETERMINISTIC, typename DType, typename IdType>
  236. __global__ void SamplingFromProbKernel(DType* probs, DType* uniform_samples,
  237. IdType* output, IdType* row_indices,
  238. uint32_t d) {
  239. const uint32_t bx = blockIdx.x, tx = threadIdx.x;
  240. const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
  241. extern __shared__ __align__(
  242. alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  243. REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
  244. auto& temp_storage =
  245. reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  246. REDUCE_ALGORITHM>&>(smem_sampling);
  247. temp_storage.data.sampled_id = d - 1;
  248. __syncthreads();
  249. vec_t<DType, VEC_SIZE> probs_vec;
  250. DType aggregate(0);
  251. float u = uniform_samples[bx];
  252. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  253. probs_vec.fill(DType(0));
  254. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  255. probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  256. tx * VEC_SIZE);
  257. }
  258. DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
  259. REDUCE_ALGORITHM, DETERMINISTIC, DType>(
  260. i, d, DType(0), u, probs_vec, aggregate, &temp_storage);
  261. if (float(aggregate) > u) {
  262. break;
  263. }
  264. }
  265. output[bx] = temp_storage.data.sampled_id;
  266. }
  267. template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
  268. BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
  269. bool DETERMINISTIC, typename DType, typename IdType>
  270. __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples,
  271. IdType* output, bool* success,
  272. IdType* top_k_arr,
  273. uint32_t top_k_val, uint32_t d,
  274. uint32_t max_top_k_rounds) {
  275. const uint32_t batch_size = gridDim.x;
  276. const uint32_t bx = blockIdx.x, tx = threadIdx.x;
  277. uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
  278. extern __shared__ __align__(
  279. alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  280. REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
  281. auto& temp_storage =
  282. reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  283. REDUCE_ALGORITHM>&>(smem_sampling);
  284. vec_t<DType, VEC_SIZE> probs_vec;
  285. DType aggregate;
  286. DType q = DType(1);
  287. DType pivot = DType(0);
  288. IdType sampled_id;
  289. for (uint32_t round = 0; round < max_top_k_rounds; ++round) {
  290. temp_storage.data.sampled_id = d - 1;
  291. __syncthreads();
  292. DType u = uniform_samples[round * batch_size + bx] * q;
  293. aggregate = DType(0);
  294. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  295. probs_vec.fill(DType(0));
  296. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  297. probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
  298. }
  299. DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
  300. REDUCE_ALGORITHM, DETERMINISTIC, DType>(
  301. i, d, pivot, u, probs_vec, aggregate, &temp_storage);
  302. if (aggregate > u) {
  303. break;
  304. }
  305. }
  306. __syncthreads();
  307. sampled_id = temp_storage.data.sampled_id;
  308. pivot = max(pivot, probs[bx * d + sampled_id]);
  309. Pair<DType> aggregate_gt_pivot{DType(0), 0};
  310. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  311. probs_vec.fill(DType(0));
  312. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  313. probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
  314. }
  315. Pair<DType> probs_gt_pivot[VEC_SIZE];
  316. #pragma unroll
  317. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  318. probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
  319. (probs_vec[j] > pivot &&
  320. (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
  321. }
  322. aggregate_gt_pivot +=
  323. BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
  324. temp_storage.block_prim.reduce_pair)
  325. .Sum<VEC_SIZE>(probs_gt_pivot);
  326. if (tx == 0) {
  327. temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
  328. }
  329. __syncthreads();
  330. }
  331. q = temp_storage.data.block_aggregate.pair.value;
  332. if (temp_storage.data.block_aggregate.pair.count < k) {
  333. break;
  334. }
  335. }
  336. __syncthreads();
  337. if (tx == 0) {
  338. output[bx] = sampled_id;
  339. if (temp_storage.data.block_aggregate.pair.count >= k) {
  340. // failed to sample within MAX_TOP_P_ROUNDS
  341. if (success != nullptr) {
  342. success[bx] = false;
  343. }
  344. } else {
  345. if (success != nullptr) {
  346. success[bx] = true;
  347. }
  348. }
  349. }
  350. }
  351. template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
  352. BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
  353. bool DETERMINISTIC, typename DType, typename IdType>
  354. __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
  355. IdType* output, bool* success,
  356. IdType* row_indices,
  357. float* top_p_arr, float top_p_val,
  358. uint32_t d,
  359. uint32_t max_top_p_rounds) {
  360. const uint32_t batch_size = gridDim.x;
  361. const uint32_t bx = blockIdx.x, tx = threadIdx.x;
  362. float top_p = (top_p_arr == nullptr) ? top_p_val : top_p_arr[bx];
  363. const uint32_t row_idx = row_indices == nullptr ? bx : row_indices[bx];
  364. extern __shared__ __align__(
  365. alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  366. REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
  367. auto& temp_storage =
  368. reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  369. REDUCE_ALGORITHM>&>(smem_sampling);
  370. vec_t<DType, VEC_SIZE> probs_vec;
  371. DType aggregate;
  372. DType q = DType(1);
  373. DType pivot = DType(0);
  374. IdType sampled_id;
  375. for (uint32_t round = 0; round < max_top_p_rounds; ++round) {
  376. temp_storage.data.sampled_id = d - 1;
  377. __syncthreads();
  378. DType u = uniform_samples[round * batch_size + bx] * q;
  379. aggregate = DType(0);
  380. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  381. probs_vec.fill(DType(0));
  382. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  383. probs_vec.load(probs + row_idx * d +
  384. (i * BLOCK_THREADS + tx) * VEC_SIZE);
  385. }
  386. DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
  387. REDUCE_ALGORITHM, DETERMINISTIC, DType>(
  388. i, d, pivot, u, probs_vec, aggregate, &temp_storage);
  389. if (aggregate > u) {
  390. break;
  391. }
  392. }
  393. __syncthreads();
  394. sampled_id = temp_storage.data.sampled_id;
  395. pivot = max(pivot, probs[row_idx * d + sampled_id]);
  396. DType aggregate_gt_pivot = DType(0);
  397. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  398. probs_vec.fill(DType(0));
  399. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  400. probs_vec.load(probs + row_idx * d +
  401. (i * BLOCK_THREADS + tx) * VEC_SIZE);
  402. }
  403. DType probs_gt_pivot[VEC_SIZE];
  404. #pragma unroll
  405. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  406. probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
  407. }
  408. aggregate_gt_pivot +=
  409. BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
  410. .Sum<VEC_SIZE>(probs_gt_pivot);
  411. if (tx == 0) {
  412. temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
  413. }
  414. __syncthreads();
  415. }
  416. q = temp_storage.data.block_aggregate.value;
  417. if (float(q) < top_p) {
  418. break;
  419. }
  420. }
  421. __syncthreads();
  422. if (tx == 0) {
  423. output[bx] = sampled_id;
  424. if (float(q) >= top_p) {
  425. // failed to sample within MAX_TOP_P_ROUNDS
  426. if (success != nullptr) {
  427. success[bx] = false;
  428. }
  429. } else {
  430. if (success != nullptr) {
  431. success[bx] = true;
  432. }
  433. }
  434. }
  435. }
  436. template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
  437. BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
  438. bool DETERMINISTIC, typename DType, typename IdType>
  439. __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples,
  440. DType* min_p_arr, IdType* output,
  441. bool* success, float min_p_val,
  442. uint32_t d,
  443. uint32_t max_min_p_rounds) {
  444. const uint32_t batch_size = gridDim.x;
  445. const uint32_t bx = blockIdx.x, tx = threadIdx.x;
  446. DType p = (min_p_arr == nullptr) ? min_p_val : min_p_arr[bx];
  447. extern __shared__ __align__(
  448. alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  449. REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
  450. auto& temp_storage =
  451. reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  452. REDUCE_ALGORITHM>&>(smem_sampling);
  453. vec_t<DType, VEC_SIZE> probs_vec;
  454. DType aggregate;
  455. DType q = DType(1);
  456. DType pivot = DType(0);
  457. DType max_p = 0;
  458. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  459. probs_vec.fill(DType(0));
  460. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  461. probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
  462. }
  463. DType probs_[VEC_SIZE];
  464. #pragma unroll
  465. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  466. probs_[j] = probs_vec[j];
  467. }
  468. max_p = max(
  469. max_p, BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
  470. .Reduce<VEC_SIZE>(probs_, cub::Max()));
  471. __syncthreads();
  472. }
  473. if (tx == 0) {
  474. temp_storage.data.block_aggregate.max_p = max_p;
  475. }
  476. __syncthreads();
  477. DType scaled_p = temp_storage.data.block_aggregate.max_p * p;
  478. IdType sampled_id;
  479. for (uint32_t round = 0; round < max_min_p_rounds; ++round) {
  480. temp_storage.data.sampled_id = d - 1;
  481. __syncthreads();
  482. DType u = uniform_samples[round * batch_size + bx] * q;
  483. aggregate = DType(0);
  484. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  485. probs_vec.fill(DType(0));
  486. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  487. probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
  488. }
  489. DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
  490. REDUCE_ALGORITHM, DETERMINISTIC, DType>(
  491. i, d, pivot, u, probs_vec, aggregate, &temp_storage);
  492. if (aggregate > u) {
  493. break;
  494. }
  495. }
  496. __syncthreads();
  497. sampled_id = temp_storage.data.sampled_id;
  498. pivot = max(pivot, probs[bx * d + sampled_id]);
  499. if (pivot >= scaled_p) {
  500. break;
  501. }
  502. DType aggregate_gt_pivot = DType(0);
  503. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  504. probs_vec.fill(DType(0));
  505. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  506. probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
  507. }
  508. DType probs_gt_pivot[VEC_SIZE];
  509. #pragma unroll
  510. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  511. probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0);
  512. }
  513. aggregate_gt_pivot +=
  514. BlockReduce<DType, BLOCK_THREADS>(temp_storage.block_prim.reduce)
  515. .Sum<VEC_SIZE>(probs_gt_pivot);
  516. if (tx == 0) {
  517. temp_storage.data.block_aggregate.value = aggregate_gt_pivot;
  518. }
  519. __syncthreads();
  520. }
  521. q = temp_storage.data.block_aggregate.value;
  522. }
  523. __syncthreads();
  524. if (tx == 0) {
  525. output[bx] = sampled_id;
  526. if (pivot < scaled_p) {
  527. // failed to sample within MAX_ROUNDS
  528. if (success != nullptr) {
  529. success[bx] = false;
  530. }
  531. } else {
  532. if (success != nullptr) {
  533. success[bx] = true;
  534. }
  535. }
  536. }
  537. }
  538. template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
  539. BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE,
  540. bool DETERMINISTIC, typename DType, typename IdType>
  541. __global__ void TopKTopPSamplingFromProbKernel(
  542. DType* probs, DType* uniform_samples, IdType* top_k_arr, DType* top_p_arr,
  543. IdType* output, bool* success, IdType top_k_val, DType top_p_val,
  544. uint32_t d, uint32_t max_rounds) {
  545. const uint32_t batch_size = gridDim.x;
  546. const uint32_t bx = blockIdx.x, tx = threadIdx.x;
  547. IdType k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
  548. DType p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
  549. extern __shared__ __align__(
  550. alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  551. REDUCE_ALGORITHM>)) uint8_t smem_sampling[];
  552. auto& temp_storage =
  553. reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM,
  554. REDUCE_ALGORITHM>&>(smem_sampling);
  555. vec_t<DType, VEC_SIZE> probs_vec;
  556. DType aggregate;
  557. DType q = DType(1);
  558. DType pivot = DType(0);
  559. IdType sampled_id;
  560. for (uint32_t round = 0; round < max_rounds; ++round) {
  561. temp_storage.data.sampled_id = d - 1;
  562. __syncthreads();
  563. DType u = uniform_samples[round * batch_size + bx] * q;
  564. aggregate = DType(0);
  565. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  566. probs_vec.fill(DType(0));
  567. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  568. probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
  569. }
  570. DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM,
  571. REDUCE_ALGORITHM, DETERMINISTIC, DType>(
  572. i, d, pivot, u, probs_vec, aggregate, &temp_storage);
  573. if (aggregate > u) {
  574. break;
  575. }
  576. }
  577. __syncthreads();
  578. sampled_id = temp_storage.data.sampled_id;
  579. pivot = max(pivot, probs[bx * d + sampled_id]);
  580. Pair<DType> aggregate_gt_pivot{DType(0), 0};
  581. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  582. probs_vec.fill(DType(0));
  583. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  584. probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE);
  585. }
  586. Pair<DType> probs_gt_pivot[VEC_SIZE];
  587. #pragma unroll
  588. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  589. probs_gt_pivot[j] = {(probs_vec[j] > pivot) ? probs_vec[j] : DType(0),
  590. (probs_vec[j] > pivot &&
  591. (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
  592. }
  593. aggregate_gt_pivot +=
  594. BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
  595. temp_storage.block_prim.reduce_pair)
  596. .Sum<VEC_SIZE>(probs_gt_pivot);
  597. if (tx == 0) {
  598. temp_storage.data.block_aggregate.pair = aggregate_gt_pivot;
  599. }
  600. __syncthreads();
  601. }
  602. q = temp_storage.data.block_aggregate.pair.value;
  603. if (temp_storage.data.block_aggregate.pair.count < k && float(q) < p) {
  604. break;
  605. }
  606. }
  607. __syncthreads();
  608. if (tx == 0) {
  609. output[bx] = sampled_id;
  610. if (temp_storage.data.block_aggregate.pair.count >= k || float(q) >= p) {
  611. // failed to sample within MAX_TOP_P_ROUNDS
  612. if (success != nullptr) {
  613. success[bx] = false;
  614. }
  615. } else {
  616. if (success != nullptr) {
  617. success[bx] = true;
  618. }
  619. }
  620. }
  621. }
  622. template <typename T, typename IdType>
  623. cudaError_t SamplingFromProb(T* probs, T* uniform_samples, IdType* output,
  624. uint32_t batch_size, uint32_t d,
  625. bool deterministic, cudaStream_t stream = 0) {
  626. constexpr uint32_t BLOCK_THREADS = 1024;
  627. const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
  628. dim3 nblks(batch_size);
  629. dim3 nthrs(BLOCK_THREADS);
  630. IdType* row_indices_placeholder = nullptr;
  631. void* args[] = {&probs, &uniform_samples, &output, &row_indices_placeholder,
  632. &d};
  633. const uint32_t smem_size =
  634. sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
  635. DISPATCH_ALIGNED_VEC_SIZE(
  636. vec_size, VEC_SIZE,
  637. {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
  638. auto kernel =
  639. SamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
  640. VEC_SIZE, DETERMINISTIC, T, IdType>;
  641. APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
  642. smem_size, stream));
  643. })});
  644. return cudaSuccess;
  645. }
  646. template <typename T, typename IdType>
  647. cudaError_t ParallelSamplingFromProb(T* probs, T* uniform_samples,
  648. IdType* output, IdType* row_indices,
  649. uint32_t batch_size, uint32_t d,
  650. bool deterministic,
  651. cudaStream_t stream = 0) {
  652. constexpr uint32_t BLOCK_THREADS = 1024;
  653. const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
  654. dim3 nblks(batch_size);
  655. dim3 nthrs(BLOCK_THREADS);
  656. void* args[] = {&probs, &uniform_samples, &output, &row_indices, &d};
  657. const uint32_t smem_size =
  658. sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
  659. DISPATCH_ALIGNED_VEC_SIZE(
  660. vec_size, VEC_SIZE,
  661. {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
  662. auto kernel =
  663. SamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
  664. VEC_SIZE, DETERMINISTIC, T, IdType>;
  665. APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
  666. smem_size, stream));
  667. })});
  668. return cudaSuccess;
  669. }
  670. template <typename T, typename IdType>
  671. cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
  672. bool* success, T* top_k_arr,
  673. uint32_t batch_size, uint32_t top_k_val,
  674. uint32_t d, uint32_t max_top_k_rounds,
  675. bool deterministic, cudaStream_t stream = 0) {
  676. constexpr uint32_t BLOCK_THREADS = 1024;
  677. const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
  678. const uint32_t smem_size =
  679. sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
  680. dim3 nblks(batch_size);
  681. dim3 nthrs(BLOCK_THREADS);
  682. void* args[] = {&probs, &uniform_samples, &output, &success,
  683. &top_k_arr, &top_k_val, &d, &max_top_k_rounds};
  684. DISPATCH_ALIGNED_VEC_SIZE(
  685. vec_size, VEC_SIZE,
  686. {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
  687. auto kernel =
  688. TopKSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
  689. VEC_SIZE, DETERMINISTIC, T, IdType>;
  690. APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
  691. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  692. APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
  693. smem_size, stream));
  694. })});
  695. return cudaSuccess;
  696. }
  697. template <typename T, typename IdType>
  698. cudaError_t TopPSamplingFromProb(T* probs, T* uniform_samples, IdType* output,
  699. bool* success, T* top_p_arr,
  700. uint32_t batch_size, T top_p_val, uint32_t d,
  701. uint32_t max_top_p_rounds, bool deterministic,
  702. cudaStream_t stream = 0) {
  703. constexpr uint32_t BLOCK_THREADS = 1024;
  704. const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
  705. const uint32_t smem_size =
  706. sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
  707. dim3 nblks(batch_size);
  708. dim3 nthrs(BLOCK_THREADS);
  709. IdType* row_indices_placeholder = nullptr;
  710. void* args[] = {&probs,
  711. &uniform_samples,
  712. &output,
  713. &success,
  714. &row_indices_placeholder,
  715. &top_p_arr,
  716. &top_p_val,
  717. &d,
  718. &max_top_p_rounds};
  719. DISPATCH_ALIGNED_VEC_SIZE(
  720. vec_size, VEC_SIZE,
  721. {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
  722. auto kernel =
  723. TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
  724. VEC_SIZE, DETERMINISTIC, T, IdType>;
  725. APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
  726. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  727. APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
  728. smem_size, stream));
  729. })});
  730. return cudaSuccess;
  731. }
  732. template <typename T, typename IdType>
  733. cudaError_t MinPSamplingFromProb(T* probs, T* uniform_samples, T* min_p_arr,
  734. IdType* output, bool* success,
  735. uint32_t batch_size, float min_p_val,
  736. uint32_t d, uint32_t max_rounds,
  737. bool deterministic, cudaStream_t stream = 0) {
  738. constexpr uint32_t BLOCK_THREADS = 1024;
  739. const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
  740. const uint32_t smem_size =
  741. sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
  742. dim3 nblks(batch_size);
  743. dim3 nthrs(BLOCK_THREADS);
  744. void* args[] = {&probs, &uniform_samples, &min_p_arr, &output,
  745. &success, &min_p_val, &d, &max_rounds};
  746. DISPATCH_ALIGNED_VEC_SIZE(
  747. vec_size, VEC_SIZE,
  748. {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
  749. auto kernel =
  750. MinPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
  751. VEC_SIZE, DETERMINISTIC, T, IdType>;
  752. APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
  753. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  754. APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
  755. smem_size, stream));
  756. })});
  757. return cudaSuccess;
  758. }
  759. template <typename T, typename IdType>
  760. cudaError_t TopKTopPSamplingFromProb(T* probs, T* uniform_samples,
  761. IdType* top_k_arr, T* top_p_arr,
  762. IdType* output, bool* success,
  763. uint32_t batch_size, IdType top_k_val,
  764. T top_p_val, uint32_t d,
  765. uint32_t max_rounds, bool deterministic,
  766. cudaStream_t stream = 0) {
  767. constexpr uint32_t BLOCK_THREADS = 1024;
  768. const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
  769. const uint32_t smem_size =
  770. sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
  771. dim3 nblks(batch_size);
  772. dim3 nthrs(BLOCK_THREADS);
  773. void* args[] = {&probs, &uniform_samples, &top_k_arr, &top_p_arr,
  774. &output, &success, &top_k_val, &top_p_val,
  775. &d, &max_rounds};
  776. DISPATCH_ALIGNED_VEC_SIZE(
  777. vec_size, VEC_SIZE,
  778. {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
  779. auto kernel = TopKTopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO,
  780. REDUCE_ALGO, VEC_SIZE,
  781. DETERMINISTIC, T, IdType>;
  782. APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
  783. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  784. APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
  785. smem_size, stream));
  786. })});
  787. return cudaSuccess;
  788. }
  789. template <typename T, uint32_t BLOCK_THREADS,
  790. BlockReduceAlgorithm REDUCE_ALGORITHM>
  791. struct RenormTempStorage {
  792. union {
  793. typename BlockReduce<T, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
  794. reduce;
  795. typename BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
  796. reduce_int;
  797. typename BlockReduce<Pair<T>, BLOCK_THREADS, REDUCE_ALGORITHM>::TempStorage
  798. reduce_pair;
  799. } block_prim;
  800. struct {
  801. T max_val;
  802. T min_val;
  803. union {
  804. T value;
  805. int count;
  806. Pair<T> pair;
  807. } block_aggregate;
  808. } data;
  809. };
  810. template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
  811. uint32_t VEC_SIZE, typename DType>
  812. __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob,
  813. DType* top_p_arr, float top_p_val,
  814. uint32_t d) {
  815. const uint32_t bx = blockIdx.x, tx = threadIdx.x;
  816. const uint32_t row_idx = bx;
  817. float p = top_p_arr == nullptr ? top_p_val : top_p_arr[bx];
  818. extern __shared__ __align__(
  819. alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
  820. uint8_t smem_renorm[];
  821. auto& temp_storage =
  822. reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
  823. smem_renorm);
  824. temp_storage.data.max_val = DType(0);
  825. vec_t<DType, VEC_SIZE> probs_vec;
  826. DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
  827. DType threadlocal_max_val = DType(0);
  828. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  829. probs_vec.fill(DType(0));
  830. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  831. probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  832. tx * VEC_SIZE);
  833. }
  834. #pragma unroll
  835. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  836. probs_greater_than_pivot[j] = probs_vec[j];
  837. }
  838. threadlocal_max_val =
  839. max(threadlocal_max_val,
  840. BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  841. temp_storage.block_prim.reduce)
  842. .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
  843. __syncthreads();
  844. }
  845. if (tx == 0) {
  846. temp_storage.data.max_val = threadlocal_max_val;
  847. }
  848. __syncthreads();
  849. threadlocal_max_val = temp_storage.data.max_val;
  850. float low = 0, high = threadlocal_max_val;
  851. DType min_gt_low, max_le_high;
  852. DType sum_low(1);
  853. // f(x) = sum(probs[probs > x]), f(x) is non-increasing
  854. // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p
  855. // <= high} loop invariant:
  856. // - f(low) >= p, f(high) < p
  857. // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
  858. // stopping condition
  859. // - f(low) >= p, f(min_gt_low) == f(max_le_high) == f(high) < p
  860. do {
  861. DType threadlocal_sum(0);
  862. float mid = (low + high) / 2;
  863. min_gt_low = high;
  864. max_le_high = low;
  865. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  866. probs_vec.fill(DType(0));
  867. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  868. probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  869. tx * VEC_SIZE);
  870. }
  871. #pragma unroll
  872. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  873. probs_greater_than_pivot[j] =
  874. (probs_vec[j] > mid) ? probs_vec[j] : DType(0);
  875. if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
  876. min_gt_low = min(min_gt_low, probs_vec[j]);
  877. }
  878. if (probs_vec[j] <= high &&
  879. (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
  880. max_le_high = max(max_le_high, probs_vec[j]);
  881. }
  882. }
  883. threadlocal_sum += BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  884. temp_storage.block_prim.reduce)
  885. .Sum<VEC_SIZE>(probs_greater_than_pivot);
  886. __syncthreads();
  887. }
  888. min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  889. temp_storage.block_prim.reduce)
  890. .Reduce(min_gt_low, cub::Min());
  891. __syncthreads();
  892. max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  893. temp_storage.block_prim.reduce)
  894. .Reduce(max_le_high, cub::Max());
  895. if (tx == 0) {
  896. temp_storage.data.block_aggregate.value = threadlocal_sum;
  897. temp_storage.data.min_val = min_gt_low;
  898. temp_storage.data.max_val = max_le_high;
  899. }
  900. __syncthreads();
  901. threadlocal_sum = temp_storage.data.block_aggregate.value;
  902. min_gt_low = temp_storage.data.min_val;
  903. max_le_high = temp_storage.data.max_val;
  904. if (threadlocal_sum >= p) {
  905. low = mid;
  906. sum_low = float(threadlocal_sum);
  907. } else {
  908. high = min(mid, max_le_high);
  909. }
  910. } while (min_gt_low != max_le_high);
  911. DType normalizer = math::ptx_rcp(max(sum_low, 1e-8));
  912. // normalize
  913. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  914. probs_vec.fill(DType(0));
  915. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  916. probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  917. tx * VEC_SIZE);
  918. }
  919. #pragma unroll
  920. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  921. probs_vec[j] =
  922. (probs_vec[j] > low) ? probs_vec[j] * normalizer : DType(0);
  923. }
  924. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  925. probs_vec.store(renormed_prob + row_idx * d +
  926. i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
  927. }
  928. }
  929. }
  930. template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
  931. uint32_t VEC_SIZE, typename DType, typename IdType>
  932. __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits,
  933. IdType* top_k_arr, uint32_t top_k_val,
  934. uint32_t d) {
  935. const uint32_t bx = blockIdx.x, tx = threadIdx.x;
  936. const uint32_t row_idx = bx;
  937. uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
  938. float pivot = -std::numeric_limits<float>::infinity();
  939. vec_t<DType, VEC_SIZE> logits_vec;
  940. if (k < d) {
  941. extern __shared__ __align__(
  942. alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
  943. uint8_t smem_renorm[];
  944. auto& temp_storage =
  945. reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
  946. smem_renorm);
  947. DType logits_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
  948. DType threadlocal_max_val = DType(-std::numeric_limits<float>::infinity()),
  949. threadlocal_min_val = DType(std::numeric_limits<float>::infinity());
  950. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  951. logits_vec.fill(DType(0));
  952. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  953. logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  954. tx * VEC_SIZE);
  955. }
  956. #pragma unroll
  957. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  958. logits_greater_than_pivot[j] = logits_vec[j];
  959. }
  960. threadlocal_max_val =
  961. max(threadlocal_max_val,
  962. BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  963. temp_storage.block_prim.reduce)
  964. .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Max()));
  965. __syncthreads();
  966. threadlocal_min_val =
  967. min(threadlocal_min_val,
  968. BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  969. temp_storage.block_prim.reduce)
  970. .Reduce<VEC_SIZE>(logits_greater_than_pivot, cub::Min()));
  971. __syncthreads();
  972. }
  973. if (tx == 0) {
  974. temp_storage.data.max_val = threadlocal_max_val;
  975. temp_storage.data.min_val = threadlocal_min_val;
  976. }
  977. __syncthreads();
  978. threadlocal_max_val = temp_storage.data.max_val;
  979. threadlocal_min_val = temp_storage.data.min_val;
  980. float low = threadlocal_min_val - 1, high = threadlocal_max_val;
  981. DType min_gt_low, max_le_high;
  982. // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
  983. // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
  984. // p <= high} loop invariant:
  985. // - f(low) >= k, f(high) < k
  986. // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
  987. // stopping condition: min_gt_low == max_le_high
  988. // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
  989. do {
  990. int threadlocal_count_sum = 0;
  991. int probs_greater_than_pivot_count[VEC_SIZE]; // pivot initialized to 0
  992. float mid = (low + high) / 2;
  993. min_gt_low = high;
  994. max_le_high = low;
  995. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  996. logits_vec.fill(DType(0));
  997. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  998. logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  999. tx * VEC_SIZE);
  1000. }
  1001. #pragma unroll
  1002. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  1003. probs_greater_than_pivot_count[j] =
  1004. logits_vec[j] > mid &&
  1005. (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d;
  1006. if (logits_vec[j] > low &&
  1007. (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
  1008. min_gt_low = min(min_gt_low, logits_vec[j]);
  1009. }
  1010. if (logits_vec[j] <= high &&
  1011. (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
  1012. max_le_high = max(max_le_high, logits_vec[j]);
  1013. }
  1014. }
  1015. threadlocal_count_sum +=
  1016. BlockReduce<int, BLOCK_THREADS, REDUCE_ALGORITHM>(
  1017. temp_storage.block_prim.reduce_int)
  1018. .Sum<VEC_SIZE>(probs_greater_than_pivot_count);
  1019. __syncthreads();
  1020. }
  1021. min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  1022. temp_storage.block_prim.reduce)
  1023. .Reduce(min_gt_low, cub::Min());
  1024. __syncthreads();
  1025. max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  1026. temp_storage.block_prim.reduce)
  1027. .Reduce(max_le_high, cub::Max());
  1028. if (tx == 0) {
  1029. temp_storage.data.block_aggregate.count = threadlocal_count_sum;
  1030. temp_storage.data.min_val = min_gt_low;
  1031. temp_storage.data.max_val = max_le_high;
  1032. }
  1033. __syncthreads();
  1034. threadlocal_count_sum = temp_storage.data.block_aggregate.count;
  1035. min_gt_low = temp_storage.data.min_val;
  1036. max_le_high = temp_storage.data.max_val;
  1037. if (threadlocal_count_sum >= k) {
  1038. low = mid;
  1039. } else {
  1040. high = min(mid, max_le_high);
  1041. }
  1042. } while (min_gt_low != max_le_high);
  1043. pivot = low;
  1044. }
  1045. // masking
  1046. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  1047. logits_vec.fill(DType(0));
  1048. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  1049. logits_vec.load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  1050. tx * VEC_SIZE);
  1051. }
  1052. #pragma unroll
  1053. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  1054. logits_vec[j] = (logits_vec[j] > pivot)
  1055. ? logits_vec[j]
  1056. : DType(-std::numeric_limits<float>::infinity());
  1057. }
  1058. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  1059. logits_vec.store(masked_logits + row_idx * d +
  1060. i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
  1061. }
  1062. }
  1063. }
  1064. template <uint32_t BLOCK_THREADS, BlockReduceAlgorithm REDUCE_ALGORITHM,
  1065. uint32_t VEC_SIZE, typename DType, typename IdType>
  1066. __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob,
  1067. IdType* top_k_arr, uint32_t top_k_val,
  1068. uint32_t d) {
  1069. const uint32_t bx = blockIdx.x, tx = threadIdx.x;
  1070. const uint32_t row_idx = bx;
  1071. uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx];
  1072. float pivot = -std::numeric_limits<float>::infinity(), normalizer = 1;
  1073. vec_t<DType, VEC_SIZE> probs_vec;
  1074. if (k < d) {
  1075. extern __shared__ __align__(
  1076. alignof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>))
  1077. uint8_t smem_renorm[];
  1078. auto& temp_storage =
  1079. reinterpret_cast<RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>&>(
  1080. smem_renorm);
  1081. temp_storage.data.max_val = DType(0);
  1082. DType probs_greater_than_pivot[VEC_SIZE]; // pivot initialized to 0
  1083. DType threadlocal_max_val = DType(0);
  1084. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  1085. probs_vec.fill(DType(0));
  1086. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  1087. probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  1088. tx * VEC_SIZE);
  1089. }
  1090. #pragma unroll
  1091. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  1092. probs_greater_than_pivot[j] = probs_vec[j];
  1093. }
  1094. threadlocal_max_val =
  1095. max(threadlocal_max_val,
  1096. BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  1097. temp_storage.block_prim.reduce)
  1098. .Reduce<VEC_SIZE>(probs_greater_than_pivot, cub::Max()));
  1099. __syncthreads();
  1100. }
  1101. if (tx == 0) {
  1102. temp_storage.data.max_val = threadlocal_max_val;
  1103. }
  1104. __syncthreads();
  1105. threadlocal_max_val = temp_storage.data.max_val;
  1106. float low = 0, high = threadlocal_max_val;
  1107. DType min_gt_low, max_le_high;
  1108. DType sum_low(1);
  1109. // f(x) = len(nonzero(probs > x)), f(x) is non-increasing
  1110. // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs |
  1111. // p <= high} loop invariant:
  1112. // - f(low) >= k, f(high) < k
  1113. // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high)
  1114. // stopping condition: min_gt_low == max_le_high
  1115. // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k
  1116. do {
  1117. Pair<DType> threadlocal_sum{DType(0), 0};
  1118. Pair<DType>
  1119. probs_greater_than_pivot_pair[VEC_SIZE]; // pivot initialized to 0
  1120. float mid = (low + high) / 2;
  1121. min_gt_low = high;
  1122. max_le_high = low;
  1123. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  1124. probs_vec.fill(DType(0));
  1125. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  1126. probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  1127. tx * VEC_SIZE);
  1128. }
  1129. #pragma unroll
  1130. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  1131. probs_greater_than_pivot_pair[j] = {
  1132. (probs_vec[j] > mid) ? probs_vec[j] : DType(0),
  1133. (probs_vec[j] > mid &&
  1134. (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)};
  1135. if (probs_vec[j] > low &&
  1136. (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
  1137. min_gt_low = min(min_gt_low, probs_vec[j]);
  1138. }
  1139. if (probs_vec[j] <= high &&
  1140. (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) {
  1141. max_le_high = max(max_le_high, probs_vec[j]);
  1142. }
  1143. }
  1144. threadlocal_sum +=
  1145. BlockReduce<Pair<DType>, BLOCK_THREADS, REDUCE_ALGORITHM>(
  1146. temp_storage.block_prim.reduce_pair)
  1147. .Sum<VEC_SIZE>(probs_greater_than_pivot_pair);
  1148. __syncthreads();
  1149. }
  1150. min_gt_low = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  1151. temp_storage.block_prim.reduce)
  1152. .Reduce(min_gt_low, cub::Min());
  1153. __syncthreads();
  1154. max_le_high = BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(
  1155. temp_storage.block_prim.reduce)
  1156. .Reduce(max_le_high, cub::Max());
  1157. if (tx == 0) {
  1158. temp_storage.data.block_aggregate.pair = threadlocal_sum;
  1159. temp_storage.data.min_val = min_gt_low;
  1160. temp_storage.data.max_val = max_le_high;
  1161. }
  1162. __syncthreads();
  1163. threadlocal_sum = temp_storage.data.block_aggregate.pair;
  1164. min_gt_low = temp_storage.data.min_val;
  1165. max_le_high = temp_storage.data.max_val;
  1166. if (threadlocal_sum.count >= k) {
  1167. low = mid;
  1168. sum_low = float(threadlocal_sum.value);
  1169. } else {
  1170. high = min(mid, max_le_high);
  1171. }
  1172. } while (min_gt_low != max_le_high);
  1173. normalizer = math::ptx_rcp(max(sum_low, 1e-8));
  1174. pivot = low;
  1175. }
  1176. // normalize
  1177. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  1178. probs_vec.fill(DType(0));
  1179. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  1180. probs_vec.load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE +
  1181. tx * VEC_SIZE);
  1182. }
  1183. #pragma unroll
  1184. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  1185. probs_vec[j] =
  1186. (probs_vec[j] > pivot) ? probs_vec[j] * normalizer : DType(0);
  1187. }
  1188. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  1189. probs_vec.store(renormed_prob + row_idx * d +
  1190. i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
  1191. }
  1192. }
  1193. }
  1194. template <typename DType>
  1195. cudaError_t TopPRenormProb(DType* probs, DType* renormed_prob, DType* top_p_arr,
  1196. uint32_t batch_size, float top_p_val, uint32_t d,
  1197. cudaStream_t stream = 0) {
  1198. const uint32_t BLOCK_THREADS = 1024;
  1199. const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
  1200. const uint32_t smem_size =
  1201. sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
  1202. dim3 nblks(batch_size);
  1203. dim3 nthrs(BLOCK_THREADS);
  1204. void* args[] = {&probs, &renormed_prob, &top_p_arr, &top_p_val, &d};
  1205. DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
  1206. auto kernel =
  1207. TopPRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE, DType>;
  1208. APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
  1209. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  1210. APHRODITE_CUDA_CALL(
  1211. cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
  1212. });
  1213. return cudaSuccess;
  1214. }
  1215. template <typename DType, typename IdType>
  1216. cudaError_t TopKRenormProb(DType* probs, DType* renormed_prob,
  1217. IdType* top_k_arr, uint32_t batch_size,
  1218. uint32_t top_k_val, uint32_t d,
  1219. cudaStream_t stream = 0) {
  1220. const uint32_t BLOCK_THREADS = 1024;
  1221. const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
  1222. const uint32_t smem_size =
  1223. sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
  1224. dim3 nblks(batch_size);
  1225. dim3 nthrs(BLOCK_THREADS);
  1226. void* args[] = {&probs, &renormed_prob, &top_k_arr, &top_k_val, &d};
  1227. DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
  1228. auto kernel = TopKRenormProbKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE,
  1229. DType, IdType>;
  1230. APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
  1231. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  1232. APHRODITE_CUDA_CALL(
  1233. cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
  1234. });
  1235. return cudaSuccess;
  1236. }
  1237. template <typename DType, typename IdType>
  1238. cudaError_t TopKMaskLogits(DType* logits, DType* masked_logits,
  1239. IdType* top_k_arr, uint32_t batch_size,
  1240. uint32_t top_k_val, uint32_t d,
  1241. cudaStream_t stream = 0) {
  1242. const uint32_t BLOCK_THREADS = 1024;
  1243. const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
  1244. const uint32_t smem_size =
  1245. sizeof(RenormTempStorage<DType, BLOCK_THREADS, REDUCE_ALGO>);
  1246. dim3 nblks(batch_size);
  1247. dim3 nthrs(BLOCK_THREADS);
  1248. void* args[] = {&logits, &masked_logits, &top_k_arr, &top_k_val, &d};
  1249. DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
  1250. auto kernel = TopKMaskLogitsKernel<BLOCK_THREADS, REDUCE_ALGO, VEC_SIZE,
  1251. DType, IdType>;
  1252. APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
  1253. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  1254. APHRODITE_CUDA_CALL(
  1255. cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
  1256. });
  1257. return cudaSuccess;
  1258. }
  1259. template <uint32_t BLOCK_THREADS, BlockScanAlgorithm SCAN_ALGORITHM,
  1260. BlockReduceAlgorithm REDUCE_ALGORITHM, uint32_t VEC_SIZE, bool DETERMINISTIC,
  1261. typename DType, typename IdType>
  1262. __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token_ids,
  1263. DType* uniform_samples, DType* target_probs,
  1264. IdType* output_token_ids,
  1265. IdType* output_accepted_token_num,
  1266. IdType* output_emitted_token_num,
  1267. uint32_t num_speculative_tokens, uint32_t d) {
  1268. const uint32_t bx = blockIdx.x, tx = threadIdx.x;
  1269. const uint32_t row_idx = bx;
  1270. extern __shared__ __align__(
  1271. alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
  1272. uint8_t smem_sampling[];
  1273. auto& temp_storage = reinterpret_cast<
  1274. SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
  1275. uint32_t pos = num_speculative_tokens;
  1276. for (uint32_t i = 0; i < num_speculative_tokens; ++i) {
  1277. IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i];
  1278. float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id],
  1279. p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id];
  1280. DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i];
  1281. if (u * p < q) {
  1282. // accept the draft models output
  1283. output_token_ids[row_idx * (num_speculative_tokens + 1) + i] = draft_id;
  1284. } else {
  1285. pos = i;
  1286. break;
  1287. }
  1288. }
  1289. uint32_t emitted_token_num = pos;
  1290. uint32_t accepted_token_num = pos;
  1291. for (uint32_t i = pos; i < num_speculative_tokens; ++i) {
  1292. IdType draft_id = draft_token_ids[row_idx * num_speculative_tokens + i];
  1293. float q = target_probs[(row_idx * (num_speculative_tokens + 1) + i) * d + draft_id],
  1294. p = draft_probs[(row_idx * num_speculative_tokens + i) * d + draft_id];
  1295. DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) + i];
  1296. if (u * p < q) {
  1297. ++accepted_token_num;
  1298. }
  1299. }
  1300. if (tx == 0) {
  1301. output_accepted_token_num[row_idx] += accepted_token_num;
  1302. output_emitted_token_num[row_idx] += emitted_token_num;
  1303. }
  1304. // sample from relu(target_probs - draft_probs)
  1305. DType sum_relu_q_minus_p(0);
  1306. vec_t<DType, VEC_SIZE> q_vec, p_vec;
  1307. DType relu_q_minus_p[VEC_SIZE];
  1308. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  1309. q_vec.fill(DType(0));
  1310. p_vec.fill(DType(0));
  1311. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  1312. q_vec.load(target_probs + (row_idx * (num_speculative_tokens + 1) + pos) * d +
  1313. i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
  1314. if (pos != num_speculative_tokens) {
  1315. // there is no draft_probs for the bonus token
  1316. p_vec.load(draft_probs + (row_idx * num_speculative_tokens + pos) * d +
  1317. i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
  1318. }
  1319. }
  1320. #pragma unroll
  1321. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  1322. relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0));
  1323. }
  1324. sum_relu_q_minus_p +=
  1325. BlockReduce<DType, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
  1326. .Sum<VEC_SIZE>(relu_q_minus_p);
  1327. __syncthreads();
  1328. }
  1329. if (tx == 0) {
  1330. temp_storage.data.block_aggregate.value = sum_relu_q_minus_p;
  1331. }
  1332. // init the first rejected token to (d - 1)
  1333. temp_storage.data.sampled_id = d - 1;
  1334. __syncthreads();
  1335. sum_relu_q_minus_p = temp_storage.data.block_aggregate.value;
  1336. DType u = uniform_samples[row_idx * (num_speculative_tokens + 1) +
  1337. min(pos + 1, num_speculative_tokens)] *
  1338. sum_relu_q_minus_p;
  1339. DType aggregate_relu_q_minus_p(0);
  1340. for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
  1341. q_vec.fill(DType(0));
  1342. p_vec.fill(DType(0));
  1343. if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {
  1344. q_vec.load(target_probs + (row_idx * (num_speculative_tokens + 1) + pos) * d +
  1345. i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
  1346. if (pos != num_speculative_tokens) {
  1347. // there is no draft_probs for the bonus token
  1348. p_vec.load(draft_probs + (row_idx * num_speculative_tokens + pos) * d +
  1349. i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE);
  1350. }
  1351. }
  1352. vec_t<DType, VEC_SIZE> relu_q_minus_p_vec;
  1353. #pragma unroll
  1354. for (uint32_t j = 0; j < VEC_SIZE; ++j) {
  1355. relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
  1356. }
  1357. DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC,
  1358. DType>(i, d, DType(0), u, relu_q_minus_p_vec, aggregate_relu_q_minus_p,
  1359. &temp_storage);
  1360. if (aggregate_relu_q_minus_p > u) {
  1361. break;
  1362. }
  1363. }
  1364. __syncthreads();
  1365. // set the first rejected token
  1366. output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = temp_storage.data.sampled_id;
  1367. // move to the next token
  1368. pos++;
  1369. // pad remaining tokens with -1
  1370. for (; pos < num_speculative_tokens + 1; ++pos) {
  1371. output_token_ids[row_idx * (num_speculative_tokens + 1) + pos] = -1;
  1372. }
  1373. }
  1374. template <typename T, typename IdType>
  1375. cudaError_t ParallelTopPSamplingFromProb(
  1376. T* probs, T* uniform_samples, IdType* output, bool* success,
  1377. IdType* row_indices, T* top_p_arr, uint32_t batch_size, uint32_t d,
  1378. uint32_t max_top_p_rounds, bool deterministic, cudaStream_t stream = 0) {
  1379. constexpr uint32_t BLOCK_THREADS = 1024;
  1380. const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
  1381. const uint32_t smem_size =
  1382. sizeof(SamplingTempStorage<T, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
  1383. dim3 nblks(batch_size);
  1384. dim3 nthrs(BLOCK_THREADS);
  1385. T top_p_placeholder = 0;
  1386. void* args[] = {
  1387. &probs, &uniform_samples, &output, &success, &row_indices,
  1388. &top_p_arr, &top_p_placeholder, &d, &max_top_p_rounds};
  1389. DISPATCH_ALIGNED_VEC_SIZE(
  1390. vec_size, VEC_SIZE,
  1391. {DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, {
  1392. auto kernel =
  1393. TopPSamplingFromProbKernel<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO,
  1394. VEC_SIZE, DETERMINISTIC, T, IdType>;
  1395. APHRODITE_CUDA_CALL(cudaFuncSetAttribute(
  1396. kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
  1397. APHRODITE_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args,
  1398. smem_size, stream));
  1399. })});
  1400. return cudaSuccess;
  1401. }
  1402. } // namespace sampling
  1403. } // namespace aphrodite
  1404. #endif // APHRODITE_SAMPLING_CUH_