mamba_ssm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. import torch
  3. import triton
  4. import triton.language as tl
  5. from einops import rearrange
  6. from packaging import version
  7. from aphrodite import _custom_ops as ops
  8. TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
  9. if TRITON3:
  10. @triton.jit
  11. def softplus(dt):
  12. dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
  13. return dt
  14. else:
  15. @triton.jit
  16. def softplus(dt):
  17. dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
  18. return dt
  19. @triton.heuristics(
  20. {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
  21. @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
  22. @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
  23. @triton.heuristics(
  24. {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
  25. @triton.jit
  26. def _selective_scan_update_kernel(
  27. # Pointers to matrices
  28. state_ptr,
  29. x_ptr,
  30. dt_ptr,
  31. dt_bias_ptr,
  32. A_ptr,
  33. B_ptr,
  34. C_ptr,
  35. D_ptr,
  36. z_ptr,
  37. out_ptr,
  38. # Matrix dimensions
  39. batch,
  40. nheads,
  41. dim,
  42. dstate,
  43. nheads_ngroups_ratio,
  44. # Strides
  45. stride_state_batch,
  46. stride_state_head,
  47. stride_state_dim,
  48. stride_state_dstate,
  49. stride_x_batch,
  50. stride_x_head,
  51. stride_x_dim,
  52. stride_dt_batch,
  53. stride_dt_head,
  54. stride_dt_dim,
  55. stride_dt_bias_head,
  56. stride_dt_bias_dim,
  57. stride_A_head,
  58. stride_A_dim,
  59. stride_A_dstate,
  60. stride_B_batch,
  61. stride_B_group,
  62. stride_B_dstate,
  63. stride_C_batch,
  64. stride_C_group,
  65. stride_C_dstate,
  66. stride_D_head,
  67. stride_D_dim,
  68. stride_z_batch,
  69. stride_z_head,
  70. stride_z_dim,
  71. stride_out_batch,
  72. stride_out_head,
  73. stride_out_dim,
  74. # Meta-parameters
  75. DT_SOFTPLUS: tl.constexpr,
  76. TIE_HDIM: tl.constexpr,
  77. BLOCK_SIZE_M: tl.constexpr,
  78. HAS_DT_BIAS: tl.constexpr,
  79. HAS_D: tl.constexpr,
  80. HAS_Z: tl.constexpr,
  81. BLOCK_SIZE_DSTATE: tl.constexpr,
  82. ):
  83. pid_m = tl.program_id(axis=0)
  84. pid_b = tl.program_id(axis=1)
  85. pid_h = tl.program_id(axis=2)
  86. state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
  87. x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
  88. dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
  89. if HAS_DT_BIAS:
  90. dt_bias_ptr += pid_h * stride_dt_bias_head
  91. A_ptr += pid_h * stride_A_head
  92. B_ptr += pid_b * stride_B_batch + (pid_h //
  93. nheads_ngroups_ratio) * stride_B_group
  94. C_ptr += pid_b * stride_C_batch + (pid_h //
  95. nheads_ngroups_ratio) * stride_C_group
  96. if HAS_Z:
  97. z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
  98. out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
  99. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  100. offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
  101. state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
  102. offs_n[None, :] * stride_state_dstate)
  103. x_ptrs = x_ptr + offs_m * stride_x_dim
  104. dt_ptrs = dt_ptr + offs_m * stride_dt_dim
  105. if HAS_DT_BIAS:
  106. dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
  107. if HAS_D:
  108. D_ptr += pid_h * stride_D_head
  109. A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
  110. offs_n[None, :] * stride_A_dstate)
  111. B_ptrs = B_ptr + offs_n * stride_B_dstate
  112. C_ptrs = C_ptr + offs_n * stride_C_dstate
  113. if HAS_D:
  114. D_ptrs = D_ptr + offs_m * stride_D_dim
  115. if HAS_Z:
  116. z_ptrs = z_ptr + offs_m * stride_z_dim
  117. out_ptrs = out_ptr + offs_m * stride_out_dim
  118. state = tl.load(state_ptrs,
  119. mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
  120. other=0.0)
  121. x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  122. if not TIE_HDIM:
  123. dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  124. if HAS_DT_BIAS:
  125. dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
  126. other=0.0).to(tl.float32)
  127. if DT_SOFTPLUS:
  128. dt = softplus(dt)
  129. A = tl.load(A_ptrs,
  130. mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
  131. other=0.0).to(tl.float32)
  132. dA = tl.exp(A * dt[:, None])
  133. else:
  134. dt = tl.load(dt_ptr).to(tl.float32)
  135. if HAS_DT_BIAS:
  136. dt += tl.load(dt_bias_ptr).to(tl.float32)
  137. if DT_SOFTPLUS:
  138. dt = softplus(dt)
  139. A = tl.load(A_ptr).to(tl.float32)
  140. dA = tl.exp(A * dt) # scalar, not a matrix
  141. B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
  142. C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
  143. if HAS_D:
  144. D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  145. if HAS_Z:
  146. z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  147. dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
  148. state = state * dA + dB * x[:, None]
  149. tl.store(state_ptrs,
  150. state,
  151. mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
  152. out = tl.sum(state * C[None, :], axis=1)
  153. if HAS_D:
  154. out += x * D
  155. if HAS_Z:
  156. out *= z * tl.sigmoid(z)
  157. tl.store(out_ptrs, out, mask=offs_m < dim)
  158. def selective_state_update(state,
  159. x,
  160. dt,
  161. A,
  162. B,
  163. C,
  164. D=None,
  165. z=None,
  166. dt_bias=None,
  167. dt_softplus=False):
  168. """
  169. Argument:
  170. state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
  171. x: (batch, dim) or (batch, nheads, dim)
  172. dt: (batch, dim) or (batch, nheads, dim)
  173. A: (dim, dstate) or (nheads, dim, dstate)
  174. B: (batch, dstate) or (batch, ngroups, dstate)
  175. C: (batch, dstate) or (batch, ngroups, dstate)
  176. D: (dim,) or (nheads, dim)
  177. z: (batch, dim) or (batch, nheads, dim)
  178. dt_bias: (dim,) or (nheads, dim)
  179. Return:
  180. out: (batch, dim) or (batch, nheads, dim)
  181. """
  182. has_heads = state.dim() > 3
  183. if state.dim() == 3:
  184. state = state.unsqueeze(1)
  185. if x.dim() == 2:
  186. x = x.unsqueeze(1)
  187. if dt.dim() == 2:
  188. dt = dt.unsqueeze(1)
  189. if A.dim() == 2:
  190. A = A.unsqueeze(0)
  191. if B.dim() == 2:
  192. B = B.unsqueeze(1)
  193. if C.dim() == 2:
  194. C = C.unsqueeze(1)
  195. if D is not None and D.dim() == 1:
  196. D = D.unsqueeze(0)
  197. if z is not None and z.dim() == 2:
  198. z = z.unsqueeze(1)
  199. if dt_bias is not None and dt_bias.dim() == 1:
  200. dt_bias = dt_bias.unsqueeze(0)
  201. batch, nheads, dim, dstate = state.shape
  202. assert x.shape == (batch, nheads, dim)
  203. assert dt.shape == x.shape
  204. assert A.shape == (nheads, dim, dstate)
  205. ngroups = B.shape[1]
  206. assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
  207. assert B.shape == (batch, ngroups, dstate)
  208. assert C.shape == B.shape
  209. if D is not None:
  210. assert D.shape == (nheads, dim)
  211. if z is not None:
  212. assert z.shape == x.shape
  213. if dt_bias is not None:
  214. assert dt_bias.shape == (nheads, dim)
  215. out = torch.empty_like(x)
  216. grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
  217. z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
  218. (0, 0, 0))
  219. # We don't want autotune since it will overwrite the state
  220. # We instead tune by hand.
  221. BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
  222. ((16, 4) if dstate <= 32 else
  223. ((8, 4) if dstate <= 64 else
  224. ((4, 4) if dstate <= 128 else ((4, 8))))))
  225. tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(
  226. -1) == 0 and dt_bias.stride(-1) == 0
  227. with torch.cuda.device(x.device.index):
  228. _selective_scan_update_kernel[grid](
  229. state,
  230. x,
  231. dt,
  232. dt_bias,
  233. A,
  234. B,
  235. C,
  236. D,
  237. z,
  238. out,
  239. batch,
  240. nheads,
  241. dim,
  242. dstate,
  243. nheads // ngroups,
  244. state.stride(0),
  245. state.stride(1),
  246. state.stride(2),
  247. state.stride(3),
  248. x.stride(0),
  249. x.stride(1),
  250. x.stride(2),
  251. dt.stride(0),
  252. dt.stride(1),
  253. dt.stride(2),
  254. *(dt_bias.stride(0),
  255. dt_bias.stride(1)) if dt_bias is not None else 0,
  256. A.stride(0),
  257. A.stride(1),
  258. A.stride(2),
  259. B.stride(0),
  260. B.stride(1),
  261. B.stride(2),
  262. C.stride(0),
  263. C.stride(1),
  264. C.stride(2),
  265. *(D.stride(0), D.stride(1)) if D is not None else 0,
  266. z_strides[0],
  267. z_strides[1],
  268. z_strides[2],
  269. out.stride(0),
  270. out.stride(1),
  271. out.stride(2),
  272. dt_softplus,
  273. tie_hdim,
  274. BLOCK_SIZE_M,
  275. num_warps=num_warps,
  276. )
  277. if not has_heads:
  278. out = out.squeeze(1)
  279. return out
  280. def selective_scan_fn(u,
  281. delta,
  282. A,
  283. B,
  284. C,
  285. D=None,
  286. z=None,
  287. delta_bias=None,
  288. delta_softplus=False,
  289. return_last_state=False,
  290. position_indices=None,
  291. prev_state=None):
  292. """if return_last_state is True, returns (out, last_state)
  293. last_state has shape (batch, dim, dstate).
  294. """
  295. if u.stride(-1) != 1:
  296. u = u.contiguous()
  297. if delta.stride(-1) != 1:
  298. delta = delta.contiguous()
  299. if D is not None:
  300. D = D.contiguous()
  301. if B.stride(-1) != 1:
  302. B = B.contiguous()
  303. if C.stride(-1) != 1:
  304. C = C.contiguous()
  305. if z is not None and z.stride(-1) != 1:
  306. z = z.contiguous()
  307. if B.dim() == 3:
  308. B = rearrange(B, "b dstate l -> b 1 dstate l")
  309. if C.dim() == 3:
  310. C = rearrange(C, "b dstate l -> b 1 dstate l")
  311. n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
  312. x = torch.zeros((
  313. u.shape[0],
  314. u.shape[1],
  315. n_chunks,
  316. int(A.shape[1] * 2),
  317. ),
  318. device=u.device,
  319. dtype=torch.float32,
  320. requires_grad=False)
  321. x[:, :, 0, 0::2] = 1
  322. if prev_state is not None:
  323. x[:, :, 0, 1::2].copy_(prev_state)
  324. out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
  325. delta_softplus, position_indices, x)
  326. last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
  327. if z is None:
  328. return out if not return_last_state else (out, last_state)
  329. else:
  330. out_z = rest[0]
  331. return out_z if not return_last_state else (out_z, last_state)