advance_step.cu 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /*
  2. * The goal of this GPU kernel is to advance input tensors on the GPU directly
  3. * Current restrictions:
  4. * 1. Specialized for DraftModelRunner
  5. * 2. Supports flash_attn only
  6. */
  7. #include "advance_step.cuh"
  8. namespace prepare_inputs {
  9. //
  10. template <int const num_threads>
  11. __global__ void advance_step_kernel(int num_seqs, int num_queries,
  12. int block_size, long* input_tokens_ptr,
  13. long const* sampled_token_ids_ptr,
  14. long* input_positions_ptr,
  15. int* seq_lens_ptr, long* slot_mapping_ptr,
  16. int const* block_tables_ptr,
  17. int64_t const block_tables_stride) {
  18. int num_query_blocks = div_ceil(num_queries, num_threads);
  19. if (blockIdx.x >= num_query_blocks) {
  20. return;
  21. }
  22. int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
  23. if (cur_query_id >= num_queries) {
  24. return;
  25. }
  26. // Update input_tokens
  27. input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
  28. int seq_len = seq_lens_ptr[cur_query_id];
  29. int next_seq_len = seq_len + 1;
  30. int next_input_pos = next_seq_len - 1;
  31. // Update seq_lens
  32. seq_lens_ptr[cur_query_id] = next_seq_len;
  33. // Update input_positions
  34. input_positions_ptr[cur_query_id] = next_input_pos;
  35. int const* seq_block_tables_ptr =
  36. block_tables_ptr + block_tables_stride * cur_query_id;
  37. int block_index = next_input_pos / block_size;
  38. int block_offset = next_input_pos % block_size;
  39. int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
  40. // Update slot_mapping
  41. slot_mapping_ptr[cur_query_id] = slot_num;
  42. }
  43. inline void verify_tensor(std::string const& name, torch::Tensor& t,
  44. int64_t const size_0, int64_t const size_1,
  45. c10::ScalarType const type) {
  46. bool size_0_cond = true;
  47. if (size_0 != -1) {
  48. size_0_cond = t.size(0) == size_0;
  49. }
  50. bool size_1_cond = true;
  51. if (size_1 != -1) {
  52. size_1_cond = t.size(1) == size_1;
  53. }
  54. bool is_contiguous = t.is_contiguous();
  55. bool same_type = t.dtype() == type;
  56. bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
  57. if (!pass) {
  58. TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
  59. " is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
  60. " is not as expected: shape = [", size_0, ", ", size_1,
  61. "], type = ", type);
  62. }
  63. }
  64. void advance_step(int num_seqs, int num_queries, int block_size,
  65. torch::Tensor& input_tokens, // type: long
  66. torch::Tensor& sampled_token_ids, // type: long
  67. torch::Tensor& input_positions, // type: long
  68. torch::Tensor& seq_lens, // type: int
  69. torch::Tensor& slot_mapping, // type: long
  70. torch::Tensor& block_tables) { // type: int
  71. if (logging) {
  72. printf("advance_step:\n");
  73. printf(" num_seqs = %d\n", num_seqs);
  74. printf(" num_queries = %d\n", num_queries);
  75. printf(" block_size = %d\n", block_size);
  76. }
  77. // Verify all tensors
  78. verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
  79. verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
  80. at::kLong);
  81. verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
  82. verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
  83. verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
  84. verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
  85. int dev = sampled_token_ids.get_device();
  86. cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
  87. int blocks;
  88. cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
  89. advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
  90. num_seqs, num_queries, block_size,
  91. reinterpret_cast<long*>(input_tokens.data_ptr()),
  92. reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
  93. reinterpret_cast<long*>(input_positions.data_ptr()),
  94. reinterpret_cast<int*>(seq_lens.data_ptr()),
  95. reinterpret_cast<long*>(slot_mapping.data_ptr()),
  96. reinterpret_cast<int const*>(block_tables.data_ptr()),
  97. block_tables.stride(0));
  98. }
  99. } // namespace prepare_inputs
  100. void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
  101. torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
  102. torch::Tensor& input_positions, torch::Tensor& seq_lens,
  103. torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
  104. prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
  105. sampled_token_ids, input_positions, seq_lens,
  106. slot_mapping, block_tables);
  107. }