mamba_ssm.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. # Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
  3. import torch
  4. import triton
  5. import triton.language as tl
  6. from packaging import version
  7. from aphrodite import _custom_ops as ops
  8. TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
  9. if TRITON3:
  10. @triton.jit
  11. def softplus(dt):
  12. dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
  13. return dt
  14. else:
  15. @triton.jit
  16. def softplus(dt):
  17. dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
  18. return dt
  19. @triton.heuristics(
  20. {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
  21. @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
  22. @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
  23. @triton.heuristics({
  24. "HAS_STATE_BATCH_INDICES":
  25. lambda args: args["state_batch_indices_ptr"] is not None
  26. })
  27. @triton.heuristics(
  28. {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
  29. @triton.jit
  30. def _selective_scan_update_kernel(
  31. # Pointers to matrices
  32. state_ptr,
  33. x_ptr,
  34. dt_ptr,
  35. dt_bias_ptr,
  36. A_ptr,
  37. B_ptr,
  38. C_ptr,
  39. D_ptr,
  40. z_ptr,
  41. out_ptr,
  42. state_batch_indices_ptr,
  43. # Matrix dimensions
  44. batch,
  45. nheads,
  46. dim,
  47. dstate,
  48. nheads_ngroups_ratio,
  49. # Strides
  50. stride_state_batch,
  51. stride_state_head,
  52. stride_state_dim,
  53. stride_state_dstate,
  54. stride_x_batch,
  55. stride_x_head,
  56. stride_x_dim,
  57. stride_dt_batch,
  58. stride_dt_head,
  59. stride_dt_dim,
  60. stride_dt_bias_head,
  61. stride_dt_bias_dim,
  62. stride_A_head,
  63. stride_A_dim,
  64. stride_A_dstate,
  65. stride_B_batch,
  66. stride_B_group,
  67. stride_B_dstate,
  68. stride_C_batch,
  69. stride_C_group,
  70. stride_C_dstate,
  71. stride_D_head,
  72. stride_D_dim,
  73. stride_z_batch,
  74. stride_z_head,
  75. stride_z_dim,
  76. stride_out_batch,
  77. stride_out_head,
  78. stride_out_dim,
  79. # Meta-parameters
  80. DT_SOFTPLUS: tl.constexpr,
  81. TIE_HDIM: tl.constexpr,
  82. BLOCK_SIZE_M: tl.constexpr,
  83. HAS_DT_BIAS: tl.constexpr,
  84. HAS_D: tl.constexpr,
  85. HAS_Z: tl.constexpr,
  86. HAS_STATE_BATCH_INDICES: tl.constexpr,
  87. BLOCK_SIZE_DSTATE: tl.constexpr,
  88. ):
  89. pid_m = tl.program_id(axis=0)
  90. pid_b = tl.program_id(axis=1)
  91. pid_h = tl.program_id(axis=2)
  92. # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
  93. # is taken from the state_batch_indices_ptr Otherwise, the state coordinate
  94. # is the same as the batch id.
  95. if HAS_STATE_BATCH_INDICES:
  96. state_batch_indices_ptr += pid_b
  97. state_batch_idx = tl.load(state_batch_indices_ptr)
  98. state_ptr += (state_batch_idx * stride_state_batch +
  99. pid_h * stride_state_head)
  100. else:
  101. state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
  102. x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
  103. dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
  104. if HAS_DT_BIAS:
  105. dt_bias_ptr += pid_h * stride_dt_bias_head
  106. A_ptr += pid_h * stride_A_head
  107. B_ptr += pid_b * stride_B_batch + (pid_h //
  108. nheads_ngroups_ratio) * stride_B_group
  109. C_ptr += pid_b * stride_C_batch + (pid_h //
  110. nheads_ngroups_ratio) * stride_C_group
  111. if HAS_Z:
  112. z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
  113. out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
  114. offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  115. offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
  116. state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
  117. offs_n[None, :] * stride_state_dstate)
  118. x_ptrs = x_ptr + offs_m * stride_x_dim
  119. dt_ptrs = dt_ptr + offs_m * stride_dt_dim
  120. if HAS_DT_BIAS:
  121. dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
  122. if HAS_D:
  123. D_ptr += pid_h * stride_D_head
  124. A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
  125. offs_n[None, :] * stride_A_dstate)
  126. B_ptrs = B_ptr + offs_n * stride_B_dstate
  127. C_ptrs = C_ptr + offs_n * stride_C_dstate
  128. if HAS_D:
  129. D_ptrs = D_ptr + offs_m * stride_D_dim
  130. if HAS_Z:
  131. z_ptrs = z_ptr + offs_m * stride_z_dim
  132. out_ptrs = out_ptr + offs_m * stride_out_dim
  133. state = tl.load(state_ptrs,
  134. mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
  135. other=0.0)
  136. x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  137. if not TIE_HDIM:
  138. dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  139. if HAS_DT_BIAS:
  140. dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
  141. other=0.0).to(tl.float32)
  142. if DT_SOFTPLUS:
  143. dt = softplus(dt)
  144. A = tl.load(A_ptrs,
  145. mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
  146. other=0.0).to(tl.float32)
  147. dA = tl.exp(A * dt[:, None])
  148. else:
  149. dt = tl.load(dt_ptr).to(tl.float32)
  150. if HAS_DT_BIAS:
  151. dt += tl.load(dt_bias_ptr).to(tl.float32)
  152. if DT_SOFTPLUS:
  153. dt = softplus(dt)
  154. A = tl.load(A_ptr).to(tl.float32)
  155. dA = tl.exp(A * dt) # scalar, not a matrix
  156. B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
  157. C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
  158. if HAS_D:
  159. D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  160. if HAS_Z:
  161. z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
  162. dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
  163. state = state * dA + dB * x[:, None]
  164. tl.store(state_ptrs,
  165. state,
  166. mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate))
  167. out = tl.sum(state * C[None, :], axis=1)
  168. if HAS_D:
  169. out += x * D
  170. if HAS_Z:
  171. out *= z * tl.sigmoid(z)
  172. tl.store(out_ptrs, out, mask=offs_m < dim)
  173. def selective_state_update(state,
  174. x,
  175. dt,
  176. A,
  177. B,
  178. C,
  179. D=None,
  180. z=None,
  181. dt_bias=None,
  182. dt_softplus=False,
  183. state_batch_indices=None):
  184. """
  185. Argument:
  186. state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
  187. x: (batch, dim) or (batch, nheads, dim)
  188. dt: (batch, dim) or (batch, nheads, dim)
  189. A: (dim, dstate) or (nheads, dim, dstate)
  190. B: (batch, dstate) or (batch, ngroups, dstate)
  191. C: (batch, dstate) or (batch, ngroups, dstate)
  192. D: (dim,) or (nheads, dim)
  193. z: (batch, dim) or (batch, nheads, dim)
  194. dt_bias: (dim,) or (nheads, dim)
  195. Return:
  196. out: (batch, dim) or (batch, nheads, dim)
  197. """
  198. has_heads = state.dim() > 3
  199. if state.dim() == 3:
  200. state = state.unsqueeze(1)
  201. if x.dim() == 2:
  202. x = x.unsqueeze(1)
  203. if dt.dim() == 2:
  204. dt = dt.unsqueeze(1)
  205. if A.dim() == 2:
  206. A = A.unsqueeze(0)
  207. if B.dim() == 2:
  208. B = B.unsqueeze(1)
  209. if C.dim() == 2:
  210. C = C.unsqueeze(1)
  211. if D is not None and D.dim() == 1:
  212. D = D.unsqueeze(0)
  213. if z is not None and z.dim() == 2:
  214. z = z.unsqueeze(1)
  215. if dt_bias is not None and dt_bias.dim() == 1:
  216. dt_bias = dt_bias.unsqueeze(0)
  217. _, nheads, dim, dstate = state.shape
  218. batch = x.shape[0]
  219. assert x.shape == (batch, nheads, dim)
  220. assert dt.shape == x.shape
  221. assert A.shape == (nheads, dim, dstate)
  222. ngroups = B.shape[1]
  223. assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
  224. assert B.shape == (batch, ngroups, dstate)
  225. assert C.shape == B.shape
  226. if D is not None:
  227. assert D.shape == (nheads, dim)
  228. if z is not None:
  229. assert z.shape == x.shape
  230. if dt_bias is not None:
  231. assert dt_bias.shape == (nheads, dim)
  232. if state_batch_indices is not None:
  233. assert state_batch_indices.shape == (batch, )
  234. out = torch.empty_like(x)
  235. grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
  236. z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
  237. (0, 0, 0))
  238. # We don't want autotune since it will overwrite the state
  239. # We instead tune by hand.
  240. BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
  241. ((16, 4) if dstate <= 32 else
  242. ((8, 4) if dstate <= 64 else
  243. ((4, 4) if dstate <= 128 else ((4, 8))))))
  244. tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(
  245. -1) == 0 and dt_bias.stride(-1) == 0
  246. with torch.cuda.device(x.device.index):
  247. _selective_scan_update_kernel[grid](
  248. state,
  249. x,
  250. dt,
  251. dt_bias,
  252. A,
  253. B,
  254. C,
  255. D,
  256. z,
  257. out,
  258. state_batch_indices,
  259. batch,
  260. nheads,
  261. dim,
  262. dstate,
  263. nheads // ngroups,
  264. state.stride(0),
  265. state.stride(1),
  266. state.stride(2),
  267. state.stride(3),
  268. x.stride(0),
  269. x.stride(1),
  270. x.stride(2),
  271. dt.stride(0),
  272. dt.stride(1),
  273. dt.stride(2),
  274. *(dt_bias.stride(0),
  275. dt_bias.stride(1)) if dt_bias is not None else 0,
  276. A.stride(0),
  277. A.stride(1),
  278. A.stride(2),
  279. B.stride(0),
  280. B.stride(1),
  281. B.stride(2),
  282. C.stride(0),
  283. C.stride(1),
  284. C.stride(2),
  285. *(D.stride(0), D.stride(1)) if D is not None else 0,
  286. z_strides[0],
  287. z_strides[1],
  288. z_strides[2],
  289. out.stride(0),
  290. out.stride(1),
  291. out.stride(2),
  292. dt_softplus,
  293. tie_hdim,
  294. BLOCK_SIZE_M,
  295. num_warps=num_warps,
  296. )
  297. if not has_heads:
  298. out = out.squeeze(1)
  299. return out
  300. def selective_scan_fn(u,
  301. delta,
  302. A,
  303. B,
  304. C,
  305. D=None,
  306. z=None,
  307. delta_bias=None,
  308. delta_softplus=False,
  309. return_last_state=False,
  310. position_indices=None,
  311. prev_state=None):
  312. """if return_last_state is True, returns (out, last_state)
  313. last_state has shape (batch, dim, dstate).
  314. """
  315. if u.stride(-1) != 1:
  316. u = u.contiguous()
  317. if delta.stride(-1) != 1:
  318. delta = delta.contiguous()
  319. if D is not None:
  320. D = D.contiguous()
  321. if B.stride(-1) != 1:
  322. B = B.contiguous()
  323. if C.stride(-1) != 1:
  324. C = C.contiguous()
  325. if z is not None and z.stride(-1) != 1:
  326. z = z.contiguous()
  327. if B.dim() == 3:
  328. B = B.unsqueeze(1)
  329. if C.dim() == 3:
  330. C = C.unsqueeze(1)
  331. n_chunks = int((u.shape[-1] + 2048 - 1) / 2048)
  332. x = torch.zeros((
  333. u.shape[0],
  334. u.shape[1],
  335. n_chunks,
  336. int(A.shape[1] * 2),
  337. ),
  338. device=u.device,
  339. dtype=torch.float32,
  340. requires_grad=False)
  341. x[:, :, 0, 0::2] = 1
  342. if prev_state is not None:
  343. x[:, :, 0, 1::2].copy_(prev_state)
  344. out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias,
  345. delta_softplus, position_indices, x)
  346. last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
  347. if z is None:
  348. return out if not return_last_state else (out, last_state)
  349. else:
  350. out_z = rest[0]
  351. return out_z if not return_last_state else (out_z, last_state)