1
0

sampling.cu 16 KB


  1. /*
  2. * Copyright (c) 2024 by PygmalionAI team.
  3. * Copyright (c) 2024 by FlashInfer team.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <c10/cuda/CUDAStream.h>
  18. #include "sampling.cuh"
  19. #include "../ops.h"
  20. #include "utils.cuh"
  21. // Check utils
  22. #define CUDA_CHECK(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
  23. #define CHECK_CONTIGUOUS(x) \
  24. TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
  25. #define CHECK_INPUT(x) \
  26. CUDA_CHECK(x); \
  27. CHECK_CONTIGUOUS(x)
  28. #define CHECK_EQ(a, b) \
  29. TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
  30. #define CHECK_GE(a, b) \
  31. TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
  32. #define CHECK_DIM(d, x) \
  33. TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
  34. using namespace aphrodite;
  35. torch::Tensor sampling_from_probs(torch::Tensor probs,
  36. torch::Tensor uniform_samples,
  37. bool deterministic) {
  38. CHECK_INPUT(probs);
  39. CHECK_INPUT(uniform_samples);
  40. auto device = probs.device();
  41. CHECK_EQ(uniform_samples.device(), device);
  42. CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
  43. CHECK_DIM(1, uniform_samples); // uniform_samples: (batch_size)
  44. CHECK_EQ(probs.size(0), uniform_samples.size(0));
  45. unsigned int batch_size = probs.size(0);
  46. unsigned int vocab_size = probs.size(1);
  47. probs = probs.to(torch::kFloat32);
  48. uniform_samples = uniform_samples.to(torch::kFloat32);
  49. cudaStream_t torch_current_stream =
  50. c10::cuda::getCurrentCUDAStream(device.index());
  51. auto samples =
  52. torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
  53. cudaError_t status = sampling::SamplingFromProb(
  54. static_cast<float*>(probs.data_ptr()),
  55. static_cast<float*>(uniform_samples.data_ptr()),
  56. static_cast<int*>(samples.data_ptr()), batch_size, vocab_size,
  57. deterministic, torch_current_stream);
  58. TORCH_CHECK(status == cudaSuccess,
  59. "SamplingFromProbs failed with error code " +
  60. std::string(cudaGetErrorString(status)));
  61. return samples;
  62. }
  63. std::vector<torch::Tensor> top_p_sampling_from_probs(
  64. torch::Tensor probs, torch::Tensor uniform_samples,
  65. std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
  66. bool deterministic) {
  67. CHECK_INPUT(probs);
  68. CHECK_INPUT(uniform_samples);
  69. auto device = probs.device();
  70. CHECK_EQ(uniform_samples.device(), device);
  71. CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
  72. CHECK_DIM(
  73. 2, uniform_samples); // uniform_samples: (max_top_p_rounds, batch_size)
  74. CHECK_EQ(probs.size(0), uniform_samples.size(1));
  75. unsigned int batch_size = probs.size(0);
  76. unsigned int vocab_size = probs.size(1);
  77. unsigned int max_top_p_rounds = uniform_samples.size(0);
  78. bool has_top_p_arr = maybe_top_p_arr.has_value();
  79. auto top_p_arr = maybe_top_p_arr.value_or(
  80. torch::empty({0}, torch::dtype(torch::kFloat32)));
  81. if (has_top_p_arr) {
  82. CHECK_INPUT(top_p_arr);
  83. CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,)
  84. CHECK_EQ(top_p_arr.size(0), batch_size);
  85. CHECK_EQ(top_p_arr.device(), device);
  86. }
  87. probs = probs.to(torch::kFloat32);
  88. uniform_samples = uniform_samples.to(torch::kFloat32);
  89. top_p_arr = top_p_arr.to(torch::kFloat32);
  90. cudaStream_t torch_current_stream =
  91. c10::cuda::getCurrentCUDAStream(device.index());
  92. auto samples =
  93. torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
  94. auto success =
  95. torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
  96. cudaError_t status = sampling::TopPSamplingFromProb<float, int>(
  97. static_cast<float*>(probs.data_ptr()),
  98. static_cast<float*>(uniform_samples.data_ptr()),
  99. static_cast<int*>(samples.data_ptr()),
  100. static_cast<bool*>(success.data_ptr()),
  101. has_top_p_arr ? static_cast<float*>(top_p_arr.data_ptr()) : nullptr,
  102. batch_size, top_p_val, vocab_size, max_top_p_rounds, deterministic,
  103. torch_current_stream);
  104. TORCH_CHECK(status == cudaSuccess,
  105. "TopPSamplingFromProbs failed with error code " +
  106. std::string(cudaGetErrorString(status)));
  107. return {samples, success};
  108. }
  109. std::vector<torch::Tensor> top_k_sampling_from_probs(
  110. torch::Tensor probs, torch::Tensor uniform_samples,
  111. std::optional<torch::Tensor> maybe_top_k_arr, int64_t top_k_val,
  112. bool deterministic) {
  113. CHECK_INPUT(probs);
  114. CHECK_INPUT(uniform_samples);
  115. auto device = probs.device();
  116. CHECK_EQ(uniform_samples.device(), device);
  117. CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
  118. CHECK_DIM(
  119. 2, uniform_samples); // uniform_samples: (max_top_k_rounds, batch_size)
  120. CHECK_EQ(probs.size(0), uniform_samples.size(1));
  121. unsigned int batch_size = probs.size(0);
  122. unsigned int vocab_size = probs.size(1);
  123. unsigned int max_top_k_rounds = uniform_samples.size(0);
  124. bool has_top_k_arr = maybe_top_k_arr.has_value();
  125. auto top_k_arr =
  126. maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
  127. if (has_top_k_arr) {
  128. CHECK_INPUT(top_k_arr);
  129. CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,)
  130. CHECK_EQ(top_k_arr.size(0), batch_size);
  131. CHECK_EQ(top_k_arr.device(), device);
  132. }
  133. probs = probs.to(torch::kFloat32);
  134. uniform_samples = uniform_samples.to(torch::kFloat32);
  135. top_k_arr = top_k_arr.to(torch::kInt32);
  136. cudaStream_t torch_current_stream =
  137. c10::cuda::getCurrentCUDAStream(device.index());
  138. auto samples =
  139. torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
  140. auto success =
  141. torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
  142. cudaError_t status = sampling::TopKSamplingFromProb<float, int>(
  143. static_cast<float*>(probs.data_ptr()),
  144. static_cast<float*>(uniform_samples.data_ptr()),
  145. static_cast<int*>(samples.data_ptr()),
  146. static_cast<bool*>(success.data_ptr()),
  147. has_top_k_arr ? static_cast<float*>(top_k_arr.data_ptr()) : nullptr,
  148. batch_size, top_k_val, vocab_size, max_top_k_rounds, deterministic,
  149. torch_current_stream);
  150. TORCH_CHECK(status == cudaSuccess,
  151. "TopKSamplingFromProbs failed with error code " +
  152. std::string(cudaGetErrorString(status)));
  153. return {samples, success};
  154. }
  155. std::vector<torch::Tensor> min_p_sampling_from_probs(
  156. torch::Tensor probs, torch::Tensor uniform_samples,
  157. std::optional<torch::Tensor> maybe_min_p_arr, double min_p_val,
  158. bool deterministic) {
  159. CHECK_INPUT(probs);
  160. CHECK_INPUT(uniform_samples);
  161. auto device = probs.device();
  162. CHECK_EQ(uniform_samples.device(), device);
  163. CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
  164. CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size)
  165. unsigned int batch_size = probs.size(0);
  166. unsigned int vocab_size = probs.size(1);
  167. unsigned int max_rounds = uniform_samples.size(0);
  168. CHECK_EQ(uniform_samples.size(1), batch_size);
  169. bool has_min_p_arr = maybe_min_p_arr.has_value();
  170. auto min_p_arr = maybe_min_p_arr.value_or(
  171. torch::empty({0}, torch::dtype(torch::kFloat32)));
  172. if (has_min_p_arr) {
  173. CHECK_INPUT(min_p_arr);
  174. CHECK_DIM(1, min_p_arr); // min_p_arr: (batch_size,)
  175. CHECK_EQ(min_p_arr.size(0), batch_size);
  176. CHECK_EQ(min_p_arr.device(), device);
  177. }
  178. min_p_arr = min_p_arr.to(torch::kFloat32);
  179. probs = probs.to(torch::kFloat32);
  180. uniform_samples = uniform_samples.to(torch::kFloat32);
  181. cudaStream_t torch_current_stream =
  182. c10::cuda::getCurrentCUDAStream(device.index());
  183. auto samples =
  184. torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
  185. auto success =
  186. torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
  187. cudaError_t status = sampling::MinPSamplingFromProb<float, int>(
  188. static_cast<float*>(probs.data_ptr()),
  189. static_cast<float*>(uniform_samples.data_ptr()),
  190. has_min_p_arr ? static_cast<float*>(min_p_arr.data_ptr()) : nullptr,
  191. static_cast<int*>(samples.data_ptr()),
  192. static_cast<bool*>(success.data_ptr()), batch_size, min_p_val, vocab_size,
  193. max_rounds, deterministic, torch_current_stream);
  194. TORCH_CHECK(status == cudaSuccess,
  195. "MinPSamplingFromProb failed with error code " +
  196. std::string(cudaGetErrorString(status)));
  197. return {samples, success};
  198. }
  199. std::vector<torch::Tensor> top_k_top_p_sampling_from_probs(
  200. torch::Tensor probs, torch::Tensor uniform_samples,
  201. std::optional<torch::Tensor> maybe_top_k_arr, double top_k_val,
  202. std::optional<torch::Tensor> maybe_top_p_arr, double top_p_val,
  203. bool deterministic) {
  204. CHECK_INPUT(probs);
  205. CHECK_INPUT(uniform_samples);
  206. auto device = probs.device();
  207. CHECK_EQ(uniform_samples.device(), device);
  208. CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
  209. CHECK_DIM(2, uniform_samples); // uniform_samples: (max_rounds, batch_size)
  210. unsigned int batch_size = probs.size(0);
  211. unsigned int vocab_size = probs.size(1);
  212. unsigned int max_rounds = uniform_samples.size(0);
  213. CHECK_EQ(uniform_samples.size(1), batch_size);
  214. bool has_top_k_arr = maybe_top_k_arr.has_value();
  215. auto top_k_arr =
  216. maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
  217. if (has_top_k_arr) {
  218. CHECK_INPUT(top_k_arr);
  219. CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,)
  220. CHECK_EQ(top_k_arr.size(0), batch_size);
  221. CHECK_EQ(top_k_arr.device(), device);
  222. }
  223. top_k_arr = top_k_arr.to(torch::kInt32);
  224. bool has_top_p_arr = maybe_top_p_arr.has_value();
  225. auto top_p_arr = maybe_top_p_arr.value_or(
  226. torch::empty({0}, torch::dtype(torch::kFloat32)));
  227. if (has_top_p_arr) {
  228. CHECK_INPUT(top_p_arr);
  229. CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,)
  230. CHECK_EQ(top_p_arr.size(0), batch_size);
  231. CHECK_EQ(top_p_arr.device(), device);
  232. }
  233. top_p_arr = top_p_arr.to(torch::kFloat32);
  234. probs = probs.to(torch::kFloat32);
  235. uniform_samples = uniform_samples.to(torch::kFloat32);
  236. cudaStream_t torch_current_stream =
  237. c10::cuda::getCurrentCUDAStream(device.index());
  238. auto samples =
  239. torch::empty({batch_size}, torch::dtype(torch::kInt32).device(device));
  240. auto success =
  241. torch::empty({batch_size}, torch::dtype(torch::kBool).device(device));
  242. cudaError_t status = sampling::TopKTopPSamplingFromProb<float, int>(
  243. static_cast<float*>(probs.data_ptr()),
  244. static_cast<float*>(uniform_samples.data_ptr()),
  245. has_top_k_arr ? static_cast<int*>(top_k_arr.data_ptr()) : nullptr,
  246. has_top_p_arr ? static_cast<float*>(top_p_arr.data_ptr()) : nullptr,
  247. static_cast<int*>(samples.data_ptr()),
  248. static_cast<bool*>(success.data_ptr()), batch_size, top_k_val, top_p_val,
  249. vocab_size, max_rounds, deterministic, torch_current_stream);
  250. TORCH_CHECK(status == cudaSuccess,
  251. "TopKTopPSamplingFromProbs failed with error code " +
  252. std::string(cudaGetErrorString(status)));
  253. return {samples, success};
  254. }
  255. torch::Tensor top_p_renorm_prob(torch::Tensor probs,
  256. std::optional<torch::Tensor> maybe_top_p_arr,
  257. double top_p_val) {
  258. CHECK_INPUT(probs);
  259. auto device = probs.device();
  260. CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
  261. unsigned int batch_size = probs.size(0);
  262. unsigned int vocab_size = probs.size(1);
  263. bool has_top_p_arr = maybe_top_p_arr.has_value();
  264. auto top_p_arr = maybe_top_p_arr.value_or(
  265. torch::empty({0}, torch::dtype(torch::kFloat32)));
  266. if (has_top_p_arr) {
  267. CHECK_INPUT(top_p_arr);
  268. CHECK_DIM(1, top_p_arr); // top_p_arr: (batch_size,)
  269. CHECK_EQ(top_p_arr.size(0), batch_size);
  270. CHECK_EQ(top_p_arr.device(), device);
  271. }
  272. top_p_arr = top_p_arr.to(torch::kFloat32);
  273. probs = probs.to(torch::kFloat32);
  274. cudaStream_t torch_current_stream =
  275. c10::cuda::getCurrentCUDAStream(device.index());
  276. auto renorm_probs = torch::empty(
  277. {batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
  278. cudaError_t status = sampling::TopPRenormProb<float>(
  279. static_cast<float*>(probs.data_ptr()),
  280. static_cast<float*>(renorm_probs.data_ptr()),
  281. has_top_p_arr ? static_cast<float*>(top_p_arr.data_ptr()) : nullptr,
  282. batch_size, top_p_val, vocab_size, torch_current_stream);
  283. TORCH_CHECK(status == cudaSuccess,
  284. "TopPRenormProb failed with error code " +
  285. std::string(cudaGetErrorString(status)));
  286. return renorm_probs;
  287. }
  288. torch::Tensor top_k_renorm_prob(torch::Tensor probs,
  289. std::optional<torch::Tensor> maybe_top_k_arr,
  290. int64_t top_k_val) {
  291. CHECK_INPUT(probs);
  292. auto device = probs.device();
  293. CHECK_DIM(2, probs); // probs: (batch_size, vocab_size)
  294. unsigned int batch_size = probs.size(0);
  295. unsigned int vocab_size = probs.size(1);
  296. bool has_top_k_arr = maybe_top_k_arr.has_value();
  297. auto top_k_arr =
  298. maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
  299. if (has_top_k_arr) {
  300. CHECK_INPUT(top_k_arr);
  301. CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,)
  302. CHECK_EQ(top_k_arr.size(0), batch_size);
  303. CHECK_EQ(top_k_arr.device(), device);
  304. }
  305. top_k_arr = top_k_arr.to(torch::kInt32);
  306. probs = probs.to(torch::kFloat32);
  307. cudaStream_t torch_current_stream =
  308. c10::cuda::getCurrentCUDAStream(device.index());
  309. auto renorm_probs = torch::empty(
  310. {batch_size, vocab_size}, torch::dtype(torch::kFloat32).device(device));
  311. cudaError_t status = sampling::TopKRenormProb<float>(
  312. static_cast<float*>(probs.data_ptr()),
  313. static_cast<float*>(renorm_probs.data_ptr()),
  314. has_top_k_arr ? static_cast<int*>(top_k_arr.data_ptr()) : nullptr,
  315. batch_size, top_k_val, vocab_size, torch_current_stream);
  316. TORCH_CHECK(status == cudaSuccess,
  317. "TopKRenormProb failed with error code " +
  318. std::string(cudaGetErrorString(status)));
  319. return renorm_probs;
  320. }
  321. torch::Tensor top_k_mask_logits(torch::Tensor logits,
  322. std::optional<torch::Tensor> maybe_top_k_arr,
  323. int64_t top_k_val) {
  324. CHECK_INPUT(logits);
  325. auto device = logits.device();
  326. CHECK_DIM(2, logits); // logits: (batch_size, vocab_size)
  327. unsigned int batch_size = logits.size(0);
  328. unsigned int vocab_size = logits.size(1);
  329. bool has_top_k_arr = maybe_top_k_arr.has_value();
  330. auto top_k_arr =
  331. maybe_top_k_arr.value_or(torch::empty({0}, torch::dtype(torch::kInt32)));
  332. if (has_top_k_arr) {
  333. CHECK_INPUT(top_k_arr);
  334. CHECK_DIM(1, top_k_arr); // top_k_arr: (batch_size,)
  335. CHECK_EQ(top_k_arr.size(0), batch_size);
  336. CHECK_EQ(top_k_arr.device(), device);
  337. }
  338. top_k_arr = top_k_arr.to(torch::kInt32);
  339. logits = logits.to(torch::kFloat32);
  340. cudaStream_t torch_current_stream =
  341. c10::cuda::getCurrentCUDAStream(device.index());
  342. auto mask_logits = torch::empty({batch_size, vocab_size},
  343. torch::dtype(torch::kFloat32).device(device));
  344. cudaError_t status = sampling::TopKMaskLogits<float>(
  345. static_cast<float*>(logits.data_ptr()),
  346. static_cast<float*>(mask_logits.data_ptr()),
  347. has_top_k_arr ? static_cast<int*>(top_k_arr.data_ptr()) : nullptr,
  348. batch_size, top_k_val, vocab_size, torch_current_stream);
  349. TORCH_CHECK(status == cudaSuccess,
  350. "TopKMaskLogits failed with error code " +
  351. std::string(cudaGetErrorString(status)));
  352. return mask_logits;
  353. }