causal_conv1d.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2024, Tri Dao.
  2. from typing import Optional
  3. import torch
  4. from aphrodite import _custom_ops as ops
  5. def causal_conv1d_fn(
  6. x: torch.Tensor,
  7. weight: torch.Tensor,
  8. bias: Optional[torch.Tensor] = None,
  9. seq_idx: Optional[torch.Tensor] = None,
  10. initial_states: Optional[torch.Tensor] = None,
  11. return_final_states: bool = False,
  12. final_states_out=None,
  13. activation: str = "silu",
  14. ):
  15. """
  16. x: (batch, dim, seqlen)
  17. weight: (dim, width)
  18. bias: (dim,)
  19. seq_idx: (batch, seqlen)
  20. initial_states: (batch, dim, width - 1)
  21. final_states_out: (batch, dim, width - 1), to be written to
  22. activation: either None or "silu" or "swish"
  23. out: (batch, dim, seqlen)
  24. """
  25. if activation not in [None, "silu", "swish"]:
  26. raise NotImplementedError("activation must be None, silu, or swish")
  27. if x.stride(2) != 1 and x.stride(1) != 1:
  28. x = x.contiguous()
  29. bias = bias.contiguous() if bias is not None else None
  30. if seq_idx is not None:
  31. assert (initial_states is
  32. None), "initial_states must be None if seq_idx is not None"
  33. assert (not return_final_states
  34. ), "If seq_idx is not None, we don't return final_states_out"
  35. seq_idx = seq_idx.contiguous() if seq_idx is not None else None
  36. if initial_states is not None and (initial_states.stride(2) != 1
  37. and initial_states.stride(1) != 1):
  38. initial_states = initial_states.contiguous()
  39. if return_final_states:
  40. assert (
  41. x.stride(1) == 1
  42. ), "Only channel-last layout support returning final_states_out"
  43. if final_states_out is not None:
  44. assert (final_states_out.stride(2) == 1
  45. or final_states_out.stride(1) == 1)
  46. else:
  47. batch, dim, seqlen = x.shape
  48. width = weight.shape[1]
  49. final_states_out = torch.empty(batch,
  50. width - 1,
  51. dim,
  52. device=x.device,
  53. dtype=x.dtype).transpose(1, 2)
  54. else:
  55. final_states_out = None
  56. out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
  57. final_states_out, activation
  58. in ["silu", "swish"])
  59. return (out, None) if not return_final_states else (out, final_states_out)
  60. def causal_conv1d_update(x: torch.Tensor,
  61. conv_state: torch.Tensor,
  62. weight: torch.Tensor,
  63. bias: Optional[torch.Tensor] = None,
  64. activation: Optional[str] = None):
  65. """
  66. x: (batch, dim)
  67. conv_state: (batch, dim, width)
  68. weight: (dim, width)
  69. bias: (dim,)
  70. out: (batch, dim)
  71. """
  72. if activation not in [None, "silu", "swish"]:
  73. raise NotImplementedError("activation must be None, silu, or swish")
  74. activation_bool = activation in ["silu", "swish"]
  75. return ops.causal_conv1d_update(x, conv_state, weight, bias,
  76. activation_bool)