advance_step.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. /*
  2. * The goal of this GPU kernel is to advance input tensors on the GPU directly
  3. * Current restrictions:
  4. * 1. Specialized for DraftModelRunner
  5. * 2. Supports flash_attn only
  6. */
  7. #include "advance_step.cuh"
  8. namespace prepare_inputs {
  9. //
  10. template <int const num_threads>
  11. __global__ void advance_step_flashattn_kernel(
  12. int num_seqs, int num_queries, int block_size, long* input_tokens_ptr,
  13. long const* sampled_token_ids_ptr, long* input_positions_ptr,
  14. int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
  15. int64_t const block_tables_stride) {
  16. int num_query_blocks = div_ceil(num_queries, num_threads);
  17. if (blockIdx.x >= num_query_blocks) {
  18. return;
  19. }
  20. int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
  21. if (cur_query_id >= num_queries) {
  22. return;
  23. }
  24. // Update input_tokens
  25. input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
  26. int seq_len = seq_lens_ptr[cur_query_id];
  27. int next_seq_len = seq_len + 1;
  28. int next_input_pos = next_seq_len - 1;
  29. // Update seq_lens
  30. seq_lens_ptr[cur_query_id] = next_seq_len;
  31. // Update input_positions
  32. input_positions_ptr[cur_query_id] = next_input_pos;
  33. int const* seq_block_tables_ptr =
  34. block_tables_ptr + block_tables_stride * cur_query_id;
  35. int block_index = next_input_pos / block_size;
  36. int block_offset = next_input_pos % block_size;
  37. int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
  38. // Update slot_mapping
  39. slot_mapping_ptr[cur_query_id] = slot_num;
  40. }
  41. inline void verify_tensor(std::string const& name, torch::Tensor& t,
  42. int64_t const size_0, int64_t const size_1,
  43. c10::ScalarType const type) {
  44. bool size_0_cond = true;
  45. if (size_0 != -1) {
  46. size_0_cond = t.size(0) == size_0;
  47. }
  48. bool size_1_cond = true;
  49. if (size_1 != -1) {
  50. size_1_cond = t.size(1) == size_1;
  51. }
  52. bool is_contiguous = t.is_contiguous();
  53. bool same_type = t.dtype() == type;
  54. bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
  55. if (!pass) {
  56. TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
  57. " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
  58. " is not as expected: shape = [", size_0, ", ", size_1,
  59. "], type = ", type);
  60. }
  61. }
  62. __global__ void advance_step_flashinfer_kernel(
  63. int num_threads, int num_seqs, int num_queries, int block_size,
  64. long* input_tokens_ptr, long const* sampled_token_ids_ptr,
  65. long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr,
  66. int const* block_tables_ptr, int64_t const block_tables_stride,
  67. int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) {
  68. int num_query_blocks = div_ceil(num_queries, num_threads);
  69. if (blockIdx.x < num_query_blocks) {
  70. int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
  71. if (cur_query_id < num_queries) {
  72. // Update input_tokens
  73. input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
  74. int seq_len = seq_lens_ptr[cur_query_id];
  75. int next_seq_len = seq_len + 1;
  76. int next_input_pos = next_seq_len - 1;
  77. // Update seq_lens
  78. seq_lens_ptr[cur_query_id] = next_seq_len;
  79. // Update input_positions
  80. input_positions_ptr[cur_query_id] = next_input_pos;
  81. int const* seq_block_tables_ptr =
  82. block_tables_ptr + block_tables_stride * cur_query_id;
  83. int block_index = next_input_pos / block_size;
  84. int block_offset = next_input_pos % block_size;
  85. // Update paged_kv_last_page_len
  86. paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1;
  87. int slot_num =
  88. seq_block_tables_ptr[block_index] * block_size + block_offset;
  89. // Update slot_mapping
  90. slot_mapping_ptr[cur_query_id] = slot_num;
  91. block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size);
  92. }
  93. }
  94. }
  95. __global__ void advance_step_flashinfer_indptr_kernel(
  96. int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr,
  97. int* block_table_bound_ptr) {
  98. int idx = blockIdx.x * num_threads + threadIdx.x;
  99. // Update paged_kv_indptr
  100. if (idx < num_queries) {
  101. int sum = 0;
  102. for (int i = 0; i <= idx; ++i) {
  103. sum += block_table_bound_ptr[i];
  104. }
  105. paged_kv_indptr_ptr[idx + 1] = sum;
  106. }
  107. }
  108. __global__ void advance_step_flashinfer_indices_kernel(
  109. int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr,
  110. int64_t const block_tables_stride, int* paged_kv_indices_ptr,
  111. int* paged_kv_indptr_ptr, int* block_table_bound_ptr) {
  112. int idx = blockIdx.x * num_threads + threadIdx.x;
  113. int row = idx / block_tables_stride;
  114. int col = idx % block_tables_stride;
  115. if (row < num_queries && col < block_table_bound_ptr[row]) {
  116. paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] =
  117. block_tables_ptr[row * block_tables_stride + col];
  118. }
  119. // if cudagraph, fill padded seqs with the last valid seq's indptr
  120. if (num_queries < row && row <= num_seqs) {
  121. paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries];
  122. }
  123. }
  124. void advance_step_flashattn(int num_seqs, int num_queries, int block_size,
  125. torch::Tensor& input_tokens, // type: long
  126. torch::Tensor& sampled_token_ids, // type: long
  127. torch::Tensor& input_positions, // type: long
  128. torch::Tensor& seq_lens, // type: int
  129. torch::Tensor& slot_mapping, // type: long
  130. torch::Tensor& block_tables) { // type: int
  131. if (logging) {
  132. printf("advance_step_flashattn:\n");
  133. printf(" num_seqs = %d\n", num_seqs);
  134. printf(" num_queries = %d\n", num_queries);
  135. printf(" block_size = %d\n", block_size);
  136. }
  137. // Verify all tensors
  138. verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
  139. verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
  140. at::kLong);
  141. verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
  142. verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
  143. verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
  144. verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
  145. int dev = sampled_token_ids.get_device();
  146. cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
  147. int blocks;
  148. cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
  149. advance_step_flashattn_kernel<max_threads>
  150. <<<blocks, max_threads, 0, stream>>>(
  151. num_seqs, num_queries, block_size,
  152. reinterpret_cast<long*>(input_tokens.data_ptr()),
  153. reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
  154. reinterpret_cast<long*>(input_positions.data_ptr()),
  155. reinterpret_cast<int*>(seq_lens.data_ptr()),
  156. reinterpret_cast<long*>(slot_mapping.data_ptr()),
  157. reinterpret_cast<int const*>(block_tables.data_ptr()),
  158. block_tables.stride(0));
  159. }
  160. void advance_step_flashinfer(
  161. int num_seqs, int num_queries, int block_size,
  162. torch::Tensor& input_tokens, // type: long
  163. torch::Tensor& sampled_token_ids, // type: long
  164. torch::Tensor& input_positions, // type: long
  165. torch::Tensor& seq_lens, // type: int
  166. torch::Tensor& slot_mapping, // type: long
  167. torch::Tensor& block_tables, // type: int
  168. torch::Tensor& paged_kv_indices, // type: int
  169. torch::Tensor& paged_kv_indptr, // type: int
  170. torch::Tensor& paged_kv_last_page_len, // type: int
  171. torch::Tensor& block_table_bound) { // type: int
  172. if (logging) {
  173. printf("advance_step_flashinfer:\n");
  174. printf(" num_seqs = %d\n", num_seqs);
  175. printf(" num_queries = %d\n", num_queries);
  176. printf(" block_size = %d\n", block_size);
  177. printf(" block_tables.stride(0) = %d\n", block_tables.stride(0));
  178. }
  179. // Verify all tensors
  180. verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
  181. // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
  182. // at::kLong);
  183. verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
  184. verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
  185. verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
  186. verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
  187. verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt);
  188. verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt);
  189. verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1,
  190. at::kInt);
  191. verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt);
  192. int dev = sampled_token_ids.get_device();
  193. cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
  194. int blocks;
  195. int threads;
  196. cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
  197. cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev);
  198. if (logging) {
  199. printf("launching kernel with %d blocks\n", blocks);
  200. }
  201. // TODO(will): support arbitrary block_tables stride
  202. if ((blocks * threads) / block_tables.stride(0) < num_queries) {
  203. TORCH_CHECK(false,
  204. "multi-step: not enough threads to map block_table to"
  205. "FlashInfer's paged_kv_indices on GPU. Try reducing the number "
  206. "of seqs,",
  207. " increasing the block size or take smaller steps.",
  208. " num_queries = ", num_queries,
  209. " block_tables.stride(0) = ", block_tables.stride(0),
  210. " blocks = ", blocks, " max_threads = ", threads);
  211. }
  212. advance_step_flashinfer_kernel<<<blocks, threads, 0, stream>>>(
  213. threads, num_seqs, num_queries, block_size,
  214. reinterpret_cast<long*>(input_tokens.data_ptr()),
  215. reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
  216. reinterpret_cast<long*>(input_positions.data_ptr()),
  217. reinterpret_cast<int*>(seq_lens.data_ptr()),
  218. reinterpret_cast<long*>(slot_mapping.data_ptr()),
  219. reinterpret_cast<int const*>(block_tables.data_ptr()),
  220. block_tables.stride(0),
  221. reinterpret_cast<int*>(paged_kv_last_page_len.data_ptr()),
  222. reinterpret_cast<int*>(block_table_bound.data_ptr()));
  223. advance_step_flashinfer_indptr_kernel<<<blocks, threads, 0, stream>>>(
  224. threads, num_seqs, num_queries,
  225. reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
  226. reinterpret_cast<int*>(block_table_bound.data_ptr()));
  227. advance_step_flashinfer_indices_kernel<<<blocks, threads, 0, stream>>>(
  228. threads, num_seqs, num_queries,
  229. reinterpret_cast<int const*>(block_tables.data_ptr()),
  230. block_tables.stride(0),
  231. reinterpret_cast<int*>(paged_kv_indices.data_ptr()),
  232. reinterpret_cast<int*>(paged_kv_indptr.data_ptr()),
  233. reinterpret_cast<int*>(block_table_bound.data_ptr()));
  234. }
  235. } // namespace prepare_inputs
  236. void advance_step_flashattn(int64_t num_seqs, int64_t num_queries,
  237. int64_t block_size, torch::Tensor& input_tokens,
  238. torch::Tensor& sampled_token_ids,
  239. torch::Tensor& input_positions,
  240. torch::Tensor& seq_lens,
  241. torch::Tensor& slot_mapping,
  242. torch::Tensor& block_tables) {
  243. prepare_inputs::advance_step_flashattn(
  244. num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
  245. input_positions, seq_lens, slot_mapping, block_tables);
  246. }
  247. void advance_step_flashinfer(
  248. int64_t num_seqs, int64_t num_queries, int64_t block_size,
  249. torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
  250. torch::Tensor& input_positions, torch::Tensor& seq_lens,
  251. torch::Tensor& slot_mapping, torch::Tensor& block_tables,
  252. torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr,
  253. torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) {
  254. prepare_inputs::advance_step_flashinfer(
  255. num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
  256. input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices,
  257. paged_kv_indptr, paged_kv_last_page_len, block_table_bound);
  258. }