marlin_moe_kernel_ku8b128.h 716 B

123456789101112131415161718
  1. #pragma once
  2. #include "marlin_moe_kernel.h"
  3. namespace marlin_moe {
  4. bool call_marlin_moe_kernel_ku8b128(
  5. aphrodite::ScalarType const& q_type, int thread_n_blocks,
  6. int thread_k_blocks, bool has_act_order, int group_blocks, int num_threads,
  7. int blocks, int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
  8. const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
  9. const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
  10. int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
  11. int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
  12. bool replicate_input, bool apply_weights, int m_block, int max_par,
  13. int cfg_max_m_blocks);
  14. }