|
@@ -62,29 +62,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
|
|
torch::Tensor expert_ids,
|
|
|
torch::Tensor num_tokens_post_pad);
|
|
|
|
|
|
-std::vector<torch::Tensor> selective_scan_fwd(
|
|
|
- const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
|
|
- const torch::Tensor& B, const torch::Tensor& C,
|
|
|
- const c10::optional<torch::Tensor>& D_,
|
|
|
- const c10::optional<torch::Tensor>& z_,
|
|
|
- const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
|
|
- const c10::optional<torch::Tensor>& index_,
|
|
|
- const c10::optional<torch::Tensor>& x);
|
|
|
-
|
|
|
-at::Tensor causal_conv1d_update(const at::Tensor& x,
|
|
|
- const at::Tensor& conv_state,
|
|
|
- const at::Tensor& weight,
|
|
|
- const c10::optional<at::Tensor>& bias_,
|
|
|
- bool silu_activation);
|
|
|
-
|
|
|
-at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
|
|
- const c10::optional<at::Tensor>& bias_,
|
|
|
- const c10::optional<at::Tensor>& seq_idx_,
|
|
|
- const c10::optional<at::Tensor>& seq_pos_idx_,
|
|
|
- const c10::optional<at::Tensor>& initial_states_,
|
|
|
- const c10::optional<at::Tensor>& final_states_out_,
|
|
|
- bool silu_activation);
|
|
|
-
|
|
|
#ifndef USE_ROCM
|
|
|
using fptr_t = int64_t;
|
|
|
fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
|
|
@@ -105,4 +82,24 @@ std::tuple<torch::Tensor, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
|
|
fptr_t _fa);
|
|
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string>& handles,
|
|
|
const std::vector<std::vector<int64_t>>& offsets);
|
|
|
+std::vector<torch::Tensor> selective_scan_fwd(
|
|
|
+ const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
|
|
|
+ const torch::Tensor& B, const torch::Tensor& C,
|
|
|
+ const c10::optional<torch::Tensor>& D_,
|
|
|
+ const c10::optional<torch::Tensor>& z_,
|
|
|
+ const c10::optional<torch::Tensor>& delta_bias_, bool delta_softplus,
|
|
|
+ const c10::optional<torch::Tensor>& index_,
|
|
|
+ const c10::optional<torch::Tensor>& x);
|
|
|
+at::Tensor causal_conv1d_update(const at::Tensor& x,
|
|
|
+ const at::Tensor& conv_state,
|
|
|
+ const at::Tensor& weight,
|
|
|
+ const c10::optional<at::Tensor>& bias_,
|
|
|
+ bool silu_activation);
|
|
|
+at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
|
|
+ const c10::optional<at::Tensor>& bias_,
|
|
|
+ const c10::optional<at::Tensor>& seq_idx_,
|
|
|
+ const c10::optional<at::Tensor>& seq_pos_idx_,
|
|
|
+ const c10::optional<at::Tensor>& initial_states_,
|
|
|
+ const c10::optional<at::Tensor>& final_states_out_,
|
|
|
+ bool silu_activation);
|
|
|
#endif
|