test_causal_conv1d.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. from typing import Optional
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange
  6. from aphrodite.modeling.layers.mamba.ops.causal_conv1d import (
  7. causal_conv1d_fn, causal_conv1d_update)
  8. def causal_conv1d_ref(
  9. x: torch.Tensor,
  10. weight: torch.Tensor,
  11. bias: Optional[torch.Tensor] = None,
  12. initial_states: Optional[torch.Tensor] = None,
  13. return_final_states: bool = False,
  14. final_states_out: Optional[torch.Tensor] = None,
  15. activation: Optional[str] = "silu",
  16. ):
  17. """
  18. x: (batch, dim, seqlen)
  19. weight: (dim, width)
  20. bias: (dim,)
  21. initial_states: (batch, dim, width - 1)
  22. final_states_out: (batch, dim, width - 1)
  23. out: (batch, dim, seqlen)
  24. """
  25. if activation not in [None, "silu", "swish"]:
  26. raise NotImplementedError("activation must be None, silu, or swish")
  27. dtype_in = x.dtype
  28. x = x.to(weight.dtype)
  29. seqlen = x.shape[-1]
  30. dim, width = weight.shape
  31. if initial_states is None:
  32. out = F.conv1d(
  33. x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim
  34. )
  35. else:
  36. x = torch.cat([initial_states, x], dim=-1)
  37. out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
  38. out = out[..., :seqlen]
  39. if return_final_states:
  40. final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
  41. dtype_in
  42. ) # (batch, dim, width - 1)
  43. if final_states_out is not None:
  44. final_states_out.copy_(final_states)
  45. else:
  46. final_states_out = final_states
  47. out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
  48. return (out, None) if not return_final_states else (out, final_states_out)
  49. def causal_conv1d_update_ref(
  50. x: torch.Tensor,
  51. conv_state: torch.Tensor,
  52. weight: torch.Tensor,
  53. bias: Optional[torch.Tensor] = None,
  54. activation: Optional[str] = None,
  55. ):
  56. """
  57. x: (batch, dim)
  58. conv_state: (batch, dim, width)
  59. weight: (dim, width)
  60. bias: (dim,)
  61. out: (batch, dim)
  62. """
  63. if activation not in [None, "silu", "swish"]:
  64. raise NotImplementedError("activation must be None, silu, or swish")
  65. dtype_in = x.dtype
  66. batch, dim = x.shape
  67. width = weight.shape[1]
  68. assert conv_state.shape == (batch, dim, width)
  69. assert weight.shape == (dim, width)
  70. conv_state.copy_(
  71. torch.roll(conv_state, shifts=-1, dims=-1)
  72. ) # Update state (B D W)
  73. conv_state[:, :, -1] = x
  74. out = torch.sum(conv_state * weight, dim=-1) # (B D)
  75. if bias is not None:
  76. out += bias
  77. return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
  78. @pytest.mark.parametrize("return_final_states", [False, True])
  79. @pytest.mark.parametrize("has_initial_states", [False, True])
  80. @pytest.mark.parametrize("channel_last", [False, True])
  81. @pytest.mark.parametrize("itype", [torch.bfloat16])
  82. @pytest.mark.parametrize("silu_activation", [False, True])
  83. @pytest.mark.parametrize("has_bias", [False, True])
  84. @pytest.mark.parametrize("width", [4])
  85. @pytest.mark.parametrize("seqlen", [128, 512, 4096])
  86. @pytest.mark.parametrize("dim", [64, 4096 + 32])
  87. @pytest.mark.parametrize("batch", [1, 2])
  88. def test_causal_conv1d(
  89. batch,
  90. dim,
  91. seqlen,
  92. width,
  93. has_bias,
  94. silu_activation,
  95. itype,
  96. channel_last,
  97. has_initial_states,
  98. return_final_states,
  99. ):
  100. if not channel_last and (has_initial_states or return_final_states):
  101. pytest.skip(
  102. "Only channel_last support initial_states or return_final_states"
  103. )
  104. device = "cuda"
  105. rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
  106. if itype == torch.bfloat16:
  107. rtol, atol = 1e-2, 5e-2
  108. # set seed
  109. torch.random.manual_seed(0)
  110. if not channel_last:
  111. x = torch.randn(
  112. batch, 4096 + dim + 64, seqlen, device=device, dtype=itype
  113. )[:, 4096 : 4096 + dim, :]
  114. else:
  115. x = rearrange(
  116. torch.randn(
  117. batch, seqlen, 4096 + dim + 64, device=device, dtype=itype
  118. )[:, :, 4096 : 4096 + dim],
  119. "b s d -> b d s",
  120. )
  121. weight = torch.randn(dim, width, device=device, dtype=itype)
  122. bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
  123. if has_initial_states:
  124. initial_states = torch.randn(
  125. batch, width - 1, dim, device=device, dtype=itype
  126. ).transpose(1, 2)
  127. else:
  128. initial_states = None
  129. x_ref = x.detach().clone()
  130. weight_ref = weight.detach().clone()
  131. bias_ref = bias.detach().clone() if bias is not None else None
  132. initial_states_ref = (
  133. initial_states.detach().clone() if initial_states is not None else None
  134. )
  135. activation = None if not silu_activation else "silu"
  136. out, final_states = causal_conv1d_fn(
  137. x,
  138. weight,
  139. bias,
  140. initial_states=initial_states,
  141. return_final_states=return_final_states,
  142. activation=activation,
  143. )
  144. out_ref, final_states_ref = causal_conv1d_ref(
  145. x_ref,
  146. weight_ref,
  147. bias_ref,
  148. initial_states=initial_states_ref,
  149. return_final_states=return_final_states,
  150. activation=activation,
  151. )
  152. if return_final_states:
  153. assert final_states is not None and final_states_ref is not None
  154. assert torch.allclose(
  155. final_states, final_states_ref, rtol=rtol, atol=atol
  156. )
  157. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
  158. if return_final_states:
  159. out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
  160. out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
  161. @pytest.mark.parametrize("itype", [torch.bfloat16])
  162. @pytest.mark.parametrize("silu_activation", [False, True])
  163. @pytest.mark.parametrize("has_bias", [False, True])
  164. @pytest.mark.parametrize("width", [2, 3, 4])
  165. @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
  166. @pytest.mark.parametrize("batch", [1, 2])
  167. def test_causal_conv1d_update(
  168. batch, dim, width, has_bias, silu_activation, itype
  169. ):
  170. device = "cuda"
  171. rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
  172. if itype == torch.bfloat16:
  173. rtol, atol = 1e-2, 5e-2
  174. # set seed
  175. torch.random.manual_seed(0)
  176. batch = 2
  177. x = torch.randn(batch, dim, device=device, dtype=itype)
  178. conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
  179. weight = torch.randn(
  180. dim, width, device=device, dtype=itype, requires_grad=True
  181. )
  182. if has_bias:
  183. bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
  184. else:
  185. bias = None
  186. conv_state_ref = conv_state.detach().clone()
  187. activation = None if not silu_activation else "silu"
  188. out = causal_conv1d_update(
  189. x, conv_state, weight, bias, activation=activation
  190. )
  191. out_ref = causal_conv1d_update_ref(
  192. x, conv_state_ref, weight, bias, activation=activation
  193. )
  194. assert torch.equal(conv_state, conv_state_ref)
  195. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)