decoder_xqa_impl.h 877 B

12345678910111213141516171819202122232425262728
  1. #pragma once
  2. #include "xqa_params.h"
  3. #include "decoder_xqa_impl_common.h"
  4. class DecoderXQARunner;
  5. class DecoderXQAImpl {
  6. public:
  7. void run(XQAParams const& xqa_params,
  8. KVCacheListParams const& kv_cache_buffer,
  9. cudaStream_t const& stream);
  10. enum class ImplType {
  11. kPrecompiled = 0,
  12. };
  13. static std::unique_ptr<DecoderXQAImpl> create(DecoderXQARunner* runner,
  14. ImplType implType);
  15. protected:
  16. DecoderXQAImpl(DecoderXQARunner* runner) : mRunner(runner) {}
  17. virtual void runWithKVBlockArray(XQAParams const& xqa_params,
  18. KVCacheListParams const& kv_block_array,
  19. cudaStream_t const& stream) = 0;
  20. DecoderXQARunner* mRunner;
  21. };
  22. enum class XQAKernelType : int32_t {
  23. kAMPERE_WARP_SPECIALIZED = 0,
  24. kHOPPER_WARP_SPECIALIZED = 1
  25. };