decoder_xqa_impl_precompiled.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. #include "env_utils.h"
  2. #include "decoder_xqa_impl_precompiled.h"
  3. #include <cuda.h>
  4. #include <functional>
  5. #include <memory>
  6. #include <mutex>
  7. #include "cubin/xqa_kernel_cubin.h"
  8. #include "decoder_xqa_runner.h"
  9. uint32_t getElemBytes(CUtensorMapDataType_enum dataType) {
  10. switch (dataType) {
  11. case CU_TENSOR_MAP_DATA_TYPE_UINT8:
  12. return 1;
  13. case CU_TENSOR_MAP_DATA_TYPE_UINT16:
  14. return 2;
  15. case CU_TENSOR_MAP_DATA_TYPE_UINT32:
  16. return 4;
  17. case CU_TENSOR_MAP_DATA_TYPE_INT32:
  18. return 4;
  19. case CU_TENSOR_MAP_DATA_TYPE_UINT64:
  20. return 8;
  21. case CU_TENSOR_MAP_DATA_TYPE_INT64:
  22. return 8;
  23. case CU_TENSOR_MAP_DATA_TYPE_FLOAT16:
  24. return 2;
  25. case CU_TENSOR_MAP_DATA_TYPE_FLOAT32:
  26. return 4;
  27. case CU_TENSOR_MAP_DATA_TYPE_FLOAT64:
  28. return 8;
  29. case CU_TENSOR_MAP_DATA_TYPE_BFLOAT16:
  30. return 2;
  31. case CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ:
  32. return 4;
  33. case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32:
  34. return 4;
  35. case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ:
  36. return 4;
  37. }
  38. throw std::runtime_error("unsupported data type");
  39. }
  40. CUtensorMap makeTensorMapForPagedKVCache(void const* addr,
  41. CUtensorMapDataType_enum dataType,
  42. uint32_t headElems, uint32_t nbKHeads,
  43. uint32_t tokensPerPage,
  44. uint32_t nbTokensPerTile = 64) {
  45. CUtensorMap tensorMap{};
  46. uint32_t elemBytes = getElemBytes(dataType);
  47. uint64_t const globalDims[] = {headElems, tokensPerPage, nbKHeads, 1U << 31};
  48. uint32_t const headBytes = elemBytes * headElems;
  49. uint64_t const globalStrides[] = {headBytes, headBytes * tokensPerPage,
  50. headBytes * tokensPerPage * nbKHeads};
  51. TORCH_CHECK(headElems <= 256);
  52. uint32_t const paddedHeadElems =
  53. headElems <= 64 ? 64 : (headElems <= 128 ? 128 : 256);
  54. uint32_t const partElems =
  55. std::min(elemBytes * paddedHeadElems, 128U) / elemBytes;
  56. uint32_t const boxDims[] = {partElems,
  57. std::min(tokensPerPage, nbTokensPerTile), 1, 1};
  58. uint32_t const elemStrides[] = {1, 1, 1, 1};
  59. auto const swizzle = [&] {
  60. switch (partElems) {
  61. case 128:
  62. return CU_TENSOR_MAP_SWIZZLE_128B;
  63. case 64:
  64. return CU_TENSOR_MAP_SWIZZLE_64B;
  65. default:
  66. throw std::runtime_error("unsupported cache head size");
  67. // default: TLLM_THROW("unsupported cache head size");
  68. }
  69. }();
  70. cuErrCheck(cuTensorMapEncodeTiled(
  71. &tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
  72. globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
  73. swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE,
  74. CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
  75. return tensorMap;
  76. }
  77. CUtensorMap makeTensorMapForKVCache(XQAParams const& xqaParams,
  78. KVCacheListParams const& kv_cache_buffer) {
  79. return makeTensorMapForPagedKVCache(
  80. kv_cache_buffer.pool, CU_TENSOR_MAP_DATA_TYPE_UINT8, xqaParams.head_size,
  81. xqaParams.num_kv_heads, xqaParams.tokens_per_block);
  82. }
  83. class XQAKernelList {
  84. public:
  85. using TKernelMeta = XQAKernelMetaInfo;
  86. XQAKernelList(Data_type type, unsigned int sm)
  87. : mDataType(type),
  88. mKernelMetaCount(sizeof(sXqaKernelMetaInfo) /
  89. sizeof(sXqaKernelMetaInfo[0])),
  90. mKernelMeta(&sXqaKernelMetaInfo[0]),
  91. mSM(sm) {
  92. mForceXQA = forceXQAKernels();
  93. }
  94. void loadXQAKernels() {
  95. std::cout << "entering load XQA Kernels\n";
  96. if (!mFunctions.empty()) {
  97. return;
  98. }
  99. std::cout << "here mKernelMetaCount=" << mKernelMetaCount << std::endl;
  100. for (unsigned int i = 0; i < mKernelMetaCount; ++i) {
  101. auto const& kernelMeta = mKernelMeta[i];
  102. // std::cout << "00000000000000\n";
  103. // std::cout << kernelMeta.mSM << "; " << kernelMeta.mDataType <<
  104. // std::endl; std::cout << mSM << "; " << mDataType << std::endl;
  105. if (kernelMeta.mSM != mSM || kernelMeta.mDataType != mDataType) continue;
  106. // Cubins for kernels that would take the JIT path are removed from
  107. // kernelMeta.
  108. if (kernelMeta.mCubin == nullptr) continue;
  109. // std::cout << "11111111111111\n";
  110. CUmodule hmod{0};
  111. auto findModuleIter = mModules.find(kernelMeta.mCubin);
  112. if (findModuleIter != mModules.end()) {
  113. hmod = findModuleIter->second;
  114. } else {
  115. cuErrCheck(cuModuleLoadData(&hmod, kernelMeta.mCubin));
  116. mModules.insert(std::make_pair(kernelMeta.mCubin, hmod));
  117. }
  118. XQAKernelFuncInfo funcInfo{};
  119. funcInfo.mMetaInfoIndex = i;
  120. cuErrCheck(cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod,
  121. kernelMeta.mFuncName));
  122. // std::cout << "reading mDeviceFunction:" <<funcInfo.mDeviceFunction
  123. // <<std::endl;
  124. funcInfo.mSharedMemBytes =
  125. getGlobalVar<uint32_t>(hmod, "smemSize", true).value();
  126. funcInfo.mKernelType =
  127. getGlobalVar<XQAKernelType>(hmod, "kernelType", false)
  128. .value_or(XQAKernelType::kAMPERE_WARP_SPECIALIZED);
  129. /* Set 46KB threshold here because we have to take static/driver shared
  130. * memory into consideration. */
  131. if (funcInfo.mSharedMemBytes >= 46 * 1024) {
  132. cuErrCheck(
  133. cuFuncSetAttribute(funcInfo.mDeviceFunction,
  134. CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
  135. funcInfo.mSharedMemBytes));
  136. }
  137. XQAKernelRuntimeHashKey hash_key{
  138. kernelMeta.mKVDataType, kernelMeta.mHeadDim,
  139. kernelMeta.mBeamWidth, kernelMeta.mNumQHeadsOverKV,
  140. kernelMeta.mMTileSize, kernelMeta.mTokensPerPage,
  141. kernelMeta.mPagedKVCache, kernelMeta.mMultiQueryTokens};
  142. mFunctions.insert(std::make_pair(hash_key, funcInfo));
  143. }
  144. }
  145. bool supportConfig(XQAParams const& xqaParams) const {
  146. unsigned int head_size = xqaParams.head_size;
  147. int num_q_heads = xqaParams.num_q_heads;
  148. int num_kv_heads = xqaParams.num_kv_heads;
  149. TORCH_CHECK(num_q_heads % num_kv_heads == 0,
  150. "numQHeads should be multiple of numKVHeads.");
  151. unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
  152. unsigned int beam_width = xqaParams.beam_width;
  153. // MultiQueryToken kernels can support any num_q_heads_over_kv that is power
  154. // of 2.
  155. unsigned int kernel_num_q_heads_over_kv =
  156. xqaParams.multi_query_tokens ? 0 : num_q_heads_over_kv;
  157. unsigned int m_tilesize;
  158. if (xqaParams.multi_query_tokens) {
  159. // MultiQueryToken kernels can handle either 16/32 for M direction per
  160. // CTA.
  161. m_tilesize = xqaParams.generation_input_length <= 16 ? 16 : 32;
  162. } else {
  163. m_tilesize = num_q_heads_over_kv;
  164. }
  165. XQAKernelRuntimeHashKey hash_key = {
  166. xqaParams.kv_cache_data_type,
  167. head_size,
  168. beam_width,
  169. kernel_num_q_heads_over_kv,
  170. m_tilesize,
  171. xqaParams.paged_kv_cache
  172. ? static_cast<unsigned int>(xqaParams.tokens_per_block)
  173. : 0,
  174. xqaParams.paged_kv_cache,
  175. xqaParams.multi_query_tokens};
  176. auto const findIter = mFunctions.find(hash_key);
  177. return findIter != mFunctions.end();
  178. }
  179. bool mayHavePerfGain(XQAParams const& xqaParams,
  180. int multiprocessor_count) const {
  181. return true;
  182. }
  183. template <typename T>
  184. void run(XQAParams const& xqaParams, KVCacheListParams const& kv_cache_buffer,
  185. int multiprocessor_count, cudaStream_t const& stream) const {
  186. unsigned int head_size = xqaParams.head_size;
  187. int num_q_heads = xqaParams.num_q_heads;
  188. int num_kv_heads = xqaParams.num_kv_heads;
  189. TORCH_CHECK(num_q_heads % num_kv_heads == 0,
  190. "numQHeads should be multiple of numKVHeads.");
  191. unsigned int num_q_heads_over_kv = num_q_heads / num_kv_heads;
  192. unsigned int beam_width = xqaParams.beam_width;
  193. unsigned int batch_beam_size = xqaParams.batch_size * beam_width;
  194. XQALaunchParam launchParams;
  195. buildXQALaunchParams(launchParams, xqaParams, kv_cache_buffer);
  196. void* xqa_q_input_ptr = const_cast<void*>(xqaParams.qHeads);
  197. XQAKernelRuntimeHashKey hash_key =
  198. getRuntimeHashKeyFromXQAParams(xqaParams);
  199. auto const findIter = mFunctions.find(hash_key);
  200. // std::cout << "at running mDeviceFunction:"
  201. // <<findIter->second.mDeviceFunction <<std::endl;
  202. TORCH_CHECK(findIter != mFunctions.end(), "XQAKernelFunc not found.");
  203. auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
  204. const CUfunction func = findIter->second.mDeviceFunction;
  205. unsigned int const shared_mem_bytes = findIter->second.mSharedMemBytes;
  206. auto const kernelType = findIter->second.mKernelType;
  207. if (false && xqaParams.multi_query_tokens) {
  208. // pass
  209. } else {
  210. bool const isGmmaKernel =
  211. (kernelType == XQAKernelType::kHOPPER_WARP_SPECIALIZED);
  212. TORCH_CHECK(isGmmaKernel == (mSM == kSM_90 &&
  213. xqaParams.kv_cache_data_type ==
  214. XQADataType::DATA_TYPE_E4M3 &&
  215. xqaParams.beam_width == 1));
  216. constexpr uint32_t kMAX_NB_KERNEL_PARAMS = 11;
  217. uint32_t const maxNbKernelParams = (isGmmaKernel ? 11 : 10);
  218. uint32_t idxNextParam = 0;
  219. void* kernelParams[kMAX_NB_KERNEL_PARAMS];
  220. auto appendParam = [&](auto* p) mutable {
  221. TORCH_CHECK(idxNextParam < maxNbKernelParams);
  222. kernelParams[idxNextParam++] = p;
  223. };
  224. appendParam(&launchParams.num_k_heads);
  225. appendParam(&launchParams.output);
  226. appendParam(&xqa_q_input_ptr);
  227. appendParam(&launchParams.kvCacheParams);
  228. appendParam(&launchParams.batch_size);
  229. appendParam(&launchParams.kv_scale_quant_orig);
  230. CUtensorMap tensorMap{};
  231. if (isGmmaKernel) {
  232. tensorMap = makeTensorMapForKVCache(xqaParams, kv_cache_buffer);
  233. appendParam(&tensorMap);
  234. }
  235. appendParam(&launchParams.semaphores);
  236. appendParam(&launchParams.scratch);
  237. kernelParams[idxNextParam] =
  238. nullptr; // one extra nullptr at end as guard.
  239. int multi_block = 1;
  240. if (xqaParams.multi_block_mode) {
  241. multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size,
  242. multiprocessor_count);
  243. }
  244. auto blockz = isGmmaKernel ? 3 : 2;
  245. cuErrCheck(cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads,
  246. xqaParams.batch_size, 128, 1,
  247. isGmmaKernel ? 3 : 2, shared_mem_bytes, stream,
  248. kernelParams, nullptr));
  249. }
  250. }
  251. protected:
  252. Data_type mDataType;
  253. TKernelMeta const* mKernelMeta;
  254. unsigned int mKernelMetaCount;
  255. unsigned int mSM;
  256. std::unordered_map<unsigned long long const*, CUmodule> mModules;
  257. bool mForceXQA = false;
  258. struct XQAKernelFuncInfo {
  259. unsigned int mMetaInfoIndex;
  260. unsigned int mSharedMemBytes;
  261. CUfunction mDeviceFunction;
  262. XQAKernelType mKernelType;
  263. };
  264. std::unordered_map<XQAKernelRuntimeHashKey, XQAKernelFuncInfo,
  265. XQAKernelRuntimeHasher>
  266. mFunctions;
  267. };
  268. class XQAKernelLoader {
  269. public:
  270. XQAKernelList const* getXQAKernels(Data_type type, unsigned int sm) {
  271. static std::mutex s_mutex;
  272. std::lock_guard<std::mutex> lg(s_mutex);
  273. XQAKernelLoadHashKey hash_key{type, sm};
  274. auto const findIter = mKernels.find(hash_key);
  275. if (findIter == mKernels.end()) {
  276. XQAKernelList* newKernel = new XQAKernelList{type, sm};
  277. newKernel->loadXQAKernels();
  278. mKernels.insert(
  279. std::make_pair(hash_key, std::unique_ptr<XQAKernelList>(newKernel)));
  280. return newKernel;
  281. } else {
  282. return findIter->second.get();
  283. }
  284. }
  285. static XQAKernelLoader& Get() {
  286. int device_id = getDevice();
  287. static std::unique_ptr<XQAKernelLoader> s_factory[32] = {nullptr};
  288. if (s_factory[device_id] == nullptr) {
  289. assert(device_id <= 32);
  290. s_factory[device_id] =
  291. std::make_unique<XQAKernelLoader>(XQAKernelLoader());
  292. }
  293. return *(s_factory[device_id]);
  294. }
  295. private:
  296. XQAKernelLoader() = default;
  297. std::unordered_map<XQAKernelLoadHashKey, const std::unique_ptr<XQAKernelList>,
  298. XQAKernelLoadHasher>
  299. mKernels;
  300. };
  301. inline XQAKernelList const* getXQAKernels(Data_type type, unsigned int sm) {
  302. return XQAKernelLoader::Get().getXQAKernels(type, sm);
  303. }
  304. #define XQA_KERNEL_RUN(DATA_TYPE) \
  305. xqa_kernel->template run<DATA_TYPE>(xqa_params, kv_cache_buffer, \
  306. multi_processor_count, stream);
  307. void DecoderXQAImplPrecompiled::runDispatchBuffer(
  308. XQAParams const& xqa_params, KVCacheListParams const& kv_cache_buffer,
  309. cudaStream_t const& stream) {
  310. XQAKernelList const* xqa_kernel =
  311. getXQAKernels(/*mRunner->mDataType*/ mRunner->mDataType, getSMVersion());
  312. int multi_processor_count = mRunner->mMultiProcessorCount;
  313. if (mRunner->mDataType == DATA_TYPE_FP16) {
  314. XQA_KERNEL_RUN(__half);
  315. } else {
  316. XQA_KERNEL_RUN(__nv_bfloat16);
  317. }
  318. }
  319. void DecoderXQAImplPrecompiled::runWithKVBlockArray(
  320. XQAParams const& xqa_params, KVCacheListParams const& kv_block_array,
  321. cudaStream_t const& stream) {
  322. runDispatchBuffer(xqa_params, kv_block_array, stream);
  323. }
  324. #undef XQA_KERNEL_RUN