marlin_moe_ops.cu 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. /*
  2. * Modified by Neural Magic
  3. * Copyright (C) Marlin.2024 Elias Frantar
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #include <torch/all.h>
  18. #include <ATen/cuda/CUDAContext.h>
  19. #include <c10/cuda/CUDAGuard.h>
  20. #include <cuda.h>
  21. #include <cuda_fp16.h>
  22. #include <cuda_runtime.h>
  23. #include <iostream>
  24. #include "core/scalar_type.hpp"
  25. #include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
  26. #include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
  27. template <typename T>
  28. inline std::string str(T x) {
  29. return std::to_string(x);
  30. }
  31. namespace marlin_moe {
  32. #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  33. // For a given "a" of size [M,K] performs a permutation of the K columns based
  34. // on the given "perm" indices.
  35. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
  36. int const* __restrict__ perm_int_ptr,
  37. int4* __restrict__ out_int4_ptr, int size_m,
  38. int size_k, int block_rows) {
  39. int start_row = block_rows * blockIdx.x;
  40. int finish_row = start_row + block_rows;
  41. if (finish_row > size_m) {
  42. finish_row = size_m;
  43. }
  44. int cur_block_rows = finish_row - start_row;
  45. int row_stride = size_k * sizeof(half) / 16;
  46. auto permute_row = [&](int row) {
  47. int iters = size_k / blockDim.x;
  48. int rest = size_k % blockDim.x;
  49. int offset = row * row_stride;
  50. half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
  51. half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
  52. int base_k = 0;
  53. for (int i = 0; i < iters; i++) {
  54. int cur_k = base_k + threadIdx.x;
  55. int src_pos = perm_int_ptr[cur_k];
  56. out_half[cur_k] = a_row_half[src_pos];
  57. base_k += blockDim.x;
  58. }
  59. if (rest) {
  60. if (threadIdx.x < rest) {
  61. int cur_k = base_k + threadIdx.x;
  62. int src_pos = perm_int_ptr[cur_k];
  63. out_half[cur_k] = a_row_half[src_pos];
  64. }
  65. }
  66. };
  67. for (int i = 0; i < cur_block_rows; i++) {
  68. int cur_row = start_row + i;
  69. if (cur_row < size_m) {
  70. permute_row(cur_row);
  71. }
  72. }
  73. }
  74. __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
  75. int* __restrict__ expert_offsets,
  76. int topk_length, int block_size) {
  77. int expert_id = threadIdx.x;
  78. int num_experts = blockDim.x;
  79. int occurrences = 0;
  80. for (int i = 0; i < topk_length; ++i) {
  81. occurrences += (topk_ids[i] == expert_id);
  82. }
  83. expert_offsets[expert_id + 1] = occurrences;
  84. __syncthreads();
  85. if (threadIdx.x == 0) {
  86. int tot_offset = 0;
  87. expert_offsets[0] = 0;
  88. for (int i = 0; i < num_experts; ++i) {
  89. tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size;
  90. expert_offsets[i + 1] = tot_offset;
  91. }
  92. }
  93. __syncthreads();
  94. }
  95. #else
  96. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
  97. int const* __restrict__ perm_int_ptr,
  98. int4* __restrict__ out_int4_ptr, int size_m,
  99. int size_k, int block_rows) {
  100. // Marlin is not implemented yet for SM < 8.0
  101. assert(false);
  102. return;
  103. }
  104. __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
  105. int* __restrict__ expert_offsets,
  106. int topk_length, int block_size) {
  107. // Marlin is not implemented yet for SM < 8.0
  108. assert(false);
  109. return;
  110. }
  111. #endif
  112. typedef struct {
  113. int thread_k;
  114. int thread_n;
  115. int num_threads;
  116. } thread_config_t;
  117. typedef struct {
  118. int max_m_blocks;
  119. thread_config_t tb_cfg;
  120. } exec_config_t;
  121. thread_config_t small_batch_thread_configs[] = {
  122. // Ordered by priority
  123. // thread_k, thread_n, num_threads
  124. {128, 128, 256}, // Default
  125. {128, 64, 128}, // Reduce N 2X, same K
  126. {64, 256, 256}, // Reduce K 2X, increase N 2X
  127. {64, 128, 128}, // Reduce K 2X, same N
  128. };
  129. thread_config_t large_batch_thread_configs[] = {
  130. // Ordered by priority
  131. // thread_k, thread_n, num_threads
  132. {64, 256, 256}, // Default
  133. {128, 128, 256}, // Reduce N 2X, increase K 2X
  134. {64, 128, 128}, // Reduce N 2X, same K
  135. {128, 64, 128}, // Reduce N 4X, increase K 2X
  136. };
  137. int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
  138. int prob_n, int prob_k, int num_bits, int group_size,
  139. bool has_act_order, bool is_k_full) {
  140. bool cache_scales_chunk = has_act_order && !is_k_full;
  141. int tb_n = th_config.thread_n;
  142. int tb_k = th_config.thread_k;
  143. // Get max scale groups per thread-block
  144. int tb_groups;
  145. if (group_size == -1) {
  146. tb_groups = 1;
  147. } else if (group_size == 0) {
  148. tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size
  149. } else {
  150. tb_groups = ceildiv(tb_k, group_size);
  151. }
  152. if (cache_scales_chunk) {
  153. int load_groups =
  154. tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
  155. load_groups = max(load_groups, 32); // We load at least 32 scale groups
  156. return load_groups * tb_n * 2;
  157. } else {
  158. int tb_scales = tb_groups * tb_n * 2;
  159. return tb_scales * STAGES;
  160. }
  161. }
  162. bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
  163. int prob_m, int prob_n, int prob_k, int num_bits,
  164. int scales_cache_size, int max_shared_mem) {
  165. int pack_factor = 32 / num_bits;
  166. // Get B size
  167. int tb_k = th_config.thread_k;
  168. int tb_n = th_config.thread_n;
  169. int b_size = (tb_k * tb_n / pack_factor) * 4;
  170. // Get A size
  171. int m_blocks = ceildiv(prob_m, 16);
  172. int tb_max_m = 16;
  173. while (true) {
  174. if (m_blocks >= max_m_blocks) {
  175. tb_max_m *= max_m_blocks;
  176. break;
  177. }
  178. max_m_blocks--;
  179. if (max_m_blocks == 0) {
  180. TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
  181. }
  182. }
  183. int a_size = (tb_max_m * tb_k) * 2;
  184. float pipe_size = (a_size + b_size) * STAGES;
  185. TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
  186. return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
  187. }
  188. bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
  189. int prob_m, int prob_n, int prob_k, int num_bits,
  190. int group_size, bool has_act_order, bool is_k_full,
  191. int max_shared_mem) {
  192. // Sanity
  193. if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
  194. th_config.num_threads == -1) {
  195. return false;
  196. }
  197. // Verify K/N are divisible by thread K/N
  198. if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
  199. return false;
  200. }
  201. // thread_k can be only 128 or 64 (because it must be less than groupsize
  202. // which is 128)
  203. if (th_config.thread_k != 128 && th_config.thread_k != 64) {
  204. return false;
  205. }
  206. // Verify min for thread K/N
  207. if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
  208. return false;
  209. }
  210. // num_threads must be at least 128 (= 4 warps)
  211. if (th_config.num_threads < 128) {
  212. return false;
  213. }
  214. // Determine cache for scales
  215. int scales_cache_size =
  216. get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
  217. group_size, has_act_order, is_k_full);
  218. // Check that pipeline fits into cache
  219. if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  220. num_bits, scales_cache_size, max_shared_mem)) {
  221. return false;
  222. }
  223. return true;
  224. }
  225. exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
  226. int num_bits, int group_size,
  227. bool has_act_order, bool is_k_full,
  228. int max_shared_mem) {
  229. int max_m_blocks = 4;
  230. while (max_m_blocks > 0) {
  231. if (prob_m <= 16) {
  232. for (auto th_config : small_batch_thread_configs) {
  233. if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  234. num_bits, group_size, has_act_order, is_k_full,
  235. max_shared_mem)) {
  236. return exec_config_t{max_m_blocks, th_config};
  237. }
  238. }
  239. } else {
  240. for (auto th_config : large_batch_thread_configs) {
  241. if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
  242. num_bits, group_size, has_act_order, is_k_full,
  243. max_shared_mem)) {
  244. return exec_config_t{max_m_blocks, th_config};
  245. }
  246. }
  247. }
  248. max_m_blocks--; // Process less M blocks per invocation to reduce cache
  249. // usage
  250. }
  251. return exec_config_t{0, {-1, -1, -1}};
  252. }
  253. #define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
  254. else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \
  255. has_act_order, group_blocks, num_threads, blocks, \
  256. max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \
  257. sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \
  258. expert_offsets_ptr, num_groups, expert_idx, \
  259. num_experts, topk, prob_m, prob_n, prob_k, tot_m, \
  260. locks, replicate_input, apply_weights, m_block, \
  261. max_par, exec_cfg.max_m_blocks)) { \
  262. }
  263. void marlin_mm_moe(const void* A, const void* B, void* C,
  264. const void* sorted_ids, const void* topk_weights,
  265. const void* topk_ids, const void* s, const void* g_idx,
  266. const void* perm, void* a_tmp, void* expert_offsets,
  267. int prob_m, int prob_n, int prob_k, void* workspace,
  268. aphrodite::ScalarType const& q_type, bool has_act_order,
  269. bool is_k_full, int num_groups, int group_size,
  270. int num_experts, int topk, int moe_block_size, int dev,
  271. cudaStream_t stream, int thread_k, int thread_n, int sms,
  272. int max_par, bool replicate_input, bool apply_weights) {
  273. TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
  274. ", ", prob_n, ", ", prob_k, "]");
  275. if (sms == -1) {
  276. cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
  277. }
  278. int max_shared_mem = 0;
  279. cudaDeviceGetAttribute(&max_shared_mem,
  280. cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
  281. TORCH_CHECK(max_shared_mem > 0);
  282. int num_bits = q_type.size_bits();
  283. // Set thread config
  284. exec_config_t exec_cfg;
  285. if (thread_k != -1 && thread_n != -1) {
  286. // User-defined config
  287. exec_cfg =
  288. exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}};
  289. } else {
  290. // Auto config
  291. exec_cfg =
  292. determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
  293. has_act_order, is_k_full, max_shared_mem);
  294. }
  295. TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
  296. is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
  297. prob_m, prob_n, prob_k, num_bits, group_size,
  298. has_act_order, is_k_full, max_shared_mem),
  299. "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
  300. ", thread_k = ", exec_cfg.tb_cfg.thread_k,
  301. ", thread_n = ", exec_cfg.tb_cfg.thread_n,
  302. ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
  303. prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
  304. ", group_size = ", group_size,
  305. ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
  306. ", max_shared_mem = ", max_shared_mem);
  307. int num_threads = exec_cfg.tb_cfg.num_threads;
  308. thread_k = exec_cfg.tb_cfg.thread_k;
  309. thread_n = exec_cfg.tb_cfg.thread_n;
  310. int thread_k_blocks = thread_k / 16;
  311. int thread_n_blocks = thread_n / 16;
  312. int blocks = sms;
  313. TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
  314. " is not divisible by thread_n = ", thread_n);
  315. TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
  316. " is not divisible by thread_k = ", thread_k);
  317. int group_blocks = 0;
  318. if (has_act_order) {
  319. if (is_k_full) {
  320. TORCH_CHECK(group_size != -1);
  321. group_blocks = group_size / 16;
  322. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  323. " is not divisible by group_blocks = ", group_blocks);
  324. } else {
  325. TORCH_CHECK(group_size == 0);
  326. group_blocks = 0;
  327. }
  328. } else {
  329. if (group_size == -1) {
  330. group_blocks = -1;
  331. } else {
  332. group_blocks = group_size / 16;
  333. TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k,
  334. " is not divisible by group_blocks = ", group_blocks);
  335. }
  336. }
  337. int tot_m = prob_m;
  338. const int* topk_ids_ptr = (const int*)topk_ids;
  339. int* expert_offsets_ptr = (int*)expert_offsets;
  340. compute_expert_offsets<<<1, num_experts, 0, stream>>>(
  341. topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size);
  342. bool do_permute_a = has_act_order;
  343. // If we have a full K, then we can run the non-act-order version of Marlin
  344. // (since the weight rows are reordered by increasing group ids, and by
  345. // having a full K, we have full original groups)
  346. if (is_k_full) {
  347. has_act_order = false;
  348. }
  349. int pack_factor = 32 / q_type.size_bits();
  350. for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) {
  351. const int4* A_ptr = (const int4*)A;
  352. int4* a_tmp_ptr = (int4*)a_tmp;
  353. const int4* B_ptr =
  354. (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx;
  355. int4* C_ptr = (int4*)C;
  356. const float* topk_weights_ptr = (const float*)topk_weights;
  357. const int* sorted_ids_ptr = (const int*)sorted_ids;
  358. const int4* s_ptr =
  359. (const int4*)s +
  360. (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
  361. prob_n / 8) *
  362. expert_idx;
  363. const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
  364. const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
  365. int* locks = (int*)workspace;
  366. if (do_permute_a) {
  367. // Permute A columns
  368. int topk_rows = replicate_input ? tot_m : tot_m * topk;
  369. int block_rows = ceildiv(topk_rows, blocks);
  370. permute_cols_kernel<<<blocks, num_threads, 0, stream>>>(
  371. A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows);
  372. A_ptr = a_tmp_ptr;
  373. }
  374. int tot_m_blocks = ceildiv(tot_m, 16);
  375. for (int m_block = 0; m_block < tot_m_blocks;
  376. m_block += 4 * exec_cfg.max_m_blocks) {
  377. if (false) {
  378. }
  379. CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
  380. CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
  381. else {
  382. TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
  383. str(prob_n) + ", " + str(prob_k) + "]" +
  384. ", has_act_order = " + str(has_act_order) +
  385. ", num_groups = " + str(num_groups) +
  386. ", group_size = " + str(group_size) +
  387. ", thread_n_blocks = " + str(thread_n_blocks) +
  388. ", thread_k_blocks = " + str(thread_k_blocks));
  389. }
  390. }
  391. }
  392. }
  393. } // namespace marlin_moe
  394. torch::Tensor marlin_gemm_moe(
  395. const torch::Tensor& a, const torch::Tensor& b_q_weights,
  396. const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
  397. const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
  398. const torch::Tensor& g_idx, const torch::Tensor& perm,
  399. torch::Tensor& workspace, aphrodite::ScalarTypeTorchPtr const& b_q_type,
  400. int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
  401. int64_t num_experts, int64_t topk, int64_t moe_block_size,
  402. bool replicate_input, bool apply_weights) {
  403. TORCH_CHECK(*b_q_type == aphrodite::kU4B8 || *b_q_type == aphrodite::kU8B128,
  404. "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
  405. int pack_factor = 32 / b_q_type->size_bits();
  406. int max_par = 4;
  407. int dev = a.get_device();
  408. auto options_dtype =
  409. torch::TensorOptions().dtype(a.dtype()).device(a.device());
  410. auto options_int =
  411. torch::TensorOptions().dtype(torch::kInt).device(a.device());
  412. torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype);
  413. torch::Tensor a_tmp =
  414. replicate_input ? torch::zeros({size_m, size_k}, options_dtype)
  415. : torch::zeros({size_m, topk, size_k}, options_dtype);
  416. torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int);
  417. // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
  418. // auto -1)
  419. int thread_k = -1;
  420. // thread_n: `n` size of a thread_tile in `weights` (can usually be left as
  421. // auto -1)
  422. int thread_n = -1;
  423. // sms: number of SMs to use for the kernel (can usually be left as auto -1)
  424. int sms = -1;
  425. // Detect groupsize and act_order
  426. int num_groups = -1;
  427. int group_size = -1;
  428. bool has_act_order = g_idx.size(1) != 0;
  429. int b_rank = b_scales.sizes().size();
  430. TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3");
  431. TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2),
  432. " is not size_n = ", size_n);
  433. num_groups = b_scales.size(1);
  434. if (has_act_order) {
  435. if (is_k_full) {
  436. TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
  437. TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k,
  438. ", is not divisible by num_groups = ", num_groups);
  439. group_size = size_k / num_groups;
  440. } else {
  441. group_size = 0;
  442. }
  443. } else {
  444. if (num_groups > 1) {
  445. TORCH_CHECK(
  446. size_k % num_groups == 0, "size_k = ", size_k,
  447. ", is not divisible by b_scales.size(0) = ", b_scales.size(0));
  448. group_size = size_k / num_groups;
  449. } else {
  450. group_size = -1;
  451. }
  452. }
  453. marlin_moe::marlin_mm_moe(
  454. a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
  455. topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
  456. g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
  457. expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
  458. *b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts,
  459. topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
  460. thread_n, sms, max_par, replicate_input, apply_weights);
  461. return c;
  462. }