|
@@ -8,6 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
+import torch.distributed
|
|
|
import torch.nn as nn
|
|
|
from loguru import logger
|
|
|
|
|
@@ -27,11 +28,13 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
|
|
|
LoRAConfig, ModelConfig, ParallelConfig,
|
|
|
SchedulerConfig, VisionLanguageConfig)
|
|
|
from aphrodite.common.sampling_params import SamplingParams
|
|
|
-from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
|
|
|
+from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
|
|
|
+ SequenceGroupMetadata)
|
|
|
from aphrodite.common.utils import (CudaMemoryProfiler,
|
|
|
get_kv_cache_torch_dtype, is_hip,
|
|
|
is_pin_memory_available,
|
|
|
make_tensor_with_pad)
|
|
|
+from aphrodite.distributed import get_pp_group
|
|
|
from aphrodite.distributed.parallel_state import (
|
|
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
|
|
graph_capture)
|
|
@@ -83,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|
|
lora_requests: Optional[Set[LoRARequest]] = None
|
|
|
attn_metadata: Optional["AttentionMetadata"] = None
|
|
|
multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None
|
|
|
+ virtual_engine: int = 0
|
|
|
|
|
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
|
|
tensor_dict = {
|
|
@@ -91,6 +95,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|
|
"lora_requests": self.lora_requests,
|
|
|
"lora_mapping": self.lora_mapping,
|
|
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
|
|
+ "virtual_engine": self.virtual_engine,
|
|
|
}
|
|
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
|
|
return tensor_dict
|
|
@@ -124,6 +129,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
|
|
"lora_requests": self.lora_requests,
|
|
|
"lora_mapping": self.lora_mapping,
|
|
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
|
|
+ "virtual_engine": self.virtual_engine,
|
|
|
}
|
|
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
|
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
|
@@ -181,7 +187,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
|
self.sliding_window = model_config.get_sliding_window()
|
|
|
self.block_size = cache_config.block_size
|
|
|
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
|
|
|
- self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
|
|
+
|
|
|
+ self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
|
|
|
+ {} for _ in range(self.parallel_config.pipeline_parallel_size)
|
|
|
+ ]
|
|
|
self.graph_memory_pool: Optional[Tuple[
|
|
|
int, int]] = None # Set during graph capture.
|
|
|
# When using CUDA graph, the input block tables must be padded to
|
|
@@ -806,9 +815,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
|
max_num_seqs = min(
|
|
|
max_num_seqs,
|
|
|
int(max_num_batched_tokens / vlm_config.image_feature_size))
|
|
|
+ batch_size = 0
|
|
|
for group_id in range(max_num_seqs):
|
|
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
|
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
|
|
+ batch_size += seq_len
|
|
|
|
|
|
seq_data, dummy_multi_modal_data = INPUT_REGISTRY \
|
|
|
.dummy_data_for_profiling(model_config, seq_len)
|
|
@@ -830,7 +841,13 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
|
|
kv_caches = [None] * num_layers
|
|
|
model_input = self.prepare_model_input(seqs)
|
|
|
- self.execute_model(model_input, kv_caches)
|
|
|
+ intermediate_tensors = None
|
|
|
+ if not get_pp_group().is_first_rank:
|
|
|
+ intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
|
|
+ batch_size=batch_size,
|
|
|
+ dtype=self.model_config.dtype,
|
|
|
+ device=self.device)
|
|
|
+ self.execute_model(model_input, kv_caches, intermediate_tensors)
|
|
|
torch.cuda.synchronize()
|
|
|
return
|
|
|
|
|
@@ -866,7 +883,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
|
return self.lora_manager.list_loras()
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
- def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
|
|
|
+ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
|
|
"""Cuda graph capture a model.
|
|
|
|
|
|
Note that CUDA graph's performance gain is negligible if number
|
|
@@ -899,10 +916,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
|
slot_mapping.fill_(_PAD_SLOT_ID)
|
|
|
seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
|
|
|
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
|
|
|
+ intermediate_inputs = None
|
|
|
+ if not get_pp_group().is_first_rank:
|
|
|
+ intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
|
|
+ batch_size=max_batch_size,
|
|
|
+ dtype=self.model_config.dtype,
|
|
|
+ device=self.device)
|
|
|
|
|
|
# Prepare buffer for outputs. These will be reused for all batch sizes.
|
|
|
# It will be filled after the first graph capture.
|
|
|
- hidden_states: Optional[torch.Tensor] = None
|
|
|
+ hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [
|
|
|
+ None
|
|
|
+ ] * self.parallel_config.pipeline_parallel_size
|
|
|
|
|
|
graph_batch_size = _get_graph_batch_size(
|
|
|
self.scheduler_config.max_num_seqs)
|
|
@@ -931,109 +956,120 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|
|
with graph_capture() as graph_capture_context:
|
|
|
# NOTE: Capturing the largest batch size first may help reduce the
|
|
|
# memory usage of CUDA graph.
|
|
|
- for batch_size in reversed(batch_size_capture_list):
|
|
|
- if self.attn_backend.get_name() == "flashinfer":
|
|
|
- indptr_buffer = indptr_buffer[:batch_size + 1]
|
|
|
- last_page_len_buffer = last_page_len_buffer[:batch_size]
|
|
|
-
|
|
|
- num_qo_heads = self.model_config.get_num_attention_heads(
|
|
|
- self.parallel_config)
|
|
|
- num_kv_heads = self.model_config.get_num_kv_heads(
|
|
|
- self.parallel_config)
|
|
|
- if num_qo_heads // num_kv_heads >= 4:
|
|
|
- use_tensor_cores = True
|
|
|
+ for virtual_engine in range(
|
|
|
+ self.parallel_config.pipeline_parallel_size):
|
|
|
+ for batch_size in reversed(batch_size_capture_list):
|
|
|
+ if self.attn_backend.get_name() == "flashinfer":
|
|
|
+ indptr_buffer = indptr_buffer[:batch_size + 1]
|
|
|
+ last_page_len_buffer = last_page_len_buffer[:
|
|
|
+ batch_size]
|
|
|
+
|
|
|
+ num_qo_heads = (
|
|
|
+ self.model_config.get_num_attention_heads(
|
|
|
+ self.parallel_config))
|
|
|
+ num_kv_heads = self.model_config.get_num_kv_heads(
|
|
|
+ self.parallel_config)
|
|
|
+ if num_qo_heads // num_kv_heads >= 4:
|
|
|
+ use_tensor_cores = True
|
|
|
+ else:
|
|
|
+ use_tensor_cores = False
|
|
|
+ decode_wrapper = \
|
|
|
+ CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
|
|
+ decode_workspace_buffer, indptr_buffer,
|
|
|
+ indices_buffer, last_page_len_buffer, "NHD",
|
|
|
+ use_tensor_cores)
|
|
|
+ kv_cache_dtype = get_kv_cache_torch_dtype(
|
|
|
+ self.kv_cache_dtype, self.model_config.dtype)
|
|
|
+
|
|
|
+ paged_kv_indptr_tensor_host = torch.arange(
|
|
|
+ 0, batch_size + 1, dtype=torch.int32)
|
|
|
+ paged_kv_indices_tensor_host = torch.arange(
|
|
|
+ 0, batch_size, dtype=torch.int32)
|
|
|
+ paged_kv_last_page_len_tensor_host = torch.full(
|
|
|
+ (batch_size, ), self.block_size, dtype=torch.int32)
|
|
|
+ query_start_loc_host = torch.arange(0,
|
|
|
+ batch_size + 1,
|
|
|
+ dtype=torch.int32)
|
|
|
+
|
|
|
+ attn_metadata = self.attn_backend.make_metadata(
|
|
|
+ num_prefills=0,
|
|
|
+ slot_mapping=slot_mapping[:batch_size],
|
|
|
+ num_prefill_tokens=0,
|
|
|
+ num_decode_tokens=batch_size,
|
|
|
+ max_prefill_seq_len=0,
|
|
|
+ block_tables=block_tables,
|
|
|
+ paged_kv_indptr=paged_kv_indptr_tensor_host,
|
|
|
+ paged_kv_indices=paged_kv_indices_tensor_host,
|
|
|
+ paged_kv_last_page_len=
|
|
|
+ paged_kv_last_page_len_tensor_host,
|
|
|
+ num_qo_heads=num_qo_heads,
|
|
|
+ num_kv_heads=num_kv_heads,
|
|
|
+ head_dim=self.model_config.get_head_size(),
|
|
|
+ page_size=self.block_size,
|
|
|
+ seq_start_loc=None,
|
|
|
+ query_start_loc=query_start_loc_host,
|
|
|
+ device=self.device,
|
|
|
+ data_type=kv_cache_dtype,
|
|
|
+ use_cuda_graph=True,
|
|
|
+ decode_wrapper=decode_wrapper,
|
|
|
+ prefill_wrapper=None)
|
|
|
+ attn_metadata.begin_forward()
|
|
|
else:
|
|
|
- use_tensor_cores = False
|
|
|
- decode_wrapper = \
|
|
|
- CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
|
|
|
- decode_workspace_buffer, indptr_buffer, indices_buffer,
|
|
|
- last_page_len_buffer, "NHD", use_tensor_cores)
|
|
|
- kv_cache_dtype = get_kv_cache_torch_dtype(
|
|
|
- self.kv_cache_dtype, self.model_config.dtype)
|
|
|
-
|
|
|
- paged_kv_indptr_tensor_host = torch.arange(
|
|
|
- 0, batch_size + 1, dtype=torch.int32)
|
|
|
- paged_kv_indices_tensor_host = torch.arange(
|
|
|
- 0, batch_size, dtype=torch.int32)
|
|
|
- paged_kv_last_page_len_tensor_host = torch.full(
|
|
|
- (batch_size, ), self.block_size, dtype=torch.int32)
|
|
|
- query_start_loc_host = torch.arange(0,
|
|
|
- batch_size + 1,
|
|
|
- dtype=torch.int32)
|
|
|
-
|
|
|
- attn_metadata = self.attn_backend.make_metadata(
|
|
|
- num_prefills=0,
|
|
|
- slot_mapping=slot_mapping[:batch_size],
|
|
|
- num_prefill_tokens=0,
|
|
|
- num_decode_tokens=batch_size,
|
|
|
- max_prefill_seq_len=0,
|
|
|
- block_tables=block_tables,
|
|
|
- paged_kv_indptr=paged_kv_indptr_tensor_host,
|
|
|
- paged_kv_indices=paged_kv_indices_tensor_host,
|
|
|
- paged_kv_last_page_len=
|
|
|
- paged_kv_last_page_len_tensor_host,
|
|
|
- num_qo_heads=num_qo_heads,
|
|
|
- num_kv_heads=num_kv_heads,
|
|
|
- head_dim=self.model_config.get_head_size(),
|
|
|
- page_size=self.block_size,
|
|
|
- seq_start_loc=None,
|
|
|
- query_start_loc=query_start_loc_host,
|
|
|
- device=self.device,
|
|
|
- data_type=kv_cache_dtype,
|
|
|
- use_cuda_graph=True,
|
|
|
- decode_wrapper=decode_wrapper,
|
|
|
- prefill_wrapper=None)
|
|
|
- attn_metadata.begin_forward()
|
|
|
- else:
|
|
|
- attn_metadata = self.attn_backend.make_metadata(
|
|
|
- num_prefills=0,
|
|
|
- num_prefill_tokens=0,
|
|
|
- num_decode_tokens=batch_size,
|
|
|
- slot_mapping=slot_mapping[:batch_size],
|
|
|
- seq_lens=None,
|
|
|
- seq_lens_tensor=seq_lens[:batch_size],
|
|
|
- max_query_len=None,
|
|
|
- max_prefill_seq_len=0,
|
|
|
- max_decode_seq_len=self.max_seq_len_to_capture,
|
|
|
- query_start_loc=None,
|
|
|
- seq_start_loc=None,
|
|
|
- context_lens_tensor=None,
|
|
|
- block_tables=block_tables[:batch_size],
|
|
|
- use_cuda_graph=True,
|
|
|
+ attn_metadata = self.attn_backend.make_metadata(
|
|
|
+ num_prefills=0,
|
|
|
+ num_prefill_tokens=0,
|
|
|
+ num_decode_tokens=batch_size,
|
|
|
+ slot_mapping=slot_mapping[:batch_size],
|
|
|
+ seq_lens=None,
|
|
|
+ seq_lens_tensor=seq_lens[:batch_size],
|
|
|
+ max_query_len=None,
|
|
|
+ max_prefill_seq_len=0,
|
|
|
+ max_decode_seq_len=self.max_seq_len_to_capture,
|
|
|
+ query_start_loc=None,
|
|
|
+ seq_start_loc=None,
|
|
|
+ context_lens_tensor=None,
|
|
|
+ block_tables=block_tables[:batch_size],
|
|
|
+ use_cuda_graph=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ if self.lora_config:
|
|
|
+ lora_mapping = LoRAMapping(
|
|
|
+ [0] * batch_size,
|
|
|
+ [0] * batch_size,
|
|
|
+ )
|
|
|
+ self.set_active_loras(set(), lora_mapping)
|
|
|
+
|
|
|
+ graph_runner = CUDAGraphRunner(
|
|
|
+ self.model, self.attn_backend.get_name())
|
|
|
+
|
|
|
+ if self.attn_backend.get_name() == "flashinfer":
|
|
|
+ graph_runner.flashinfer_indptr_buffer = indptr_buffer
|
|
|
+ graph_runner.flashinfer_indices_buffer = indices_buffer
|
|
|
+ graph_runner.flashinfer_last_page_len_buffer = \
|
|
|
+ last_page_len_buffer
|
|
|
+ graph_runner.flashinfer_decode_workspace_buffer = \
|
|
|
+ decode_workspace_buffer
|
|
|
+ graph_runner.flashinfer_decode_wrapper = \
|
|
|
+ decode_wrapper
|
|
|
+
|
|
|
+ graph_runner.capture(
|
|
|
+ input_tokens[:batch_size],
|
|
|
+ input_positions[:batch_size],
|
|
|
+ hidden_or_intermediate_states[
|
|
|
+ virtual_engine] # type: ignore
|
|
|
+ [:batch_size]
|
|
|
+ if hidden_or_intermediate_states[virtual_engine]
|
|
|
+ is not None else None,
|
|
|
+ intermediate_inputs[:batch_size]
|
|
|
+ if intermediate_inputs is not None else None,
|
|
|
+ kv_caches[virtual_engine],
|
|
|
+ attn_metadata,
|
|
|
+ memory_pool=self.graph_memory_pool,
|
|
|
+ stream=graph_capture_context.stream,
|
|
|
)
|
|
|
-
|
|
|
- if self.lora_config:
|
|
|
- lora_mapping = LoRAMapping(
|
|
|
- [0] * batch_size,
|
|
|
- [0] * batch_size,
|
|
|
- )
|
|
|
- self.set_active_loras(set(), lora_mapping)
|
|
|
-
|
|
|
- graph_runner = CUDAGraphRunner(self.model,
|
|
|
- self.attn_backend.get_name())
|
|
|
-
|
|
|
- if self.attn_backend.get_name() == "flashinfer":
|
|
|
- graph_runner.flashinfer_indptr_buffer = indptr_buffer
|
|
|
- graph_runner.flashinfer_indices_buffer = indices_buffer
|
|
|
- graph_runner.flashinfer_last_page_len_buffer = \
|
|
|
- last_page_len_buffer
|
|
|
- graph_runner.flashinfer_decode_workspace_buffer = \
|
|
|
- decode_workspace_buffer
|
|
|
- graph_runner.flashinfer_decode_wrapper = \
|
|
|
- decode_wrapper
|
|
|
-
|
|
|
- graph_runner.capture(
|
|
|
- input_tokens[:batch_size],
|
|
|
- input_positions[:batch_size],
|
|
|
- hidden_states[:batch_size]
|
|
|
- if hidden_states is not None else None,
|
|
|
- kv_caches,
|
|
|
- attn_metadata,
|
|
|
- memory_pool=self.graph_memory_pool,
|
|
|
- stream=graph_capture_context.stream,
|
|
|
- )
|
|
|
- self.graph_memory_pool = graph_runner.graph.pool()
|
|
|
- self.graph_runners[batch_size] = graph_runner
|
|
|
+ self.graph_memory_pool = graph_runner.graph.pool()
|
|
|
+ self.graph_runners[virtual_engine][batch_size] = (
|
|
|
+ graph_runner)
|
|
|
|
|
|
end_time = time.perf_counter()
|
|
|
elapsed_time = end_time - start_time
|
|
@@ -1066,6 +1102,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|
|
def prepare_model_input(
|
|
|
self,
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
|
|
+ virtual_engine: int = 0,
|
|
|
) -> ModelInputForGPUWithSamplingMetadata:
|
|
|
"""Prepare the model input based on a given sequence group, including
|
|
|
metadata for the sampling step.
|
|
@@ -1091,15 +1128,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|
|
if seq_group_metadata_list else None)
|
|
|
return dataclasses.replace(model_input,
|
|
|
sampling_metadata=sampling_metadata,
|
|
|
- is_prompt=is_prompt)
|
|
|
+ is_prompt=is_prompt,
|
|
|
+ virtual_engine=virtual_engine)
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def execute_model(
|
|
|
self,
|
|
|
model_input: ModelInputForGPUWithSamplingMetadata,
|
|
|
kv_caches: List[torch.Tensor],
|
|
|
+ intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
|
num_steps: int = 1,
|
|
|
- ) -> Optional[List[SamplerOutput]]:
|
|
|
+ ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
|
|
|
if num_steps > 1:
|
|
|
raise ValueError("num_steps > 1 is not supported in ModelRunner")
|
|
|
|
|
@@ -1143,27 +1182,34 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|
|
assert model_input.attn_metadata is not None
|
|
|
prefill_meta = model_input.attn_metadata.prefill_metadata
|
|
|
decode_meta = model_input.attn_metadata.decode_metadata
|
|
|
+ # TODO: We can remove this once all
|
|
|
+ # virtual engines share the same kv cache.
|
|
|
+ virtual_engine = model_input.virtual_engine
|
|
|
if prefill_meta is None and decode_meta.use_cuda_graph:
|
|
|
assert model_input.input_tokens is not None
|
|
|
graph_batch_size = model_input.input_tokens.shape[0]
|
|
|
- model_executable = self.graph_runners[graph_batch_size]
|
|
|
+ model_executable = self.graph_runners[virtual_engine][
|
|
|
+ graph_batch_size]
|
|
|
else:
|
|
|
model_executable = self.model
|
|
|
|
|
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
|
|
- hidden_states = model_executable(
|
|
|
+ hidden_or_intermediate_states = model_executable(
|
|
|
input_ids=model_input.input_tokens,
|
|
|
positions=model_input.input_positions,
|
|
|
kv_caches=kv_caches,
|
|
|
attn_metadata=model_input.attn_metadata,
|
|
|
+ intermediate_tensors=intermediate_tensors,
|
|
|
**multi_modal_kwargs,
|
|
|
)
|
|
|
|
|
|
- # Compute the logits.
|
|
|
- logits = self.model.compute_logits(hidden_states,
|
|
|
+ # Compute the logits in the last pipeline stage.
|
|
|
+ if not get_pp_group().is_last_rank:
|
|
|
+ return hidden_or_intermediate_states
|
|
|
+
|
|
|
+ logits = self.model.compute_logits(hidden_or_intermediate_states,
|
|
|
model_input.sampling_metadata)
|
|
|
|
|
|
- # Only perform sampling in the driver worker.
|
|
|
if not self.is_driver_worker:
|
|
|
return []
|
|
|
|
|
@@ -1178,9 +1224,12 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|
|
assert model_input.sampling_metadata is not None
|
|
|
indices = model_input.sampling_metadata.selected_token_indices
|
|
|
if model_input.is_prompt:
|
|
|
- hidden_states = hidden_states.index_select(0, indices)
|
|
|
+ hidden_states = hidden_or_intermediate_states.index_select(
|
|
|
+ 0, indices)
|
|
|
elif decode_meta.use_cuda_graph:
|
|
|
- hidden_states = hidden_states[:len(indices)]
|
|
|
+ hidden_states = hidden_or_intermediate_states[:len(indices)]
|
|
|
+ else:
|
|
|
+ hidden_states = hidden_or_intermediate_states
|
|
|
|
|
|
output.hidden_states = hidden_states
|
|
|
|
|
@@ -1214,13 +1263,15 @@ class CUDAGraphRunner:
|
|
|
self,
|
|
|
input_ids: torch.Tensor,
|
|
|
positions: torch.Tensor,
|
|
|
- hidden_states: Optional[torch.Tensor],
|
|
|
+ hidden_or_intermediate_states: Optional[Union[IntermediateTensors,
|
|
|
+ torch.Tensor]],
|
|
|
+ intermediate_inputs: Optional[IntermediateTensors],
|
|
|
kv_caches: List[torch.Tensor],
|
|
|
attn_metadata: AttentionMetadata,
|
|
|
memory_pool: Optional[Tuple[int, int]],
|
|
|
stream: torch.cuda.Stream,
|
|
|
**kwargs,
|
|
|
- ) -> torch.Tensor:
|
|
|
+ ) -> Union[torch.Tensor, IntermediateTensors]:
|
|
|
assert self._graph is None
|
|
|
# Run the model a few times without capturing the graph.
|
|
|
# This is to make sure that the captured graph does not include the
|
|
@@ -1232,6 +1283,7 @@ class CUDAGraphRunner:
|
|
|
positions,
|
|
|
kv_caches,
|
|
|
attn_metadata,
|
|
|
+ intermediate_inputs,
|
|
|
**kwargs,
|
|
|
)
|
|
|
torch.cuda.synchronize()
|
|
@@ -1239,18 +1291,27 @@ class CUDAGraphRunner:
|
|
|
# Capture the graph.
|
|
|
self._graph = torch.cuda.CUDAGraph()
|
|
|
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
|
|
- output_hidden_states = self.model(
|
|
|
+ output_hidden_or_intermediate_states = self.model(
|
|
|
input_ids,
|
|
|
positions,
|
|
|
kv_caches,
|
|
|
attn_metadata,
|
|
|
+ intermediate_inputs,
|
|
|
**kwargs,
|
|
|
)
|
|
|
- if hidden_states is not None:
|
|
|
- hidden_states.copy_(output_hidden_states)
|
|
|
+ if hidden_or_intermediate_states is not None:
|
|
|
+ if get_pp_group().is_last_rank:
|
|
|
+ hidden_or_intermediate_states.copy_(
|
|
|
+ output_hidden_or_intermediate_states)
|
|
|
+ else:
|
|
|
+ for key in hidden_or_intermediate_states.tensors:
|
|
|
+ hidden_or_intermediate_states[key].copy_(
|
|
|
+ output_hidden_or_intermediate_states[key])
|
|
|
else:
|
|
|
- hidden_states = output_hidden_states
|
|
|
- del output_hidden_states
|
|
|
+ hidden_or_intermediate_states = (
|
|
|
+ output_hidden_or_intermediate_states)
|
|
|
+
|
|
|
+ del output_hidden_or_intermediate_states
|
|
|
# make sure `output_hidden_states` is deleted
|
|
|
# in the graph's memory pool
|
|
|
gc.collect()
|
|
@@ -1274,8 +1335,15 @@ class CUDAGraphRunner:
|
|
|
attn_metadata.decode_metadata.seq_lens_tensor,
|
|
|
"block_tables": attn_metadata.decode_metadata.block_tables,
|
|
|
}
|
|
|
- self.output_buffers = {"hidden_states": hidden_states}
|
|
|
- return hidden_states
|
|
|
+ if intermediate_inputs is not None:
|
|
|
+ self.input_buffers.update(intermediate_inputs.tensors)
|
|
|
+ if get_pp_group().is_last_rank:
|
|
|
+ self.output_buffers = {
|
|
|
+ "hidden_states": hidden_or_intermediate_states
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ self.output_buffers = hidden_or_intermediate_states
|
|
|
+ return hidden_or_intermediate_states
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
@@ -1283,6 +1351,7 @@ class CUDAGraphRunner:
|
|
|
positions: torch.Tensor,
|
|
|
kv_caches: List[torch.Tensor],
|
|
|
attn_metadata: AttentionMetadata,
|
|
|
+ intermediate_tensors: Optional[IntermediateTensors],
|
|
|
**kwargs,
|
|
|
) -> torch.Tensor:
|
|
|
# KV caches are fixed tensors, so we don't need to copy them.
|
|
@@ -1299,11 +1368,18 @@ class CUDAGraphRunner:
|
|
|
non_blocking=True)
|
|
|
self.input_buffers["block_tables"].copy_(
|
|
|
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
|
|
+ if intermediate_tensors is not None:
|
|
|
+ for key in intermediate_tensors.tensors:
|
|
|
+ self.input_buffers[key].copy_(intermediate_tensors[key],
|
|
|
+ non_blocking=True)
|
|
|
# Run the graph.
|
|
|
self.graph.replay()
|
|
|
|
|
|
# Return the output tensor.
|
|
|
- return self.output_buffers["hidden_states"]
|
|
|
+ if get_pp_group().is_last_rank:
|
|
|
+ return self.output_buffers["hidden_states"]
|
|
|
+
|
|
|
+ return self.output_buffers
|
|
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
|
return self.forward(*args, **kwargs)
|