test_mamba_ssm.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import pytest
  2. import torch
  3. import torch.nn.functional as F
  4. from einops import rearrange, repeat
  5. from aphrodite.common.utils import seed_everything
  6. from aphrodite.modeling.layers.mamba.ops.mamba_ssm import (
  7. selective_scan_fn, selective_state_update)
  8. def selective_state_update_ref(state,
  9. x,
  10. dt,
  11. A,
  12. B,
  13. C,
  14. D=None,
  15. z=None,
  16. dt_bias=None,
  17. dt_softplus=False):
  18. """
  19. Argument:
  20. state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
  21. x: (batch, dim) or (batch, nheads, dim)
  22. dt: (batch, dim) or (batch, nheads, dim)
  23. A: (dim, dstate) or (nheads, dim, dstate)
  24. B: (batch, dstate) or (batch, ngroups, dstate)
  25. C: (batch, dstate) or (batch, ngroups, dstate)
  26. D: (dim,) or (nheads, dim)
  27. z: (batch, dim) or (batch, nheads, dim)
  28. dt_bias: (dim,) or (nheads, dim)
  29. Return:
  30. out: (batch, dim) or (batch, nheads, dim)
  31. """
  32. has_heads = state.dim() > 3
  33. if state.dim() == 3:
  34. state = state.unsqueeze(1)
  35. if x.dim() == 2:
  36. x = x.unsqueeze(1)
  37. if dt.dim() == 2:
  38. dt = dt.unsqueeze(1)
  39. if A.dim() == 2:
  40. A = A.unsqueeze(0)
  41. if B.dim() == 2:
  42. B = B.unsqueeze(1)
  43. if C.dim() == 2:
  44. C = C.unsqueeze(1)
  45. if D is not None and D.dim() == 1:
  46. D = D.unsqueeze(0)
  47. if z is not None and z.dim() == 2:
  48. z = z.unsqueeze(1)
  49. if dt_bias is not None and dt_bias.dim() == 1:
  50. dt_bias = dt_bias.unsqueeze(0)
  51. batch, nheads, dim, dstate = state.shape
  52. assert x.shape == (batch, nheads, dim)
  53. assert dt.shape == x.shape
  54. assert A.shape == (nheads, dim, dstate)
  55. ngroups = B.shape[1]
  56. assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
  57. assert B.shape == (batch, ngroups, dstate)
  58. assert C.shape == B.shape
  59. if D is not None:
  60. assert D.shape == (nheads, dim)
  61. if z is not None:
  62. assert z.shape == x.shape
  63. if dt_bias is not None:
  64. assert dt_bias.shape == (nheads, dim)
  65. dt = dt + dt_bias
  66. dt = F.softplus(dt) if dt_softplus else dt
  67. dA = torch.exp(rearrange(dt, "b h d -> b h d 1") *
  68. A) # (batch, nheads, dim, dstate)
  69. B = repeat(B, "b g n -> b (g h) n",
  70. h=nheads // ngroups) # (batch, nheads, dstate)
  71. C = repeat(C, "b g n -> b (g h) n",
  72. h=nheads // ngroups) # (batch, nheads, dstate)
  73. dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
  74. B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate)
  75. state.copy_(state * dA +
  76. dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate
  77. out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
  78. if D is not None:
  79. out += (x * D).to(out.dtype)
  80. out = (out if z is None else out * F.silu(z)).to(x.dtype)
  81. if not has_heads:
  82. out = out.squeeze(1)
  83. return out
  84. def selective_scan_ref(u,
  85. delta,
  86. A,
  87. B,
  88. C,
  89. D=None,
  90. z=None,
  91. delta_bias=None,
  92. delta_softplus=False,
  93. return_last_state=False,
  94. position_indices=None,
  95. prev_state=None):
  96. """
  97. u: r(B D L)
  98. delta: r(B D L)
  99. A: c(D N) or r(D N)
  100. B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  101. C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
  102. D: r(D)
  103. z: r(B D L)
  104. delta_bias: r(D), fp32
  105. prev_state: r(B D N), fp32
  106. out: r(B D L)
  107. last_state (optional): r(B D dstate) or c(B D dstate)
  108. """
  109. dtype_in = u.dtype
  110. u = u.float()
  111. delta = delta.float()
  112. if delta_bias is not None:
  113. delta = delta + delta_bias[..., None].float()
  114. if delta_softplus:
  115. delta = F.softplus(delta)
  116. batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
  117. is_variable_B = B.dim() >= 3
  118. is_variable_C = C.dim() >= 3
  119. B = B.float()
  120. C = C.float()
  121. x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state
  122. ys = []
  123. deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
  124. if not is_variable_B:
  125. deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
  126. else:
  127. if B.dim() == 3:
  128. deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
  129. else:
  130. B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
  131. deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
  132. if is_variable_C and C.dim() == 4:
  133. C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
  134. last_state = None
  135. for i in range(u.shape[2]):
  136. if position_indices is not None and position_indices[0, i] == 0:
  137. x = deltaB_u[:, :, i]
  138. else:
  139. x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
  140. if not is_variable_C:
  141. y = torch.einsum('bdn,dn->bd', x, C)
  142. else:
  143. if C.dim() == 3:
  144. y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
  145. else:
  146. y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
  147. if i == u.shape[2] - 1:
  148. last_state = x
  149. ys.append(y)
  150. y = torch.stack(ys, dim=2) # (batch dim L)
  151. out = y if D is None else y + u * rearrange(D, "d -> d 1")
  152. if z is not None:
  153. out = out * F.silu(z)
  154. out = out.to(dtype=dtype_in)
  155. return out if not return_last_state else (out, last_state)
  156. @pytest.mark.parametrize('wtype', [torch.float32])
  157. @pytest.mark.parametrize('itype', [torch.float32])
  158. @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096])
  159. @pytest.mark.parametrize("return_last_state", [True])
  160. @pytest.mark.parametrize('has_delta_bias', [True])
  161. @pytest.mark.parametrize('delta_softplus', [True])
  162. @pytest.mark.parametrize('has_z', [True])
  163. @pytest.mark.parametrize('has_D', [True])
  164. @pytest.mark.parametrize("varBC_groups", [1, 2])
  165. @pytest.mark.parametrize("is_variable_C", [True])
  166. @pytest.mark.parametrize("is_variable_B", [True])
  167. @pytest.mark.parametrize("scan_chunks", [1, 2, 3])
  168. def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D,
  169. has_z, has_delta_bias, delta_softplus,
  170. return_last_state, seqlen, itype, wtype, scan_chunks):
  171. if varBC_groups > 1 and (not is_variable_B or not is_variable_C):
  172. pytest.skip() # This config is not applicable
  173. device = 'cuda'
  174. rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3)
  175. if itype == torch.bfloat16:
  176. rtol, atol = 3e-2, 5e-2
  177. rtolw, atolw = (1e-3, 1e-3)
  178. if has_z: # If we have z, the errors on the weights seem higher
  179. rtolw = max(rtolw, rtol)
  180. atolw = max(atolw, atol)
  181. # set seed
  182. seed_everything(0)
  183. batch_size = 2
  184. dim = 4
  185. dstate = 8
  186. A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype))
  187. if not is_variable_B:
  188. B_shape = [dim, dstate]
  189. elif varBC_groups == 1:
  190. B_shape = [batch_size, dstate, seqlen]
  191. else:
  192. B_shape = [batch_size, varBC_groups, dstate, seqlen]
  193. B = torch.randn(B_shape,
  194. device=device,
  195. dtype=wtype if not is_variable_B else itype)
  196. if not is_variable_C:
  197. C_shape = [dim, dstate]
  198. elif varBC_groups == 1:
  199. C_shape = [batch_size, dstate, seqlen]
  200. else:
  201. C_shape = [batch_size, varBC_groups, dstate, seqlen]
  202. C = torch.randn(C_shape,
  203. device=device,
  204. dtype=wtype if not is_variable_C else itype)
  205. D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None
  206. z = torch.randn(batch_size, dim, seqlen, device=device,
  207. dtype=itype) if has_z else None
  208. delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32)
  209. ) if has_delta_bias else None
  210. u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype)
  211. delta = (0.5 *
  212. torch.rand(batch_size, dim, seqlen, device=device, dtype=itype))
  213. state = None
  214. state_ref = None
  215. out = None
  216. out_ref = None
  217. outs = []
  218. for c in range(scan_chunks):
  219. chunked_prompt_len = seqlen // scan_chunks
  220. chunk_start = chunked_prompt_len * c
  221. chunk_end = chunked_prompt_len * (c + 1)
  222. if c == scan_chunks - 1:
  223. chunk_end = seqlen
  224. _B = B
  225. if is_variable_B:
  226. _B = B[..., chunk_start:chunk_end]
  227. _C = C
  228. if is_variable_B:
  229. _C = C[..., chunk_start:chunk_end]
  230. _z = z
  231. if has_z:
  232. assert z is not None
  233. _z = z[..., chunk_start:chunk_end]
  234. out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end],
  235. delta[..., chunk_start:chunk_end],
  236. A,
  237. _B,
  238. _C,
  239. D,
  240. z=_z,
  241. delta_bias=delta_bias,
  242. delta_softplus=delta_softplus,
  243. return_last_state=return_last_state,
  244. prev_state=state if c > 0 else None)
  245. outs.append(out)
  246. if return_last_state:
  247. state = rest[0]
  248. if len(outs) > 1:
  249. out = torch.cat(outs, dim=-1)
  250. out_ref, *rest = selective_scan_ref(u,
  251. delta,
  252. A,
  253. B,
  254. C,
  255. D,
  256. z=z,
  257. delta_bias=delta_bias,
  258. delta_softplus=delta_softplus,
  259. return_last_state=return_last_state)
  260. if return_last_state:
  261. state_ref = rest[0]
  262. assert out is not None and out_ref is not None
  263. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
  264. if return_last_state:
  265. assert state is not None and state_ref is not None
  266. assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
  267. @pytest.mark.parametrize("itype",
  268. [torch.float32, torch.float16, torch.bfloat16])
  269. @pytest.mark.parametrize("has_z", [False, True])
  270. @pytest.mark.parametrize("dstate", [16, 32, 64])
  271. @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
  272. def test_selective_state_update(dim, dstate, has_z, itype):
  273. device = "cuda"
  274. rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
  275. if itype == torch.bfloat16:
  276. rtol, atol = 1e-2, 5e-2
  277. if torch.version.hip:
  278. atol *= 2
  279. # set seed
  280. seed_everything(0)
  281. batch_size = 1
  282. state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
  283. x = torch.randn(batch_size, dim, device=device, dtype=itype)
  284. dt = torch.randn(batch_size, dim, device=device, dtype=itype)
  285. dt_bias = torch.rand(dim, device=device) - 4.0
  286. A = -torch.rand(dim, dstate, device=device) - 1.0
  287. B = torch.randn(batch_size, dstate, device=device)
  288. C = torch.randn(batch_size, dstate, device=device)
  289. D = torch.randn(dim, device=device)
  290. z = torch.randn_like(x) if has_z else None
  291. state_ref = state.detach().clone()
  292. out = selective_state_update(state,
  293. x,
  294. dt,
  295. A,
  296. B,
  297. C,
  298. D=D,
  299. z=z,
  300. dt_bias=dt_bias,
  301. dt_softplus=True)
  302. out_ref = selective_state_update_ref(state_ref,
  303. x,
  304. dt,
  305. A,
  306. B,
  307. C,
  308. D=D,
  309. z=z,
  310. dt_bias=dt_bias,
  311. dt_softplus=True)
  312. assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
  313. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
  314. @pytest.mark.parametrize("itype",
  315. [torch.float32, torch.float16, torch.bfloat16])
  316. @pytest.mark.parametrize("has_z", [False, True])
  317. @pytest.mark.parametrize("dstate", [16, 32, 64])
  318. @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
  319. def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
  320. device = "cuda"
  321. rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
  322. if itype == torch.bfloat16:
  323. rtol, atol = 7e-2, 7e-2
  324. if torch.version.hip:
  325. atol *= 2
  326. # set seed
  327. torch.random.manual_seed(0)
  328. batch_size = 16
  329. total_entries = 10 * batch_size
  330. state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
  331. state_indices = torch.randperm(total_entries)[:batch_size].to(
  332. dtype=torch.int32, device=device)
  333. x = torch.randn(batch_size, dim, device=device, dtype=itype)
  334. dt = torch.randn(batch_size, dim, device=device, dtype=itype)
  335. dt_bias = torch.rand(dim, device=device) - 4.0
  336. A = -torch.rand(dim, dstate, device=device) - 1.0
  337. B = torch.randn(batch_size, dstate, device=device)
  338. C = torch.randn(batch_size, dstate, device=device)
  339. D = torch.randn(dim, device=device)
  340. z = torch.randn_like(x) if has_z else None
  341. state_ref = state[state_indices, :].detach().clone()
  342. out = selective_state_update(state,
  343. x,
  344. dt,
  345. A,
  346. B,
  347. C,
  348. D=D,
  349. z=z,
  350. dt_bias=dt_bias,
  351. dt_softplus=True,
  352. state_batch_indices=state_indices)
  353. out_ref = selective_state_update_ref(state_ref,
  354. x,
  355. dt,
  356. A,
  357. B,
  358. C,
  359. D=D,
  360. z=z,
  361. dt_bias=dt_bias,
  362. dt_softplus=True)
  363. assert torch.allclose(state[state_indices, :],
  364. state_ref,
  365. rtol=rtol,
  366. atol=atol)
  367. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
  368. @pytest.mark.parametrize("itype",
  369. [torch.float32, torch.float16, torch.bfloat16])
  370. @pytest.mark.parametrize("has_z", [False, True])
  371. @pytest.mark.parametrize("tie_hdim", [False, True])
  372. @pytest.mark.parametrize("ngroups", [1, 2, 4])
  373. @pytest.mark.parametrize("dstate", [16, 32, 64])
  374. @pytest.mark.parametrize("dim", [2048, 4096])
  375. def test_selective_state_update_with_heads_with_batch_indices(
  376. dim, dstate, ngroups, has_z, tie_hdim, itype):
  377. device = "cuda"
  378. rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
  379. if itype == torch.bfloat16:
  380. rtol, atol = 1e-1, 1e-1
  381. # set seed
  382. torch.random.manual_seed(0)
  383. batch_size = 16
  384. headdim = 64
  385. nheads = dim // headdim
  386. total_entries = 10 * batch_size
  387. state = torch.randn(total_entries,
  388. nheads,
  389. headdim,
  390. dstate,
  391. dtype=itype,
  392. device=device)
  393. state_indices = torch.randperm(total_entries)[:batch_size].to(
  394. dtype=torch.int32, device=device)
  395. x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
  396. if not tie_hdim:
  397. dt = torch.randn(batch_size,
  398. nheads,
  399. headdim,
  400. device=device,
  401. dtype=itype)
  402. dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
  403. A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
  404. D = torch.randn(nheads, headdim, device=device)
  405. else:
  406. dt = repeat(torch.randn(batch_size, nheads, device=device,
  407. dtype=itype),
  408. "b h -> b h p",
  409. p=headdim)
  410. dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
  411. "h -> h p",
  412. p=headdim)
  413. A = repeat(-torch.rand(nheads, device=device) - 1.0,
  414. "h -> h p n",
  415. p=headdim,
  416. n=dstate)
  417. D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
  418. B = torch.randn(batch_size, ngroups, dstate, device=device)
  419. C = torch.randn(batch_size, ngroups, dstate, device=device)
  420. z = torch.randn_like(x) if has_z else None
  421. state_ref = state[state_indices, :].detach().clone()
  422. out = selective_state_update(state,
  423. x,
  424. dt,
  425. A,
  426. B,
  427. C,
  428. D=D,
  429. z=z,
  430. dt_bias=dt_bias,
  431. dt_softplus=True,
  432. state_batch_indices=state_indices)
  433. out_ref = selective_state_update_ref(state_ref,
  434. x,
  435. dt,
  436. A,
  437. B,
  438. C,
  439. D=D,
  440. z=z,
  441. dt_bias=dt_bias,
  442. dt_softplus=True)
  443. print(f"Output max diff: {(out - out_ref).abs().max().item()}")
  444. print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
  445. assert torch.allclose(state[state_indices, :],
  446. state_ref,
  447. rtol=rtol,
  448. atol=atol)
  449. assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)