test_causal_conv1d.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from typing import Optional
  2. import pytest
  3. import torch
  4. import torch.nn.functional as F
  5. from einops import rearrange
  6. from aphrodite.modeling.layers.mamba.ops.causal_conv1d import (
  7. causal_conv1d_fn, causal_conv1d_update)
  8. def causal_conv1d_ref(
  9. x: torch.Tensor,
  10. weight: torch.Tensor,
  11. bias: Optional[torch.Tensor] = None,
  12. initial_states: Optional[torch.Tensor] = None,
  13. return_final_states: bool = False,
  14. final_states_out: Optional[torch.Tensor] = None,
  15. activation: Optional[str] = "silu",
  16. ):
  17. """
  18. x: (batch, dim, seqlen)
  19. weight: (dim, width)
  20. bias: (dim,)
  21. initial_states: (batch, dim, width - 1)
  22. final_states_out: (batch, dim, width - 1)
  23. out: (batch, dim, seqlen)
  24. """
  25. if activation not in [None, "silu", "swish"]:
  26. raise NotImplementedError("activation must be None, silu, or swish")
  27. dtype_in = x.dtype
  28. x = x.to(weight.dtype)
  29. seqlen = x.shape[-1]
  30. dim, width = weight.shape
  31. if initial_states is None:
  32. out = F.conv1d(x,
  33. weight.unsqueeze(1),
  34. bias,
  35. padding=width - 1,
  36. groups=dim)
  37. else:
  38. x = torch.cat([initial_states, x], dim=-1)
  39. out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
  40. out = out[..., :seqlen]
  41. if return_final_states:
  42. final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
  43. dtype_in) # (batch, dim, width - 1)
  44. if final_states_out is not None:
  45. final_states_out.copy_(final_states)
  46. else:
  47. final_states_out = final_states
  48. out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
  49. return (out, None) if not return_final_states else (out, final_states_out)
  50. def causal_conv1d_update_ref(x: torch.Tensor,
  51. conv_state: torch.Tensor,
  52. weight: torch.Tensor,
  53. bias: Optional[torch.Tensor] = None,
  54. activation: Optional[str] = None):
  55. """
  56. x: (batch, dim)
  57. conv_state: (batch, dim, width)
  58. weight: (dim, width)
  59. bias: (dim,)
  60. out: (batch, dim)
  61. """
  62. if activation not in [None, "silu", "swish"]:
  63. raise NotImplementedError("activation must be None, silu, or swish")
  64. dtype_in = x.dtype
  65. batch, dim = x.shape
  66. width = weight.shape[1]
  67. assert conv_state.shape == (batch, dim, width)
  68. assert weight.shape == (dim, width)
  69. conv_state.copy_(torch.roll(conv_state, shifts=-1,
  70. dims=-1)) # Update state (B D W)
  71. conv_state[:, :, -1] = x
  72. out = torch.sum(conv_state * weight, dim=-1) # (B D)
  73. if bias is not None:
  74. out += bias
  75. return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
  76. @pytest.mark.parametrize("return_final_states", [False, True])
  77. @pytest.mark.parametrize("has_initial_states", [False, True])
  78. @pytest.mark.parametrize("channel_last", [False, True])
  79. @pytest.mark.parametrize("itype", [torch.bfloat16])
  80. @pytest.mark.parametrize("silu_activation", [False, True])
  81. @pytest.mark.parametrize("has_bias", [False, True])
  82. @pytest.mark.parametrize("width", [4])
  83. @pytest.mark.parametrize("seqlen", [128, 512, 4096])
  84. @pytest.mark.parametrize('dim', [64, 4096 + 32])
  85. @pytest.mark.parametrize('batch', [1, 2])
  86. def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation,
  87. itype, channel_last, has_initial_states,
  88. return_final_states):
  89. if not channel_last and (has_initial_states or return_final_states):
  90. pytest.skip(
  91. "Only channel_last support initial_states or return_final_states")
  92. device = "cuda"
  93. rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
  94. if itype == torch.bfloat16:
  95. rtol, atol = 1e-2, 5e-2
  96. # set seed
  97. torch.random.manual_seed(0)
  98. if not channel_last:
  99. x = torch.randn(batch,
  100. 4096 + dim + 64,
  101. seqlen,
  102. device=device,
  103. dtype=itype)[:, 4096:4096 + dim, :]
  104. else:
  105. x = rearrange(
  106. torch.randn(batch,
  107. seqlen,
  108. 4096 + dim + 64,
  109. device=device,
  110. dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s")
  111. weight = torch.randn(dim, width, device=device, dtype=itype)
  112. bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
  113. if has_initial_states:
  114. initial_states = torch.randn(batch,
  115. width - 1,
  116. dim,
  117. device=device,
  118. dtype=itype).transpose(1, 2)
  119. else:
  120. initial_states = None
  121. x_ref = x.detach().clone()
  122. weight_ref = weight.detach().clone()
  123. bias_ref = bias.detach().clone() if bias is not None else None
  124. initial_states_ref = initial_states.detach().clone(
  125. ) if initial_states is not None else None
  126. activation = None if not silu_activation else "silu"
  127. out, final_states = causal_conv1d_fn(
  128. x,
  129. weight,
  130. bias,
  131. initial_states=initial_states,
  132. return_final_states=return_final_states,
  133. activation=activation)
  134. out_ref, final_states_ref = causal_conv1d_ref(
  135. x_ref,
  136. weight_ref,
  137. bias_ref,
  138. initial_states=initial_states_ref,
  139. return_final_states=return_final_states,
  140. activation=activation)
  141. if return_final_states:
  142. assert final_states is not None and final_states_ref is not None
  143. assert torch.allclose(final_states,
  144. final_states_ref,
  145. rtol=rtol,
  146. atol=atol)
  147. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
  148. if return_final_states:
  149. out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
  150. out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
  151. @pytest.mark.parametrize("itype", [torch.bfloat16])
  152. @pytest.mark.parametrize("silu_activation", [False, True])
  153. @pytest.mark.parametrize("has_bias", [False, True])
  154. @pytest.mark.parametrize("width", [2, 3, 4])
  155. @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
  156. @pytest.mark.parametrize("batch", [1, 2])
  157. def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation,
  158. itype):
  159. device = "cuda"
  160. rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
  161. if itype == torch.bfloat16:
  162. rtol, atol = 1e-2, 5e-2
  163. # set seed
  164. torch.random.manual_seed(0)
  165. batch = 2
  166. x = torch.randn(batch, dim, device=device, dtype=itype)
  167. conv_state = torch.randn(batch, dim, width, device=device, dtype=itype)
  168. weight = torch.randn(dim,
  169. width,
  170. device=device,
  171. dtype=itype,
  172. requires_grad=True)
  173. if has_bias:
  174. bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
  175. else:
  176. bias = None
  177. conv_state_ref = conv_state.detach().clone()
  178. activation = None if not silu_activation else "silu"
  179. out = causal_conv1d_update(x,
  180. conv_state,
  181. weight,
  182. bias,
  183. activation=activation)
  184. out_ref = causal_conv1d_update_ref(x,
  185. conv_state_ref,
  186. weight,
  187. bias,
  188. activation=activation)
  189. assert torch.equal(conv_state, conv_state_ref)
  190. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
  191. @pytest.mark.parametrize("itype",
  192. [torch.float32, torch.float16, torch.bfloat16])
  193. @pytest.mark.parametrize("silu_activation", [False, True])
  194. @pytest.mark.parametrize("has_bias", [False, True])
  195. @pytest.mark.parametrize("seqlen", [1, 4, 5])
  196. @pytest.mark.parametrize("width", [2, 3, 4])
  197. @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
  198. def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias,
  199. silu_activation, itype):
  200. device = "cuda"
  201. rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
  202. if itype == torch.bfloat16:
  203. rtol, atol = 1e-2, 5e-2
  204. # set seed
  205. torch.random.manual_seed(0)
  206. batch = 64
  207. x = torch.randn(batch, dim, device=device, dtype=itype)
  208. total_entries = 10 * batch
  209. conv_state = torch.randn(total_entries,
  210. dim,
  211. width,
  212. device=device,
  213. dtype=itype)
  214. conv_state_indices = torch.randperm(total_entries)[:batch].to(
  215. dtype=torch.int32, device=device)
  216. weight = torch.randn(dim,
  217. width,
  218. device=device,
  219. dtype=itype,
  220. requires_grad=True)
  221. if has_bias:
  222. bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True)
  223. else:
  224. bias = None
  225. conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
  226. activation = None if not silu_activation else "silu"
  227. out = causal_conv1d_update(x,
  228. conv_state,
  229. weight,
  230. bias,
  231. activation=activation,
  232. conv_state_indices=conv_state_indices)
  233. out_ref = causal_conv1d_update_ref(x,
  234. conv_state_ref,
  235. weight,
  236. bias,
  237. activation=activation)
  238. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  239. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  240. assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
  241. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)