12345678910111213141516171819202122232425262728 |
- #pragma once
- #include "xqa_params.h"
- #include "decoder_xqa_impl_common.h"
- class DecoderXQARunner;
- class DecoderXQAImpl {
- public:
- void run(XQAParams const& xqa_params,
- KVCacheListParams const& kv_cache_buffer,
- cudaStream_t const& stream);
- enum class ImplType {
- kPrecompiled = 0,
- };
- static std::unique_ptr<DecoderXQAImpl> create(DecoderXQARunner* runner,
- ImplType implType);
- protected:
- DecoderXQAImpl(DecoderXQARunner* runner) : mRunner(runner) {}
- virtual void runWithKVBlockArray(XQAParams const& xqa_params,
- KVCacheListParams const& kv_block_array,
- cudaStream_t const& stream) = 0;
- DecoderXQARunner* mRunner;
- };
- enum class XQAKernelType : int32_t {
- kAMPERE_WARP_SPECIALIZED = 0,
- kHOPPER_WARP_SPECIALIZED = 1
- };
|