decoder_xqa_impl_common.cpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. #include "decoder_xqa_impl_common.h"
  2. // Overloading << operator for XQAKernelRuntimeHashKey
  3. std::ostream& operator<<(std::ostream& os, const XQAKernelRuntimeHashKey& key) {
  4. os << "{kv_data_type: " << key.kv_data_type
  5. << ", head_size: " << key.head_size << ", beam_size: " << key.beam_size
  6. << ", num_q_heads_per_kv: " << key.num_q_heads_per_kv
  7. << ", m_tilesize: " << key.m_tilesize
  8. << ", tokens_per_page: " << key.tokens_per_page
  9. << ", paged_kv_cache: " << (key.paged_kv_cache ? "true" : "false")
  10. << ", multi_query_tokens: " << (key.multi_query_tokens ? "true" : "false")
  11. << "}";
  12. return os;
  13. }
  14. XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(
  15. XQAParams const& xqaParams) {
  16. unsigned int head_size = xqaParams.head_size;
  17. unsigned int num_q_heads = xqaParams.num_q_heads;
  18. unsigned int num_kv_heads = xqaParams.num_kv_heads;
  19. TORCH_CHECK(num_q_heads % num_kv_heads == 0,
  20. "numQHeads should be multiple of numKVHeads.");
  21. unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
  22. unsigned int beam_width = xqaParams.beam_width;
  23. // Use mTileSize = 16 kernels when qSeqLen <= 16.vi
  24. unsigned int qSeqLen =
  25. static_cast<unsigned int>(xqaParams.generation_input_length);
  26. unsigned int mTileSize = qSeqLen <= 16 ? 16 : 32;
  27. // MultiQueryToken kernels can support any num_q_heads_over_kv that is power
  28. // of 2.
  29. unsigned int kernel_num_q_heads_over_kv =
  30. xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
  31. // MultiQueryToken kernels can handle either 16/32 for M direction per CTA.
  32. unsigned int kernel_m_tilesize =
  33. xqaParams.multi_query_tokens ? mTileSize : num_q_heads_over_kv;
  34. return {xqaParams.kv_cache_data_type,
  35. head_size,
  36. beam_width,
  37. kernel_num_q_heads_over_kv,
  38. kernel_m_tilesize,
  39. xqaParams.paged_kv_cache
  40. ? static_cast<unsigned int>(xqaParams.tokens_per_block)
  41. : 0,
  42. xqaParams.paged_kv_cache,
  43. xqaParams.multi_query_tokens};
  44. }
  45. // Setup launch params and ioScratch. ioScratch is for RoPE and output type
  46. // conversion. not used
  47. void buildXQALaunchParams(XQALaunchParam& launchParams, XQAParams const& params,
  48. KVCacheListParams kv_cache_buffer) {
  49. TORCH_CHECK(
  50. params.data_type == DATA_TYPE_FP16 || params.data_type == DATA_TYPE_BF16,
  51. "Only fp16 or bf16 supported now.");
  52. memset(&launchParams, 0, sizeof(XQALaunchParam));
  53. launchParams.num_k_heads = params.num_kv_heads;
  54. launchParams.output = static_cast<uint8_t*>(params.output);
  55. launchParams.batch_size = params.batch_size;
  56. launchParams.kv_scale_quant_orig = params.kv_scale_quant_orig;
  57. launchParams.semaphores = params.semaphores;
  58. // Workspace.
  59. int8_t* workspace = reinterpret_cast<int8_t*>(params.workspaces);
  60. // workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(
  61. // workspace, 2 * params.head_size * params.num_q_heads *
  62. // params.total_num_input_tokens);
  63. // unsigned int batch_beam_size = params.batch_size * params.beam_width;
  64. // const size_t cu_seqlens_size = sizeof(int) * (batch_beam_size + 0);
  65. // launchParams.cu_seq_lens (workspace);
  66. // launchParams.cu_seq_lens = launchParams.cu_seq_lens;
  67. // workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(workspace,
  68. // cu_seqlens_size); launchParams.rotary_inv_freq_buf =
  69. // reinterpret_cast<float*>(workspace); auto const
  70. // multi_block_workspace_alignment = tensorrt_llm::common::roundUp(
  71. // sizeof(half) * params.head_size * (params.num_q_heads /
  72. // params.num_kv_heads) * params.beam_width, 128);
  73. // workspace = tensorrt_llm::common::nextWorkspacePtrWithAlignment(
  74. // workspace, rotary_inv_freq_size, multi_block_workspace_alignment);
  75. launchParams.scratch = reinterpret_cast<void*>(workspace);
  76. launchParams.kvCacheParams = kv_cache_buffer;
  77. }