Browse Source

add phi3_small support with blocksparse attention

AlpinDale 7 months ago
parent
commit
696f2cd59c

+ 1 - 0
aphrodite/attention/backends/abstract.py

@@ -111,6 +111,7 @@ class AttentionImpl(ABC, Generic[T]):
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
         kv_cache_dtype: str = "auto",
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
         raise NotImplementedError
 

+ 405 - 0
aphrodite/attention/backends/blocksparse_attn.py

@@ -0,0 +1,405 @@
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import torch
+
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata)
+from aphrodite.attention.ops.blocksparse_attention.interface import (
+    LocalStridedBlockSparseAttn, get_head_sliding_step)
+from aphrodite.attention.ops.paged_attn import PagedAttention
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
+
+
+@dataclass
+class BlocksparseParams:
+    max_seqlen: int
+
+    # Num q heads per tensor-parallel rank/partition
+    num_heads: int  # per TP partition
+    # Num kv heads per tensor-parallel rank/partition
+    num_kv_heads: int
+
+    # block size used for blocksparse attention.
+    # This is the block_size used in `local_blocks`, `vert_stride`.
+    block_size: int
+
+    # Number of blocks for local attention, i.e., number of
+    # local attended tokens / `sparse_block_size`
+    local_blocks: int
+
+    # Attend to one block per every `vert_stride` blocks.
+    # Controlling the sparsity
+    vert_stride: int
+    """
+    If to use the same vertical stride offset for all heads, 
+    i.e., attend to the same block of tokens on all heads.
+    By default, it is False, i.e., attention on the non-local 
+    blocks depends on the `head_idx`, that is on
+    blocks satisfying 
+    `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
+    where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
+            `block_idx = position_id // sparse_block_size`.
+    See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
+    for more detail.
+    """
+    homo_head: bool = False
+
+    # If within a group, the kv offsets that each q attends is the same or no.
+    homo_head_group: bool = False
+
+    # Decided by homo_head and homo_head group
+    head_sliding_step: int = field(init=False)
+
+    # range of q heads to for a TP rank
+    active_head_range: Tuple = field(init=False)
+
+    def __post_init__(self):
+        assert self.block_size > 0
+        assert self.local_blocks >= 0
+        assert self.vert_stride >= 1
+        assert self.num_heads % self.num_kv_heads == 0
+
+        tp_size = get_tensor_model_parallel_world_size()
+        tp_rank = get_tensor_model_parallel_rank()
+        total_heads = tp_size * self.num_heads
+        total_kv_heads = tp_size * self.num_kv_heads
+
+        if self.homo_head:
+            self.head_sliding_step = 0
+        elif self.homo_head_group:
+            head_sliding_step = get_head_sliding_step(total_kv_heads,
+                                                      self.vert_stride)
+            # negative indicates sliding along kv heads, i.e., homo q group
+            self.head_sliding_step = -head_sliding_step
+        else:
+            self.head_sliding_step = get_head_sliding_step(
+                total_heads, self.vert_stride)
+
+        self.active_head_range = (
+            tp_rank * self.num_heads,
+            (tp_rank + 1) * self.num_heads,
+        )
+
+
+class BlocksparseFlashAttentionBackend(AttentionBackend):
+
+    @staticmethod
+    def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
+        return BlocksparseFlashAttentionImpl
+
+    @staticmethod
+    def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
+        return BlocksparseFlashAttentionMetadata(*args, **kwargs)
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
+                                                 num_kv_heads, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+    ) -> None:
+        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+    ) -> None:
+        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+
+
+@dataclass
+class BlocksparseFlashAttentionMetadata(AttentionMetadata):
+    """A copy of Metadata for FlashAttentionBackend,
+    to avoid having to install flash_attn.
+    NOTE: Any python object stored here is not updated when it is
+    cuda-graph replayed. If you have values that need to be changed
+    dynamically, it should be stored in tensor. The tensor has to be
+    updated from `CUDAGraphRunner.forward` API.
+    """
+    # (batch_size,). The sequence length per sequence. Sequence length means
+    # the computed tokens + new tokens None if it is a decoding.
+    seq_lens: Optional[List[int]]
+    # seq_lens stored as a tensor.
+    seq_lens_tensor: Optional[torch.Tensor]
+
+    # NOTE(sang): Definition of context_len, query_len, and seq_len.
+    # |---------- N-1 iteration --------|
+    # |---------------- N iteration ---------------------|
+    # |- tokenA -|......................|-- newTokens ---|
+    # |---------- context_len ----------|
+    # |-------------------- seq_len ----------------------|
+    #                                   |-- query_len ---|
+
+    # Maximum query length in the batch. None for decoding.
+    max_query_len: Optional[int]
+    # Maximum sequence length among prefill batch. 0 if there are decoding
+    # requests only.
+    max_prefill_seq_len: int
+    # Maximum sequence length among decode batch. 0 if there are prefill
+    # requests only.
+    max_decode_seq_len: int
+    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+    # the batch, used to index into subquery. E.g., if the subquery length
+    # is [4, 6], it is [0, 4, 10].
+    query_start_loc: Optional[torch.Tensor]
+    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+    # the batch, used to index into sequence. E.g., if the sequence length is
+    # [4, 6], it is [0, 4, 10].
+    seq_start_loc: Optional[torch.Tensor]
+    # (batch_size,) A tensor of context lengths (tokens that are computed
+    # so far).
+    context_lens_tensor: Optional[torch.Tensor]
+
+    # (batch_size, max_blocks_per_seq).
+    # Block addresses per sequence. (Seq id -> list of physical block)
+    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
+    # in the kv cache. Each block can contain up to block_size tokens.
+    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
+    # captured.
+    block_tables: Optional[torch.Tensor]
+
+    # Whether or not if cuda graph is enabled.
+    # Cuda-graph is currently enabled for decoding only.
+    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
+    use_cuda_graph: bool
+
+    _cached_prefill_metadata: Optional[
+        "BlocksparseFlashAttentionMetadata"] = None
+    _cached_decode_metadata: Optional[
+        "BlocksparseFlashAttentionMetadata"] = None
+
+    @property
+    def prefill_metadata(
+            self) -> Optional["BlocksparseFlashAttentionMetadata"]:
+        if self.num_prefills == 0:
+            return None
+
+        if self._cached_prefill_metadata is not None:
+            return self._cached_prefill_metadata
+
+        assert self.seq_lens is not None
+        assert self.seq_lens_tensor is not None
+        assert self.query_start_loc is not None
+        assert self.context_lens_tensor is not None
+        assert self.block_tables is not None
+        assert self.seq_start_loc is not None
+
+        self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
+            num_prefills=self.num_prefills,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=0,
+            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
+            seq_lens=self.seq_lens[:self.num_prefills],
+            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
+            max_query_len=self.max_query_len,
+            max_prefill_seq_len=self.max_prefill_seq_len,
+            max_decode_seq_len=0,
+            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
+            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
+            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
+            block_tables=self.block_tables[:self.num_prefills],
+            use_cuda_graph=False,
+        )
+        return self._cached_prefill_metadata
+
+    @property
+    def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
+        if self.num_decode_tokens == 0:
+            return None
+
+        if self._cached_decode_metadata is not None:
+            return self._cached_decode_metadata
+        assert self.block_tables is not None
+        assert self.seq_lens_tensor is not None
+
+        self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
+            num_prefills=0,
+            num_prefill_tokens=0,
+            num_decode_tokens=self.num_decode_tokens,
+            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
+            seq_lens=None,
+            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
+            max_query_len=None,
+            max_prefill_seq_len=0,
+            max_decode_seq_len=self.max_decode_seq_len,
+            query_start_loc=None,
+            seq_start_loc=None,
+            context_lens_tensor=None,
+            block_tables=self.block_tables[self.num_prefills:],
+            use_cuda_graph=self.use_cuda_graph,
+        )
+        return self._cached_decode_metadata
+
+
+class BlocksparseFlashAttentionImpl(AttentionImpl):
+    """
+    If the input tensors contain prompt tokens, the layout is as follows:
+    |<--------------- num_prompt_tokens -------------->|
+    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
+    Otherwise, the layout is as follows:
+    |<------------------ num_generation_tokens (M) ----------------->|
+    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
+    Generation tokens can contain padding when cuda-graph is used.
+    Currently, prompt tokens don't contain any padding.
+    The prompts might have different lengths, while the generation tokens
+    always have length 1.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        assert blocksparse_params is not None
+        assert alibi_slopes is None, ValueError(
+            "Alibi not support for blocksparse flash attention.")
+        assert sliding_window is None, ValueError(
+            "sliding_window is invalid for blocksparse attention.")
+
+        if "num_heads" not in blocksparse_params:
+            blocksparse_params["num_heads"] = num_heads
+        if "num_kv_heads" not in blocksparse_params:
+            blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
+        self.blocksparse_params = BlocksparseParams(**blocksparse_params)
+        self.kv_cache_dtype = kv_cache_dtype
+
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.alibi_slopes = alibi_slopes
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+        self.local_blocks = self.blocksparse_params.local_blocks
+        self.vert_stride = self.blocksparse_params.vert_stride
+        self.sparse_block_size = self.blocksparse_params.block_size
+        self.head_sliding_step = self.blocksparse_params.head_sliding_step
+
+        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+        self.tp_size = get_tensor_model_parallel_world_size()
+        self.tp_rank = get_tensor_model_parallel_rank()
+
+        total_num_heads = num_heads * self.tp_size
+        self.bs_attn = LocalStridedBlockSparseAttn(
+            total_num_heads,
+            self.blocksparse_params.max_seqlen,
+            self.blocksparse_params.local_blocks,
+            self.blocksparse_params.vert_stride,
+            self.blocksparse_params.block_size,
+            homo_head=self.blocksparse_params.homo_head,
+            active_head_range=self.blocksparse_params.active_head_range,
+        )
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: BlocksparseFlashAttentionMetadata,
+        kv_scale: float = 1.0,
+    ) -> torch.Tensor:
+        """Forward pass with FlashAttention and PagedAttention.
+        Args:
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [num_tokens, num_heads * head_size]
+        """
+        num_tokens, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        if kv_cache is not None:
+            key_cache, value_cache = PagedAttention.split_kv_cache(
+                kv_cache, self.num_kv_heads, self.head_size)
+
+            # Reshape the input keys and values and store them in the cache.
+            # If kv_cache is not provided, the new key and value tensors are
+            # not cached. This happens during the initial memory profiling run.
+
+            PagedAttention.write_to_paged_cache(
+                key,
+                value,
+                key_cache,
+                value_cache,
+                attn_metadata.slot_mapping,
+                self.kv_cache_dtype,
+                kv_scale,
+            )
+
+        if prefill_meta := attn_metadata.prefill_metadata:
+
+            # Prompt run.
+            # normal attention
+            # When block_tables are not filled, it means q and k are the
+            # prompt, and they have the same length.
+
+            assert kv_cache is None \
+                    or prefill_meta.block_tables is None \
+                    or prefill_meta.block_tables.numel() == 0, \
+                "Does not support prefix-enabled attention."
+
+            output = self.bs_attn(
+                q=query,
+                k=key,
+                v=value,
+                cu_seqlens_q=prefill_meta.seq_start_loc,
+                cu_seqlens_k=prefill_meta.seq_start_loc,
+                sm_scale=self.scale,
+            )
+
+        if decode_meta := attn_metadata.decode_metadata:
+            # Decoding run.
+            output = PagedAttention.forward_decode(
+                query,
+                key_cache,
+                value_cache,
+                decode_meta.block_tables,
+                decode_meta.seq_lens_tensor,
+                self.blocksparse_params.max_seqlen,
+                self.kv_cache_dtype,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+                kv_scale,
+                tp_rank=self.tp_rank,
+                blocksparse_local_blocks=self.local_blocks,
+                blocksparse_vert_stride=self.vert_stride,
+                blocksparse_block_size=self.sparse_block_size,
+                blocksparse_head_sliding_step=self.head_sliding_step,
+            )
+
+        # Reshape the output tensor.
+        return output.view(num_tokens, hidden_size)

+ 5 - 1
aphrodite/attention/backends/flash_attn.py

@@ -1,6 +1,6 @@
 """Attention layer with FlashAttention."""
 from dataclasses import dataclass
-from typing import List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 import torch
 from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@@ -220,7 +220,10 @@ class FlashAttentionImpl(AttentionImpl):
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "FlashAttention does not support block-sparse attention.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
@@ -240,6 +243,7 @@ class FlashAttentionImpl(AttentionImpl):
             # paged KV cache.
             raise ValueError(
                 "Sliding window is not supported in FlashAttention.")
+
         support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
         if head_size not in support_head_sizes:
             raise ValueError(

+ 3 - 0
aphrodite/attention/backends/flashinfer.py

@@ -169,7 +169,10 @@ class FlashInferImpl(AttentionImpl):
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "FlashInfer does not support block-sparse attention.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)

+ 5 - 2
aphrodite/attention/backends/rocm_flash_attn.py

@@ -1,7 +1,7 @@
 """Attention layer ROCm GPUs."""
-from dataclasses import dataclass
 import os
-from typing import List, Optional, Tuple, Type
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 import torch
 from loguru import logger
@@ -200,7 +200,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "ROCm FlashAttention does not support block-sparse attention.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)

+ 4 - 1
aphrodite/attention/backends/torch_sdpa.py

@@ -1,7 +1,7 @@
 """ Attention layer with torch scaled_dot_product_attention
     and PagedAttention."""
 from dataclasses import dataclass
-from typing import List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 import torch
 from torch.nn.functional import scaled_dot_product_attention
@@ -101,7 +101,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
+    blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "Torch SDPA does not support block-sparse attention.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)

+ 4 - 1
aphrodite/attention/backends/xformers.py

@@ -1,6 +1,6 @@
 """Attention layer with xFormers and PagedAttention."""
 from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 import torch
 from xformers import ops as xops
@@ -210,7 +210,10 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "XFormers does not support block-sparse attention.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)

+ 12 - 4
aphrodite/attention/layer.py

@@ -1,5 +1,5 @@
 """Attention layer."""
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
 
 import torch
 import torch.nn as nn
@@ -12,9 +12,11 @@ from aphrodite.quantization.base_config import QuantizationConfig
 
 class Attention(nn.Module):
     """Attention layer.
+
     This class takes query, key, and value tensors as input. The input tensors
     can either contain prompt tokens or generation tokens.
     The class does the following:
+
     1. Store the input key and value tensors in the KV cache.
     2. Perform (multi-head/multi-query/grouped-query) attention.
     3. Return the output tensor.
@@ -30,6 +32,7 @@ class Attention(nn.Module):
         sliding_window: Optional[int] = None,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
         super().__init__()
         if cache_config is not None:
@@ -60,15 +63,18 @@ class Attention(nn.Module):
             # to self._kv_scale in a native float32 value after weight loading.
             self.quant_method = quant_method
             self.quant_method.create_weights(self)
+
         # During model initialization, the default dtype is set as the model
         # weight and activation dtype.
         dtype = torch.get_default_dtype()
         attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
                                         sliding_window, dtype, kv_cache_dtype,
-                                        block_size)
+                                        block_size, blocksparse_params
+                                        is not None)
         impl_cls = attn_backend.get_impl_cls()
         self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
-                             alibi_slopes, sliding_window, kv_cache_dtype)
+                             alibi_slopes, sliding_window, kv_cache_dtype,
+                             blocksparse_params)
 
     def forward(
         self,
@@ -78,11 +84,13 @@ class Attention(nn.Module):
         kv_cache: Optional[torch.Tensor],
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
-        return self.impl.forward(query, key, value, kv_cache, attn_metadata)
+        return self.impl.forward(query, key, value, kv_cache, attn_metadata,
+                                 self._kv_scale)
 
     def extra_repr(self) -> str:
         s = f"head_size={self.impl.head_size}"  # type: ignore
         s += f", num_heads={self.impl.num_heads}"  # type: ignore
         s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
         s += f", scale={self.impl.scale}"  # type: ignore
+        s += f", backend={self.impl.__class__.__name__}"
         return s

+ 0 - 0
aphrodite/attention/ops/blocksparse_attention/__init__.py


+ 422 - 0
aphrodite/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py

@@ -0,0 +1,422 @@
+import torch
+import triton
+import triton.language as tl
+
+
+def blocksparse_flash_attn_varlen_fwd(
+        q,
+        k,
+        v,  # (#tokens, n_heads, head_size)
+        cu_seqlens_k,
+        cu_seqlens_q,
+        sm_scale,
+        sparse_layout,
+        *,
+        block_size=64,
+        q_block_size=None,
+        max_seqlen=None):
+    # split q to blocks
+
+    assert isinstance(sparse_layout, (list, tuple))
+
+    _, n_heads, head_size = q.shape
+    batch_size = cu_seqlens_k.size(0) - 1
+    q_block_size = q_block_size or block_size
+
+    assert q.dim() == k.dim() == v.dim() == 3
+    assert q.size(1) % k.size(1) == 0
+    assert q.size(2) == k.size(2)
+    # TODO: allow k, v to have different head_size
+    assert k.shape == v.shape
+    assert cu_seqlens_k.dim() == 1
+
+    q_k_ratio = q.size(1) // k.size(1)
+
+    if cu_seqlens_q is None:
+        if q.size(0) == batch_size:  # decoding only
+            cu_seqlens_q = torch.arange(
+                0,
+                batch_size + 1,
+                dtype=cu_seqlens_k.dtype,
+                device=cu_seqlens_k.device,
+            )
+        elif q.size(0) == k.size(0):
+            cu_seqlens_q = cu_seqlens_k
+        else:
+            raise ValueError("cu_seqlens_q must be specified\
+                    if it mix of prefilling and decoding.")
+    else:
+        assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
+
+    # switch to use cpu to avoid too many kernel launches when iterated over
+    q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
+    k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
+
+    assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
+        "length of q should either be 1 (decoding) or same as k (prefilling).")
+
+    if max_seqlen:
+        assert k_lens.max() <= max_seqlen
+
+    n_blocks = (q_lens + q_block_size - 1) // q_block_size
+
+    q_batch_ids = torch.tensor(
+        [i for i, n in enumerate(n_blocks) for _ in range(n)],
+        dtype=cu_seqlens_q.dtype,
+        device=cu_seqlens_q.device,
+    )
+    q_start_sids = torch.tensor(
+        [i * q_block_size for n in n_blocks for i in range(n)],
+        dtype=cu_seqlens_q.dtype,
+        device=cu_seqlens_q.device,
+    )
+
+    out = q.new_empty(q.shape)
+    cu_seqlens_q = cu_seqlens_q.contiguous()
+    cu_seqlens_k = cu_seqlens_k.contiguous()
+
+    layout_crow_indices, layout_col_indices = sparse_layout
+    block_d = triton.next_power_of_2(head_size)
+
+    decoding_only = (q_lens == 1).all().item()
+    grid = (len(q_start_sids), n_heads, 1)
+
+    _fwd_kernel_batch_inference[grid](
+        q,
+        k,
+        v,
+        out,
+        sm_scale,
+        cu_seqlens_q[:-1],
+        cu_seqlens_q[1:],
+        cu_seqlens_k[:-1],
+        cu_seqlens_k[1:],
+        q_batch_ids,
+        q_start_sids,
+        0,
+        *q.stride(),
+        0,
+        *k.stride(),
+        0,
+        *v.stride(),
+        0,
+        *out.stride(),
+        layout_crow_indices,
+        layout_col_indices,
+        *layout_crow_indices.stride(),
+        *layout_col_indices.stride(),
+        q_k_ratio,
+        HAS_BATCH_DIM=False,
+        D_HEAD=head_size,
+        BLOCK_M=q_block_size,
+        BLOCK_N=block_size,
+        BLOCK_D=block_d,
+        BLOCK_M_LOADING=(16 if decoding_only else
+                         q_block_size),  # smaller for decoding
+        EVEN_D=block_d == head_size,
+        num_warps=1 if decoding_only else 4,
+        num_stages=3)
+
+    return out
+
+
+@triton.jit
+def _fwd_kernel_inner(
+    acc,
+    l_i,
+    m_i,
+    q,
+    Q,
+    k_block_col_idx,
+    layout_col_ptr,
+    layout_col_stride_h,
+    layout_col_stride_m,
+    k_ptrs,
+    v_ptrs,
+    off_h,
+    offs_m,
+    offs_n,
+    offs_d,
+    stride_kt,
+    stride_vt,
+    sm_scale,
+    k_seqlen,
+    past_len,
+    LAST_K_BLOCK: tl.constexpr,
+    BLOCK_M_LOADING: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    D_HEAD: tl.constexpr,
+    EVEN_D: tl.constexpr,
+    M_LT_N: tl.constexpr,
+):
+    k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
+                         k_block_col_idx * layout_col_stride_m).to(tl.int32)
+    start_n = k_block_id * BLOCK_N
+    if LAST_K_BLOCK:
+        if EVEN_D:
+            k = tl.load(
+                k_ptrs + start_n * stride_kt,
+                mask=offs_n[None, :] + start_n < k_seqlen,
+            )
+        else:
+            k = tl.load(
+                k_ptrs + start_n * stride_kt,
+                mask=(offs_n[None, :] + start_n < k_seqlen) &
+                (offs_d[:, None] < D_HEAD),
+            )
+    else:
+        if EVEN_D:
+            k = tl.load(k_ptrs + start_n * stride_kt)
+        else:
+            k = tl.load(k_ptrs + start_n * stride_kt,
+                        mask=offs_d[:, None] < D_HEAD)
+
+    qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
+    qk += tl.dot(q, k)
+    qk *= sm_scale
+
+    # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
+    if LAST_K_BLOCK | M_LT_N:
+        qk += tl.where(
+            offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
+            0,
+            float("-inf"),
+        )
+
+    # flash-attn2
+    m_ij = tl.maximum(m_i, tl.max(qk, 1))
+    p = tl.math.exp2(qk - m_ij[:, None])
+    l_ij = tl.sum(p, 1)
+    alpha = tl.math.exp2(m_i - m_ij)
+    acc = acc * alpha[:, None]
+    # update m_i
+    m_i = m_ij
+    l_i = l_i * alpha + l_ij
+
+    p = p.to(Q.dtype.element_ty)
+    # update acc
+    if LAST_K_BLOCK:
+        if EVEN_D:
+            v = tl.load(
+                v_ptrs + start_n * stride_vt,
+                mask=offs_n[:, None] + start_n < k_seqlen,
+            )
+        else:
+            v = tl.load(
+                v_ptrs + start_n * stride_vt,
+                mask=(offs_n[:, None] + start_n < k_seqlen) &
+                (offs_d[None, :] < D_HEAD),
+            )
+    else:
+        if EVEN_D:
+            v = tl.load(v_ptrs + start_n * stride_vt)
+        else:
+            v = tl.load(v_ptrs + start_n * stride_vt,
+                        mask=offs_d[None, :] < D_HEAD)
+
+    acc += tl.dot(p, v)
+
+    return acc, l_i, m_i
+
+
+@triton.heuristics({
+    "M_LT_N":
+    lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
+})
+@triton.jit
+def _fwd_kernel_batch_inference(
+    Q,
+    K,
+    V,
+    Out,
+    sm_scale,
+    q_batch_starts,
+    q_batch_ends,
+    k_batch_starts,
+    k_batch_ends,
+    q_batch_ids,
+    q_start_sids,
+    stride_qb,
+    stride_qt,
+    stride_qh,
+    stride_qd,
+    stride_kb,
+    stride_kt,
+    stride_kh,
+    stride_kd,
+    stride_vb,
+    stride_vt,
+    stride_vh,
+    stride_vd,
+    stride_ob,
+    stride_ot,
+    stride_oh,
+    stride_od,
+    layout_crow_ptr,
+    layout_col_ptr,
+    layout_crow_stride_h,
+    layout_crow_stride_m,
+    layout_col_stride_h,
+    layout_col_stride_m,
+    q_k_ratio,
+    HAS_BATCH_DIM: tl.constexpr,
+    D_HEAD: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    BLOCK_D: tl.constexpr,
+    BLOCK_M_LOADING: tl.constexpr,
+    EVEN_D: tl.constexpr,
+    M_LT_N: tl.constexpr,
+):
+    """
+    NOTATION:
+    pid: position id
+    sid: storage id
+    sbid: storage block id
+    pbid: position block id
+    offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
+    TODO:
+    Optimize grouped-attn
+    """
+    off_zm = tl.program_id(0)
+    off_h = tl.program_id(1)
+
+    off_h_for_kv = off_h // q_k_ratio
+
+    if HAS_BATCH_DIM:
+        off_z = tl.program_id(2)
+        Q += off_z * stride_qb
+        K += off_z * stride_kb
+        V += off_z * stride_vb
+        Out += off_z * stride_ob
+        start_m = off_zm
+        q_start_sid = start_m * BLOCK_M  # always 0 for decoding
+    else:
+        off_z = tl.load(q_batch_ids + off_zm).to(tl.int32)  # [0, 0, 0, 1]
+        q_start_sid = tl.load(q_start_sids + off_zm)
+        start_m = q_start_sid // BLOCK_M  # q_sbid
+
+    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
+    offs_n = tl.arange(0, BLOCK_N)
+    offs_d = tl.arange(0, BLOCK_D)
+
+    q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
+    q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
+    k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
+    k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
+    past_len = k_seqlen - q_seqlen
+
+    Q += q_cu_start * stride_qt + off_h * stride_qh
+    K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
+    V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
+    Out += q_cu_start * stride_ot + off_h * stride_oh
+
+    q_pbid = (past_len + q_start_sid) // BLOCK_M
+
+    if EVEN_D:
+        q = tl.load(
+            Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
+            mask=offs_m[:, None] < q_seqlen,
+        )
+    else:
+        q = tl.load(
+            Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
+            mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
+            other=0,
+        )
+
+    sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
+                       q_pbid * layout_crow_stride_m)
+
+    # TODO: load at once, with any Triton version
+    # that supports `tl.split`, e.g., Triton 3.0
+    k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
+    k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
+
+    m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
+    l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
+    acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
+
+    k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
+    v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
+
+    sm_scale *= (
+        1.44269504  # 1/log2 as we use base2 for exponential and logarithm
+    )
+
+    for k_block_col_idx in range(k_block_start, k_block_end - 1):
+        acc, l_i, m_i = _fwd_kernel_inner(
+            acc,
+            l_i,
+            m_i,
+            q,
+            Q,
+            k_block_col_idx,
+            layout_col_ptr,
+            layout_col_stride_h,
+            layout_col_stride_m,
+            k_ptrs,
+            v_ptrs,
+            off_h,
+            offs_m,
+            offs_n,
+            offs_d,
+            stride_kt,
+            stride_vt,
+            sm_scale,
+            k_seqlen,
+            past_len,
+            False,
+            BLOCK_M_LOADING,
+            BLOCK_N,
+            D_HEAD,
+            EVEN_D,
+            M_LT_N,
+        )
+
+    acc, l_i, m_i = _fwd_kernel_inner(
+        acc,
+        l_i,
+        m_i,
+        q,
+        Q,
+        k_block_end - 1,
+        layout_col_ptr,
+        layout_col_stride_h,
+        layout_col_stride_m,
+        k_ptrs,
+        v_ptrs,
+        off_h,
+        offs_m,
+        offs_n,
+        offs_d,
+        stride_kt,
+        stride_vt,
+        sm_scale,
+        k_seqlen,
+        past_len,
+        True,
+        BLOCK_M_LOADING,
+        BLOCK_N,
+        D_HEAD,
+        EVEN_D,
+        M_LT_N,
+    )
+
+    # flash-attn 2
+    m_i += tl.math.log2(l_i)
+    acc = acc / l_i[:, None]
+
+    # write output
+    if EVEN_D:
+        tl.store(
+            Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
+            acc,
+            mask=offs_m[:, None] < q_seqlen,
+        )
+    else:
+        tl.store(
+            Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
+            acc,
+            mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
+        )

+ 235 - 0
aphrodite/attention/ops/blocksparse_attention/interface.py

@@ -0,0 +1,235 @@
+import math
+
+import torch
+
+from aphrodite.attention.ops.blocksparse_attention.utils import (
+    dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask)
+from aphrodite.common.utils import is_cpu, is_hip
+
+IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
+                         and torch.cuda.get_device_capability()[0] >= 8)
+
+if IS_COMPUTE_8_OR_ABOVE:
+    from aphrodite.attention.ops.blocksparse_attention.blocksparse_attention_kernel import \
+        blocksparse_flash_attn_varlen_fwd  # noqa: E501
+
+
+class LocalStridedBlockSparseAttn(torch.nn.Module):
+
+    def __init__(
+        self,
+        n_heads,
+        max_seqlen,
+        local_blocks,
+        vert_stride,
+        block_size,
+        device=None,
+        dtype=None,
+        homo_head=False,
+        active_head_range=None,
+        q_block_size=None,
+        use_spda=None,
+    ):
+        super().__init__()
+        if use_spda is None:
+            use_spda = is_hip() or is_cpu() or not \
+                       IS_COMPUTE_8_OR_ABOVE
+        device = device or (torch.cuda.current_device()
+                            if torch.cuda.is_available() else "cpu")
+        device = torch.device(device)
+        # NOTE: aphrodite CPU backend support BF16 instead of FP16.
+        dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
+                          or device.type == "cpu" else torch.half)
+
+        self.n_heads = n_heads
+        self.max_seqlen = max_seqlen
+        self.local_blocks = local_blocks
+        self.vert_stride = vert_stride
+        self.use_spda = use_spda
+        self.dtype = dtype
+        self.device = device
+        self.block_size = block_size
+        self.q_block_size = q_block_size
+        self.homo_head = homo_head
+        self.active_head_range = active_head_range
+        self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
+                                                       homo_head)
+
+        sparse_layout, sparse_pattern, self.dense_attn_mask = (
+            self.get_attn_pattern(dtype, device))
+
+        if q_block_size is not None and q_block_size != block_size:
+            if q_block_size > block_size:
+                assert q_block_size % block_size == 0
+                blocks_to_merge = q_block_size // block_size
+                shape = sparse_pattern.shape
+                sparse_pattern = sparse_pattern.view(shape[0], -1,
+                                                     blocks_to_merge,
+                                                     shape[-1])
+                sparse_pattern = sparse_pattern.sum(2)
+                sparse_layout = dense_to_crow_col(sparse_pattern)
+            else:
+                raise ValueError(
+                    "Does not support smaller q_block_size. It will be slower."
+                )
+
+        self.sparse_layout = sparse_layout
+
+    def get_attn_pattern(self, dtype, device):
+        sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
+            self.n_heads,
+            self.max_seqlen,
+            self.max_seqlen,
+            dtype,
+            device,
+            block_size=self.block_size,
+            local_blocks=self.local_blocks,
+            vert_stride=self.vert_stride,
+            homo_head=self.homo_head,
+            return_dense=self.use_spda,
+            dense_mask_type="bias",
+        )
+        if (not self.homo_head) and (self.active_head_range is not None):
+            assert isinstance(self.active_head_range, tuple)
+            assert (len(self.active_head_range) == 2)
+            h_start, h_end = self.active_head_range
+            sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
+            if self.use_spda:
+                dense_attn_mask = dense_attn_mask[h_start:h_end]
+        return sparse_layout, sparse_pattern, dense_attn_mask
+
+    def varlen_attn(self,
+                    q,
+                    k,
+                    v,
+                    cu_seqlens_k,
+                    cu_seqlens_q=None,
+                    sm_scale=None):
+        """
+        q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
+        Support grouped attention, with `q[:, i*r:(i*r + r)]`
+        is correspondent to `k[:, i]`, where `r` is the q/k ratio.
+        cu_seqlens_k: shape=(batch_size + 1,), 
+        indicating segment of samples, 
+        e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
+        cu_seqlens_q: shape=(batch_size + 1, ).
+        Default None: same as cu_seqlens_k for prefilling or
+        [0, 1, .., batch_size] for decoding.
+        The only case you need to specify is when q is a mix of 
+        prefilling and decoding.
+        sm_scale: softmax scale, default to 1/sqrt(head_size).
+        return: tensor of shape as q.
+        """
+        assert (
+            IS_COMPUTE_8_OR_ABOVE
+        ), "Requires compute capability of 8 or above (Ampere or newer) to use \
+            Triton kernel."
+
+        sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
+
+        return blocksparse_flash_attn_varlen_fwd(
+            q,
+            k,
+            v,
+            cu_seqlens_k,
+            cu_seqlens_q,
+            sm_scale,
+            self.sparse_layout,
+            block_size=self.block_size,
+            q_block_size=self.q_block_size,
+            max_seqlen=self.max_seqlen,
+        )
+
+    @staticmethod
+    def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
+        """
+        :param x: (total_tokens, n_heads, head_size)
+        :return: (batch, n_heads, length, head_size)
+        """
+        x_padded = x.new_empty(
+            len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
+        cu_seqlens = cu_seqlens.cpu()
+        for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
+            x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
+                                                             1).unsqueeze(1))
+        return x_padded.flatten(1, 2)
+
+    @staticmethod
+    def transpose_and_unpad(x_padded, cu_seqlens):
+        """
+        :param x_padded: (batch, n_heads, length, head_size)
+        :return: (total_tokens, n_heads, head_size)
+        """
+        cu_seqlens = cu_seqlens.cpu()
+        total_n_tokens = cu_seqlens[-1]
+        x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
+                               x_padded.size(3))
+        for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
+            x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
+        return x
+
+    def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
+        """For CPU, V100 or other older GPUs.
+        NOTE: torch SPDA supports nested tensor, 
+        but seems extremely slow. Choose to pad instead.
+        """
+        assert (cu_seqlens_q is None or
+                (cu_seqlens_q
+                 == cu_seqlens_k).all()), "Can only handle prompt with SPDA."
+        assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
+
+        assert q.size(1) % k.size(1) == 0
+        q_k_ratio = q.size(1) // k.size(1)
+        sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
+        cu_seqlens = cu_seqlens_k.cpu()
+        maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+
+        if (self.dense_attn_mask.dtype != q.dtype
+                or self.dense_attn_mask.device != q.device):
+            _, _, self.dense_attn_mask = self.get_attn_pattern(
+                q.dtype, q.device)
+        attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
+
+        q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
+        k2, v2 = [
+            self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
+            for x in [k, v]
+        ]
+        spda_output = torch.nn.functional.scaled_dot_product_attention(
+            q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
+        return self.transpose_and_unpad(spda_output, cu_seqlens)
+
+    def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
+        """Dispatch to `varlen_attn` (Ampere or newer) or 
+        `self.spda`(cpu, Volta, Turing or older)based on 
+        the type of device used and cuda compute capability.
+        q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
+                Support grouped attention, with `q[:, i*r:(i*r + r)]`
+                is correspondent to `k[:, i]`, where `r` is the q/k ratio.
+        cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
+                    e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
+        cu_seqlens_q: shape=(batch_size + 1, ).
+                    Default None: same as cu_seqlens_k for prefilling or
+                    [0, 1, .., batch_size] for decoding.
+                    The only case you need to specify 
+                    is when q is a mix of prefilling 
+                    and decoding.
+        sm_scale: softmax scale, default to 1/sqrt(head_size).
+        return: tensor of shape as q.
+        """
+        assert k.dim() == 3
+        if self.use_spda:
+            return self.spda(
+                q,
+                k,
+                v,
+                cu_seqlens_k,
+                cu_seqlens_q=cu_seqlens_q,
+                sm_scale=sm_scale,
+            )
+        return self.varlen_attn(q,
+                                k,
+                                v,
+                                cu_seqlens_k,
+                                cu_seqlens_q=cu_seqlens_q,
+                                sm_scale=sm_scale)

+ 216 - 0
aphrodite/attention/ops/blocksparse_attention/utils.py

@@ -0,0 +1,216 @@
+# Helper functions for 3D sparse pattern
+# These function are not optimized and very inefficient.
+# Avoid calling them too frequent or use a cache mechanism.
+
+from functools import lru_cache
+
+import torch
+import triton
+from scipy import sparse
+
+
+def dense_to_crow_col(x: torch.Tensor):
+    """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
+    NOTE: col_indices padded -1
+    """
+    device = x.device
+    pad = -1
+    dim = x.dim()
+    assert x.dim() in (2, 3)
+    if x.dim() == 2:
+        x = x[None]
+    x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
+    crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
+    cols = [torch.from_numpy(xi.indices) for xi in x]
+    max_cols = max(len(xi) for xi in cols)
+    cols = [
+        torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
+        for xi in cols
+    ]
+    cols = torch.vstack(cols)
+    if dim == 2:
+        crows = crows[0]
+        cols = cols[0]
+    return crows.to(device), cols.to(device)
+
+
+def crow_col_to_dense(crows: torch.Tensor,
+                      cols: torch.Tensor,
+                      dtype: torch.dtype = torch.float16):
+    dim = crows.dim()
+    if dim == 1:
+        crows = crows[None]
+        cols = cols[None]
+    device = crows.device
+    crows, cols = crows.cpu(), cols.cpu()  # faster in cpu
+    shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
+    x = torch.zeros(shape, dtype=dtype)
+    for i in range(shape[0]):
+        for j in range(shape[1]):
+            x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
+    if dim == 1:
+        x = x[0]
+    return x.to(device)
+
+
+def dense_to_ccol_row(x: torch.Tensor):
+    """Similar, but to CSC format"""
+    x = x.transpose(-2, -1)
+    return dense_to_crow_col(x)
+
+
+def ccol_row_to_dense(ccol: torch.Tensor,
+                      rows: torch.Tensor,
+                      dtype: torch.dtype = torch.float16):
+    return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
+
+
+def _get_sparse_attn_mask_homo_head(
+    q_len: int,
+    max_seqlen: int,
+    dtype: torch.dtype,
+    device: torch.device,
+    block_size: int = 128,
+    local_blocks: int = 4,
+    vert_stride: int = 4,
+    return_dense: bool = False,
+):
+    """
+    :return: a tuple of 3:
+        - tuple of crow_indices, col_indices representation 
+            of CSR format.
+        - block dense mask
+        - all token dense mask (be aware that it can be 
+            OOM if it is too big) if `return_dense==True`, 
+            otherwise, None
+    """
+    with torch.no_grad():
+        num_blocks = triton.cdiv(max_seqlen, block_size)
+        q_pos = torch.arange(num_blocks)[:, None]
+        k_pos = torch.arange(num_blocks)[None]
+        mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
+        block_mask_dense = (((q_pos >= k_pos)
+                             & ((q_pos - k_pos < local_blocks)
+                                | mask_vert_strided)).to(device).to(dtype))
+        num_blocks_q = triton.cdiv(q_len, block_size)
+        block_mask_dense_output = (dense_to_crow_col(
+            block_mask_dense[-num_blocks_q:].contiguous()))
+    if return_dense:
+        mask_dense = torch.kron(
+            block_mask_dense,
+            block_mask_dense.new_ones((block_size, block_size)),
+        )
+        causal_mask = torch.tril(torch.ones(
+            max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
+        mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
+        return (
+            block_mask_dense_output,
+            block_mask_dense,
+            mask_dense,
+        )
+    else:
+        return (
+            block_mask_dense_output,
+            block_mask_dense,
+            None,
+        )
+
+
+def binary_mask_to_bias(mask_dense: torch.Tensor):
+    mask_dense = 1 - mask_dense
+    mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
+    return mask_dense
+
+
+def get_head_sliding_step(n_heads: int,
+                          vert_stride: int,
+                          homo_head: bool = False):
+    if homo_head:
+        return 0
+    return max(1, int(vert_stride / n_heads))
+
+
+@lru_cache
+def get_sparse_attn_mask(
+    n_heads: int,
+    q_len: int,
+    max_seqlen: int,
+    dtype: torch.dtype,
+    device: torch.device,
+    block_size: int = 64,
+    local_blocks: int = 4,
+    vert_stride: int = 4,
+    homo_head: bool = True,
+    return_dense: bool = False,
+    dense_mask_type: str = "binary",
+):
+    """
+    :param dense_mask_type: "binary" (0 for skip token, 1 for others)
+        or "bias" (-inf for skip token, 0 or others)
+    :return: a tuple of 3:
+        - tuple of crow_indices, col_indices representation 
+            of CSR format.
+        - block dense mask
+        - all token dense mask (be aware that it can be OOM if it 
+            is too big) if `return_dense==True`, otherwise, None
+    """
+    assert dense_mask_type in ("binary", "bias")
+    if homo_head:
+        with torch.no_grad():
+            (crow, col), block_mask_dense, mask_dense = (
+                _get_sparse_attn_mask_homo_head(
+                    q_len,
+                    max_seqlen,
+                    dtype,
+                    device,
+                    block_size,
+                    local_blocks,
+                    vert_stride,
+                    return_dense,
+                ))
+            crow = crow[None].expand(n_heads, crow.shape[0])
+            col = col[None].expand(n_heads, col.shape[0])
+            if return_dense:
+                mask_dense = mask_dense[None].expand(n_heads,
+                                                     *mask_dense.shape)
+                if dense_mask_type == "bias":
+                    mask_dense = binary_mask_to_bias(mask_dense)
+            return (crow, col), block_mask_dense, mask_dense
+
+    with torch.no_grad():
+        num_blocks = triton.cdiv(max_seqlen, block_size)
+        q_pos = torch.arange(num_blocks)[None, :, None]
+        k_pos = torch.arange(num_blocks)[None, None]
+        head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
+        mask_vert_strided = [
+            (torch.arange(num_blocks) + h * head_sliding_step + 1) %
+            vert_stride == 0 for h in range(n_heads)
+        ]
+        mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
+        block_mask_dense = (((q_pos >= k_pos)
+                             & ((q_pos - k_pos < local_blocks)
+                                | mask_vert_strided)).to(device).to(dtype))
+        num_blocks_q = triton.cdiv(q_len, block_size)
+        block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
+    if return_dense:
+        mask_dense = torch.kron(
+            block_mask_dense,
+            block_mask_dense.new_ones((block_size, block_size)),
+        )
+        causal_mask = torch.tril(torch.ones(
+            max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
+        mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
+        if dense_mask_type == "bias":
+            mask_dense = binary_mask_to_bias(mask_dense)
+
+        return (
+            dense_to_crow_col(block_mask_dense_output),
+            block_mask_dense,
+            mask_dense,
+        )
+    else:
+        return (
+            dense_to_crow_col(block_mask_dense_output),
+            block_mask_dense,
+            None,
+        )

+ 23 - 0
aphrodite/attention/ops/paged_attn.py

@@ -92,7 +92,20 @@ class PagedAttention:
         scale: float,
         alibi_slopes: Optional[torch.Tensor],
         kv_scale: float,
+        tp_rank: int = 0,
+        blocksparse_local_blocks: int = 0,
+        blocksparse_vert_stride: int = 0,
+        blocksparse_block_size: int = 64,
+        blocksparse_head_sliding_step: int = 0,
     ) -> torch.Tensor:
+        if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
+            # use blocksparse paged attention
+            block_size = value_cache.size(-1)
+            assert (blocksparse_block_size > 0 and
+                    blocksparse_block_size % block_size == 0), \
+                (f"{blocksparse_block_size=} needs to be a multiple of"
+                 f"{block_size=} used in block_tables.")
+
         output = torch.empty_like(query)
 
         block_size = value_cache.shape[3]
@@ -124,6 +137,11 @@ class PagedAttention:
                 alibi_slopes,
                 kv_cache_dtype,
                 kv_scale,
+                tp_rank,
+                blocksparse_local_blocks,
+                blocksparse_vert_stride,
+                blocksparse_block_size,
+                blocksparse_head_sliding_step,
             )
         else:
             # Run PagedAttention V2.
@@ -156,6 +174,11 @@ class PagedAttention:
                 alibi_slopes,
                 kv_cache_dtype,
                 kv_scale,
+                tp_rank,
+                blocksparse_local_blocks,
+                blocksparse_vert_stride,
+                blocksparse_block_size,
+                blocksparse_head_sliding_step,
             )
         return output
 

+ 10 - 3
aphrodite/attention/selector.py

@@ -5,6 +5,7 @@ from typing import Optional, Type
 
 import torch
 from loguru import logger
+
 from aphrodite.attention.backends.abstract import AttentionBackend
 from aphrodite.common.utils import is_cpu, is_hip
 
@@ -28,7 +29,14 @@ def get_attn_backend(
     dtype: torch.dtype,
     kv_cache_dtype: Optional[str],
     block_size: int,
+    is_blocksparse: bool = False,
 ) -> Type[AttentionBackend]:
+
+    if is_blocksparse:
+        logger.info("Using BlocksparseFlashAttention backend.")
+        from aphrodite.attention.backends.blocksparse_attn import \
+            BlocksparseFlashAttentionBackend
+        return BlocksparseFlashAttentionBackend
     """Determine which attention backend to use and only import
     the selected backend module.
     """
@@ -38,7 +46,6 @@ def get_attn_backend(
     if backend == _Backend.FLASH_ATTN:
         from aphrodite.attention.backends.flash_attn import \
             FlashAttentionBackend  # noqa: F401
-        logger.info("Using FlashAttention backend.")
         return FlashAttentionBackend
     if backend == _Backend.XFORMERS:
         logger.info("Using XFormers backend.")
@@ -136,8 +143,8 @@ def which_attn_to_use(
         try:
             import vllm_flash_attn  # noqa: F401
 
-            from aphrodite.attention.backends.flash_attn import (  # noqa: F401
-                FlashAttentionBackend)
+            from aphrodite.attention.backends.flash_attn import \
+                FlashAttentionBackend  # noqa: F401
 
             supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
             if head_size not in supported_sizes:

+ 1 - 0
aphrodite/endpoints/openai/serving_engine.py

@@ -122,6 +122,7 @@ class OpenAIServing:
                 token_logprob = step_top_logprobs[token_id].logprob
                 token = step_top_logprobs[token_id].decoded_token
                 logprobs.tokens.append(token)
+                token_logprob = max(token_logprob, -9999.0)
                 logprobs.token_logprobs.append(token_logprob)
 
                 if num_output_top_logprobs:

+ 1 - 0
aphrodite/modeling/models/__init__.py

@@ -54,6 +54,7 @@ _GENERATION_MODELS = {
     "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
     "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
     "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
+    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
 }
 
 _EMBEDDING_MODELS = {

+ 446 - 0
aphrodite/modeling/models/phi3_small.py

@@ -0,0 +1,446 @@
+import math
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig, LoRAConfig
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
+
+
+def load_column_parallel_weight(param: torch.nn.Parameter,
+                                loaded_weight: torch.Tensor):
+    tp = get_tensor_model_parallel_world_size()
+    rk = get_tensor_model_parallel_rank()
+    assert param.size(0) * tp == loaded_weight.size(0)
+    s = rk * param.size(0)
+    e = (rk + 1) * param.size(0)
+    loaded_weight = loaded_weight[s:e]
+    assert param.shape == loaded_weight.shape
+    param.data.copy_(loaded_weight)
+
+
+class HeadMajorQKVParallelLinear(QKVParallelLinear):
+
+    def weight_loader(self, param: torch.nn.Parameter,
+                      loaded_weight: torch.Tensor):
+        return load_column_parallel_weight(param, loaded_weight)
+
+
+class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
+
+    def weight_loader(self, param: torch.nn.Parameter,
+                      loaded_weight: torch.Tensor):
+        return load_column_parallel_weight(param, loaded_weight)
+
+
+@torch.jit.script
+def quick_gelu(x):
+    return x * torch.sigmoid(1.702 * x)
+
+
+@torch.jit.script
+def gegelu(input, limit: Optional[float] = None):
+    a_gelu, a_linear = input[..., ::2], input[..., 1::2]
+    if limit is not None:
+        a_gelu = torch.where(torch.isinf(a_gelu), a_gelu,
+                             a_gelu.clamp(min=None, max=limit))
+        a_linear = torch.where(
+            torch.isinf(a_linear),
+            a_linear,
+            a_linear.clamp(min=-limit, max=limit),
+        )
+    out_gelu = quick_gelu(a_gelu)
+    return out_gelu * (a_linear + 1)
+
+
+class Phi3SmallMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        assert (self.config.hidden_act == "gegelu"
+                ), "Only `gegelu` is supported for the 4.7 series of models .."
+        self.hidden_size = config.hidden_size
+        self.gegelu_limit = config.gegelu_limit
+        self.intermediate_size = config.intermediate_size
+
+        self.up_proj = HeadMajorColumnParallelLinear(
+            self.hidden_size,
+            2 * [self.intermediate_size],
+            bias=True,
+            quant_config=quant_config,
+        )
+        self.down_proj = RowParallelLinear(
+            self.intermediate_size,
+            self.hidden_size,
+            bias=True,
+            quant_config=quant_config,
+        )
+
+    def forward(self, x):
+        gate_up, _ = self.up_proj(x)
+        x = gegelu(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class Phi3SmallSelfAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        layer_idx: int,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.layer_idx = layer_idx
+        self.config = config
+        self.sparse_block_size = config.blocksparse_block_size
+        self.homo_heads = config.blocksparse_homo_head_pattern
+        self.local_blocks = config.blocksparse_num_local_blocks
+        self.vert_stride = config.blocksparse_vert_stride
+
+        assert (config.blocksparse_block_size ==
+                config.blocksparse_triton_kernel_block_size)
+
+        self.hidden_size = config.hidden_size
+        # Number of Query Heads
+        self.num_heads = config.num_attention_heads
+
+        self.head_dim = self.hidden_size // self.num_heads
+        self.tp_size = get_tensor_model_parallel_world_size()
+        # Number of total Key Value Heads before tensor parallel
+        self.num_key_value_heads = config.num_key_value_heads
+        self.num_q_per_kv = self.num_heads // self.num_key_value_heads
+        if self.tp_size > 1:
+            assert self.num_key_value_heads % self.tp_size == 0
+        self.num_kv_heads_per_partion = max(
+            1, self.num_key_value_heads // self.tp_size)
+        self.num_heads_per_partition = self.num_heads // self.tp_size
+
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_embedding_base = config.rope_embedding_base
+        self.rope_position_scale = config.rope_position_scale
+        self.is_causal = True
+
+        norm_factor = None
+        if config.mup_use_scaling:
+            norm_factor = self.head_dim / config.mup_attn_multiplier
+        else:
+            norm_factor = math.sqrt(self.head_dim)
+        self.scale = 1 / norm_factor
+
+        self.query_key_value = HeadMajorQKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.num_heads,
+            self.num_key_value_heads,
+            bias=True,
+            quant_config=quant_config,
+        )
+
+        self.dense = RowParallelLinear(self.hidden_size,
+                                       self.hidden_size,
+                                       bias=True,
+                                       quant_config=quant_config)
+
+        if getattr(self.config, "rope_scaling", None) is not None:
+            rope_scaling = self.config.rope_scaling
+            for key in rope_scaling:
+                if isinstance(rope_scaling[key], list):
+                    rope_scaling[key] = tuple(rope_scaling[key])
+
+            if "factor" not in rope_scaling:
+                rope_scaling["factor"] = self.rope_position_scale
+        else:
+            rope_scaling = {
+                "type": "linear",
+                "factor": self.rope_position_scale,
+            }
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=self.max_position_embeddings,
+            base=self.rope_embedding_base,
+            rope_scaling=rope_scaling,
+        )
+
+        # blocksparse params
+        self.blocksparse_block_size = config.blocksparse_block_size
+        self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
+        self.blocksparse_vert_stride = config.blocksparse_vert_stride
+
+        use_dense_attn = (getattr(self.config,
+                                  "dense_attention_every_n_layers", None)
+                          and (self.layer_idx + 1) %
+                          self.config.dense_attention_every_n_layers == 0)
+
+        bs_params = None
+        if not use_dense_attn:
+            bs_params = {
+                'max_seqlen': self.max_position_embeddings,
+                'num_heads': self.num_heads_per_partition,
+                "num_kv_heads": self.num_kv_heads_per_partion,
+                "block_size": self.sparse_block_size,
+                "local_blocks": self.local_blocks,
+                "vert_stride": self.vert_stride,
+                "homo_head": self.homo_heads
+            }
+
+        self.attn = Attention(
+            self.num_heads_per_partition,
+            self.head_dim,
+            self.scale,
+            num_kv_heads=self.num_kv_heads_per_partion,
+            cache_config=cache_config,
+            quant_config=quant_config,
+            blocksparse_params=bs_params,
+        )
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+               Optional[Tuple[torch.Tensor]]]:
+        qkv, _ = self.query_key_value(hidden_states)
+
+        qkv = qkv.view(qkv.shape[:-1] +
+                       (-1, (self.num_q_per_kv + 2), self.head_dim))
+        q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
+
+        # NOTE: this is required by RotaryEmbed, which indeed does not have to
+        # TODO: allow 3D QK for rotary forward
+        q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
+        k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
+        v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
+
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
+        output, _ = self.dense(attn_output)
+
+        return output
+
+
+class Phi3SmallDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        layer_idx: int,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.self_attn = Phi3SmallSelfAttention(config,
+                                                layer_idx,
+                                                cache_config=cache_config,
+                                                quant_config=quant_config)
+        self.mlp = Phi3SmallMLP(config, quant_config)
+
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.layer_norm_epsilon)
+        self.post_attention_layernorm = nn.LayerNorm(
+            config.hidden_size, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+        return hidden_states
+
+
+class Phi3SmallModel(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size)
+        self.mup_embedding_multiplier = config.mup_embedding_multiplier
+        self.layers = nn.ModuleList([
+            Phi3SmallDecoderLayer(config, layer_idx, cache_config,
+                                  quant_config)
+            for layer_idx in range(config.num_hidden_layers)
+        ])
+
+        self.final_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.layer_norm_epsilon)
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: Optional[torch.LongTensor],
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata = None,
+    ):
+        hidden_states = self.embed_tokens(input_ids)
+        if (self.mup_embedding_multiplier is not None
+                and self.mup_embedding_multiplier > 0.0):
+            hidden_states = hidden_states * self.mup_embedding_multiplier
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                attn_metadata,
+            )
+        hidden_states = self.final_layernorm(hidden_states)
+        return hidden_states
+
+
+class Phi3SmallForCausalLM(nn.Module):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(
+        self,
+        config,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.quant_config = quant_config
+        self.model = Phi3SmallModel(config, cache_config, quant_config)
+        self.vocab_size = config.vocab_size
+        self.mup_width_multiplier = config.mup_width_multiplier
+        self.lm_head = ParallelLMHead(
+            self.vocab_size,
+            config.hidden_size,
+            org_num_embeddings=config.vocab_size,
+            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+        )
+        self.logits_processor = LogitsProcessor(config.vocab_size)
+        self.sampler = Sampler()
+
+        # tokens in tiktoken but not used
+        if hasattr(config, 'dummy_token_indices'):
+            device = self.lm_head.weight.device
+            self.register_buffer('dummy_token_indices',
+                                 torch.LongTensor(
+                                     config.dummy_token_indices).to(device),
+                                 persistent=False)
+        else:
+            self.dummy_token_indices = None
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, value):
+        self.lm_head = value
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+                                       sampling_metadata)
+        if self.dummy_token_indices is not None and logits is not None:
+            logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
+        return logits
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: Optional[torch.LongTensor],
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        output_hidden_states = self.model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata,
+        )
+        output_hidden_states = output_hidden_states
+        return output_hidden_states
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+
+        next_tokens = self.sampler(logits / self.mup_width_multiplier,
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if name.endswith(".bias") and name not in params_dict:
+                continue
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
+        self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)

+ 170 - 59
kernels/attention/attention_kernels.cu

@@ -86,6 +86,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
 // Grid: (num_heads, num_seqs, max_num_partitions).
 template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
           int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
+          bool IS_BLOCK_SPARSE,
           int PARTITION_SIZE = 0>  // Zero means no partitioning.
 __device__ void paged_attention_kernel(
     float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
@@ -105,7 +106,9 @@ __device__ void paged_attention_kernel(
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
-    const float kv_scale) {
+    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   const int seq_idx = blockIdx.y;
   const int partition_idx = blockIdx.z;
   const int max_num_partitions = gridDim.z;
@@ -172,8 +175,8 @@ __device__ void paged_attention_kernel(
   // Each thread in a thread group has a different part of the query.
   // For example, if the the thread group size is 4, then the first thread in
   // the group has 0, 4, 8, ... th vectors of the query, and the second thread
-  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because q is
-  // split from a qkv tensor, it may not be contiguous.
+  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
+  // q is split from a qkv tensor, it may not be contiguous.
   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
   __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
 #pragma unroll
@@ -183,8 +186,8 @@ __device__ void paged_attention_kernel(
     q_vecs[thread_group_offset][i] =
         *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
   }
-  __syncthreads();  // TODO: possible speedup if this is replaced with a memory
-                    // wall right before we use q_vecs
+  __syncthreads();  // TODO: possible speedup if this is replaced with a
+                    // memory wall right before we use q_vecs
 
   // Memory planning.
   extern __shared__ char shared_mem[];
@@ -203,11 +206,55 @@ __device__ void paged_attention_kernel(
   // Each thread group in a warp fetches a key from the block, and computes
   // dot product with the query.
   const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+
+  // blocksparse specific vars
+  int bs_block_offset;
+  int q_bs_block_id;
+  if constexpr (IS_BLOCK_SPARSE) {
+    // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
+    // blocksparse_block_size);
+    q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
+    if (blocksparse_head_sliding_step >= 0)
+      // sliding on q heads
+      bs_block_offset =
+          (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
+    else
+      // sliding on kv heads
+      bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
+                            (-blocksparse_head_sliding_step) +
+                        1;
+  }
+
   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
-    // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by
-    // large numbers (e.g., kv_block_stride).
+    // NOTE: The block number is stored in int32. However, we cast it to
+    // int64 because int32 can lead to overflow when this variable is multiplied
+    // by large numbers (e.g., kv_block_stride).
+    // For blocksparse attention: skip computation on blocks that are not
+    // attended
+    if constexpr (IS_BLOCK_SPARSE) {
+      const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
+      const bool is_remote =
+          ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
+      const bool is_local =
+          (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
+      if (!is_remote && !is_local) {
+        for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+          const int physical_block_offset =
+              (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+          const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+
+          if (thread_group_offset == 0) {
+            // NOTE: assign very large number to skipped tokens to
+            // avoid contribution to the sumexp softmax normalizer. This will
+            // not be used at computing sum(softmax*v) as the blocks will be
+            // skipped.
+            logits[token_idx - start_token_idx] = -FLT_MAX;
+          }
+        }
+        continue;
+      }
+    }
     const int64_t physical_block_number =
         static_cast<int64_t>(block_table[block_idx]);
 
@@ -333,9 +380,18 @@ __device__ void paged_attention_kernel(
   zero(zero_value);
   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
-    // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by
-    // large numbers (e.g., kv_block_stride).
+    // NOTE: The block number is stored in int32. However, we cast it to
+    // int64 because int32 can lead to overflow when this variable is multiplied
+    // by large numbers (e.g., kv_block_stride).
+    // For blocksparse attention: skip computation on blocks that are not
+    // attended
+    if constexpr (IS_BLOCK_SPARSE) {
+      int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
+      if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
+          !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
+        continue;
+      }
+    }
     const int64_t physical_block_number =
         static_cast<int64_t>(block_table[block_idx]);
     const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
@@ -363,9 +419,9 @@ __device__ void paged_attention_kernel(
                                                                     kv_scale);
         }
         if (block_idx == num_seq_blocks - 1) {
-          // NOTE: When v_vec contains the tokens that are out of the context,
-          // we should explicitly zero out the values since they may contain
-          // NaNs.
+          // NOTE: When v_vec contains the tokens that are out of the
+          // context, we should explicitly zero out the values since they may
+          // contain NaNs.
           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
 #pragma unroll
           for (int j = 0; j < V_VEC_SIZE; j++) {
@@ -388,8 +444,8 @@ __device__ void paged_attention_kernel(
     accs[i] = acc;
   }
 
-  // NOTE: A barrier is required because the shared memory space for logits
-  // is reused for the output.
+  // NOTE: A barrier is required because the shared memory space for
+  // logits is reused for the output.
   __syncthreads();
 
   // Perform reduction across warps.
@@ -441,8 +497,8 @@ __device__ void paged_attention_kernel(
 
 // Grid: (num_heads, num_seqs, 1).
 template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
-          int NUM_THREADS,
-          aphrodite::Fp8KVCacheDataType KV_DTYPE>
+          int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
+          bool IS_BLOCK_SPARSE>
 __global__ void paged_attention_v1_kernel(
     scalar_t* __restrict__ out,           // [num_seqs, num_heads, head_size]
     const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
@@ -457,18 +513,23 @@ __global__ void paged_attention_v1_kernel(
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
-    const float kv_scale) {
+    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
-                         KV_DTYPE>(
+                         KV_DTYPE, IS_BLOCK_SPARSE>(
       /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
       v_cache, num_kv_heads, scale, block_tables, seq_lens,
       max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
-      kv_head_stride, kv_scale);
+      kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
+      blocksparse_vert_stride, blocksparse_block_size,
+      blocksparse_head_sliding_step);
 }
 
 // Grid: (num_heads, num_seqs, max_num_partitions).
 template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
           int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
+          bool IS_BLOCK_SPARSE,
           int PARTITION_SIZE>
 __global__ void paged_attention_v2_kernel(
     float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
@@ -488,12 +549,16 @@ __global__ void paged_attention_v2_kernel(
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
-    const float kv_scale) {
+    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
-                         KV_DTYPE, PARTITION_SIZE>(
+                         KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
       exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
       block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
-      kv_block_stride, kv_head_stride, kv_scale);
+      kv_block_stride, kv_head_stride, kv_scale, tp_rank,
+      blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
+      blocksparse_head_sliding_step);
 }
 
 // Grid: (num_heads, num_seqs).
@@ -605,27 +670,34 @@ __global__ void paged_attention_v2_reduce_kernel(
 
 }  // namespace aphrodite
 
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                \
-  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                \
-      ((void*)aphrodite::paged_attention_v1_kernel<                         \
-          T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>),       \
-      shared_mem_size);                                                     \
-  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,   \
-                                       NUM_THREADS, KV_DTYPE>               \
-      <<<grid, block, shared_mem_size, stream>>>(                           \
-          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
-          scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,    \
-          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,      \
-          kv_scale);
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                   \
+  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                   \
+      ((void*)aphrodite::paged_attention_v1_kernel<                            \
+          T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE,            \
+          IS_BLOCK_SPARSE>),                                                   \
+      shared_mem_size);                                                        \
+  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,      \
+                                       NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
+      <<<grid, block, shared_mem_size, stream>>>(                              \
+          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads,    \
+          scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,       \
+          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,         \
+          kv_scale, tp_rank, blocksparse_local_blocks,                         \
+          blocksparse_vert_stride, blocksparse_block_size,                     \
+          blocksparse_head_sliding_step);
 
 // TODO: Tune NUM_THREADS.
 template <typename T, typename CACHE_T, int BLOCK_SIZE,
-          aphrodite::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
+          aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
+          int NUM_THREADS = 128>
 void paged_attention_v1_launcher(
     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
-    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
+    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
+    const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int head_size = query.size(2);
@@ -692,23 +764,36 @@ void paged_attention_v1_launcher(
   }
 }
 
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE)                   \
-  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>(             \
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE)  \
+  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,              \
+                              IS_BLOCK_SPARSE>(                              \
       out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
-      seq_lens, max_seq_len, alibi_slopes, kv_scale);
+      seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank,                \
+      blocksparse_local_blocks, blocksparse_vert_stride,                     \
+      blocksparse_block_size, blocksparse_head_sliding_step);
+
+#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+  switch (is_block_sparse) {                                               \
+    case true:                                                             \
+      CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true);     \
+      break;                                                               \
+    case false:                                                            \
+      CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false);    \
+      break;                                                               \
+  }
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
 #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
   switch (block_size) {                                           \
     case 8:                                                       \
-      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE);                  \
+      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE);         \
       break;                                                      \
     case 16:                                                      \
-      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE);                 \
+      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE);        \
       break;                                                      \
     case 32:                                                      \
-      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE);                 \
+      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE);        \
       break;                                                      \
     default:                                                      \
       TORCH_CHECK(false, "Unsupported block size: ", block_size); \
@@ -728,18 +813,26 @@ void paged_attention_v1(
     torch::Tensor& seq_lens,      // [num_seqs]
     int block_size, int max_seq_len,
     const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale){
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
+  const bool is_block_sparse = (blocksparse_vert_stride > 1);
+
+  DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
+                             CALL_V1_LAUNCHER_BLOCK_SIZE)
+}
 
-    DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
-                               CALL_V1_LAUNCHER_BLOCK_SIZE)}
 #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                   \
   aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,      \
-                                       NUM_THREADS, KV_DTYPE, PARTITION_SIZE>  \
+                                       NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
+                                       PARTITION_SIZE>                         \
       <<<grid, block, shared_mem_size, stream>>>(                              \
           exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
           value_cache_ptr, num_kv_heads, scale, block_tables_ptr,              \
           seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride,    \
-          kv_block_stride, kv_head_stride, kv_scale);                          \
+          kv_block_stride, kv_head_stride, kv_scale, tp_rank,                  \
+          blocksparse_local_blocks, blocksparse_vert_stride,                   \
+          blocksparse_block_size, blocksparse_head_sliding_step);              \
   aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS,       \
                                               PARTITION_SIZE>                  \
       <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                \
@@ -747,14 +840,17 @@ void paged_attention_v1(
           max_num_partitions);
 
 template <typename T, typename CACHE_T, int BLOCK_SIZE,
-          aphrodite::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
-          int PARTITION_SIZE = 512>
+          aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
+          int NUM_THREADS = 128, int PARTITION_SIZE = 512>
 void paged_attention_v2_launcher(
     torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
-    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
+    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
+    const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int head_size = query.size(2);
@@ -825,24 +921,36 @@ void paged_attention_v2_launcher(
   }
 }
 
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE)                    \
-  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>(              \
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE)   \
+  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,               \
+                              IS_BLOCK_SPARSE>(                               \
       out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,      \
       num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
-      kv_scale);
+      kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,   \
+      blocksparse_block_size, blocksparse_head_sliding_step);
+
+#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+  switch (is_block_sparse) {                                               \
+    case true:                                                             \
+      CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true);     \
+      break;                                                               \
+    case false:                                                            \
+      CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false);    \
+      break;                                                               \
+  }
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
 #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
   switch (block_size) {                                           \
     case 8:                                                       \
-      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE);                  \
+      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE);         \
       break;                                                      \
     case 16:                                                      \
-      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE);                 \
+      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE);        \
       break;                                                      \
     case 32:                                                      \
-      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE);                 \
+      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE);        \
       break;                                                      \
     default:                                                      \
       TORCH_CHECK(false, "Unsupported block size: ", block_size); \
@@ -866,7 +974,10 @@ void paged_attention_v2(
     torch::Tensor& seq_lens,      // [num_seqs]
     int block_size, int max_seq_len,
     const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale) {
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
+  const bool is_block_sparse = (blocksparse_vert_stride > 1);
   DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
                              CALL_V2_LAUNCHER_BLOCK_SIZE)
 }

+ 21 - 16
kernels/cpu/attention.cpp

@@ -415,14 +415,17 @@ void paged_attention_v1_impl_launcher(
   }
 }  // namespace
 
-void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
-                        torch::Tensor& key_cache, torch::Tensor& value_cache,
-                        int num_kv_heads, float scale,
-                        torch::Tensor& block_tables, torch::Tensor& seq_lens,
-                        int block_size, int max_seq_len,
-                        const c10::optional<torch::Tensor>& alibi_slopes,
-                        const std::string& kv_cache_dtype, float kv_scale) {
+void paged_attention_v1(
+    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
+    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
   TORCH_CHECK(kv_scale == 1.0f);
+  TORCH_CHECK(blocksparse_vert_stride <= 1,
+              "CPU backend does not support blocksparse attention yet.");
   APHRODITE_DISPATCH_FLOATING_TYPES(
       query.scalar_type(), "paged_attention_v1_impl", [&] {
         CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
@@ -726,16 +729,18 @@ void paged_attention_v2_impl_launcher(
   }
 }  // namespace
 
-void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
-                        torch::Tensor& max_logits, torch::Tensor& tmp_out,
-                        torch::Tensor& query, torch::Tensor& key_cache,
-                        torch::Tensor& value_cache, int num_kv_heads,
-                        float scale, torch::Tensor& block_tables,
-                        torch::Tensor& seq_lens, int block_size,
-                        int max_seq_len,
-                        const c10::optional<torch::Tensor>& alibi_slopes,
-                        const std::string& kv_cache_dtype, float kv_scale) {
+void paged_attention_v2(
+    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
+    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
+    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
   TORCH_CHECK(kv_scale == 1.0f);
+  TORCH_CHECK(blocksparse_vert_stride <= 1,
+              "CPU backend does not support blocksparse attention yet.");
   APHRODITE_DISPATCH_FLOATING_TYPES(
       query.scalar_type(), "paged_attention_v2_impl", [&] {
         CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)

+ 17 - 16
kernels/ops.h

@@ -2,23 +2,24 @@
 
 #include <torch/extension.h>
 
-void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
-                        torch::Tensor& key_cache, torch::Tensor& value_cache,
-                        int num_kv_heads, float scale,
-                        torch::Tensor& block_tables, torch::Tensor& seq_lens,
-                        int block_size, int max_seq_len,
-                        const c10::optional<torch::Tensor>& alibi_slopes,
-                        const std::string& kv_cache_dtype, float kv_scale);
+void paged_attention_v1(
+    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
+    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step);
 
-void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
-                        torch::Tensor& max_logits, torch::Tensor& tmp_out,
-                        torch::Tensor& query, torch::Tensor& key_cache,
-                        torch::Tensor& value_cache, int num_kv_heads,
-                        float scale, torch::Tensor& block_tables,
-                        torch::Tensor& seq_lens, int block_size,
-                        int max_seq_len,
-                        const c10::optional<torch::Tensor>& alibi_slopes,
-                        const std::string& kv_cache_dtype, float kv_scale);
+void paged_attention_v2(
+    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
+    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
+    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step);
 
 void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
               float epsilon);