decoder_xqa_impl_common.h 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. /*
  2. * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. *
  16. * Common utils to be shared between Precompiled and JIT implementation.
  17. */
  18. #pragma once
  19. // NOTE: we use int32_t sequence lengths as gpt attention plugins use int32_t
  20. // for that. XQA kernels assume all length should use uint32_t.
  21. #include "xqa_params.h"
  22. // #include "decoder_xqa_common.h"
  23. #include <cassert>
  24. // void syncAndCheck(char const* const file, int const line)
  25. // {
  26. // if (true)
  27. // {
  28. // cudaGetLastError();
  29. // cudaDeviceSynchronize();
  30. // }
  31. // }
  32. // #define sync_check_cuda_error() syncAndCheck(__FILE__, __LINE__)
  33. inline void checkCuda(cudaError_t err) {
  34. if (err != cudaSuccess) {
  35. printf("%s\n", cudaGetErrorName(err));
  36. throw std::runtime_error(cudaGetErrorName(err));
  37. }
  38. }
  39. inline int getMultiProcessorCount() {
  40. int device_id;
  41. int multi_processor_count;
  42. checkCuda(cudaGetDevice(&device_id));
  43. checkCuda(cudaDeviceGetAttribute(&multi_processor_count,
  44. cudaDevAttrMultiProcessorCount, device_id));
  45. return multi_processor_count;
  46. }
  47. template <typename T>
  48. HOST_DEVICE_FUNC constexpr inline T divUp(T a, T b) {
  49. return (a + b - 1) / b;
  50. }
  51. template <typename T>
  52. HOST_DEVICE_FUNC constexpr inline T roundUp(T a, T b) {
  53. return divUp(a, b) * b;
  54. }
  55. constexpr inline uint32_t exactDiv(uint32_t a, uint32_t b) {
  56. assert(a % b == 0);
  57. return a / b;
  58. }
  59. using KVCachePageIndex = int32_t;
  60. using SeqLenDataType = uint32_t;
  61. struct KVCacheListParams {
  62. void const* pool = nullptr;
  63. KVCachePageIndex const* block_indices =
  64. nullptr; // shape: [batchSize][beamWidth][2][maxNbPagesPerSeq].
  65. SeqLenDataType const* sequence_lengths =
  66. nullptr; // shape: [batchSize][beamWidth] (for compatibility)
  67. // NOTE: max_num_blocks_per_sequence for paged kv cache.
  68. uint32_t capacity = 0;
  69. KVCacheListParams(void const* _pool, KVCachePageIndex const* _block_indices,
  70. SeqLenDataType const* _sequence_lengths, uint32_t _capacity)
  71. : pool(_pool),
  72. block_indices(_block_indices),
  73. sequence_lengths(_sequence_lengths),
  74. capacity(_capacity) {}
  75. KVCacheListParams() = default;
  76. };
  77. struct XQALaunchParam {
  78. uint32_t num_k_heads;
  79. void* output;
  80. // void const* qkv;
  81. KVCacheListParams kvCacheParams;
  82. uint32_t batch_size;
  83. float const* kv_scale_quant_orig = nullptr;
  84. int* cu_seq_lens = nullptr;
  85. uint32_t* semaphores = nullptr;
  86. void* scratch = nullptr;
  87. };
  88. struct XQAKernelLoadHashKey {
  89. Data_type data_type;
  90. unsigned int sm;
  91. bool operator==(XQAKernelLoadHashKey const& other) const {
  92. return data_type == other.data_type && sm == other.sm;
  93. }
  94. };
  95. struct XQAKernelLoadHasher {
  96. size_t operator()(XQAKernelLoadHashKey const& s) const {
  97. size_t key = s.data_type;
  98. key <<= 16;
  99. key ^= s.sm;
  100. return key;
  101. }
  102. };
  103. struct XQAKernelRuntimeHashKey {
  104. Data_type kv_data_type;
  105. unsigned int head_size;
  106. unsigned int beam_size;
  107. unsigned int num_q_heads_per_kv;
  108. unsigned int m_tilesize;
  109. unsigned int tokens_per_page;
  110. bool paged_kv_cache;
  111. bool multi_query_tokens;
  112. bool operator==(XQAKernelRuntimeHashKey const& other) const {
  113. return kv_data_type == other.kv_data_type && head_size == other.head_size &&
  114. num_q_heads_per_kv == other.num_q_heads_per_kv &&
  115. beam_size == other.beam_size &&
  116. multi_query_tokens == other.multi_query_tokens &&
  117. m_tilesize == other.m_tilesize &&
  118. tokens_per_page == other.tokens_per_page &&
  119. paged_kv_cache == other.paged_kv_cache;
  120. }
  121. };
  122. std::ostream& operator<<(std::ostream& os, const XQAKernelRuntimeHashKey& key);
  123. XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(
  124. XQAParams const& xqaParams);
  125. void buildXQALaunchParams(XQALaunchParam& launchParams, XQAParams const& params,
  126. KVCacheListParams kv_cache_buffer);
  127. struct XQAKernelRuntimeHasher {
  128. size_t operator()(XQAKernelRuntimeHashKey const& s) const {
  129. size_t key = s.kv_data_type;
  130. key <<= 16;
  131. key ^= s.head_size;
  132. key <<= 8;
  133. key ^= s.num_q_heads_per_kv;
  134. key <<= 8;
  135. key ^= s.beam_size;
  136. key <<= 6;
  137. key ^= s.m_tilesize;
  138. key <<= 10;
  139. key ^= s.tokens_per_page;
  140. key <<= 1;
  141. key ^= s.paged_kv_cache;
  142. key <<= 1;
  143. key ^= s.multi_query_tokens;
  144. return key;
  145. }
  146. };
  147. // XQA kernel can be uniquely identified by (LoadHashKey, RuntimeHashKey).
  148. struct XQAKernelFullHashKey {
  149. XQAKernelLoadHashKey load_key;
  150. XQAKernelRuntimeHashKey runtime_key;
  151. XQAKernelFullHashKey() = default;
  152. XQAKernelFullHashKey(XQAKernelLoadHashKey const& load_key,
  153. XQAKernelRuntimeHashKey const& runtime_key)
  154. : load_key(load_key), runtime_key(runtime_key) {}
  155. XQAKernelFullHashKey(void const* buffer, size_t buffer_size) {
  156. TORCH_CHECK(sizeof(*this) <= buffer_size);
  157. memcpy(this, buffer, sizeof(*this));
  158. }
  159. bool operator==(XQAKernelFullHashKey const& other) const {
  160. return load_key == other.load_key && runtime_key == other.runtime_key;
  161. }
  162. size_t getSerializationSize() const { return sizeof(*this); }
  163. void serialize(void* buffer, size_t buffer_size) const {
  164. TORCH_CHECK(sizeof(*this) <= buffer_size);
  165. memcpy(buffer, this, sizeof(*this));
  166. }
  167. };
  168. struct XQAKernelFullHasher {
  169. size_t operator()(XQAKernelFullHashKey const& s) const {
  170. return XQAKernelLoadHasher()(s.load_key) ^
  171. XQAKernelRuntimeHasher()(s.runtime_key);
  172. }
  173. };
  174. std::uintptr_t constexpr kCudaMemAlign = 128;
  175. inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) {
  176. uintptr_t addr = (uintptr_t)ptr;
  177. if (addr % to) {
  178. addr += to - addr % to;
  179. }
  180. return (int8_t*)addr;
  181. }
  182. inline int8_t* nextWorkspacePtrCommon(int8_t* ptr,
  183. uintptr_t previousWorkspaceSize,
  184. uintptr_t const alignment) {
  185. uintptr_t addr = (uintptr_t)ptr;
  186. addr += previousWorkspaceSize;
  187. return alignPtr((int8_t*)addr, alignment);
  188. }
  189. inline int8_t* nextWorkspacePtrWithAlignment(
  190. int8_t* ptr, uintptr_t previousWorkspaceSize,
  191. uintptr_t const alignment = kCudaMemAlign) {
  192. return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment);
  193. }
  194. template <typename T>
  195. std::optional<T> getGlobalVar(CUmodule hmod, char const* const name,
  196. bool required = false) {
  197. T* pVar = nullptr;
  198. size_t size = 0;
  199. auto const error = cuModuleGetGlobal(reinterpret_cast<CUdeviceptr*>(&pVar),
  200. &size, hmod, name);
  201. T ret;
  202. switch (error) {
  203. case CUDA_SUCCESS:
  204. TORCH_CHECK(size == sizeof(T));
  205. CUDACHECK(cudaMemcpy(&ret, pVar, size, cudaMemcpyDeviceToHost));
  206. break;
  207. case CUDA_ERROR_NOT_FOUND:
  208. if (!required) {
  209. return std::nullopt;
  210. }
  211. [[fallthrough]];
  212. default:
  213. cuErrCheck(("Failed to retrieve global variable from cubin.", error));
  214. }
  215. return std::optional<T>{std::move(ret)};
  216. }
  217. inline int computeMultiBlockCount(XQAParams const& xqaParams, int batch_size,
  218. int multiprocessor_count) {
  219. int multi_block_count = 1;
  220. int num_kv_heads = xqaParams.num_kv_heads;
  221. int history_length = xqaParams.timestep;
  222. int32_t const maxNbSubSeq = kXQA_MAX_NUM_SUB_SEQ;
  223. multi_block_count = history_length / kMinHistoryTokensPerBlock;
  224. // avoid using too many blocks for one sequence, otherwise the final reduction
  225. // may dominate.
  226. multi_block_count = std::min(
  227. multi_block_count,
  228. static_cast<int>(std::round(std::sqrt(multi_block_count * 8.F))));
  229. multi_block_count = std::max(multi_block_count, 1);
  230. // adjust to kTargetWaveFactor, as already initialized using
  231. // kMinHistoryTokensPerBlock, only need to decrease.
  232. double wave_count = (double)batch_size * num_kv_heads * multi_block_count /
  233. (double)multiprocessor_count;
  234. double adj_factor = wave_count / (double)kTargetWaveFactor;
  235. if (adj_factor > 1.0) {
  236. multi_block_count = floor(multi_block_count / adj_factor);
  237. }
  238. multi_block_count = std::max(multi_block_count, 1);
  239. // Add limitation due to reserved workspace size.
  240. // When batch_size is large, multi-block is useless anyway. So large workspace
  241. // is not useful and we can set a hard limit for workspace size (computed from
  242. // maxNbSubSeq).
  243. multi_block_count =
  244. std::max(std::min(multi_block_count, maxNbSubSeq / batch_size), 1);
  245. TORCH_CHECK(multi_block_count >= 1,
  246. "MultiBlock count should be larger than 1");
  247. TORCH_CHECK(
  248. multi_block_count == 1 || batch_size * multi_block_count <= maxNbSubSeq,
  249. "Insufficient workspace");
  250. return multi_block_count;
  251. }