1
0

utils.cuh 14 KB


  1. /*
  2. * Copyright (c) 2024 by PygmalionAI team.
  3. * Copyright (c) 2023 by FlashInfer team.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #ifndef APHRODITE_UTILS_CUH_
  18. #define APHRODITE_UTILS_CUH_
  19. #include <cuda_runtime.h>
  20. #include <iostream>
  21. #include <sstream>
  22. #include <stdexcept>
  23. #include <vector>
  24. #include <torch/all.h>
  25. #define STR_HELPER(x) #x
  26. #define STR(x) STR_HELPER(x)
  27. // macro to turn off fp16 qk reduction to reduce binary
  28. #ifndef APHRODITE_ALWAYS_DISALLOW_FP16_QK_REDUCTION
  29. #define APHRODITE_ALWAYS_DISALLOW_FP16_QK_REDUCTION 0
  30. #endif
  31. #ifndef NDEBUG
  32. #define APHRODITE_CUDA_CALL(func, ...) \
  33. { \
  34. cudaError_t e = (func); \
  35. if (e != cudaSuccess) { \
  36. std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \
  37. << ") " << __FILE__ << ": line " << __LINE__ \
  38. << " at function " << STR(func) << std::endl; \
  39. return e; \
  40. } \
  41. }
  42. #else
  43. #define APHRODITE_CUDA_CALL(func, ...) \
  44. { \
  45. cudaError_t e = (func); \
  46. if (e != cudaSuccess) { \
  47. return e; \
  48. } \
  49. }
  50. #endif
  51. #define DISPATCH_ALLOW_FP16_QK_REDUCTION(allow_fp16_qk_reduction, \
  52. ALLOW_FP16_QK_REDUCTION, ...) \
  53. if (allow_fp16_qk_reduction) { \
  54. throw std::runtime_error("FP16_QK_REDUCTION disabled at compile time"); \
  55. } else { \
  56. constexpr bool ALLOW_FP16_QK_REDUCTION = false; \
  57. __VA_ARGS__ \
  58. }
  59. #define DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, ...) \
  60. if (num_frags_x == 1) { \
  61. constexpr size_t NUM_FRAGS_X = 1; \
  62. __VA_ARGS__ \
  63. } else if (num_frags_x == 2) { \
  64. constexpr size_t NUM_FRAGS_X = 2; \
  65. __VA_ARGS__ \
  66. } else { \
  67. std::ostringstream err_msg; \
  68. err_msg << "Unsupported num_frags_x: " << num_frags_x; \
  69. throw std::invalid_argument(err_msg.str()); \
  70. }
  71. #define DISPATCH_NUM_FRAGS_Z(max_frags_z, NUM_FRAGS_Z, ...) \
  72. if (max_frags_z >= 8) { \
  73. constexpr size_t NUM_FRAGS_Z = 8; \
  74. __VA_ARGS__ \
  75. } else if (max_frags_z >= 4) { \
  76. constexpr size_t NUM_FRAGS_Z = 4; \
  77. __VA_ARGS__ \
  78. } else if (max_frags_z >= 2) { \
  79. constexpr size_t NUM_FRAGS_Z = 2; \
  80. __VA_ARGS__ \
  81. } else if (max_frags_z >= 1) { \
  82. constexpr size_t NUM_FRAGS_Z = 1; \
  83. __VA_ARGS__ \
  84. } else { \
  85. std::ostringstream err_msg; \
  86. err_msg << "Unsupported max_frags_z: " << max_frags_z; \
  87. throw std::invalid_argument(err_msg.str()); \
  88. }
  89. #define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
  90. if (group_size == 1) { \
  91. constexpr size_t GROUP_SIZE = 1; \
  92. __VA_ARGS__ \
  93. } else if (group_size == 2) { \
  94. constexpr size_t GROUP_SIZE = 2; \
  95. __VA_ARGS__ \
  96. } else if (group_size == 4) { \
  97. constexpr size_t GROUP_SIZE = 4; \
  98. __VA_ARGS__ \
  99. } else if (group_size == 8) { \
  100. constexpr size_t GROUP_SIZE = 8; \
  101. __VA_ARGS__ \
  102. } else { \
  103. std::ostringstream err_msg; \
  104. err_msg << "Unsupported group_size: " << group_size; \
  105. throw std::invalid_argument(err_msg.str()); \
  106. }
  107. #define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \
  108. switch (mask_mode) { \
  109. case MaskMode::kNone: { \
  110. constexpr MaskMode MASK_MODE = MaskMode::kNone; \
  111. __VA_ARGS__ \
  112. break; \
  113. } \
  114. case MaskMode::kCausal: { \
  115. constexpr MaskMode MASK_MODE = MaskMode::kCausal; \
  116. __VA_ARGS__ \
  117. break; \
  118. } \
  119. case MaskMode::kCustom: { \
  120. constexpr MaskMode MASK_MODE = MaskMode::kCustom; \
  121. __VA_ARGS__ \
  122. break; \
  123. } \
  124. default: { \
  125. std::ostringstream err_msg; \
  126. err_msg << "Unsupported mask_mode: " << int(mask_mode); \
  127. throw std::invalid_argument(err_msg.str()); \
  128. } \
  129. }
  130. #define DISPATCH_LOGITS_POST_HOOK(logits_soft_cap, LOGITS_POST_HOOK, ...) \
  131. if (logits_soft_cap > 0.f) { \
  132. constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kSoftCap; \
  133. __VA_ARGS__ \
  134. } else if (logits_soft_cap == 0.f) { \
  135. constexpr LogitsPostHook LOGITS_POST_HOOK = LogitsPostHook::kNone; \
  136. __VA_ARGS__ \
  137. } else { \
  138. std::ostringstream err_msg; \
  139. err_msg << "Invalid logits_soft_cap (should be >= 0): " \
  140. << logits_soft_cap; \
  141. throw std::invalid_argument(err_msg.str()); \
  142. }
  143. #define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \
  144. switch (head_dim) { \
  145. case 64: { \
  146. constexpr size_t HEAD_DIM = 64; \
  147. __VA_ARGS__ \
  148. break; \
  149. } \
  150. case 128: { \
  151. constexpr size_t HEAD_DIM = 128; \
  152. __VA_ARGS__ \
  153. break; \
  154. } \
  155. case 256: { \
  156. constexpr size_t HEAD_DIM = 256; \
  157. __VA_ARGS__ \
  158. break; \
  159. } \
  160. default: { \
  161. std::ostringstream err_msg; \
  162. err_msg << "Unsupported head_dim: " << head_dim; \
  163. throw std::invalid_argument(err_msg.str()); \
  164. } \
  165. }
  166. #define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \
  167. switch (pos_encoding_mode) { \
  168. case PosEncodingMode::kNone: { \
  169. constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \
  170. __VA_ARGS__ \
  171. break; \
  172. } \
  173. case PosEncodingMode::kRoPELlama: { \
  174. constexpr PosEncodingMode POS_ENCODING_MODE = \
  175. PosEncodingMode::kRoPELlama; \
  176. __VA_ARGS__ \
  177. break; \
  178. } \
  179. case PosEncodingMode::kALiBi: { \
  180. constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \
  181. __VA_ARGS__ \
  182. break; \
  183. } \
  184. default: { \
  185. std::ostringstream err_msg; \
  186. err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
  187. throw std::invalid_argument(err_msg.str()); \
  188. } \
  189. }
  190. #define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \
  191. switch (aligned_vec_size) { \
  192. case 16: { \
  193. constexpr size_t ALIGNED_VEC_SIZE = 16; \
  194. __VA_ARGS__ \
  195. break; \
  196. } \
  197. case 8: { \
  198. constexpr size_t ALIGNED_VEC_SIZE = 8; \
  199. __VA_ARGS__ \
  200. break; \
  201. } \
  202. case 4: { \
  203. constexpr size_t ALIGNED_VEC_SIZE = 4; \
  204. __VA_ARGS__ \
  205. break; \
  206. } \
  207. case 2: { \
  208. constexpr size_t ALIGNED_VEC_SIZE = 2; \
  209. __VA_ARGS__ \
  210. break; \
  211. } \
  212. case 1: { \
  213. constexpr size_t ALIGNED_VEC_SIZE = 1; \
  214. __VA_ARGS__ \
  215. break; \
  216. } \
  217. default: { \
  218. std::ostringstream err_msg; \
  219. err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \
  220. throw std::invalid_argument(err_msg.str()); \
  221. } \
  222. }
  223. namespace aphrodite {
  224. template <typename T1, typename T2>
  225. __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) {
  226. return (x + y - 1) / y;
  227. }
  228. template <typename T>
  229. inline void DebugPrintCUDAArray(T* device_ptr, size_t size,
  230. std::string prefix = "") {
  231. std::vector<T> host_array(size);
  232. std::cout << prefix;
  233. cudaMemcpy(host_array.data(), device_ptr, size * sizeof(T),
  234. cudaMemcpyDeviceToHost);
  235. for (size_t i = 0; i < size; ++i) {
  236. std::cout << host_array[i] << " ";
  237. }
  238. std::cout << std::endl;
  239. }
  240. /*!
  241. * \brief Return x - y if x > y, otherwise return 0.
  242. */
  243. __device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x,
  244. uint32_t y) {
  245. return (x > y) ? x - y : 0U;
  246. }
  247. __device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) {
  248. uint32_t tmp = a;
  249. a = b;
  250. b = tmp;
  251. }
  252. } // namespace aphrodite
  253. #endif // APHRODITE_UTILS_CUH_