|
@@ -3,7 +3,8 @@
|
|
|
from typing import Optional
|
|
|
|
|
|
import torch
|
|
|
-from causal_conv1d_cuda import causal_conv1d_fwd, causal_conv1d_update
|
|
|
+
|
|
|
+from aphrodite import _custom_ops as ops
|
|
|
|
|
|
|
|
|
def causal_conv1d_fn(
|
|
@@ -58,12 +59,17 @@ def causal_conv1d_fn(
|
|
|
else:
|
|
|
final_states_out = None
|
|
|
|
|
|
- out = causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
|
|
|
- final_states_out, activation in ["silu", "swish"])
|
|
|
+ out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
|
|
|
+ final_states_out, activation
|
|
|
+ in ["silu", "swish"])
|
|
|
return (out, None) if not return_final_states else (out, final_states_out)
|
|
|
|
|
|
|
|
|
-def causal_conv1d_up(x, conv_state, weight, bias=None, activation=None):
|
|
|
+def causal_conv1d_update(x: torch.Tensor,
|
|
|
+ conv_state: torch.Tensor,
|
|
|
+ weight: torch.Tensor,
|
|
|
+ bias: Optional[torch.Tensor] = None,
|
|
|
+ activation: Optional[str] = None):
|
|
|
"""
|
|
|
x: (batch, dim)
|
|
|
conv_state: (batch, dim, width)
|
|
@@ -73,5 +79,6 @@ def causal_conv1d_up(x, conv_state, weight, bias=None, activation=None):
|
|
|
"""
|
|
|
if activation not in [None, "silu", "swish"]:
|
|
|
raise NotImplementedError("activation must be None, silu, or swish")
|
|
|
- activation = activation in ["silu", "swish"]
|
|
|
- return causal_conv1d_update(x, conv_state, weight, bias, activation)
|
|
|
+ activation_bool = activation in ["silu", "swish"]
|
|
|
+ return ops.causal_conv1d_update(x, conv_state, weight, bias,
|
|
|
+ activation_bool)
|