|
@@ -15,18 +15,9 @@ import torch.distributed
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
from loguru import logger
|
|
from loguru import logger
|
|
|
|
|
|
-try:
|
|
|
|
- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
|
|
|
- from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
|
|
|
|
- from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
|
|
|
|
- FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
|
|
|
-except ImportError:
|
|
|
|
- BatchDecodeWithPagedKVCacheWrapper = None
|
|
|
|
- CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
|
|
|
|
- BatchPrefillWithPagedKVCacheWrapper = None
|
|
|
|
- FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
|
|
|
-
|
|
|
|
from aphrodite.attention import AttentionMetadata, get_attn_backend
|
|
from aphrodite.attention import AttentionMetadata, get_attn_backend
|
|
|
|
+from aphrodite.attention.backends.abstract import AttentionState
|
|
|
|
+from aphrodite.attention.backends.utils import CommonAttentionState
|
|
from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
|
|
from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
|
|
LoRAConfig, ModelConfig, ParallelConfig,
|
|
LoRAConfig, ModelConfig, ParallelConfig,
|
|
PromptAdapterConfig, SchedulerConfig)
|
|
PromptAdapterConfig, SchedulerConfig)
|
|
@@ -34,8 +25,7 @@ from aphrodite.common.sampling_params import SamplingParams
|
|
from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
|
|
from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
|
|
SequenceGroupMetadata)
|
|
SequenceGroupMetadata)
|
|
from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
|
|
from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
|
|
- async_tensor_h2d, flatten_2d_lists,
|
|
|
|
- get_kv_cache_torch_dtype, is_hip,
|
|
|
|
|
|
+ async_tensor_h2d, flatten_2d_lists, is_hip,
|
|
is_pin_memory_available)
|
|
is_pin_memory_available)
|
|
from aphrodite.distributed import get_pp_group
|
|
from aphrodite.distributed import get_pp_group
|
|
from aphrodite.distributed.parallel_state import (
|
|
from aphrodite.distributed.parallel_state import (
|
|
@@ -67,7 +57,6 @@ from aphrodite.task_handler.model_runner_base import (
|
|
if TYPE_CHECKING:
|
|
if TYPE_CHECKING:
|
|
from aphrodite.attention.backends.abstract import AttentionBackend
|
|
from aphrodite.attention.backends.abstract import AttentionBackend
|
|
|
|
|
|
-_PAD_SLOT_ID = -1
|
|
|
|
LORA_WARMUP_RANK = 8
|
|
LORA_WARMUP_RANK = 8
|
|
_BATCH_SIZE_ALIGNMENT = 8
|
|
_BATCH_SIZE_ALIGNMENT = 8
|
|
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
|
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
|
|
@@ -852,6 +841,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
self.block_size,
|
|
self.block_size,
|
|
self.model_config.is_attention_free(),
|
|
self.model_config.is_attention_free(),
|
|
)
|
|
)
|
|
|
|
+ if self.attn_backend:
|
|
|
|
+ self.attn_state = self.attn_backend.get_state_cls()(
|
|
|
|
+ weakref.proxy(self))
|
|
|
|
+ else:
|
|
|
|
+ self.attn_state = CommonAttentionState(weakref.proxy(self))
|
|
|
|
|
|
# Multi-modal data support
|
|
# Multi-modal data support
|
|
self.input_registry = input_registry
|
|
self.input_registry = input_registry
|
|
@@ -866,11 +860,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
|
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
|
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
|
|
self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
|
|
|
|
|
|
- self.flashinfer_decode_workspace_buffer = None
|
|
|
|
- self.flashinfer_decode_wrapper = None
|
|
|
|
- self.flashinfer_prefill_workspace_buffer = None
|
|
|
|
- self.flashinfer_prefill_wrapper = None
|
|
|
|
-
|
|
|
|
set_cpu_offload_max_bytes(
|
|
set_cpu_offload_max_bytes(
|
|
int(self.cache_config.cpu_offload_gb * 1024**3))
|
|
int(self.cache_config.cpu_offload_gb * 1024**3))
|
|
|
|
|
|
@@ -1227,10 +1216,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
|
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
|
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
|
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
|
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
|
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
|
- slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
|
|
|
|
- slot_mapping.fill_(_PAD_SLOT_ID)
|
|
|
|
- seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
|
|
|
- block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
|
|
|
intermediate_inputs = None
|
|
intermediate_inputs = None
|
|
if not get_pp_group().is_first_rank:
|
|
if not get_pp_group().is_first_rank:
|
|
intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
|
intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
|
@@ -1250,102 +1235,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
|
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
|
]
|
|
]
|
|
|
|
|
|
- if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
- # For flashinfer, different batch sizes will share the
|
|
|
|
- # same workspace buffer.
|
|
|
|
- decode_workspace_buffer = \
|
|
|
|
- torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
|
|
|
- dtype=torch.uint8,
|
|
|
|
- device=self.device)
|
|
|
|
- indices_buffer = torch.empty(max_batch_size *
|
|
|
|
- self.cache_config.num_gpu_blocks,
|
|
|
|
- dtype=torch.int32,
|
|
|
|
- device=self.device)
|
|
|
|
- indptr_buffer = torch.empty(max_batch_size + 1,
|
|
|
|
- dtype=torch.int32,
|
|
|
|
- device=self.device)
|
|
|
|
- last_page_len_buffer = torch.empty(max_batch_size,
|
|
|
|
- dtype=torch.int32,
|
|
|
|
- device=self.device)
|
|
|
|
-
|
|
|
|
- with graph_capture() as graph_capture_context:
|
|
|
|
|
|
+ with self.attn_state.graph_capture(
|
|
|
|
+ max_batch_size), graph_capture() as graph_capture_context:
|
|
|
|
+
|
|
# NOTE: Capturing the largest batch size first may help reduce the
|
|
# NOTE: Capturing the largest batch size first may help reduce the
|
|
# memory usage of CUDA graph.
|
|
# memory usage of CUDA graph.
|
|
for virtual_engine in range(
|
|
for virtual_engine in range(
|
|
self.parallel_config.pipeline_parallel_size):
|
|
self.parallel_config.pipeline_parallel_size):
|
|
for batch_size in reversed(batch_size_capture_list):
|
|
for batch_size in reversed(batch_size_capture_list):
|
|
- if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
- _indptr_buffer = indptr_buffer[:batch_size + 1]
|
|
|
|
- _last_page_len_buffer = last_page_len_buffer[:
|
|
|
|
- batch_size]
|
|
|
|
-
|
|
|
|
- num_qo_heads = (
|
|
|
|
- self.model_config.get_num_attention_heads(
|
|
|
|
- self.parallel_config, self.tp_rank))
|
|
|
|
- num_kv_heads = self.model_config.get_num_kv_heads(
|
|
|
|
- self.parallel_config, self.tp_rank)
|
|
|
|
- if num_qo_heads // num_kv_heads >= 4:
|
|
|
|
- use_tensor_cores = True
|
|
|
|
- else:
|
|
|
|
- use_tensor_cores = False
|
|
|
|
- decode_wrapper = \
|
|
|
|
- CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
|
|
|
- decode_workspace_buffer, _indptr_buffer,
|
|
|
|
- indices_buffer, _last_page_len_buffer, "NHD",
|
|
|
|
- use_tensor_cores)
|
|
|
|
- kv_cache_dtype = get_kv_cache_torch_dtype(
|
|
|
|
- self.kv_cache_dtype, self.model_config.dtype)
|
|
|
|
-
|
|
|
|
- paged_kv_indptr_tensor_host = torch.arange(
|
|
|
|
- 0, batch_size + 1, dtype=torch.int32)
|
|
|
|
- paged_kv_indices_tensor_host = torch.arange(
|
|
|
|
- 0, batch_size, dtype=torch.int32)
|
|
|
|
- paged_kv_last_page_len_tensor_host = torch.full(
|
|
|
|
- (batch_size, ), self.block_size, dtype=torch.int32)
|
|
|
|
- query_start_loc_host = torch.arange(0,
|
|
|
|
- batch_size + 1,
|
|
|
|
- dtype=torch.int32)
|
|
|
|
-
|
|
|
|
- attn_metadata = self.attn_backend.make_metadata(
|
|
|
|
- num_prefills=0,
|
|
|
|
- slot_mapping=slot_mapping[:batch_size],
|
|
|
|
- num_prefill_tokens=0,
|
|
|
|
- num_decode_tokens=batch_size,
|
|
|
|
- max_prefill_seq_len=0,
|
|
|
|
- block_tables=block_tables,
|
|
|
|
- paged_kv_indptr=paged_kv_indptr_tensor_host,
|
|
|
|
- paged_kv_indices=paged_kv_indices_tensor_host,
|
|
|
|
- paged_kv_last_page_len=
|
|
|
|
- paged_kv_last_page_len_tensor_host,
|
|
|
|
- num_qo_heads=num_qo_heads,
|
|
|
|
- num_kv_heads=num_kv_heads,
|
|
|
|
- head_dim=self.model_config.get_head_size(),
|
|
|
|
- page_size=self.block_size,
|
|
|
|
- seq_start_loc=None,
|
|
|
|
- query_start_loc=query_start_loc_host,
|
|
|
|
- device=self.device,
|
|
|
|
- data_type=kv_cache_dtype,
|
|
|
|
- use_cuda_graph=True,
|
|
|
|
- decode_wrapper=decode_wrapper,
|
|
|
|
- prefill_wrapper=None)
|
|
|
|
- attn_metadata.begin_forward()
|
|
|
|
- else:
|
|
|
|
- attn_metadata = self.attn_backend.make_metadata(
|
|
|
|
- num_prefills=0,
|
|
|
|
- num_prefill_tokens=0,
|
|
|
|
- num_decode_tokens=batch_size,
|
|
|
|
- slot_mapping=slot_mapping[:batch_size],
|
|
|
|
- seq_lens=None,
|
|
|
|
- seq_lens_tensor=seq_lens[:batch_size],
|
|
|
|
- max_query_len=None,
|
|
|
|
- max_prefill_seq_len=0,
|
|
|
|
- max_decode_seq_len=self.max_seq_len_to_capture,
|
|
|
|
- query_start_loc=None,
|
|
|
|
- seq_start_loc=None,
|
|
|
|
- context_lens_tensor=None,
|
|
|
|
- block_tables=block_tables[:batch_size],
|
|
|
|
- use_cuda_graph=True,
|
|
|
|
- )
|
|
|
|
|
|
+ attn_metadata = (
|
|
|
|
+ self.attn_state.graph_capture_get_metadata_for_batch(
|
|
|
|
+ batch_size))
|
|
|
|
+
|
|
|
|
|
|
if self.lora_config:
|
|
if self.lora_config:
|
|
lora_mapping = LoRAMapping(
|
|
lora_mapping = LoRAMapping(
|
|
@@ -1363,17 +1264,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
set(), prompt_adapter_mapping)
|
|
set(), prompt_adapter_mapping)
|
|
|
|
|
|
graph_runner = CUDAGraphRunner(
|
|
graph_runner = CUDAGraphRunner(
|
|
- self.model, self.attn_backend.get_name())
|
|
|
|
-
|
|
|
|
- if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
- graph_runner.flashinfer_indptr_buffer = _indptr_buffer
|
|
|
|
- graph_runner.flashinfer_indices_buffer = indices_buffer
|
|
|
|
- graph_runner.flashinfer_last_page_len_buffer = \
|
|
|
|
- _last_page_len_buffer
|
|
|
|
- graph_runner.flashinfer_decode_workspace_buffer = \
|
|
|
|
- decode_workspace_buffer
|
|
|
|
- graph_runner.flashinfer_decode_wrapper = \
|
|
|
|
- decode_wrapper
|
|
|
|
|
|
+ self.model, self.attn_backend.get_name(),
|
|
|
|
+ self.attn_state.graph_clone(batch_size))
|
|
|
|
|
|
capture_inputs = {
|
|
capture_inputs = {
|
|
"input_ids":
|
|
"input_ids":
|
|
@@ -1501,36 +1393,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|
model_input.prompt_adapter_requests,
|
|
model_input.prompt_adapter_requests,
|
|
model_input.prompt_adapter_mapping)
|
|
model_input.prompt_adapter_mapping)
|
|
|
|
|
|
- if self.attn_backend.get_name() == "flashinfer":
|
|
|
|
- assert model_input.attn_metadata is not None
|
|
|
|
- assert model_input.input_tokens is not None
|
|
|
|
- if self.flashinfer_decode_workspace_buffer is None:
|
|
|
|
- self.flashinfer_decode_workspace_buffer = torch.empty(
|
|
|
|
- FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
|
|
|
- dtype=torch.uint8,
|
|
|
|
- device=self.device)
|
|
|
|
- self.flashinfer_decode_wrapper = \
|
|
|
|
- BatchDecodeWithPagedKVCacheWrapper(
|
|
|
|
- self.flashinfer_decode_workspace_buffer, "NHD")
|
|
|
|
- self.flashinfer_prefill_workspace_buffer = torch.empty(
|
|
|
|
- FLASHINFER_WORKSPACE_BUFFER_SIZE,
|
|
|
|
- dtype=torch.uint8,
|
|
|
|
- device=self.device)
|
|
|
|
- self.flashinfer_prefill_wrapper = \
|
|
|
|
- BatchPrefillWithPagedKVCacheWrapper(
|
|
|
|
- self.flashinfer_prefill_workspace_buffer, "NHD")
|
|
|
|
-
|
|
|
|
- model_input.attn_metadata.prefill_wrapper = \
|
|
|
|
- self.flashinfer_prefill_wrapper
|
|
|
|
- if model_input.attn_metadata.use_cuda_graph:
|
|
|
|
- batch_size = model_input.input_tokens.shape[0]
|
|
|
|
- model_input.attn_metadata.decode_wrapper = self.graph_runners[
|
|
|
|
- model_input.
|
|
|
|
- virtual_engine][batch_size].flashinfer_decode_wrapper
|
|
|
|
- else:
|
|
|
|
- model_input.attn_metadata.decode_wrapper = \
|
|
|
|
- self.flashinfer_decode_wrapper
|
|
|
|
- model_input.attn_metadata.begin_forward()
|
|
|
|
|
|
+ self.attn_state.begin_forward(model_input)
|
|
|
|
|
|
# Currently cuda graph is only supported by the decode phase.
|
|
# Currently cuda graph is only supported by the decode phase.
|
|
assert model_input.attn_metadata is not None
|
|
assert model_input.attn_metadata is not None
|
|
@@ -1598,22 +1461,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|
|
|
|
|
class CUDAGraphRunner:
|
|
class CUDAGraphRunner:
|
|
|
|
|
|
- def __init__(self, model: nn.Module, backend_name: str):
|
|
|
|
|
|
+ def __init__(self, model: nn.Module, backend_name: str,
|
|
|
|
+ attn_state: AttentionState):
|
|
self.model = model
|
|
self.model = model
|
|
self.backend_name = backend_name
|
|
self.backend_name = backend_name
|
|
|
|
+ self.attn_state = attn_state
|
|
|
|
|
|
self.input_buffers: Dict[str, torch.Tensor] = {}
|
|
self.input_buffers: Dict[str, torch.Tensor] = {}
|
|
self.output_buffers: Dict[str, torch.Tensor] = {}
|
|
self.output_buffers: Dict[str, torch.Tensor] = {}
|
|
|
|
|
|
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
|
self._graph: Optional[torch.cuda.CUDAGraph] = None
|
|
|
|
|
|
- self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
|
|
|
|
- self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
|
|
|
|
- self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
|
|
|
|
- self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
|
|
|
|
- self.flashinfer_decode_wrapper: Optional[
|
|
|
|
- CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
|
|
|
|
-
|
|
|
|
@property
|
|
@property
|
|
def graph(self):
|
|
def graph(self):
|
|
assert self._graph is not None
|
|
assert self._graph is not None
|
|
@@ -1678,25 +1536,13 @@ class CUDAGraphRunner:
|
|
torch.cuda.synchronize()
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
# Save the input and output buffers.
|
|
# Save the input and output buffers.
|
|
- if self.backend_name == "flashinfer":
|
|
|
|
- self.input_buffers = {
|
|
|
|
- "input_ids": input_ids,
|
|
|
|
- "positions": positions,
|
|
|
|
- "kv_caches": kv_caches,
|
|
|
|
- "slot_mapping": attn_metadata.slot_mapping,
|
|
|
|
- **kwargs,
|
|
|
|
- }
|
|
|
|
- else:
|
|
|
|
- self.input_buffers = {
|
|
|
|
- "input_ids": input_ids,
|
|
|
|
- "positions": positions,
|
|
|
|
- "kv_caches": kv_caches,
|
|
|
|
- "slot_mapping": attn_metadata.slot_mapping,
|
|
|
|
- "seq_lens_tensor":
|
|
|
|
- attn_metadata.decode_metadata.seq_lens_tensor,
|
|
|
|
- "block_tables": attn_metadata.decode_metadata.block_tables,
|
|
|
|
- **kwargs,
|
|
|
|
- }
|
|
|
|
|
|
+ self.input_buffers = {
|
|
|
|
+ "input_ids": input_ids,
|
|
|
|
+ "positions": positions,
|
|
|
|
+ "kv_caches": kv_caches,
|
|
|
|
+ **self.attn_state.get_graph_input_buffers(attn_metadata),
|
|
|
|
+ **kwargs,
|
|
|
|
+ }
|
|
if intermediate_inputs is not None:
|
|
if intermediate_inputs is not None:
|
|
self.input_buffers.update(intermediate_inputs.tensors)
|
|
self.input_buffers.update(intermediate_inputs.tensors)
|
|
if get_pp_group().is_last_rank:
|
|
if get_pp_group().is_last_rank:
|
|
@@ -1725,12 +1571,8 @@ class CUDAGraphRunner:
|
|
if self.backend_name != "No attention":
|
|
if self.backend_name != "No attention":
|
|
self.input_buffers["slot_mapping"].copy_(
|
|
self.input_buffers["slot_mapping"].copy_(
|
|
attn_metadata.slot_mapping, non_blocking=True)
|
|
attn_metadata.slot_mapping, non_blocking=True)
|
|
- if self.backend_name != "flashinfer":
|
|
|
|
- self.input_buffers["seq_lens_tensor"].copy_(
|
|
|
|
- attn_metadata.decode_metadata.seq_lens_tensor,
|
|
|
|
- non_blocking=True)
|
|
|
|
- self.input_buffers["block_tables"].copy_(
|
|
|
|
- attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
|
|
|
|
|
+ self.attn_state.prepare_graph_input_buffers(self.input_buffers,
|
|
|
|
+ attn_metadata)
|
|
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
|
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
|
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
|
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
|
**kwargs)
|
|
**kwargs)
|