123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- import math
- import torch
- from aphrodite.attention.ops.blocksparse_attention.utils import (
- dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask)
- from aphrodite.common.utils import is_cpu, is_hip
- from aphrodite.platforms import current_platform
- IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
- and current_platform.get_device_capability()[0] >= 8)
- if IS_COMPUTE_8_OR_ABOVE:
- from aphrodite.attention.ops.blocksparse_attention.blocksparse_attention_kernel import ( # noqa: E501
- blocksparse_flash_attn_varlen_fwd)
- class LocalStridedBlockSparseAttn(torch.nn.Module):
- def __init__(
- self,
- n_heads,
- max_seqlen,
- local_blocks,
- vert_stride,
- block_size,
- device=None,
- dtype=None,
- homo_head=False,
- active_head_range=None,
- q_block_size=None,
- use_spda=None,
- ):
- super().__init__()
- if use_spda is None:
- use_spda = is_hip() or is_cpu() or not \
- IS_COMPUTE_8_OR_ABOVE
- device = device or (torch.cuda.current_device()
- if torch.cuda.is_available() else "cpu")
- device = torch.device(device)
- # NOTE: aphrodite CPU backend support BF16 instead of FP16.
- dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
- or device.type == "cpu" else torch.half)
- self.n_heads = n_heads
- self.max_seqlen = max_seqlen
- self.local_blocks = local_blocks
- self.vert_stride = vert_stride
- self.use_spda = use_spda
- self.dtype = dtype
- self.device = device
- self.block_size = block_size
- self.q_block_size = q_block_size
- self.homo_head = homo_head
- self.active_head_range = active_head_range
- self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
- homo_head)
- sparse_layout, sparse_pattern, self.dense_attn_mask = (
- self.get_attn_pattern(dtype, device))
- if q_block_size is not None and q_block_size != block_size:
- if q_block_size > block_size:
- assert q_block_size % block_size == 0
- blocks_to_merge = q_block_size // block_size
- shape = sparse_pattern.shape
- sparse_pattern = sparse_pattern.view(shape[0], -1,
- blocks_to_merge,
- shape[-1])
- sparse_pattern = sparse_pattern.sum(2)
- sparse_layout = dense_to_crow_col(sparse_pattern)
- else:
- raise ValueError(
- "Does not support smaller q_block_size. It will be slower."
- )
- self.sparse_layout = sparse_layout
- def get_attn_pattern(self, dtype, device):
- sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
- self.n_heads,
- self.max_seqlen,
- self.max_seqlen,
- dtype,
- device,
- block_size=self.block_size,
- local_blocks=self.local_blocks,
- vert_stride=self.vert_stride,
- homo_head=self.homo_head,
- return_dense=self.use_spda,
- dense_mask_type="bias",
- )
- if (not self.homo_head) and (self.active_head_range is not None):
- assert isinstance(self.active_head_range, tuple)
- assert (len(self.active_head_range) == 2)
- h_start, h_end = self.active_head_range
- sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
- if self.use_spda:
- dense_attn_mask = dense_attn_mask[h_start:h_end]
- return sparse_layout, sparse_pattern, dense_attn_mask
- def varlen_attn(self,
- q,
- k,
- v,
- cu_seqlens_k,
- cu_seqlens_q=None,
- sm_scale=None):
- """
- q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
- Support grouped attention, with `q[:, i*r:(i*r + r)]`
- is correspondent to `k[:, i]`, where `r` is the q/k ratio.
- cu_seqlens_k: shape=(batch_size + 1,),
- indicating segment of samples,
- e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
- cu_seqlens_q: shape=(batch_size + 1, ).
- Default None: same as cu_seqlens_k for prefilling or
- [0, 1, .., batch_size] for decoding.
- The only case you need to specify is when q is a mix of
- prefilling and decoding.
- sm_scale: softmax scale, default to 1/sqrt(head_size).
- return: tensor of shape as q.
- """
- assert (
- IS_COMPUTE_8_OR_ABOVE
- ), "Requires compute capability of 8 or above (Ampere or newer) to use \
- Triton kernel."
- sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
- return blocksparse_flash_attn_varlen_fwd(
- q,
- k,
- v,
- cu_seqlens_k,
- cu_seqlens_q,
- sm_scale,
- self.sparse_layout,
- block_size=self.block_size,
- q_block_size=self.q_block_size,
- max_seqlen=self.max_seqlen,
- )
- @staticmethod
- def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
- """
- :param x: (total_tokens, n_heads, head_size)
- :return: (batch, n_heads, length, head_size)
- """
- x_padded = x.new_empty(
- len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
- cu_seqlens = cu_seqlens.cpu()
- for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
- x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
- 1).unsqueeze(1))
- return x_padded.flatten(1, 2)
- @staticmethod
- def transpose_and_unpad(x_padded, cu_seqlens):
- """
- :param x_padded: (batch, n_heads, length, head_size)
- :return: (total_tokens, n_heads, head_size)
- """
- cu_seqlens = cu_seqlens.cpu()
- total_n_tokens = cu_seqlens[-1]
- x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
- x_padded.size(3))
- for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
- x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
- return x
- def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
- """For CPU, V100 or other older GPUs.
- NOTE: torch SPDA supports nested tensor,
- but seems extremely slow. Choose to pad instead.
- """
- assert (cu_seqlens_q is None or
- (cu_seqlens_q
- == cu_seqlens_k).all()), "Can only handle prompt with SPDA."
- assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
- assert q.size(1) % k.size(1) == 0
- q_k_ratio = q.size(1) // k.size(1)
- sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
- cu_seqlens = cu_seqlens_k.cpu()
- maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
- if (self.dense_attn_mask.dtype != q.dtype
- or self.dense_attn_mask.device != q.device):
- _, _, self.dense_attn_mask = self.get_attn_pattern(
- q.dtype, q.device)
- attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
- q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
- k2, v2 = [
- self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
- for x in [k, v]
- ]
- spda_output = torch.nn.functional.scaled_dot_product_attention(
- q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
- return self.transpose_and_unpad(spda_output, cu_seqlens)
- def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
- """Dispatch to `varlen_attn` (Ampere or newer) or
- `self.spda`(cpu, Volta, Turing or older)based on
- the type of device used and cuda compute capability.
- q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
- Support grouped attention, with `q[:, i*r:(i*r + r)]`
- is correspondent to `k[:, i]`, where `r` is the q/k ratio.
- cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
- e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
- cu_seqlens_q: shape=(batch_size + 1, ).
- Default None: same as cu_seqlens_k for prefilling or
- [0, 1, .., batch_size] for decoding.
- The only case you need to specify
- is when q is a mix of prefilling
- and decoding.
- sm_scale: softmax scale, default to 1/sqrt(head_size).
- return: tensor of shape as q.
- """
- assert k.dim() == 3
- if self.use_spda:
- return self.spda(
- q,
- k,
- v,
- cu_seqlens_k,
- cu_seqlens_q=cu_seqlens_q,
- sm_scale=sm_scale,
- )
- return self.varlen_attn(q,
- k,
- v,
- cu_seqlens_k,
- cu_seqlens_q=cu_seqlens_q,
- sm_scale=sm_scale)
|