1
0

causal_conv1d.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. # Copyright (c) 2024, Tri Dao.
  2. # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
  3. from typing import Optional
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. def causal_conv1d_fn(
  7. x: torch.Tensor,
  8. weight: torch.Tensor,
  9. bias: Optional[torch.Tensor] = None,
  10. seq_idx: Optional[torch.Tensor] = None,
  11. initial_states: Optional[torch.Tensor] = None,
  12. return_final_states: bool = False,
  13. final_states_out=None,
  14. activation: str = "silu",
  15. ):
  16. """
  17. x: (batch, dim, seqlen)
  18. weight: (dim, width)
  19. bias: (dim,)
  20. seq_idx: (batch, seqlen)
  21. initial_states: (batch, dim, width - 1)
  22. final_states_out: (batch, dim, width - 1), to be written to
  23. activation: either None or "silu" or "swish"
  24. out: (batch, dim, seqlen)
  25. """
  26. if activation not in [None, "silu", "swish"]:
  27. raise NotImplementedError("activation must be None, silu, or swish")
  28. if x.stride(2) != 1 and x.stride(1) != 1:
  29. x = x.contiguous()
  30. bias = bias.contiguous() if bias is not None else None
  31. if seq_idx is not None:
  32. assert (initial_states is
  33. None), "initial_states must be None if seq_idx is not None"
  34. assert (not return_final_states
  35. ), "If seq_idx is not None, we don't return final_states_out"
  36. seq_idx = seq_idx.contiguous() if seq_idx is not None else None
  37. if initial_states is not None and (initial_states.stride(2) != 1
  38. and initial_states.stride(1) != 1):
  39. initial_states = initial_states.contiguous()
  40. if return_final_states:
  41. assert (
  42. x.stride(1) == 1
  43. ), "Only channel-last layout support returning final_states_out"
  44. if final_states_out is not None:
  45. assert (final_states_out.stride(2) == 1
  46. or final_states_out.stride(1) == 1)
  47. else:
  48. batch, dim, seqlen = x.shape
  49. width = weight.shape[1]
  50. final_states_out = torch.empty(batch,
  51. width - 1,
  52. dim,
  53. device=x.device,
  54. dtype=x.dtype).transpose(1, 2)
  55. else:
  56. final_states_out = None
  57. out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states,
  58. final_states_out, activation
  59. in ["silu", "swish"])
  60. return (out, None) if not return_final_states else (out, final_states_out)
  61. def causal_conv1d_update(x: torch.Tensor,
  62. conv_state: torch.Tensor,
  63. weight: torch.Tensor,
  64. bias: Optional[torch.Tensor] = None,
  65. activation: Optional[str] = None,
  66. conv_state_indices: Optional[torch.Tensor] = None):
  67. """
  68. x: (batch, dim)
  69. conv_state: (batch, dim, width)
  70. weight: (dim, width)
  71. bias: (dim,)
  72. conv_state_indices: (batch,), dtype int32
  73. If not None, the conv_state is a larger tensor along the batch dim,
  74. and we are selecting the batch coords specified by conv_state_indices.
  75. Useful for a continuous batching scenario.
  76. out: (batch, dim)
  77. """
  78. if activation not in [None, "silu", "swish"]:
  79. raise NotImplementedError("activation must be None, silu, or swish")
  80. activation_bool = activation in ["silu", "swish"]
  81. return ops.causal_conv1d_update(x, conv_state, weight, bias,
  82. activation_bool, conv_state_indices)