Преглед на файлове

add phi3_small support with blocksparse attention

AlpinDale преди 7 месеца
родител
ревизия
696f2cd59c

+ 1 - 0
aphrodite/attention/backends/abstract.py

@@ -111,6 +111,7 @@ class AttentionImpl(ABC, Generic[T]):
         alibi_slopes: Optional[List[float]] = None,
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
         sliding_window: Optional[int] = None,
         kv_cache_dtype: str = "auto",
         kv_cache_dtype: str = "auto",
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
     ) -> None:
         raise NotImplementedError
         raise NotImplementedError
 
 

+ 405 - 0
aphrodite/attention/backends/blocksparse_attn.py

@@ -0,0 +1,405 @@
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple, Type
+
+import torch
+
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata)
+from aphrodite.attention.ops.blocksparse_attention.interface import (
+    LocalStridedBlockSparseAttn, get_head_sliding_step)
+from aphrodite.attention.ops.paged_attn import PagedAttention
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
+
+
+@dataclass
+class BlocksparseParams:
+    max_seqlen: int
+
+    # Num q heads per tensor-parallel rank/partition
+    num_heads: int  # per TP partition
+    # Num kv heads per tensor-parallel rank/partition
+    num_kv_heads: int
+
+    # block size used for blocksparse attention.
+    # This is the block_size used in `local_blocks`, `vert_stride`.
+    block_size: int
+
+    # Number of blocks for local attention, i.e., number of
+    # local attended tokens / `sparse_block_size`
+    local_blocks: int
+
+    # Attend to one block per every `vert_stride` blocks.
+    # Controlling the sparsity
+    vert_stride: int
+    """
+    If to use the same vertical stride offset for all heads, 
+    i.e., attend to the same block of tokens on all heads.
+    By default, it is False, i.e., attention on the non-local 
+    blocks depends on the `head_idx`, that is on
+    blocks satisfying 
+    `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0`
+    where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`,
+            `block_idx = position_id // sparse_block_size`.
+    See `..ops.blocksparse_attention.utils:get_sparse_attn_mask`
+    for more detail.
+    """
+    homo_head: bool = False
+
+    # If within a group, the kv offsets that each q attends is the same or no.
+    homo_head_group: bool = False
+
+    # Decided by homo_head and homo_head group
+    head_sliding_step: int = field(init=False)
+
+    # range of q heads to for a TP rank
+    active_head_range: Tuple = field(init=False)
+
+    def __post_init__(self):
+        assert self.block_size > 0
+        assert self.local_blocks >= 0
+        assert self.vert_stride >= 1
+        assert self.num_heads % self.num_kv_heads == 0
+
+        tp_size = get_tensor_model_parallel_world_size()
+        tp_rank = get_tensor_model_parallel_rank()
+        total_heads = tp_size * self.num_heads
+        total_kv_heads = tp_size * self.num_kv_heads
+
+        if self.homo_head:
+            self.head_sliding_step = 0
+        elif self.homo_head_group:
+            head_sliding_step = get_head_sliding_step(total_kv_heads,
+                                                      self.vert_stride)
+            # negative indicates sliding along kv heads, i.e., homo q group
+            self.head_sliding_step = -head_sliding_step
+        else:
+            self.head_sliding_step = get_head_sliding_step(
+                total_heads, self.vert_stride)
+
+        self.active_head_range = (
+            tp_rank * self.num_heads,
+            (tp_rank + 1) * self.num_heads,
+        )
+
+
+class BlocksparseFlashAttentionBackend(AttentionBackend):
+
+    @staticmethod
+    def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
+        return BlocksparseFlashAttentionImpl
+
+    @staticmethod
+    def make_metadata(*args, **kwargs) -> "BlocksparseFlashAttentionMetadata":
+        return BlocksparseFlashAttentionMetadata(*args, **kwargs)
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
+                                                 num_kv_heads, head_size)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: Dict[int, int],
+    ) -> None:
+        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: Dict[int, List[int]],
+    ) -> None:
+        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+
+
+@dataclass
+class BlocksparseFlashAttentionMetadata(AttentionMetadata):
+    """A copy of Metadata for FlashAttentionBackend,
+    to avoid having to install flash_attn.
+    NOTE: Any python object stored here is not updated when it is
+    cuda-graph replayed. If you have values that need to be changed
+    dynamically, it should be stored in tensor. The tensor has to be
+    updated from `CUDAGraphRunner.forward` API.
+    """
+    # (batch_size,). The sequence length per sequence. Sequence length means
+    # the computed tokens + new tokens None if it is a decoding.
+    seq_lens: Optional[List[int]]
+    # seq_lens stored as a tensor.
+    seq_lens_tensor: Optional[torch.Tensor]
+
+    # NOTE(sang): Definition of context_len, query_len, and seq_len.
+    # |---------- N-1 iteration --------|
+    # |---------------- N iteration ---------------------|
+    # |- tokenA -|......................|-- newTokens ---|
+    # |---------- context_len ----------|
+    # |-------------------- seq_len ----------------------|
+    #                                   |-- query_len ---|
+
+    # Maximum query length in the batch. None for decoding.
+    max_query_len: Optional[int]
+    # Maximum sequence length among prefill batch. 0 if there are decoding
+    # requests only.
+    max_prefill_seq_len: int
+    # Maximum sequence length among decode batch. 0 if there are prefill
+    # requests only.
+    max_decode_seq_len: int
+    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+    # the batch, used to index into subquery. E.g., if the subquery length
+    # is [4, 6], it is [0, 4, 10].
+    query_start_loc: Optional[torch.Tensor]
+    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+    # the batch, used to index into sequence. E.g., if the sequence length is
+    # [4, 6], it is [0, 4, 10].
+    seq_start_loc: Optional[torch.Tensor]
+    # (batch_size,) A tensor of context lengths (tokens that are computed
+    # so far).
+    context_lens_tensor: Optional[torch.Tensor]
+
+    # (batch_size, max_blocks_per_seq).
+    # Block addresses per sequence. (Seq id -> list of physical block)
+    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
+    # in the kv cache. Each block can contain up to block_size tokens.
+    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
+    # captured.
+    block_tables: Optional[torch.Tensor]
+
+    # Whether or not if cuda graph is enabled.
+    # Cuda-graph is currently enabled for decoding only.
+    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
+    use_cuda_graph: bool
+
+    _cached_prefill_metadata: Optional[
+        "BlocksparseFlashAttentionMetadata"] = None
+    _cached_decode_metadata: Optional[
+        "BlocksparseFlashAttentionMetadata"] = None
+
+    @property
+    def prefill_metadata(
+            self) -> Optional["BlocksparseFlashAttentionMetadata"]:
+        if self.num_prefills == 0:
+            return None
+
+        if self._cached_prefill_metadata is not None:
+            return self._cached_prefill_metadata
+
+        assert self.seq_lens is not None
+        assert self.seq_lens_tensor is not None
+        assert self.query_start_loc is not None
+        assert self.context_lens_tensor is not None
+        assert self.block_tables is not None
+        assert self.seq_start_loc is not None
+
+        self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata(
+            num_prefills=self.num_prefills,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=0,
+            slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
+            seq_lens=self.seq_lens[:self.num_prefills],
+            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
+            max_query_len=self.max_query_len,
+            max_prefill_seq_len=self.max_prefill_seq_len,
+            max_decode_seq_len=0,
+            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
+            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
+            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
+            block_tables=self.block_tables[:self.num_prefills],
+            use_cuda_graph=False,
+        )
+        return self._cached_prefill_metadata
+
+    @property
+    def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
+        if self.num_decode_tokens == 0:
+            return None
+
+        if self._cached_decode_metadata is not None:
+            return self._cached_decode_metadata
+        assert self.block_tables is not None
+        assert self.seq_lens_tensor is not None
+
+        self._cached_decode_metadata = BlocksparseFlashAttentionMetadata(
+            num_prefills=0,
+            num_prefill_tokens=0,
+            num_decode_tokens=self.num_decode_tokens,
+            slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
+            seq_lens=None,
+            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
+            max_query_len=None,
+            max_prefill_seq_len=0,
+            max_decode_seq_len=self.max_decode_seq_len,
+            query_start_loc=None,
+            seq_start_loc=None,
+            context_lens_tensor=None,
+            block_tables=self.block_tables[self.num_prefills:],
+            use_cuda_graph=self.use_cuda_graph,
+        )
+        return self._cached_decode_metadata
+
+
+class BlocksparseFlashAttentionImpl(AttentionImpl):
+    """
+    If the input tensors contain prompt tokens, the layout is as follows:
+    |<--------------- num_prompt_tokens -------------->|
+    |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
+    Otherwise, the layout is as follows:
+    |<------------------ num_generation_tokens (M) ----------------->|
+    |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
+    Generation tokens can contain padding when cuda-graph is used.
+    Currently, prompt tokens don't contain any padding.
+    The prompts might have different lengths, while the generation tokens
+    always have length 1.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: int,
+        alibi_slopes: Optional[List[float]],
+        sliding_window: Optional[int],
+        kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
+    ) -> None:
+        assert blocksparse_params is not None
+        assert alibi_slopes is None, ValueError(
+            "Alibi not support for blocksparse flash attention.")
+        assert sliding_window is None, ValueError(
+            "sliding_window is invalid for blocksparse attention.")
+
+        if "num_heads" not in blocksparse_params:
+            blocksparse_params["num_heads"] = num_heads
+        if "num_kv_heads" not in blocksparse_params:
+            blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads
+        self.blocksparse_params = BlocksparseParams(**blocksparse_params)
+        self.kv_cache_dtype = kv_cache_dtype
+
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.alibi_slopes = alibi_slopes
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+        self.local_blocks = self.blocksparse_params.local_blocks
+        self.vert_stride = self.blocksparse_params.vert_stride
+        self.sparse_block_size = self.blocksparse_params.block_size
+        self.head_sliding_step = self.blocksparse_params.head_sliding_step
+
+        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+        self.tp_size = get_tensor_model_parallel_world_size()
+        self.tp_rank = get_tensor_model_parallel_rank()
+
+        total_num_heads = num_heads * self.tp_size
+        self.bs_attn = LocalStridedBlockSparseAttn(
+            total_num_heads,
+            self.blocksparse_params.max_seqlen,
+            self.blocksparse_params.local_blocks,
+            self.blocksparse_params.vert_stride,
+            self.blocksparse_params.block_size,
+            homo_head=self.blocksparse_params.homo_head,
+            active_head_range=self.blocksparse_params.active_head_range,
+        )
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: BlocksparseFlashAttentionMetadata,
+        kv_scale: float = 1.0,
+    ) -> torch.Tensor:
+        """Forward pass with FlashAttention and PagedAttention.
+        Args:
+            query: shape = [num_tokens, num_heads * head_size]
+            key: shape = [num_tokens, num_kv_heads * head_size]
+            value: shape = [num_tokens, num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            attn_metadata: Metadata for attention.
+        Returns:
+            shape = [num_tokens, num_heads * head_size]
+        """
+        num_tokens, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        if kv_cache is not None:
+            key_cache, value_cache = PagedAttention.split_kv_cache(
+                kv_cache, self.num_kv_heads, self.head_size)
+
+            # Reshape the input keys and values and store them in the cache.
+            # If kv_cache is not provided, the new key and value tensors are
+            # not cached. This happens during the initial memory profiling run.
+
+            PagedAttention.write_to_paged_cache(
+                key,
+                value,
+                key_cache,
+                value_cache,
+                attn_metadata.slot_mapping,
+                self.kv_cache_dtype,
+                kv_scale,
+            )
+
+        if prefill_meta := attn_metadata.prefill_metadata:
+
+            # Prompt run.
+            # normal attention
+            # When block_tables are not filled, it means q and k are the
+            # prompt, and they have the same length.
+
+            assert kv_cache is None \
+                    or prefill_meta.block_tables is None \
+                    or prefill_meta.block_tables.numel() == 0, \
+                "Does not support prefix-enabled attention."
+
+            output = self.bs_attn(
+                q=query,
+                k=key,
+                v=value,
+                cu_seqlens_q=prefill_meta.seq_start_loc,
+                cu_seqlens_k=prefill_meta.seq_start_loc,
+                sm_scale=self.scale,
+            )
+
+        if decode_meta := attn_metadata.decode_metadata:
+            # Decoding run.
+            output = PagedAttention.forward_decode(
+                query,
+                key_cache,
+                value_cache,
+                decode_meta.block_tables,
+                decode_meta.seq_lens_tensor,
+                self.blocksparse_params.max_seqlen,
+                self.kv_cache_dtype,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+                kv_scale,
+                tp_rank=self.tp_rank,
+                blocksparse_local_blocks=self.local_blocks,
+                blocksparse_vert_stride=self.vert_stride,
+                blocksparse_block_size=self.sparse_block_size,
+                blocksparse_head_sliding_step=self.head_sliding_step,
+            )
+
+        # Reshape the output tensor.
+        return output.view(num_tokens, hidden_size)

+ 5 - 1
aphrodite/attention/backends/flash_attn.py

@@ -1,6 +1,6 @@
 """Attention layer with FlashAttention."""
 """Attention layer with FlashAttention."""
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 
 import torch
 import torch
 from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
 from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
@@ -220,7 +220,10 @@ class FlashAttentionImpl(AttentionImpl):
         alibi_slopes: Optional[List[float]],
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
         kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "FlashAttention does not support block-sparse attention.")
         self.num_heads = num_heads
         self.num_heads = num_heads
         self.head_size = head_size
         self.head_size = head_size
         self.scale = float(scale)
         self.scale = float(scale)
@@ -240,6 +243,7 @@ class FlashAttentionImpl(AttentionImpl):
             # paged KV cache.
             # paged KV cache.
             raise ValueError(
             raise ValueError(
                 "Sliding window is not supported in FlashAttention.")
                 "Sliding window is not supported in FlashAttention.")
+
         support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
         support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
         if head_size not in support_head_sizes:
         if head_size not in support_head_sizes:
             raise ValueError(
             raise ValueError(

+ 3 - 0
aphrodite/attention/backends/flashinfer.py

@@ -169,7 +169,10 @@ class FlashInferImpl(AttentionImpl):
         alibi_slopes: Optional[List[float]],
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
         kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "FlashInfer does not support block-sparse attention.")
         self.num_heads = num_heads
         self.num_heads = num_heads
         self.head_size = head_size
         self.head_size = head_size
         self.scale = float(scale)
         self.scale = float(scale)

+ 5 - 2
aphrodite/attention/backends/rocm_flash_attn.py

@@ -1,7 +1,7 @@
 """Attention layer ROCm GPUs."""
 """Attention layer ROCm GPUs."""
-from dataclasses import dataclass
 import os
 import os
-from typing import List, Optional, Tuple, Type
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 
 import torch
 import torch
 from loguru import logger
 from loguru import logger
@@ -200,7 +200,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         alibi_slopes: Optional[List[float]],
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
         kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "ROCm FlashAttention does not support block-sparse attention.")
         self.num_heads = num_heads
         self.num_heads = num_heads
         self.head_size = head_size
         self.head_size = head_size
         self.scale = float(scale)
         self.scale = float(scale)

+ 4 - 1
aphrodite/attention/backends/torch_sdpa.py

@@ -1,7 +1,7 @@
 """ Attention layer with torch scaled_dot_product_attention
 """ Attention layer with torch scaled_dot_product_attention
     and PagedAttention."""
     and PagedAttention."""
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 
 import torch
 import torch
 from torch.nn.functional import scaled_dot_product_attention
 from torch.nn.functional import scaled_dot_product_attention
@@ -101,7 +101,10 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
         alibi_slopes: Optional[List[float]],
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
         kv_cache_dtype: str,
+    blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "Torch SDPA does not support block-sparse attention.")
         self.num_heads = num_heads
         self.num_heads = num_heads
         self.head_size = head_size
         self.head_size = head_size
         self.scale = float(scale)
         self.scale = float(scale)

+ 4 - 1
aphrodite/attention/backends/xformers.py

@@ -1,6 +1,6 @@
 """Attention layer with xFormers and PagedAttention."""
 """Attention layer with xFormers and PagedAttention."""
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
 
 
 import torch
 import torch
 from xformers import ops as xops
 from xformers import ops as xops
@@ -210,7 +210,10 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
         alibi_slopes: Optional[List[float]],
         alibi_slopes: Optional[List[float]],
         sliding_window: Optional[int],
         sliding_window: Optional[int],
         kv_cache_dtype: str,
         kv_cache_dtype: str,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
     ) -> None:
+        assert blocksparse_params is None, ValueError(
+            "XFormers does not support block-sparse attention.")
         self.num_heads = num_heads
         self.num_heads = num_heads
         self.head_size = head_size
         self.head_size = head_size
         self.scale = float(scale)
         self.scale = float(scale)

+ 12 - 4
aphrodite/attention/layer.py

@@ -1,5 +1,5 @@
 """Attention layer."""
 """Attention layer."""
-from typing import List, Optional
+from typing import Any, Dict, List, Optional
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
@@ -12,9 +12,11 @@ from aphrodite.quantization.base_config import QuantizationConfig
 
 
 class Attention(nn.Module):
 class Attention(nn.Module):
     """Attention layer.
     """Attention layer.
+
     This class takes query, key, and value tensors as input. The input tensors
     This class takes query, key, and value tensors as input. The input tensors
     can either contain prompt tokens or generation tokens.
     can either contain prompt tokens or generation tokens.
     The class does the following:
     The class does the following:
+
     1. Store the input key and value tensors in the KV cache.
     1. Store the input key and value tensors in the KV cache.
     2. Perform (multi-head/multi-query/grouped-query) attention.
     2. Perform (multi-head/multi-query/grouped-query) attention.
     3. Return the output tensor.
     3. Return the output tensor.
@@ -30,6 +32,7 @@ class Attention(nn.Module):
         sliding_window: Optional[int] = None,
         sliding_window: Optional[int] = None,
         cache_config: Optional[CacheConfig] = None,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
         if cache_config is not None:
         if cache_config is not None:
@@ -60,15 +63,18 @@ class Attention(nn.Module):
             # to self._kv_scale in a native float32 value after weight loading.
             # to self._kv_scale in a native float32 value after weight loading.
             self.quant_method = quant_method
             self.quant_method = quant_method
             self.quant_method.create_weights(self)
             self.quant_method.create_weights(self)
+
         # During model initialization, the default dtype is set as the model
         # During model initialization, the default dtype is set as the model
         # weight and activation dtype.
         # weight and activation dtype.
         dtype = torch.get_default_dtype()
         dtype = torch.get_default_dtype()
         attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
         attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
                                         sliding_window, dtype, kv_cache_dtype,
                                         sliding_window, dtype, kv_cache_dtype,
-                                        block_size)
+                                        block_size, blocksparse_params
+                                        is not None)
         impl_cls = attn_backend.get_impl_cls()
         impl_cls = attn_backend.get_impl_cls()
         self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
         self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
-                             alibi_slopes, sliding_window, kv_cache_dtype)
+                             alibi_slopes, sliding_window, kv_cache_dtype,
+                             blocksparse_params)
 
 
     def forward(
     def forward(
         self,
         self,
@@ -78,11 +84,13 @@ class Attention(nn.Module):
         kv_cache: Optional[torch.Tensor],
         kv_cache: Optional[torch.Tensor],
         attn_metadata: AttentionMetadata,
         attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
     ) -> torch.Tensor:
-        return self.impl.forward(query, key, value, kv_cache, attn_metadata)
+        return self.impl.forward(query, key, value, kv_cache, attn_metadata,
+                                 self._kv_scale)
 
 
     def extra_repr(self) -> str:
     def extra_repr(self) -> str:
         s = f"head_size={self.impl.head_size}"  # type: ignore
         s = f"head_size={self.impl.head_size}"  # type: ignore
         s += f", num_heads={self.impl.num_heads}"  # type: ignore
         s += f", num_heads={self.impl.num_heads}"  # type: ignore
         s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
         s += f", num_kv_heads={self.impl.num_kv_heads}"  # type: ignore
         s += f", scale={self.impl.scale}"  # type: ignore
         s += f", scale={self.impl.scale}"  # type: ignore
+        s += f", backend={self.impl.__class__.__name__}"
         return s
         return s

+ 0 - 0
aphrodite/attention/ops/blocksparse_attention/__init__.py


+ 422 - 0
aphrodite/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py

@@ -0,0 +1,422 @@
+import torch
+import triton
+import triton.language as tl
+
+
+def blocksparse_flash_attn_varlen_fwd(
+        q,
+        k,
+        v,  # (#tokens, n_heads, head_size)
+        cu_seqlens_k,
+        cu_seqlens_q,
+        sm_scale,
+        sparse_layout,
+        *,
+        block_size=64,
+        q_block_size=None,
+        max_seqlen=None):
+    # split q to blocks
+
+    assert isinstance(sparse_layout, (list, tuple))
+
+    _, n_heads, head_size = q.shape
+    batch_size = cu_seqlens_k.size(0) - 1
+    q_block_size = q_block_size or block_size
+
+    assert q.dim() == k.dim() == v.dim() == 3
+    assert q.size(1) % k.size(1) == 0
+    assert q.size(2) == k.size(2)
+    # TODO: allow k, v to have different head_size
+    assert k.shape == v.shape
+    assert cu_seqlens_k.dim() == 1
+
+    q_k_ratio = q.size(1) // k.size(1)
+
+    if cu_seqlens_q is None:
+        if q.size(0) == batch_size:  # decoding only
+            cu_seqlens_q = torch.arange(
+                0,
+                batch_size + 1,
+                dtype=cu_seqlens_k.dtype,
+                device=cu_seqlens_k.device,
+            )
+        elif q.size(0) == k.size(0):
+            cu_seqlens_q = cu_seqlens_k
+        else:
+            raise ValueError("cu_seqlens_q must be specified\
+                    if it mix of prefilling and decoding.")
+    else:
+        assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0)
+
+    # switch to use cpu to avoid too many kernel launches when iterated over
+    q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu()
+    k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu()
+
+    assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), (
+        "length of q should either be 1 (decoding) or same as k (prefilling).")
+
+    if max_seqlen:
+        assert k_lens.max() <= max_seqlen
+
+    n_blocks = (q_lens + q_block_size - 1) // q_block_size
+
+    q_batch_ids = torch.tensor(
+        [i for i, n in enumerate(n_blocks) for _ in range(n)],
+        dtype=cu_seqlens_q.dtype,
+        device=cu_seqlens_q.device,
+    )
+    q_start_sids = torch.tensor(
+        [i * q_block_size for n in n_blocks for i in range(n)],
+        dtype=cu_seqlens_q.dtype,
+        device=cu_seqlens_q.device,
+    )
+
+    out = q.new_empty(q.shape)
+    cu_seqlens_q = cu_seqlens_q.contiguous()
+    cu_seqlens_k = cu_seqlens_k.contiguous()
+
+    layout_crow_indices, layout_col_indices = sparse_layout
+    block_d = triton.next_power_of_2(head_size)
+
+    decoding_only = (q_lens == 1).all().item()
+    grid = (len(q_start_sids), n_heads, 1)
+
+    _fwd_kernel_batch_inference[grid](
+        q,
+        k,
+        v,
+        out,
+        sm_scale,
+        cu_seqlens_q[:-1],
+        cu_seqlens_q[1:],
+        cu_seqlens_k[:-1],
+        cu_seqlens_k[1:],
+        q_batch_ids,
+        q_start_sids,
+        0,
+        *q.stride(),
+        0,
+        *k.stride(),
+        0,
+        *v.stride(),
+        0,
+        *out.stride(),
+        layout_crow_indices,
+        layout_col_indices,
+        *layout_crow_indices.stride(),
+        *layout_col_indices.stride(),
+        q_k_ratio,
+        HAS_BATCH_DIM=False,
+        D_HEAD=head_size,
+        BLOCK_M=q_block_size,
+        BLOCK_N=block_size,
+        BLOCK_D=block_d,
+        BLOCK_M_LOADING=(16 if decoding_only else
+                         q_block_size),  # smaller for decoding
+        EVEN_D=block_d == head_size,
+        num_warps=1 if decoding_only else 4,
+        num_stages=3)
+
+    return out
+
+
+@triton.jit
+def _fwd_kernel_inner(
+    acc,
+    l_i,
+    m_i,
+    q,
+    Q,
+    k_block_col_idx,
+    layout_col_ptr,
+    layout_col_stride_h,
+    layout_col_stride_m,
+    k_ptrs,
+    v_ptrs,
+    off_h,
+    offs_m,
+    offs_n,
+    offs_d,
+    stride_kt,
+    stride_vt,
+    sm_scale,
+    k_seqlen,
+    past_len,
+    LAST_K_BLOCK: tl.constexpr,
+    BLOCK_M_LOADING: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    D_HEAD: tl.constexpr,
+    EVEN_D: tl.constexpr,
+    M_LT_N: tl.constexpr,
+):
+    k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h +
+                         k_block_col_idx * layout_col_stride_m).to(tl.int32)
+    start_n = k_block_id * BLOCK_N
+    if LAST_K_BLOCK:
+        if EVEN_D:
+            k = tl.load(
+                k_ptrs + start_n * stride_kt,
+                mask=offs_n[None, :] + start_n < k_seqlen,
+            )
+        else:
+            k = tl.load(
+                k_ptrs + start_n * stride_kt,
+                mask=(offs_n[None, :] + start_n < k_seqlen) &
+                (offs_d[:, None] < D_HEAD),
+            )
+    else:
+        if EVEN_D:
+            k = tl.load(k_ptrs + start_n * stride_kt)
+        else:
+            k = tl.load(k_ptrs + start_n * stride_kt,
+                        mask=offs_d[:, None] < D_HEAD)
+
+    qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
+    qk += tl.dot(q, k)
+    qk *= sm_scale
+
+    # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N
+    if LAST_K_BLOCK | M_LT_N:
+        qk += tl.where(
+            offs_m[:, None] + past_len >= (start_n + offs_n[None, :]),
+            0,
+            float("-inf"),
+        )
+
+    # flash-attn2
+    m_ij = tl.maximum(m_i, tl.max(qk, 1))
+    p = tl.math.exp2(qk - m_ij[:, None])
+    l_ij = tl.sum(p, 1)
+    alpha = tl.math.exp2(m_i - m_ij)
+    acc = acc * alpha[:, None]
+    # update m_i
+    m_i = m_ij
+    l_i = l_i * alpha + l_ij
+
+    p = p.to(Q.dtype.element_ty)
+    # update acc
+    if LAST_K_BLOCK:
+        if EVEN_D:
+            v = tl.load(
+                v_ptrs + start_n * stride_vt,
+                mask=offs_n[:, None] + start_n < k_seqlen,
+            )
+        else:
+            v = tl.load(
+                v_ptrs + start_n * stride_vt,
+                mask=(offs_n[:, None] + start_n < k_seqlen) &
+                (offs_d[None, :] < D_HEAD),
+            )
+    else:
+        if EVEN_D:
+            v = tl.load(v_ptrs + start_n * stride_vt)
+        else:
+            v = tl.load(v_ptrs + start_n * stride_vt,
+                        mask=offs_d[None, :] < D_HEAD)
+
+    acc += tl.dot(p, v)
+
+    return acc, l_i, m_i
+
+
+@triton.heuristics({
+    "M_LT_N":
+    lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"],
+})
+@triton.jit
+def _fwd_kernel_batch_inference(
+    Q,
+    K,
+    V,
+    Out,
+    sm_scale,
+    q_batch_starts,
+    q_batch_ends,
+    k_batch_starts,
+    k_batch_ends,
+    q_batch_ids,
+    q_start_sids,
+    stride_qb,
+    stride_qt,
+    stride_qh,
+    stride_qd,
+    stride_kb,
+    stride_kt,
+    stride_kh,
+    stride_kd,
+    stride_vb,
+    stride_vt,
+    stride_vh,
+    stride_vd,
+    stride_ob,
+    stride_ot,
+    stride_oh,
+    stride_od,
+    layout_crow_ptr,
+    layout_col_ptr,
+    layout_crow_stride_h,
+    layout_crow_stride_m,
+    layout_col_stride_h,
+    layout_col_stride_m,
+    q_k_ratio,
+    HAS_BATCH_DIM: tl.constexpr,
+    D_HEAD: tl.constexpr,
+    BLOCK_M: tl.constexpr,
+    BLOCK_N: tl.constexpr,
+    BLOCK_D: tl.constexpr,
+    BLOCK_M_LOADING: tl.constexpr,
+    EVEN_D: tl.constexpr,
+    M_LT_N: tl.constexpr,
+):
+    """
+    NOTATION:
+    pid: position id
+    sid: storage id
+    sbid: storage block id
+    pbid: position block id
+    offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col)
+    TODO:
+    Optimize grouped-attn
+    """
+    off_zm = tl.program_id(0)
+    off_h = tl.program_id(1)
+
+    off_h_for_kv = off_h // q_k_ratio
+
+    if HAS_BATCH_DIM:
+        off_z = tl.program_id(2)
+        Q += off_z * stride_qb
+        K += off_z * stride_kb
+        V += off_z * stride_vb
+        Out += off_z * stride_ob
+        start_m = off_zm
+        q_start_sid = start_m * BLOCK_M  # always 0 for decoding
+    else:
+        off_z = tl.load(q_batch_ids + off_zm).to(tl.int32)  # [0, 0, 0, 1]
+        q_start_sid = tl.load(q_start_sids + off_zm)
+        start_m = q_start_sid // BLOCK_M  # q_sbid
+
+    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING)
+    offs_n = tl.arange(0, BLOCK_N)
+    offs_d = tl.arange(0, BLOCK_D)
+
+    q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32)
+    q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start
+    k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32)
+    k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start
+    past_len = k_seqlen - q_seqlen
+
+    Q += q_cu_start * stride_qt + off_h * stride_qh
+    K += k_cu_start * stride_kt + off_h_for_kv * stride_kh
+    V += k_cu_start * stride_vt + off_h_for_kv * stride_vh
+    Out += q_cu_start * stride_ot + off_h * stride_oh
+
+    q_pbid = (past_len + q_start_sid) // BLOCK_M
+
+    if EVEN_D:
+        q = tl.load(
+            Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
+            mask=offs_m[:, None] < q_seqlen,
+        )
+    else:
+        q = tl.load(
+            Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
+            mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
+            other=0,
+        )
+
+    sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
+                       q_pbid * layout_crow_stride_m)
+
+    # TODO: load at once, with any Triton version
+    # that supports `tl.split`, e.g., Triton 3.0
+    k_block_start = tl.load(sparse_crow_ptr).to(tl.int32)
+    k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32)
+
+    m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf")
+    l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32)
+    acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32)
+
+    k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd
+    v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd
+
+    sm_scale *= (
+        1.44269504  # 1/log2 as we use base2 for exponential and logarithm
+    )
+
+    for k_block_col_idx in range(k_block_start, k_block_end - 1):
+        acc, l_i, m_i = _fwd_kernel_inner(
+            acc,
+            l_i,
+            m_i,
+            q,
+            Q,
+            k_block_col_idx,
+            layout_col_ptr,
+            layout_col_stride_h,
+            layout_col_stride_m,
+            k_ptrs,
+            v_ptrs,
+            off_h,
+            offs_m,
+            offs_n,
+            offs_d,
+            stride_kt,
+            stride_vt,
+            sm_scale,
+            k_seqlen,
+            past_len,
+            False,
+            BLOCK_M_LOADING,
+            BLOCK_N,
+            D_HEAD,
+            EVEN_D,
+            M_LT_N,
+        )
+
+    acc, l_i, m_i = _fwd_kernel_inner(
+        acc,
+        l_i,
+        m_i,
+        q,
+        Q,
+        k_block_end - 1,
+        layout_col_ptr,
+        layout_col_stride_h,
+        layout_col_stride_m,
+        k_ptrs,
+        v_ptrs,
+        off_h,
+        offs_m,
+        offs_n,
+        offs_d,
+        stride_kt,
+        stride_vt,
+        sm_scale,
+        k_seqlen,
+        past_len,
+        True,
+        BLOCK_M_LOADING,
+        BLOCK_N,
+        D_HEAD,
+        EVEN_D,
+        M_LT_N,
+    )
+
+    # flash-attn 2
+    m_i += tl.math.log2(l_i)
+    acc = acc / l_i[:, None]
+
+    # write output
+    if EVEN_D:
+        tl.store(
+            Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
+            acc,
+            mask=offs_m[:, None] < q_seqlen,
+        )
+    else:
+        tl.store(
+            Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od,
+            acc,
+            mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
+        )

+ 235 - 0
aphrodite/attention/ops/blocksparse_attention/interface.py

@@ -0,0 +1,235 @@
+import math
+
+import torch
+
+from aphrodite.attention.ops.blocksparse_attention.utils import (
+    dense_to_crow_col, get_head_sliding_step, get_sparse_attn_mask)
+from aphrodite.common.utils import is_cpu, is_hip
+
+IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available()
+                         and torch.cuda.get_device_capability()[0] >= 8)
+
+if IS_COMPUTE_8_OR_ABOVE:
+    from aphrodite.attention.ops.blocksparse_attention.blocksparse_attention_kernel import \
+        blocksparse_flash_attn_varlen_fwd  # noqa: E501
+
+
+class LocalStridedBlockSparseAttn(torch.nn.Module):
+
+    def __init__(
+        self,
+        n_heads,
+        max_seqlen,
+        local_blocks,
+        vert_stride,
+        block_size,
+        device=None,
+        dtype=None,
+        homo_head=False,
+        active_head_range=None,
+        q_block_size=None,
+        use_spda=None,
+    ):
+        super().__init__()
+        if use_spda is None:
+            use_spda = is_hip() or is_cpu() or not \
+                       IS_COMPUTE_8_OR_ABOVE
+        device = device or (torch.cuda.current_device()
+                            if torch.cuda.is_available() else "cpu")
+        device = torch.device(device)
+        # NOTE: aphrodite CPU backend support BF16 instead of FP16.
+        dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
+                          or device.type == "cpu" else torch.half)
+
+        self.n_heads = n_heads
+        self.max_seqlen = max_seqlen
+        self.local_blocks = local_blocks
+        self.vert_stride = vert_stride
+        self.use_spda = use_spda
+        self.dtype = dtype
+        self.device = device
+        self.block_size = block_size
+        self.q_block_size = q_block_size
+        self.homo_head = homo_head
+        self.active_head_range = active_head_range
+        self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride,
+                                                       homo_head)
+
+        sparse_layout, sparse_pattern, self.dense_attn_mask = (
+            self.get_attn_pattern(dtype, device))
+
+        if q_block_size is not None and q_block_size != block_size:
+            if q_block_size > block_size:
+                assert q_block_size % block_size == 0
+                blocks_to_merge = q_block_size // block_size
+                shape = sparse_pattern.shape
+                sparse_pattern = sparse_pattern.view(shape[0], -1,
+                                                     blocks_to_merge,
+                                                     shape[-1])
+                sparse_pattern = sparse_pattern.sum(2)
+                sparse_layout = dense_to_crow_col(sparse_pattern)
+            else:
+                raise ValueError(
+                    "Does not support smaller q_block_size. It will be slower."
+                )
+
+        self.sparse_layout = sparse_layout
+
+    def get_attn_pattern(self, dtype, device):
+        sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask(
+            self.n_heads,
+            self.max_seqlen,
+            self.max_seqlen,
+            dtype,
+            device,
+            block_size=self.block_size,
+            local_blocks=self.local_blocks,
+            vert_stride=self.vert_stride,
+            homo_head=self.homo_head,
+            return_dense=self.use_spda,
+            dense_mask_type="bias",
+        )
+        if (not self.homo_head) and (self.active_head_range is not None):
+            assert isinstance(self.active_head_range, tuple)
+            assert (len(self.active_head_range) == 2)
+            h_start, h_end = self.active_head_range
+            sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout)
+            if self.use_spda:
+                dense_attn_mask = dense_attn_mask[h_start:h_end]
+        return sparse_layout, sparse_pattern, dense_attn_mask
+
+    def varlen_attn(self,
+                    q,
+                    k,
+                    v,
+                    cu_seqlens_k,
+                    cu_seqlens_q=None,
+                    sm_scale=None):
+        """
+        q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
+        Support grouped attention, with `q[:, i*r:(i*r + r)]`
+        is correspondent to `k[:, i]`, where `r` is the q/k ratio.
+        cu_seqlens_k: shape=(batch_size + 1,), 
+        indicating segment of samples, 
+        e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
+        cu_seqlens_q: shape=(batch_size + 1, ).
+        Default None: same as cu_seqlens_k for prefilling or
+        [0, 1, .., batch_size] for decoding.
+        The only case you need to specify is when q is a mix of 
+        prefilling and decoding.
+        sm_scale: softmax scale, default to 1/sqrt(head_size).
+        return: tensor of shape as q.
+        """
+        assert (
+            IS_COMPUTE_8_OR_ABOVE
+        ), "Requires compute capability of 8 or above (Ampere or newer) to use \
+            Triton kernel."
+
+        sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
+
+        return blocksparse_flash_attn_varlen_fwd(
+            q,
+            k,
+            v,
+            cu_seqlens_k,
+            cu_seqlens_q,
+            sm_scale,
+            self.sparse_layout,
+            block_size=self.block_size,
+            q_block_size=self.q_block_size,
+            max_seqlen=self.max_seqlen,
+        )
+
+    @staticmethod
+    def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1):
+        """
+        :param x: (total_tokens, n_heads, head_size)
+        :return: (batch, n_heads, length, head_size)
+        """
+        x_padded = x.new_empty(
+            len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2))
+        cu_seqlens = cu_seqlens.cpu()
+        for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
+            x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0,
+                                                             1).unsqueeze(1))
+        return x_padded.flatten(1, 2)
+
+    @staticmethod
+    def transpose_and_unpad(x_padded, cu_seqlens):
+        """
+        :param x_padded: (batch, n_heads, length, head_size)
+        :return: (total_tokens, n_heads, head_size)
+        """
+        cu_seqlens = cu_seqlens.cpu()
+        total_n_tokens = cu_seqlens[-1]
+        x = x_padded.new_empty(total_n_tokens, x_padded.size(1),
+                               x_padded.size(3))
+        for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])):
+            x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1))
+        return x
+
+    def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
+        """For CPU, V100 or other older GPUs.
+        NOTE: torch SPDA supports nested tensor, 
+        but seems extremely slow. Choose to pad instead.
+        """
+        assert (cu_seqlens_q is None or
+                (cu_seqlens_q
+                 == cu_seqlens_k).all()), "Can only handle prompt with SPDA."
+        assert q.size(0) == k.size(0), "can only handle prompt with SPDA."
+
+        assert q.size(1) % k.size(1) == 0
+        q_k_ratio = q.size(1) // k.size(1)
+        sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1))
+        cu_seqlens = cu_seqlens_k.cpu()
+        maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
+
+        if (self.dense_attn_mask.dtype != q.dtype
+                or self.dense_attn_mask.device != q.device):
+            _, _, self.dense_attn_mask = self.get_attn_pattern(
+                q.dtype, q.device)
+        attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen]
+
+        q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1)
+        k2, v2 = [
+            self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio)
+            for x in [k, v]
+        ]
+        spda_output = torch.nn.functional.scaled_dot_product_attention(
+            q2, k2, v2, attn_mask=attn_mask, scale=sm_scale)
+        return self.transpose_and_unpad(spda_output, cu_seqlens)
+
+    def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
+        """Dispatch to `varlen_attn` (Ampere or newer) or 
+        `self.spda`(cpu, Volta, Turing or older)based on 
+        the type of device used and cuda compute capability.
+        q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
+                Support grouped attention, with `q[:, i*r:(i*r + r)]`
+                is correspondent to `k[:, i]`, where `r` is the q/k ratio.
+        cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples,
+                    e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
+        cu_seqlens_q: shape=(batch_size + 1, ).
+                    Default None: same as cu_seqlens_k for prefilling or
+                    [0, 1, .., batch_size] for decoding.
+                    The only case you need to specify 
+                    is when q is a mix of prefilling 
+                    and decoding.
+        sm_scale: softmax scale, default to 1/sqrt(head_size).
+        return: tensor of shape as q.
+        """
+        assert k.dim() == 3
+        if self.use_spda:
+            return self.spda(
+                q,
+                k,
+                v,
+                cu_seqlens_k,
+                cu_seqlens_q=cu_seqlens_q,
+                sm_scale=sm_scale,
+            )
+        return self.varlen_attn(q,
+                                k,
+                                v,
+                                cu_seqlens_k,
+                                cu_seqlens_q=cu_seqlens_q,
+                                sm_scale=sm_scale)

+ 216 - 0
aphrodite/attention/ops/blocksparse_attention/utils.py

@@ -0,0 +1,216 @@
+# Helper functions for 3D sparse pattern
+# These function are not optimized and very inefficient.
+# Avoid calling them too frequent or use a cache mechanism.
+
+from functools import lru_cache
+
+import torch
+import triton
+from scipy import sparse
+
+
+def dense_to_crow_col(x: torch.Tensor):
+    """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing.
+    NOTE: col_indices padded -1
+    """
+    device = x.device
+    pad = -1
+    dim = x.dim()
+    assert x.dim() in (2, 3)
+    if x.dim() == 2:
+        x = x[None]
+    x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
+    crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
+    cols = [torch.from_numpy(xi.indices) for xi in x]
+    max_cols = max(len(xi) for xi in cols)
+    cols = [
+        torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])])
+        for xi in cols
+    ]
+    cols = torch.vstack(cols)
+    if dim == 2:
+        crows = crows[0]
+        cols = cols[0]
+    return crows.to(device), cols.to(device)
+
+
+def crow_col_to_dense(crows: torch.Tensor,
+                      cols: torch.Tensor,
+                      dtype: torch.dtype = torch.float16):
+    dim = crows.dim()
+    if dim == 1:
+        crows = crows[None]
+        cols = cols[None]
+    device = crows.device
+    crows, cols = crows.cpu(), cols.cpu()  # faster in cpu
+    shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1)
+    x = torch.zeros(shape, dtype=dtype)
+    for i in range(shape[0]):
+        for j in range(shape[1]):
+            x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1
+    if dim == 1:
+        x = x[0]
+    return x.to(device)
+
+
+def dense_to_ccol_row(x: torch.Tensor):
+    """Similar, but to CSC format"""
+    x = x.transpose(-2, -1)
+    return dense_to_crow_col(x)
+
+
+def ccol_row_to_dense(ccol: torch.Tensor,
+                      rows: torch.Tensor,
+                      dtype: torch.dtype = torch.float16):
+    return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous()
+
+
+def _get_sparse_attn_mask_homo_head(
+    q_len: int,
+    max_seqlen: int,
+    dtype: torch.dtype,
+    device: torch.device,
+    block_size: int = 128,
+    local_blocks: int = 4,
+    vert_stride: int = 4,
+    return_dense: bool = False,
+):
+    """
+    :return: a tuple of 3:
+        - tuple of crow_indices, col_indices representation 
+            of CSR format.
+        - block dense mask
+        - all token dense mask (be aware that it can be 
+            OOM if it is too big) if `return_dense==True`, 
+            otherwise, None
+    """
+    with torch.no_grad():
+        num_blocks = triton.cdiv(max_seqlen, block_size)
+        q_pos = torch.arange(num_blocks)[:, None]
+        k_pos = torch.arange(num_blocks)[None]
+        mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0
+        block_mask_dense = (((q_pos >= k_pos)
+                             & ((q_pos - k_pos < local_blocks)
+                                | mask_vert_strided)).to(device).to(dtype))
+        num_blocks_q = triton.cdiv(q_len, block_size)
+        block_mask_dense_output = (dense_to_crow_col(
+            block_mask_dense[-num_blocks_q:].contiguous()))
+    if return_dense:
+        mask_dense = torch.kron(
+            block_mask_dense,
+            block_mask_dense.new_ones((block_size, block_size)),
+        )
+        causal_mask = torch.tril(torch.ones(
+            max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
+        mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask
+        return (
+            block_mask_dense_output,
+            block_mask_dense,
+            mask_dense,
+        )
+    else:
+        return (
+            block_mask_dense_output,
+            block_mask_dense,
+            None,
+        )
+
+
+def binary_mask_to_bias(mask_dense: torch.Tensor):
+    mask_dense = 1 - mask_dense
+    mask_dense.masked_fill_(mask_dense.bool(), -torch.inf)
+    return mask_dense
+
+
+def get_head_sliding_step(n_heads: int,
+                          vert_stride: int,
+                          homo_head: bool = False):
+    if homo_head:
+        return 0
+    return max(1, int(vert_stride / n_heads))
+
+
+@lru_cache
+def get_sparse_attn_mask(
+    n_heads: int,
+    q_len: int,
+    max_seqlen: int,
+    dtype: torch.dtype,
+    device: torch.device,
+    block_size: int = 64,
+    local_blocks: int = 4,
+    vert_stride: int = 4,
+    homo_head: bool = True,
+    return_dense: bool = False,
+    dense_mask_type: str = "binary",
+):
+    """
+    :param dense_mask_type: "binary" (0 for skip token, 1 for others)
+        or "bias" (-inf for skip token, 0 or others)
+    :return: a tuple of 3:
+        - tuple of crow_indices, col_indices representation 
+            of CSR format.
+        - block dense mask
+        - all token dense mask (be aware that it can be OOM if it 
+            is too big) if `return_dense==True`, otherwise, None
+    """
+    assert dense_mask_type in ("binary", "bias")
+    if homo_head:
+        with torch.no_grad():
+            (crow, col), block_mask_dense, mask_dense = (
+                _get_sparse_attn_mask_homo_head(
+                    q_len,
+                    max_seqlen,
+                    dtype,
+                    device,
+                    block_size,
+                    local_blocks,
+                    vert_stride,
+                    return_dense,
+                ))
+            crow = crow[None].expand(n_heads, crow.shape[0])
+            col = col[None].expand(n_heads, col.shape[0])
+            if return_dense:
+                mask_dense = mask_dense[None].expand(n_heads,
+                                                     *mask_dense.shape)
+                if dense_mask_type == "bias":
+                    mask_dense = binary_mask_to_bias(mask_dense)
+            return (crow, col), block_mask_dense, mask_dense
+
+    with torch.no_grad():
+        num_blocks = triton.cdiv(max_seqlen, block_size)
+        q_pos = torch.arange(num_blocks)[None, :, None]
+        k_pos = torch.arange(num_blocks)[None, None]
+        head_sliding_step = get_head_sliding_step(n_heads, vert_stride)
+        mask_vert_strided = [
+            (torch.arange(num_blocks) + h * head_sliding_step + 1) %
+            vert_stride == 0 for h in range(n_heads)
+        ]
+        mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1)
+        block_mask_dense = (((q_pos >= k_pos)
+                             & ((q_pos - k_pos < local_blocks)
+                                | mask_vert_strided)).to(device).to(dtype))
+        num_blocks_q = triton.cdiv(q_len, block_size)
+        block_mask_dense_output = block_mask_dense[:, -num_blocks_q:]
+    if return_dense:
+        mask_dense = torch.kron(
+            block_mask_dense,
+            block_mask_dense.new_ones((block_size, block_size)),
+        )
+        causal_mask = torch.tril(torch.ones(
+            max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:]
+        mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None]
+        if dense_mask_type == "bias":
+            mask_dense = binary_mask_to_bias(mask_dense)
+
+        return (
+            dense_to_crow_col(block_mask_dense_output),
+            block_mask_dense,
+            mask_dense,
+        )
+    else:
+        return (
+            dense_to_crow_col(block_mask_dense_output),
+            block_mask_dense,
+            None,
+        )

+ 23 - 0
aphrodite/attention/ops/paged_attn.py

@@ -92,7 +92,20 @@ class PagedAttention:
         scale: float,
         scale: float,
         alibi_slopes: Optional[torch.Tensor],
         alibi_slopes: Optional[torch.Tensor],
         kv_scale: float,
         kv_scale: float,
+        tp_rank: int = 0,
+        blocksparse_local_blocks: int = 0,
+        blocksparse_vert_stride: int = 0,
+        blocksparse_block_size: int = 64,
+        blocksparse_head_sliding_step: int = 0,
     ) -> torch.Tensor:
     ) -> torch.Tensor:
+        if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
+            # use blocksparse paged attention
+            block_size = value_cache.size(-1)
+            assert (blocksparse_block_size > 0 and
+                    blocksparse_block_size % block_size == 0), \
+                (f"{blocksparse_block_size=} needs to be a multiple of"
+                 f"{block_size=} used in block_tables.")
+
         output = torch.empty_like(query)
         output = torch.empty_like(query)
 
 
         block_size = value_cache.shape[3]
         block_size = value_cache.shape[3]
@@ -124,6 +137,11 @@ class PagedAttention:
                 alibi_slopes,
                 alibi_slopes,
                 kv_cache_dtype,
                 kv_cache_dtype,
                 kv_scale,
                 kv_scale,
+                tp_rank,
+                blocksparse_local_blocks,
+                blocksparse_vert_stride,
+                blocksparse_block_size,
+                blocksparse_head_sliding_step,
             )
             )
         else:
         else:
             # Run PagedAttention V2.
             # Run PagedAttention V2.
@@ -156,6 +174,11 @@ class PagedAttention:
                 alibi_slopes,
                 alibi_slopes,
                 kv_cache_dtype,
                 kv_cache_dtype,
                 kv_scale,
                 kv_scale,
+                tp_rank,
+                blocksparse_local_blocks,
+                blocksparse_vert_stride,
+                blocksparse_block_size,
+                blocksparse_head_sliding_step,
             )
             )
         return output
         return output
 
 

+ 10 - 3
aphrodite/attention/selector.py

@@ -5,6 +5,7 @@ from typing import Optional, Type
 
 
 import torch
 import torch
 from loguru import logger
 from loguru import logger
+
 from aphrodite.attention.backends.abstract import AttentionBackend
 from aphrodite.attention.backends.abstract import AttentionBackend
 from aphrodite.common.utils import is_cpu, is_hip
 from aphrodite.common.utils import is_cpu, is_hip
 
 
@@ -28,7 +29,14 @@ def get_attn_backend(
     dtype: torch.dtype,
     dtype: torch.dtype,
     kv_cache_dtype: Optional[str],
     kv_cache_dtype: Optional[str],
     block_size: int,
     block_size: int,
+    is_blocksparse: bool = False,
 ) -> Type[AttentionBackend]:
 ) -> Type[AttentionBackend]:
+
+    if is_blocksparse:
+        logger.info("Using BlocksparseFlashAttention backend.")
+        from aphrodite.attention.backends.blocksparse_attn import \
+            BlocksparseFlashAttentionBackend
+        return BlocksparseFlashAttentionBackend
     """Determine which attention backend to use and only import
     """Determine which attention backend to use and only import
     the selected backend module.
     the selected backend module.
     """
     """
@@ -38,7 +46,6 @@ def get_attn_backend(
     if backend == _Backend.FLASH_ATTN:
     if backend == _Backend.FLASH_ATTN:
         from aphrodite.attention.backends.flash_attn import \
         from aphrodite.attention.backends.flash_attn import \
             FlashAttentionBackend  # noqa: F401
             FlashAttentionBackend  # noqa: F401
-        logger.info("Using FlashAttention backend.")
         return FlashAttentionBackend
         return FlashAttentionBackend
     if backend == _Backend.XFORMERS:
     if backend == _Backend.XFORMERS:
         logger.info("Using XFormers backend.")
         logger.info("Using XFormers backend.")
@@ -136,8 +143,8 @@ def which_attn_to_use(
         try:
         try:
             import vllm_flash_attn  # noqa: F401
             import vllm_flash_attn  # noqa: F401
 
 
-            from aphrodite.attention.backends.flash_attn import (  # noqa: F401
-                FlashAttentionBackend)
+            from aphrodite.attention.backends.flash_attn import \
+                FlashAttentionBackend  # noqa: F401
 
 
             supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
             supported_sizes = FlashAttentionBackend.get_supported_head_sizes()
             if head_size not in supported_sizes:
             if head_size not in supported_sizes:

+ 1 - 0
aphrodite/endpoints/openai/serving_engine.py

@@ -122,6 +122,7 @@ class OpenAIServing:
                 token_logprob = step_top_logprobs[token_id].logprob
                 token_logprob = step_top_logprobs[token_id].logprob
                 token = step_top_logprobs[token_id].decoded_token
                 token = step_top_logprobs[token_id].decoded_token
                 logprobs.tokens.append(token)
                 logprobs.tokens.append(token)
+                token_logprob = max(token_logprob, -9999.0)
                 logprobs.token_logprobs.append(token_logprob)
                 logprobs.token_logprobs.append(token_logprob)
 
 
                 if num_output_top_logprobs:
                 if num_output_top_logprobs:

+ 1 - 0
aphrodite/modeling/models/__init__.py

@@ -54,6 +54,7 @@ _GENERATION_MODELS = {
     "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
     "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
     "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
     "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
     "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
     "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
+    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
 }
 }
 
 
 _EMBEDDING_MODELS = {
 _EMBEDDING_MODELS = {

+ 446 - 0
aphrodite/modeling/models/phi3_small.py

@@ -0,0 +1,446 @@
+import math
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers.configuration_utils import PretrainedConfig
+
+from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig, LoRAConfig
+from aphrodite.common.sequence import SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear,
+                                              QKVParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.rotary_embedding import get_rope
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.quantization.base_config import QuantizationConfig
+
+
+def load_column_parallel_weight(param: torch.nn.Parameter,
+                                loaded_weight: torch.Tensor):
+    tp = get_tensor_model_parallel_world_size()
+    rk = get_tensor_model_parallel_rank()
+    assert param.size(0) * tp == loaded_weight.size(0)
+    s = rk * param.size(0)
+    e = (rk + 1) * param.size(0)
+    loaded_weight = loaded_weight[s:e]
+    assert param.shape == loaded_weight.shape
+    param.data.copy_(loaded_weight)
+
+
+class HeadMajorQKVParallelLinear(QKVParallelLinear):
+
+    def weight_loader(self, param: torch.nn.Parameter,
+                      loaded_weight: torch.Tensor):
+        return load_column_parallel_weight(param, loaded_weight)
+
+
+class HeadMajorColumnParallelLinear(MergedColumnParallelLinear):
+
+    def weight_loader(self, param: torch.nn.Parameter,
+                      loaded_weight: torch.Tensor):
+        return load_column_parallel_weight(param, loaded_weight)
+
+
+@torch.jit.script
+def quick_gelu(x):
+    return x * torch.sigmoid(1.702 * x)
+
+
+@torch.jit.script
+def gegelu(input, limit: Optional[float] = None):
+    a_gelu, a_linear = input[..., ::2], input[..., 1::2]
+    if limit is not None:
+        a_gelu = torch.where(torch.isinf(a_gelu), a_gelu,
+                             a_gelu.clamp(min=None, max=limit))
+        a_linear = torch.where(
+            torch.isinf(a_linear),
+            a_linear,
+            a_linear.clamp(min=-limit, max=limit),
+        )
+    out_gelu = quick_gelu(a_gelu)
+    return out_gelu * (a_linear + 1)
+
+
+class Phi3SmallMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        assert (self.config.hidden_act == "gegelu"
+                ), "Only `gegelu` is supported for the 4.7 series of models .."
+        self.hidden_size = config.hidden_size
+        self.gegelu_limit = config.gegelu_limit
+        self.intermediate_size = config.intermediate_size
+
+        self.up_proj = HeadMajorColumnParallelLinear(
+            self.hidden_size,
+            2 * [self.intermediate_size],
+            bias=True,
+            quant_config=quant_config,
+        )
+        self.down_proj = RowParallelLinear(
+            self.intermediate_size,
+            self.hidden_size,
+            bias=True,
+            quant_config=quant_config,
+        )
+
+    def forward(self, x):
+        gate_up, _ = self.up_proj(x)
+        x = gegelu(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class Phi3SmallSelfAttention(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        layer_idx: int,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.layer_idx = layer_idx
+        self.config = config
+        self.sparse_block_size = config.blocksparse_block_size
+        self.homo_heads = config.blocksparse_homo_head_pattern
+        self.local_blocks = config.blocksparse_num_local_blocks
+        self.vert_stride = config.blocksparse_vert_stride
+
+        assert (config.blocksparse_block_size ==
+                config.blocksparse_triton_kernel_block_size)
+
+        self.hidden_size = config.hidden_size
+        # Number of Query Heads
+        self.num_heads = config.num_attention_heads
+
+        self.head_dim = self.hidden_size // self.num_heads
+        self.tp_size = get_tensor_model_parallel_world_size()
+        # Number of total Key Value Heads before tensor parallel
+        self.num_key_value_heads = config.num_key_value_heads
+        self.num_q_per_kv = self.num_heads // self.num_key_value_heads
+        if self.tp_size > 1:
+            assert self.num_key_value_heads % self.tp_size == 0
+        self.num_kv_heads_per_partion = max(
+            1, self.num_key_value_heads // self.tp_size)
+        self.num_heads_per_partition = self.num_heads // self.tp_size
+
+        self.max_position_embeddings = config.max_position_embeddings
+        self.rope_embedding_base = config.rope_embedding_base
+        self.rope_position_scale = config.rope_position_scale
+        self.is_causal = True
+
+        norm_factor = None
+        if config.mup_use_scaling:
+            norm_factor = self.head_dim / config.mup_attn_multiplier
+        else:
+            norm_factor = math.sqrt(self.head_dim)
+        self.scale = 1 / norm_factor
+
+        self.query_key_value = HeadMajorQKVParallelLinear(
+            self.hidden_size,
+            self.head_dim,
+            self.num_heads,
+            self.num_key_value_heads,
+            bias=True,
+            quant_config=quant_config,
+        )
+
+        self.dense = RowParallelLinear(self.hidden_size,
+                                       self.hidden_size,
+                                       bias=True,
+                                       quant_config=quant_config)
+
+        if getattr(self.config, "rope_scaling", None) is not None:
+            rope_scaling = self.config.rope_scaling
+            for key in rope_scaling:
+                if isinstance(rope_scaling[key], list):
+                    rope_scaling[key] = tuple(rope_scaling[key])
+
+            if "factor" not in rope_scaling:
+                rope_scaling["factor"] = self.rope_position_scale
+        else:
+            rope_scaling = {
+                "type": "linear",
+                "factor": self.rope_position_scale,
+            }
+
+        self.rotary_emb = get_rope(
+            self.head_dim,
+            rotary_dim=self.head_dim,
+            max_position=self.max_position_embeddings,
+            base=self.rope_embedding_base,
+            rope_scaling=rope_scaling,
+        )
+
+        # blocksparse params
+        self.blocksparse_block_size = config.blocksparse_block_size
+        self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks
+        self.blocksparse_vert_stride = config.blocksparse_vert_stride
+
+        use_dense_attn = (getattr(self.config,
+                                  "dense_attention_every_n_layers", None)
+                          and (self.layer_idx + 1) %
+                          self.config.dense_attention_every_n_layers == 0)
+
+        bs_params = None
+        if not use_dense_attn:
+            bs_params = {
+                'max_seqlen': self.max_position_embeddings,
+                'num_heads': self.num_heads_per_partition,
+                "num_kv_heads": self.num_kv_heads_per_partion,
+                "block_size": self.sparse_block_size,
+                "local_blocks": self.local_blocks,
+                "vert_stride": self.vert_stride,
+                "homo_head": self.homo_heads
+            }
+
+        self.attn = Attention(
+            self.num_heads_per_partition,
+            self.head_dim,
+            self.scale,
+            num_kv_heads=self.num_kv_heads_per_partion,
+            cache_config=cache_config,
+            quant_config=quant_config,
+            blocksparse_params=bs_params,
+        )
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
+               Optional[Tuple[torch.Tensor]]]:
+        qkv, _ = self.query_key_value(hidden_states)
+
+        qkv = qkv.view(qkv.shape[:-1] +
+                       (-1, (self.num_q_per_kv + 2), self.head_dim))
+        q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2)
+
+        # NOTE: this is required by RotaryEmbed, which indeed does not have to
+        # TODO: allow 3D QK for rotary forward
+        q = q.reshape(-1, self.head_dim * self.num_heads_per_partition)
+        k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
+        v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
+
+        q, k = self.rotary_emb(positions, q, k)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
+        output, _ = self.dense(attn_output)
+
+        return output
+
+
+class Phi3SmallDecoderLayer(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        layer_idx: int,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.hidden_size = config.hidden_size
+        self.self_attn = Phi3SmallSelfAttention(config,
+                                                layer_idx,
+                                                cache_config=cache_config,
+                                                quant_config=quant_config)
+        self.mlp = Phi3SmallMLP(config, quant_config)
+
+        self.input_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.layer_norm_epsilon)
+        self.post_attention_layernorm = nn.LayerNorm(
+            config.hidden_size, eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        positions: torch.Tensor,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        residual = hidden_states
+        hidden_states = self.input_layernorm(hidden_states)
+
+        hidden_states = self.self_attn(
+            positions=positions,
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        hidden_states = residual + hidden_states
+
+        residual = hidden_states
+        hidden_states = self.post_attention_layernorm(hidden_states)
+        hidden_states = self.mlp(hidden_states)
+        hidden_states = residual + hidden_states
+        return hidden_states
+
+
+class Phi3SmallModel(nn.Module):
+
+    def __init__(
+        self,
+        config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
+                                                   config.hidden_size)
+        self.mup_embedding_multiplier = config.mup_embedding_multiplier
+        self.layers = nn.ModuleList([
+            Phi3SmallDecoderLayer(config, layer_idx, cache_config,
+                                  quant_config)
+            for layer_idx in range(config.num_hidden_layers)
+        ])
+
+        self.final_layernorm = nn.LayerNorm(config.hidden_size,
+                                            eps=config.layer_norm_epsilon)
+
+    def get_input_embeddings(self):
+        return self.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.embed_tokens = value
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: Optional[torch.LongTensor],
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata = None,
+    ):
+        hidden_states = self.embed_tokens(input_ids)
+        if (self.mup_embedding_multiplier is not None
+                and self.mup_embedding_multiplier > 0.0):
+            hidden_states = hidden_states * self.mup_embedding_multiplier
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            hidden_states = layer(
+                positions,
+                hidden_states,
+                kv_caches[i],
+                attn_metadata,
+            )
+        hidden_states = self.final_layernorm(hidden_states)
+        return hidden_states
+
+
+class Phi3SmallForCausalLM(nn.Module):
+    _tied_weights_keys = ["lm_head.weight"]
+
+    def __init__(
+        self,
+        config,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ):
+        super().__init__()
+        self.config = config
+        self.quant_config = quant_config
+        self.model = Phi3SmallModel(config, cache_config, quant_config)
+        self.vocab_size = config.vocab_size
+        self.mup_width_multiplier = config.mup_width_multiplier
+        self.lm_head = ParallelLMHead(
+            self.vocab_size,
+            config.hidden_size,
+            org_num_embeddings=config.vocab_size,
+            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
+        )
+        self.logits_processor = LogitsProcessor(config.vocab_size)
+        self.sampler = Sampler()
+
+        # tokens in tiktoken but not used
+        if hasattr(config, 'dummy_token_indices'):
+            device = self.lm_head.weight.device
+            self.register_buffer('dummy_token_indices',
+                                 torch.LongTensor(
+                                     config.dummy_token_indices).to(device),
+                                 persistent=False)
+        else:
+            self.dummy_token_indices = None
+
+    def get_input_embeddings(self):
+        return self.model.embed_tokens
+
+    def set_input_embeddings(self, value):
+        self.model.embed_tokens = value
+
+    def get_output_embeddings(self):
+        return self.lm_head
+
+    def set_output_embeddings(self, value):
+        self.lm_head = value
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head.weight, hidden_states,
+                                       sampling_metadata)
+        if self.dummy_token_indices is not None and logits is not None:
+            logits.index_fill_(-1, self.dummy_token_indices, -torch.inf)
+        return logits
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        positions: Optional[torch.LongTensor],
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        output_hidden_states = self.model(
+            input_ids=input_ids,
+            positions=positions,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata,
+        )
+        output_hidden_states = output_hidden_states
+        return output_hidden_states
+
+    def sample(
+        self,
+        logits: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+
+        next_tokens = self.sampler(logits / self.mup_width_multiplier,
+                                   sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+            if name.endswith(".bias") and name not in params_dict:
+                continue
+            param = params_dict[name]
+            weight_loader = getattr(param, "weight_loader",
+                                    default_weight_loader)
+            weight_loader(param, loaded_weight)
+        self.lm_head.weight.data.copy_(self.model.embed_tokens.weight.data)

+ 170 - 59
kernels/attention/attention_kernels.cu

@@ -86,6 +86,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
 // Grid: (num_heads, num_seqs, max_num_partitions).
 // Grid: (num_heads, num_seqs, max_num_partitions).
 template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
 template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
           int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
           int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
+          bool IS_BLOCK_SPARSE,
           int PARTITION_SIZE = 0>  // Zero means no partitioning.
           int PARTITION_SIZE = 0>  // Zero means no partitioning.
 __device__ void paged_attention_kernel(
 __device__ void paged_attention_kernel(
     float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
     float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
@@ -105,7 +106,9 @@ __device__ void paged_attention_kernel(
     const int max_num_blocks_per_seq,
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
-    const float kv_scale) {
+    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   const int seq_idx = blockIdx.y;
   const int seq_idx = blockIdx.y;
   const int partition_idx = blockIdx.z;
   const int partition_idx = blockIdx.z;
   const int max_num_partitions = gridDim.z;
   const int max_num_partitions = gridDim.z;
@@ -172,8 +175,8 @@ __device__ void paged_attention_kernel(
   // Each thread in a thread group has a different part of the query.
   // Each thread in a thread group has a different part of the query.
   // For example, if the the thread group size is 4, then the first thread in
   // For example, if the the thread group size is 4, then the first thread in
   // the group has 0, 4, 8, ... th vectors of the query, and the second thread
   // the group has 0, 4, 8, ... th vectors of the query, and the second thread
-  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because q is
-  // split from a qkv tensor, it may not be contiguous.
+  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
+  // q is split from a qkv tensor, it may not be contiguous.
   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
   __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
   __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
 #pragma unroll
 #pragma unroll
@@ -183,8 +186,8 @@ __device__ void paged_attention_kernel(
     q_vecs[thread_group_offset][i] =
     q_vecs[thread_group_offset][i] =
         *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
         *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
   }
   }
-  __syncthreads();  // TODO: possible speedup if this is replaced with a memory
-                    // wall right before we use q_vecs
+  __syncthreads();  // TODO: possible speedup if this is replaced with a
+                    // memory wall right before we use q_vecs
 
 
   // Memory planning.
   // Memory planning.
   extern __shared__ char shared_mem[];
   extern __shared__ char shared_mem[];
@@ -203,11 +206,55 @@ __device__ void paged_attention_kernel(
   // Each thread group in a warp fetches a key from the block, and computes
   // Each thread group in a warp fetches a key from the block, and computes
   // dot product with the query.
   // dot product with the query.
   const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
   const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+
+  // blocksparse specific vars
+  int bs_block_offset;
+  int q_bs_block_id;
+  if constexpr (IS_BLOCK_SPARSE) {
+    // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
+    // blocksparse_block_size);
+    q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
+    if (blocksparse_head_sliding_step >= 0)
+      // sliding on q heads
+      bs_block_offset =
+          (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
+    else
+      // sliding on kv heads
+      bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
+                            (-blocksparse_head_sliding_step) +
+                        1;
+  }
+
   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
        block_idx += NUM_WARPS) {
-    // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by
-    // large numbers (e.g., kv_block_stride).
+    // NOTE: The block number is stored in int32. However, we cast it to
+    // int64 because int32 can lead to overflow when this variable is multiplied
+    // by large numbers (e.g., kv_block_stride).
+    // For blocksparse attention: skip computation on blocks that are not
+    // attended
+    if constexpr (IS_BLOCK_SPARSE) {
+      const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
+      const bool is_remote =
+          ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
+      const bool is_local =
+          (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
+      if (!is_remote && !is_local) {
+        for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+          const int physical_block_offset =
+              (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+          const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+
+          if (thread_group_offset == 0) {
+            // NOTE: assign very large number to skipped tokens to
+            // avoid contribution to the sumexp softmax normalizer. This will
+            // not be used at computing sum(softmax*v) as the blocks will be
+            // skipped.
+            logits[token_idx - start_token_idx] = -FLT_MAX;
+          }
+        }
+        continue;
+      }
+    }
     const int64_t physical_block_number =
     const int64_t physical_block_number =
         static_cast<int64_t>(block_table[block_idx]);
         static_cast<int64_t>(block_table[block_idx]);
 
 
@@ -333,9 +380,18 @@ __device__ void paged_attention_kernel(
   zero(zero_value);
   zero(zero_value);
   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
        block_idx += NUM_WARPS) {
        block_idx += NUM_WARPS) {
-    // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by
-    // large numbers (e.g., kv_block_stride).
+    // NOTE: The block number is stored in int32. However, we cast it to
+    // int64 because int32 can lead to overflow when this variable is multiplied
+    // by large numbers (e.g., kv_block_stride).
+    // For blocksparse attention: skip computation on blocks that are not
+    // attended
+    if constexpr (IS_BLOCK_SPARSE) {
+      int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
+      if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
+          !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
+        continue;
+      }
+    }
     const int64_t physical_block_number =
     const int64_t physical_block_number =
         static_cast<int64_t>(block_table[block_idx]);
         static_cast<int64_t>(block_table[block_idx]);
     const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
     const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
@@ -363,9 +419,9 @@ __device__ void paged_attention_kernel(
                                                                     kv_scale);
                                                                     kv_scale);
         }
         }
         if (block_idx == num_seq_blocks - 1) {
         if (block_idx == num_seq_blocks - 1) {
-          // NOTE: When v_vec contains the tokens that are out of the context,
-          // we should explicitly zero out the values since they may contain
-          // NaNs.
+          // NOTE: When v_vec contains the tokens that are out of the
+          // context, we should explicitly zero out the values since they may
+          // contain NaNs.
           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
 #pragma unroll
 #pragma unroll
           for (int j = 0; j < V_VEC_SIZE; j++) {
           for (int j = 0; j < V_VEC_SIZE; j++) {
@@ -388,8 +444,8 @@ __device__ void paged_attention_kernel(
     accs[i] = acc;
     accs[i] = acc;
   }
   }
 
 
-  // NOTE: A barrier is required because the shared memory space for logits
-  // is reused for the output.
+  // NOTE: A barrier is required because the shared memory space for
+  // logits is reused for the output.
   __syncthreads();
   __syncthreads();
 
 
   // Perform reduction across warps.
   // Perform reduction across warps.
@@ -441,8 +497,8 @@ __device__ void paged_attention_kernel(
 
 
 // Grid: (num_heads, num_seqs, 1).
 // Grid: (num_heads, num_seqs, 1).
 template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
 template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
-          int NUM_THREADS,
-          aphrodite::Fp8KVCacheDataType KV_DTYPE>
+          int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
+          bool IS_BLOCK_SPARSE>
 __global__ void paged_attention_v1_kernel(
 __global__ void paged_attention_v1_kernel(
     scalar_t* __restrict__ out,           // [num_seqs, num_heads, head_size]
     scalar_t* __restrict__ out,           // [num_seqs, num_heads, head_size]
     const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
     const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
@@ -457,18 +513,23 @@ __global__ void paged_attention_v1_kernel(
     const int max_num_blocks_per_seq,
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
-    const float kv_scale) {
+    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
-                         KV_DTYPE>(
+                         KV_DTYPE, IS_BLOCK_SPARSE>(
       /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
       /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
       v_cache, num_kv_heads, scale, block_tables, seq_lens,
       v_cache, num_kv_heads, scale, block_tables, seq_lens,
       max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
       max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
-      kv_head_stride, kv_scale);
+      kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
+      blocksparse_vert_stride, blocksparse_block_size,
+      blocksparse_head_sliding_step);
 }
 }
 
 
 // Grid: (num_heads, num_seqs, max_num_partitions).
 // Grid: (num_heads, num_seqs, max_num_partitions).
 template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
 template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
           int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
           int NUM_THREADS, aphrodite::Fp8KVCacheDataType KV_DTYPE,
+          bool IS_BLOCK_SPARSE,
           int PARTITION_SIZE>
           int PARTITION_SIZE>
 __global__ void paged_attention_v2_kernel(
 __global__ void paged_attention_v2_kernel(
     float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
     float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
@@ -488,12 +549,16 @@ __global__ void paged_attention_v2_kernel(
     const int max_num_blocks_per_seq,
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const float* __restrict__ alibi_slopes,  // [num_heads]
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
-    const float kv_scale) {
+    const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
-                         KV_DTYPE, PARTITION_SIZE>(
+                         KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
       exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
       exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
       block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
       block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
-      kv_block_stride, kv_head_stride, kv_scale);
+      kv_block_stride, kv_head_stride, kv_scale, tp_rank,
+      blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
+      blocksparse_head_sliding_step);
 }
 }
 
 
 // Grid: (num_heads, num_seqs).
 // Grid: (num_heads, num_seqs).
@@ -605,27 +670,34 @@ __global__ void paged_attention_v2_reduce_kernel(
 
 
 }  // namespace aphrodite
 }  // namespace aphrodite
 
 
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                \
-  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                \
-      ((void*)aphrodite::paged_attention_v1_kernel<                         \
-          T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE>),       \
-      shared_mem_size);                                                     \
-  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,   \
-                                       NUM_THREADS, KV_DTYPE>               \
-      <<<grid, block, shared_mem_size, stream>>>(                           \
-          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
-          scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,    \
-          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,      \
-          kv_scale);
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                   \
+  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                   \
+      ((void*)aphrodite::paged_attention_v1_kernel<                            \
+          T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE,            \
+          IS_BLOCK_SPARSE>),                                                   \
+      shared_mem_size);                                                        \
+  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,      \
+                                       NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE> \
+      <<<grid, block, shared_mem_size, stream>>>(                              \
+          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads,    \
+          scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,       \
+          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,         \
+          kv_scale, tp_rank, blocksparse_local_blocks,                         \
+          blocksparse_vert_stride, blocksparse_block_size,                     \
+          blocksparse_head_sliding_step);
 
 
 // TODO: Tune NUM_THREADS.
 // TODO: Tune NUM_THREADS.
 template <typename T, typename CACHE_T, int BLOCK_SIZE,
 template <typename T, typename CACHE_T, int BLOCK_SIZE,
-          aphrodite::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128>
+          aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
+          int NUM_THREADS = 128>
 void paged_attention_v1_launcher(
 void paged_attention_v1_launcher(
     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
-    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
+    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
+    const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   int num_seqs = query.size(0);
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int num_heads = query.size(1);
   int head_size = query.size(2);
   int head_size = query.size(2);
@@ -692,23 +764,36 @@ void paged_attention_v1_launcher(
   }
   }
 }
 }
 
 
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE)                   \
-  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>(             \
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE)  \
+  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,              \
+                              IS_BLOCK_SPARSE>(                              \
       out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
       out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
-      seq_lens, max_seq_len, alibi_slopes, kv_scale);
+      seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank,                \
+      blocksparse_local_blocks, blocksparse_vert_stride,                     \
+      blocksparse_block_size, blocksparse_head_sliding_step);
+
+#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+  switch (is_block_sparse) {                                               \
+    case true:                                                             \
+      CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true);     \
+      break;                                                               \
+    case false:                                                            \
+      CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false);    \
+      break;                                                               \
+  }
 
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
 // 1, 2, 4, 64, 128, 256.
 #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
 #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
   switch (block_size) {                                           \
   switch (block_size) {                                           \
     case 8:                                                       \
     case 8:                                                       \
-      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_DTYPE);                  \
+      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE);         \
       break;                                                      \
       break;                                                      \
     case 16:                                                      \
     case 16:                                                      \
-      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_DTYPE);                 \
+      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE);        \
       break;                                                      \
       break;                                                      \
     case 32:                                                      \
     case 32:                                                      \
-      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_DTYPE);                 \
+      CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE);        \
       break;                                                      \
       break;                                                      \
     default:                                                      \
     default:                                                      \
       TORCH_CHECK(false, "Unsupported block size: ", block_size); \
       TORCH_CHECK(false, "Unsupported block size: ", block_size); \
@@ -728,18 +813,26 @@ void paged_attention_v1(
     torch::Tensor& seq_lens,      // [num_seqs]
     torch::Tensor& seq_lens,      // [num_seqs]
     int block_size, int max_seq_len,
     int block_size, int max_seq_len,
     const c10::optional<torch::Tensor>& alibi_slopes,
     const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale){
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
+  const bool is_block_sparse = (blocksparse_vert_stride > 1);
+
+  DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
+                             CALL_V1_LAUNCHER_BLOCK_SIZE)
+}
 
 
-    DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
-                               CALL_V1_LAUNCHER_BLOCK_SIZE)}
 #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                   \
 #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                   \
   aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,      \
   aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,      \
-                                       NUM_THREADS, KV_DTYPE, PARTITION_SIZE>  \
+                                       NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE, \
+                                       PARTITION_SIZE>                         \
       <<<grid, block, shared_mem_size, stream>>>(                              \
       <<<grid, block, shared_mem_size, stream>>>(                              \
           exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
           exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
           value_cache_ptr, num_kv_heads, scale, block_tables_ptr,              \
           value_cache_ptr, num_kv_heads, scale, block_tables_ptr,              \
           seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride,    \
           seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride,    \
-          kv_block_stride, kv_head_stride, kv_scale);                          \
+          kv_block_stride, kv_head_stride, kv_scale, tp_rank,                  \
+          blocksparse_local_blocks, blocksparse_vert_stride,                   \
+          blocksparse_block_size, blocksparse_head_sliding_step);              \
   aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS,       \
   aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS,       \
                                               PARTITION_SIZE>                  \
                                               PARTITION_SIZE>                  \
       <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                \
       <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                \
@@ -747,14 +840,17 @@ void paged_attention_v1(
           max_num_partitions);
           max_num_partitions);
 
 
 template <typename T, typename CACHE_T, int BLOCK_SIZE,
 template <typename T, typename CACHE_T, int BLOCK_SIZE,
-          aphrodite::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS = 128,
-          int PARTITION_SIZE = 512>
+          aphrodite::Fp8KVCacheDataType KV_DTYPE, bool IS_BLOCK_SPARSE,
+          int NUM_THREADS = 128, int PARTITION_SIZE = 512>
 void paged_attention_v2_launcher(
 void paged_attention_v2_launcher(
     torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
     torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
-    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale) {
+    const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
+    const int tp_rank, const int blocksparse_local_blocks,
+    const int blocksparse_vert_stride, const int blocksparse_block_size,
+    const int blocksparse_head_sliding_step) {
   int num_seqs = query.size(0);
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int num_heads = query.size(1);
   int head_size = query.size(2);
   int head_size = query.size(2);
@@ -825,24 +921,36 @@ void paged_attention_v2_launcher(
   }
   }
 }
 }
 
 
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE)                    \
-  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE>(              \
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE)   \
+  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,               \
+                              IS_BLOCK_SPARSE>(                               \
       out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,      \
       out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,      \
       num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
       num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
-      kv_scale);
+      kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,   \
+      blocksparse_block_size, blocksparse_head_sliding_step);
+
+#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+  switch (is_block_sparse) {                                               \
+    case true:                                                             \
+      CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true);     \
+      break;                                                               \
+    case false:                                                            \
+      CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false);    \
+      break;                                                               \
+  }
 
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
 // 1, 2, 4, 64, 128, 256.
 #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
 #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE)         \
   switch (block_size) {                                           \
   switch (block_size) {                                           \
     case 8:                                                       \
     case 8:                                                       \
-      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_DTYPE);                  \
+      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE);         \
       break;                                                      \
       break;                                                      \
     case 16:                                                      \
     case 16:                                                      \
-      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_DTYPE);                 \
+      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE);        \
       break;                                                      \
       break;                                                      \
     case 32:                                                      \
     case 32:                                                      \
-      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_DTYPE);                 \
+      CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE);        \
       break;                                                      \
       break;                                                      \
     default:                                                      \
     default:                                                      \
       TORCH_CHECK(false, "Unsupported block size: ", block_size); \
       TORCH_CHECK(false, "Unsupported block size: ", block_size); \
@@ -866,7 +974,10 @@ void paged_attention_v2(
     torch::Tensor& seq_lens,      // [num_seqs]
     torch::Tensor& seq_lens,      // [num_seqs]
     int block_size, int max_seq_len,
     int block_size, int max_seq_len,
     const c10::optional<torch::Tensor>& alibi_slopes,
     const c10::optional<torch::Tensor>& alibi_slopes,
-    const std::string& kv_cache_dtype, float kv_scale) {
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
+  const bool is_block_sparse = (blocksparse_vert_stride > 1);
   DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
   DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype,
                              CALL_V2_LAUNCHER_BLOCK_SIZE)
                              CALL_V2_LAUNCHER_BLOCK_SIZE)
 }
 }

+ 21 - 16
kernels/cpu/attention.cpp

@@ -415,14 +415,17 @@ void paged_attention_v1_impl_launcher(
   }
   }
 }  // namespace
 }  // namespace
 
 
-void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
-                        torch::Tensor& key_cache, torch::Tensor& value_cache,
-                        int num_kv_heads, float scale,
-                        torch::Tensor& block_tables, torch::Tensor& seq_lens,
-                        int block_size, int max_seq_len,
-                        const c10::optional<torch::Tensor>& alibi_slopes,
-                        const std::string& kv_cache_dtype, float kv_scale) {
+void paged_attention_v1(
+    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
+    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
   TORCH_CHECK(kv_scale == 1.0f);
   TORCH_CHECK(kv_scale == 1.0f);
+  TORCH_CHECK(blocksparse_vert_stride <= 1,
+              "CPU backend does not support blocksparse attention yet.");
   APHRODITE_DISPATCH_FLOATING_TYPES(
   APHRODITE_DISPATCH_FLOATING_TYPES(
       query.scalar_type(), "paged_attention_v1_impl", [&] {
       query.scalar_type(), "paged_attention_v1_impl", [&] {
         CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
         CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
@@ -726,16 +729,18 @@ void paged_attention_v2_impl_launcher(
   }
   }
 }  // namespace
 }  // namespace
 
 
-void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
-                        torch::Tensor& max_logits, torch::Tensor& tmp_out,
-                        torch::Tensor& query, torch::Tensor& key_cache,
-                        torch::Tensor& value_cache, int num_kv_heads,
-                        float scale, torch::Tensor& block_tables,
-                        torch::Tensor& seq_lens, int block_size,
-                        int max_seq_len,
-                        const c10::optional<torch::Tensor>& alibi_slopes,
-                        const std::string& kv_cache_dtype, float kv_scale) {
+void paged_attention_v2(
+    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
+    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
+    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
   TORCH_CHECK(kv_scale == 1.0f);
   TORCH_CHECK(kv_scale == 1.0f);
+  TORCH_CHECK(blocksparse_vert_stride <= 1,
+              "CPU backend does not support blocksparse attention yet.");
   APHRODITE_DISPATCH_FLOATING_TYPES(
   APHRODITE_DISPATCH_FLOATING_TYPES(
       query.scalar_type(), "paged_attention_v2_impl", [&] {
       query.scalar_type(), "paged_attention_v2_impl", [&] {
         CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
         CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)

+ 17 - 16
kernels/ops.h

@@ -2,23 +2,24 @@
 
 
 #include <torch/extension.h>
 #include <torch/extension.h>
 
 
-void paged_attention_v1(torch::Tensor& out, torch::Tensor& query,
-                        torch::Tensor& key_cache, torch::Tensor& value_cache,
-                        int num_kv_heads, float scale,
-                        torch::Tensor& block_tables, torch::Tensor& seq_lens,
-                        int block_size, int max_seq_len,
-                        const c10::optional<torch::Tensor>& alibi_slopes,
-                        const std::string& kv_cache_dtype, float kv_scale);
+void paged_attention_v1(
+    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
+    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step);
 
 
-void paged_attention_v2(torch::Tensor& out, torch::Tensor& exp_sums,
-                        torch::Tensor& max_logits, torch::Tensor& tmp_out,
-                        torch::Tensor& query, torch::Tensor& key_cache,
-                        torch::Tensor& value_cache, int num_kv_heads,
-                        float scale, torch::Tensor& block_tables,
-                        torch::Tensor& seq_lens, int block_size,
-                        int max_seq_len,
-                        const c10::optional<torch::Tensor>& alibi_slopes,
-                        const std::string& kv_cache_dtype, float kv_scale);
+void paged_attention_v2(
+    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
+    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& seq_lens, int block_size,
+    int max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, float kv_scale, const int tp_rank,
+    const int blocksparse_local_blocks, const int blocksparse_vert_stride,
+    const int blocksparse_block_size, const int blocksparse_head_sliding_step);
 
 
 void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
 void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
               float epsilon);
               float epsilon);