test_mamba_ssm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import pytest
  2. import torch
  3. import torch.nn.functional as F
  4. from einops import rearrange, repeat
  5. from aphrodite.modeling.layers.mamba.ops.mamba_ssm import (
  6. selective_scan_fn, selective_state_update)
  7. def selective_state_update_ref(
  8. state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
  9. ):
  10. """
  11. Argument:
  12. state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
  13. x: (batch, dim) or (batch, nheads, dim)
  14. dt: (batch, dim) or (batch, nheads, dim)
  15. A: (dim, dstate) or (nheads, dim, dstate)
  16. B: (batch, dstate) or (batch, ngroups, dstate)
  17. C: (batch, dstate) or (batch, ngroups, dstate)
  18. D: (dim,) or (nheads, dim)
  19. z: (batch, dim) or (batch, nheads, dim)
  20. dt_bias: (dim,) or (nheads, dim)
  21. Return:
  22. out: (batch, dim) or (batch, nheads, dim)
  23. """
  24. has_heads = state.dim() > 3
  25. if state.dim() == 3:
  26. state = state.unsqueeze(1)
  27. if x.dim() == 2:
  28. x = x.unsqueeze(1)
  29. if dt.dim() == 2:
  30. dt = dt.unsqueeze(1)
  31. if A.dim() == 2:
  32. A = A.unsqueeze(0)
  33. if B.dim() == 2:
  34. B = B.unsqueeze(1)
  35. if C.dim() == 2:
  36. C = C.unsqueeze(1)
  37. if D is not None and D.dim() == 1:
  38. D = D.unsqueeze(0)
  39. if z is not None and z.dim() == 2:
  40. z = z.unsqueeze(1)
  41. if dt_bias is not None and dt_bias.dim() == 1:
  42. dt_bias = dt_bias.unsqueeze(0)
  43. batch, nheads, dim, dstate = state.shape
  44. assert x.shape == (batch, nheads, dim)
  45. assert dt.shape == x.shape
  46. assert A.shape == (nheads, dim, dstate)
  47. ngroups = B.shape[1]
  48. assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
  49. assert B.shape == (batch, ngroups, dstate)
  50. assert C.shape == B.shape
  51. if D is not None:
  52. assert D.shape == (nheads, dim)
  53. if z is not None:
  54. assert z.shape == x.shape
  55. if dt_bias is not None:
  56. assert dt_bias.shape == (nheads, dim)
  57. dt = dt + dt_bias
  58. dt = F.softplus(dt) if dt_softplus else dt
  59. dA = torch.exp(
  60. rearrange(dt, "b h d -> b h d 1") * A
  61. ) # (batch, nheads, dim, dstate)
  62. B = repeat(
  63. B, "b g n -> b (g h) n", h=nheads // ngroups
  64. ) # (batch, nheads, dstate)
  65. C = repeat(
  66. C, "b g n -> b (g h) n", h=nheads // ngroups
  67. ) # (batch, nheads, dstate)
  68. dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
  69. B, "b h n -> b h 1 n"
  70. ) # (batch, nheads, dim, dstate)
  71. state.copy_(
  72. state * dA + dB * rearrange(x, "b h d -> b h d 1")
  73. ) # (batch, dim, dstate
  74. out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
  75. if D is not None:
  76. out += (x * D).to(out.dtype)
  77. out = (out if z is None else out * F.silu(z)).to(x.dtype)
  78. if not has_heads:
  79. out = out.squeeze(1)
  80. return out
  81. def selective_scan_ref(
  82. u,
  83. delta,
  84. A,
  85. B,
  86. C,
  87. D=None,
  88. z=None,
  89. delta_bias=None,
  90. delta_softplus=False,
  91. return_last_state=False,
  92. position_indices=None,
  93. prev_state=None,
  94. ):
  95. """
  96. u: r(B D L)
  97. delta: r(B D L)
  98. A: c(D N) or r(D N)
  99. B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  100. C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  101. D: r(D)
  102. z: r(B D L)
  103. delta_bias: r(D), fp32
  104. prev_state: r(B D N), fp32
  105. out: r(B D L)
  106. last_state (optional): r(B D dstate) or c(B D dstate)
  107. """
  108. dtype_in = u.dtype
  109. u = u.float()
  110. delta = delta.float()
  111. if delta_bias is not None:
  112. delta = delta + delta_bias[..., None].float()
  113. if delta_softplus:
  114. delta = F.softplus(delta)
  115. batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
  116. is_variable_B = B.dim() >= 3
  117. is_variable_C = C.dim() >= 3
  118. B = B.float()
  119. C = C.float()
  120. x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
  121. ys = []
  122. deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
  123. if not is_variable_B:
  124. deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
  125. else:
  126. if B.dim() == 3:
  127. deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
  128. else:
  129. B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
  130. deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
  131. if is_variable_C and C.dim() == 4:
  132. C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
  133. last_state = None
  134. for i in range(u.shape[2]):
  135. if position_indices is not None and position_indices[0, i] == 0:
  136. x = deltaB_u[:, :, i]
  137. else:
  138. x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
  139. if not is_variable_C:
  140. y = torch.einsum("bdn,dn->bd", x, C)
  141. else:
  142. if C.dim() == 3:
  143. y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
  144. else:
  145. y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
  146. if i == u.shape[2] - 1:
  147. last_state = x
  148. ys.append(y)
  149. y = torch.stack(ys, dim=2) # (batch dim L)
  150. out = y if D is None else y + u * rearrange(D, "d -> d 1")
  151. if z is not None:
  152. out = out * F.silu(z)
  153. out = out.to(dtype=dtype_in)
  154. return out if not return_last_state else (out, last_state)
  155. @pytest.mark.parametrize("wtype", [torch.float32])
  156. @pytest.mark.parametrize("itype", [torch.float32])
  157. @pytest.mark.parametrize("seqlen", [128, 256, 512, 1024, 2048, 4096])
  158. @pytest.mark.parametrize("return_last_state", [True])
  159. @pytest.mark.parametrize("has_delta_bias", [True])
  160. @pytest.mark.parametrize("delta_softplus", [True])
  161. @pytest.mark.parametrize("has_z", [True])
  162. @pytest.mark.parametrize("has_D", [True])
  163. @pytest.mark.parametrize("varBC_groups", [1, 2])
  164. @pytest.mark.parametrize("is_variable_C", [True])
  165. @pytest.mark.parametrize("is_variable_B", [True])
  166. @pytest.mark.parametrize("scan_chunks", [1, 2, 3])
  167. def test_selective_scan(
  168. is_variable_B,
  169. is_variable_C,
  170. varBC_groups,
  171. has_D,
  172. has_z,
  173. has_delta_bias,
  174. delta_softplus,
  175. return_last_state,
  176. seqlen,
  177. itype,
  178. wtype,
  179. scan_chunks,
  180. ):
  181. if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
  182. pytest.skip() # This config is not applicable
  183. device = "cuda"
  184. rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
  185. if itype == torch.bfloat16:
  186. rtol, atol = 3e-2, 5e-2
  187. rtolw, atolw = (1e-3, 1e-3)
  188. if has_z: # If we have z, the errors on the weights seem higher
  189. rtolw = max(rtolw, rtol)
  190. atolw = max(atolw, atol)
  191. # set seed
  192. torch.random.manual_seed(0)
  193. batch_size = 2
  194. dim = 4
  195. dstate = 8
  196. A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)
  197. if not is_variable_B:
  198. B_shape = [dim, dstate]
  199. elif varBC_groups == 1:
  200. B_shape = [batch_size, dstate, seqlen]
  201. else:
  202. B_shape = [batch_size, varBC_groups, dstate, seqlen]
  203. B = torch.randn(
  204. B_shape, device=device, dtype=wtype if not is_variable_B else itype
  205. )
  206. if not is_variable_C:
  207. C_shape = [dim, dstate]
  208. elif varBC_groups == 1:
  209. C_shape = [batch_size, dstate, seqlen]
  210. else:
  211. C_shape = [batch_size, varBC_groups, dstate, seqlen]
  212. C = torch.randn(
  213. C_shape, device=device, dtype=wtype if not is_variable_C else itype
  214. )
  215. D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
  216. z = (
  217. torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
  218. if has_z
  219. else None
  220. )
  221. delta_bias = (
  222. (0.5 * torch.rand(dim, device=device, dtype=torch.float32))
  223. if has_delta_bias
  224. else None
  225. )
  226. u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
  227. delta = 0.5 * torch.rand(
  228. batch_size, dim, seqlen, device=device, dtype=itype
  229. )
  230. state = None
  231. state_ref = None
  232. out = None
  233. out_ref = None
  234. outs = []
  235. for c in range(scan_chunks):
  236. chunked_prompt_len = seqlen // scan_chunks
  237. chunk_start = chunked_prompt_len * c
  238. chunk_end = chunked_prompt_len * (c + 1)
  239. if c == scan_chunks - 1:
  240. chunk_end = seqlen
  241. _B = B
  242. if is_variable_B:
  243. _B = B[..., chunk_start:chunk_end]
  244. _C = C
  245. if is_variable_B:
  246. _C = C[..., chunk_start:chunk_end]
  247. _z = z
  248. if has_z:
  249. assert z is not None
  250. _z = z[..., chunk_start:chunk_end]
  251. out, *rest = selective_scan_fn(
  252. u[..., chunk_start:chunk_end],
  253. delta[..., chunk_start:chunk_end],
  254. A,
  255. _B,
  256. _C,
  257. D,
  258. z=_z,
  259. delta_bias=delta_bias,
  260. delta_softplus=delta_softplus,
  261. return_last_state=return_last_state,
  262. prev_state=state if c > 0 else None,
  263. )
  264. outs.append(out)
  265. if return_last_state:
  266. state = rest[0]
  267. if len(outs) > 1:
  268. out = torch.cat(outs, dim=-1)
  269. out_ref, *rest = selective_scan_ref(
  270. u,
  271. delta,
  272. A,
  273. B,
  274. C,
  275. D,
  276. z=z,
  277. delta_bias=delta_bias,
  278. delta_softplus=delta_softplus,
  279. return_last_state=return_last_state,
  280. )
  281. if return_last_state:
  282. state_ref = rest[0]
  283. assert out is not None and out_ref is not None
  284. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
  285. if return_last_state:
  286. assert state is not None and state_ref is not None
  287. assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
  288. @pytest.mark.parametrize(
  289. "itype", [torch.float32, torch.float16, torch.bfloat16]
  290. )
  291. @pytest.mark.parametrize("has_z", [False, True])
  292. @pytest.mark.parametrize("dstate", [16, 32, 64])
  293. @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
  294. def test_selective_state_update(dim, dstate, has_z, itype):
  295. device = "cuda"
  296. rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
  297. if itype == torch.bfloat16:
  298. rtol, atol = 1e-2, 5e-2
  299. if torch.version.hip:
  300. atol *= 2
  301. # set seed
  302. torch.random.manual_seed(0)
  303. batch_size = 1
  304. state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
  305. x = torch.randn(batch_size, dim, device=device, dtype=itype)
  306. dt = torch.randn(batch_size, dim, device=device, dtype=itype)
  307. dt_bias = torch.rand(dim, device=device) - 4.0
  308. A = -torch.rand(dim, dstate, device=device) - 1.0
  309. B = torch.randn(batch_size, dstate, device=device)
  310. C = torch.randn(batch_size, dstate, device=device)
  311. D = torch.randn(dim, device=device)
  312. z = torch.randn_like(x) if has_z else None
  313. state_ref = state.detach().clone()
  314. out = selective_state_update(
  315. state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
  316. )
  317. out_ref = selective_state_update_ref(
  318. state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
  319. )
  320. assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
  321. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)