aphrodite_flash_attn.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. from typing import Optional, Union
  2. import torch
  3. import aphrodite._custom_ops as ops
  4. def maybe_contiguous(x):
  5. return x.contiguous() if x is not None and x.stride(-1) != 1 else x
  6. def _flash_attn_forward(
  7. q, k, v, dropout_p, softmax_scale, causal,
  8. window_size, softcap, alibi_slopes,
  9. return_softmax, *, out=None
  10. ):
  11. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  12. (out, q, k, v, out_padded, softmax_lse,
  13. S_dmask, rng_state) = ops.fwd(
  14. q=q,
  15. k=k,
  16. v=v,
  17. out=out,
  18. alibi_slopes=alibi_slopes,
  19. dropout_p=dropout_p,
  20. softmax_scale=softmax_scale,
  21. causal=causal,
  22. window_size_left=window_size[0],
  23. window_size_right=window_size[1],
  24. softcap=softcap,
  25. return_softmax=return_softmax,
  26. gen=None,
  27. ) # type: ignore
  28. return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
  29. def _flash_attn_varlen_forward(
  30. q,
  31. k,
  32. v,
  33. cu_seqlens_q,
  34. cu_seqlens_k,
  35. max_seqlen_q,
  36. max_seqlen_k,
  37. dropout_p,
  38. softmax_scale,
  39. causal,
  40. window_size,
  41. softcap,
  42. alibi_slopes,
  43. return_softmax,
  44. block_table,
  45. *,
  46. out=None
  47. ):
  48. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  49. (out, q, k, v, out_padded, softmax_lse,
  50. S_dmask, rng_state) = ops.varlen_fwd(
  51. q=q,
  52. k=k,
  53. v=v,
  54. cu_seqlens_q=cu_seqlens_q,
  55. cu_seqlens_k=cu_seqlens_k,
  56. max_seqlen_q=max_seqlen_q,
  57. max_seqlen_k=max_seqlen_k,
  58. dropout_p=dropout_p,
  59. softmax_scale=softmax_scale,
  60. causal=causal,
  61. window_size_left=window_size[0],
  62. window_size_right=window_size[1],
  63. softcap=softcap,
  64. alibi_slopes=alibi_slopes,
  65. block_table=block_table,
  66. return_softmax=return_softmax,
  67. gen=None,
  68. out=out,
  69. seqused_k=None,
  70. zero_tensors=False,
  71. ) # type: ignore
  72. return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
  73. class FlashAttnFunc(torch.autograd.Function):
  74. @staticmethod
  75. def forward(
  76. ctx,
  77. q,
  78. k,
  79. v,
  80. dropout_p,
  81. softmax_scale,
  82. causal,
  83. window_size,
  84. softcap,
  85. alibi_slopes,
  86. deterministic,
  87. return_softmax,
  88. out=None,
  89. ):
  90. if softmax_scale is None:
  91. softmax_scale = q.shape[-1] ** (-0.5)
  92. (out, q, k, v, out_padded, softmax_lse,
  93. S_dmask, rng_state) = _flash_attn_forward(
  94. q,
  95. k,
  96. v,
  97. dropout_p,
  98. softmax_scale,
  99. causal=causal,
  100. window_size=window_size,
  101. softcap=softcap,
  102. alibi_slopes=alibi_slopes,
  103. return_softmax=return_softmax and dropout_p > 0,
  104. out=out,
  105. )
  106. ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
  107. ctx.dropout_p = dropout_p
  108. ctx.softmax_scale = softmax_scale
  109. ctx.causal = causal
  110. ctx.window_size = window_size
  111. ctx.softcap = softcap
  112. ctx.alibi_slopes = alibi_slopes
  113. ctx.deterministic = deterministic
  114. return out if not return_softmax else (out, softmax_lse, S_dmask)
  115. class FlashAttnVarlenFunc(torch.autograd.Function):
  116. @staticmethod
  117. def forward(
  118. ctx,
  119. q,
  120. k,
  121. v,
  122. cu_seqlens_q,
  123. cu_seqlens_k,
  124. max_seqlen_q,
  125. max_seqlen_k,
  126. dropout_p,
  127. softmax_scale,
  128. causal,
  129. window_size,
  130. softcap,
  131. alibi_slopes,
  132. deterministic,
  133. return_softmax,
  134. block_table,
  135. out=None,
  136. ):
  137. if softmax_scale is None:
  138. softmax_scale = q.shape[-1] ** (-0.5)
  139. (out, q, k, v, out_padded, softmax_lse,
  140. S_dmask, rng_state) = _flash_attn_varlen_forward(
  141. q,
  142. k,
  143. v,
  144. cu_seqlens_q,
  145. cu_seqlens_k,
  146. max_seqlen_q,
  147. max_seqlen_k,
  148. dropout_p,
  149. softmax_scale,
  150. causal=causal,
  151. window_size=window_size,
  152. softcap=softcap,
  153. alibi_slopes=alibi_slopes,
  154. return_softmax=return_softmax and dropout_p > 0,
  155. block_table=block_table,
  156. out=out,
  157. )
  158. ctx.save_for_backward(
  159. q, k, v, out_padded, softmax_lse, cu_seqlens_q,
  160. cu_seqlens_k, rng_state
  161. )
  162. ctx.dropout_p = dropout_p
  163. ctx.max_seqlen_q = max_seqlen_q
  164. ctx.max_seqlen_k = max_seqlen_k
  165. ctx.softmax_scale = softmax_scale
  166. ctx.causal = causal
  167. ctx.window_size = window_size
  168. ctx.softcap = softcap
  169. ctx.alibi_slopes = alibi_slopes
  170. ctx.deterministic = deterministic
  171. return out if not return_softmax else (out, softmax_lse, S_dmask)
  172. def flash_attn_varlen_func(
  173. q,
  174. k,
  175. v,
  176. cu_seqlens_q,
  177. cu_seqlens_k,
  178. max_seqlen_q,
  179. max_seqlen_k,
  180. dropout_p=0.0,
  181. softmax_scale=None,
  182. causal=False,
  183. window_size=(-1, -1), # -1 means infinite context window
  184. softcap=0.0, # 0.0 means deactivated
  185. alibi_slopes=None,
  186. deterministic=False,
  187. return_attn_probs=False,
  188. block_table=None,
  189. *,
  190. out=None,
  191. ):
  192. return FlashAttnVarlenFunc.apply(
  193. q,
  194. k,
  195. v,
  196. cu_seqlens_q,
  197. cu_seqlens_k,
  198. max_seqlen_q,
  199. max_seqlen_k,
  200. dropout_p,
  201. softmax_scale,
  202. causal,
  203. window_size,
  204. softcap,
  205. alibi_slopes,
  206. deterministic,
  207. return_attn_probs,
  208. block_table,
  209. out,
  210. )
  211. def flash_attn_with_kvcache(
  212. q,
  213. k_cache,
  214. v_cache,
  215. k=None,
  216. v=None,
  217. rotary_cos=None,
  218. rotary_sin=None,
  219. cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
  220. cache_batch_idx: Optional[torch.Tensor] = None,
  221. block_table: Optional[torch.Tensor] = None,
  222. softmax_scale=None,
  223. causal=False,
  224. window_size=(-1, -1), # -1 means infinite context window
  225. softcap=0.0, # 0.0 means deactivated
  226. rotary_interleaved=True,
  227. alibi_slopes=None,
  228. num_splits=0,
  229. return_softmax_lse=False,
  230. *,
  231. out=None,
  232. ):
  233. assert k_cache.stride(-1) == 1, (
  234. "k_cache must have contiguous last dimension"
  235. )
  236. assert v_cache.stride(-1) == 1, (
  237. "v_cache must have contiguous last dimension"
  238. )
  239. q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
  240. if softmax_scale is None:
  241. softmax_scale = q.shape[-1] ** (-0.5)
  242. if cache_seqlens is not None and isinstance(cache_seqlens, int):
  243. cache_seqlens = torch.full(
  244. (k_cache.shape[0],), cache_seqlens,
  245. dtype=torch.int32, device=k_cache.device
  246. )
  247. cache_seqlens = maybe_contiguous(cache_seqlens)
  248. cache_batch_idx = maybe_contiguous(cache_batch_idx)
  249. block_table = maybe_contiguous(block_table)
  250. out, softmax_lse = ops.fwd_kvcache(
  251. q=q,
  252. kcache=k_cache,
  253. vcache=v_cache,
  254. k=k,
  255. v=v,
  256. seqlens_k=cache_seqlens,
  257. rotary_cos=rotary_cos,
  258. rotary_sin=rotary_sin,
  259. cache_batch_idx=cache_batch_idx,
  260. block_table=block_table,
  261. alibi_slopes=alibi_slopes,
  262. out=out,
  263. softmax_scale=softmax_scale,
  264. causal=causal,
  265. window_size_left=window_size[0],
  266. window_size_right=window_size[1],
  267. softcap=softcap,
  268. rotary_interleaved=rotary_interleaved,
  269. num_splits=num_splits,
  270. ) # type: ignore
  271. return (out, softmax_lse) if return_softmax_lse else out