interface.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. import math
  2. import torch
  3. from aphrodite.attention.ops.blocksparse_attention.utils import (
  4. dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask)
  5. from aphrodite.common.utils import is_cpu, is_hip
  6. from aphrodite.platforms import current_platform
  7. IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
  8. and current_platform.get_device_capability()[0] >= 8)
  9. if IS_COMPUTE_8_OR_ABOVE:
  10. from aphrodite.attention.ops.blocksparse_attention.blocksparse_attention_kernel import ( # noqa: E501
  11. blocksparse_flash_attn_varlen_fwd)
  12. class LocalStridedBlockSparseAttn(torch.nn.Module):
  13. def __init__(
  14. self,
  15. n_heads,
  16. max_seqlen,
  17. local_blocks,
  18. vert_stride,
  19. block_size,
  20. device=None,
  21. dtype=None,
  22. homo_head=False,
  23. active_head_range=None,
  24. q_block_size=None,
  25. use_spda=None,
  26. ):
  27. super().__init__()
  28. if use_spda is None:
  29. use_spda = is_hip() or is_cpu() or not \
  30. IS_COMPUTE_8_OR_ABOVE
  31. device = device or (torch.cuda.current_device()
  32. if torch.cuda.is_available() else "cpu")
  33. device = torch.device(device)
  34. # NOTE: aphrodite CPU backend support BF16 instead of FP16.
  35. dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
  36. or device.type == "cpu" else torch.half)
  37. self.n_heads = n_heads
  38. self.max_seqlen = max_seqlen
  39. self.local_blocks = local_blocks
  40. self.vert_stride = vert_stride
  41. self.use_spda = use_spda
  42. self.dtype = dtype
  43. self.device = device
  44. self.block_size = block_size
  45. self.q_block_size = q_block_size
  46. self.homo_head = homo_head
  47. self.active_head_range = active_head_range
  48. self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
  49. homo_head)
  50. sparse_layout, sparse_pattern, self.dense_attn_mask = (
  51. self.get_attn_pattern(dtype, device))
  52. if q_block_size is not None and q_block_size != block_size:
  53. if q_block_size > block_size:
  54. assert q_block_size % block_size == 0
  55. blocks_to_merge = q_block_size // block_size
  56. shape = sparse_pattern.shape
  57. sparse_pattern = sparse_pattern.view(shape[0], -1,
  58. blocks_to_merge,
  59. shape[-1])
  60. sparse_pattern = sparse_pattern.sum(2)
  61. sparse_layout = dense_to_crow_col(sparse_pattern)
  62. else:
  63. raise ValueError(
  64. "Does not support smaller q_block_size. It will be slower."
  65. )
  66. self.sparse_layout = sparse_layout
  67. def get_attn_pattern(self, dtype, device):
  68. sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
  69. self.n_heads,
  70. self.max_seqlen,
  71. self.max_seqlen,
  72. dtype,
  73. device,
  74. block_size=self.block_size,
  75. local_blocks=self.local_blocks,
  76. vert_stride=self.vert_stride,
  77. homo_head=self.homo_head,
  78. return_dense=self.use_spda,
  79. dense_mask_type="bias",
  80. )
  81. if (not self.homo_head) and (self.active_head_range is not None):
  82. assert isinstance(self.active_head_range, tuple)
  83. assert (len(self.active_head_range) == 2)
  84. h_start, h_end = self.active_head_range
  85. sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
  86. if self.use_spda:
  87. dense_attn_mask = dense_attn_mask[h_start:h_end]
  88. return sparse_layout, sparse_pattern, dense_attn_mask
  89. def varlen_attn(self,
  90. q,
  91. k,
  92. v,
  93. cu_seqlens_k,
  94. cu_seqlens_q=None,
  95. sm_scale=None):
  96. """
  97. q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
  98. Support grouped attention, with `q[:, i*r:(i*r + r)]`
  99. is correspondent to `k[:, i]`, where `r` is the q/k ratio.
  100. cu_seqlens_k: shape=(batch_size + 1,),
  101. indicating segment of samples,
  102. e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
  103. cu_seqlens_q: shape=(batch_size + 1, ).
  104. Default None: same as cu_seqlens_k for prefilling or
  105. [0, 1, .., batch_size] for decoding.
  106. The only case you need to specify is when q is a mix of
  107. prefilling and decoding.
  108. sm_scale: softmax scale, default to 1/sqrt(head_size).
  109. return: tensor of shape as q.
  110. """
  111. assert (
  112. IS_COMPUTE_8_OR_ABOVE
  113. ), "Requires compute capability of 8 or above (Ampere or newer) to use \
  114. Triton kernel."
  115. sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
  116. return blocksparse_flash_attn_varlen_fwd(
  117. q,
  118. k,
  119. v,
  120. cu_seqlens_k,
  121. cu_seqlens_q,
  122. sm_scale,
  123. self.sparse_layout,
  124. block_size=self.block_size,
  125. q_block_size=self.q_block_size,
  126. max_seqlen=self.max_seqlen,
  127. )
  128. @staticmethod
  129. def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
  130. """
  131. :param x: (total_tokens, n_heads, head_size)
  132. :return: (batch, n_heads, length, head_size)
  133. """
  134. x_padded = x.new_empty(
  135. len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
  136. cu_seqlens = cu_seqlens.cpu()
  137. for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
  138. x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
  139. 1).unsqueeze(1))
  140. return x_padded.flatten(1, 2)
  141. @staticmethod
  142. def transpose_and_unpad(x_padded, cu_seqlens):
  143. """
  144. :param x_padded: (batch, n_heads, length, head_size)
  145. :return: (total_tokens, n_heads, head_size)
  146. """
  147. cu_seqlens = cu_seqlens.cpu()
  148. total_n_tokens = cu_seqlens[-1]
  149. x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
  150. x_padded.size(3))
  151. for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
  152. x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
  153. return x
  154. def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
  155. """For CPU, V100 or other older GPUs.
  156. NOTE: torch SPDA supports nested tensor,
  157. but seems extremely slow. Choose to pad instead.
  158. """
  159. assert (cu_seqlens_q is None or
  160. (cu_seqlens_q
  161. == cu_seqlens_k).all()), "Can only handle prompt with SPDA."
  162. assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
  163. assert q.size(1) % k.size(1) == 0
  164. q_k_ratio = q.size(1) // k.size(1)
  165. sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
  166. cu_seqlens = cu_seqlens_k.cpu()
  167. maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
  168. if (self.dense_attn_mask.dtype != q.dtype
  169. or self.dense_attn_mask.device != q.device):
  170. _, _, self.dense_attn_mask = self.get_attn_pattern(
  171. q.dtype, q.device)
  172. attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
  173. q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
  174. k2, v2 = [
  175. self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
  176. for x in [k, v]
  177. ]
  178. spda_output = torch.nn.functional.scaled_dot_product_attention(
  179. q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
  180. return self.transpose_and_unpad(spda_output, cu_seqlens)
  181. def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
  182. """Dispatch to `varlen_attn` (Ampere or newer) or
  183. `self.spda`(cpu, Volta, Turing or older)based on
  184. the type of device used and cuda compute capability.
  185. q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
  186. Support grouped attention, with `q[:, i*r:(i*r + r)]`
  187. is correspondent to `k[:, i]`, where `r` is the q/k ratio.
  188. cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
  189. e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
  190. cu_seqlens_q: shape=(batch_size + 1, ).
  191. Default None: same as cu_seqlens_k for prefilling or
  192. [0, 1, .., batch_size] for decoding.
  193. The only case you need to specify
  194. is when q is a mix of prefilling
  195. and decoding.
  196. sm_scale: softmax scale, default to 1/sqrt(head_size).
  197. return: tensor of shape as q.
  198. """
  199. assert k.dim() == 3
  200. if self.use_spda:
  201. return self.spda(
  202. q,
  203. k,
  204. v,
  205. cu_seqlens_k,
  206. cu_seqlens_q=cu_seqlens_q,
  207. sm_scale=sm_scale,
  208. )
  209. return self.varlen_attn(q,
  210. k,
  211. v,
  212. cu_seqlens_k,
  213. cu_seqlens_q=cu_seqlens_q,
  214. sm_scale=sm_scale)