ソースを参照

misc: extend cuda graph capture size for H200 (#957)

AlpinDale 2 ヶ月 前
コミット
22b8096006
1 ファイル変更30 行追加7 行削除
  1. 30 7
      aphrodite/task_handler/model_runner.py

+ 30 - 7
aphrodite/task_handler/model_runner.py

@@ -61,10 +61,14 @@ if TYPE_CHECKING:
 
 LORA_WARMUP_RANK = 8
 _BATCH_SIZE_ALIGNMENT = 8
-# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
+# all the token sizes that **can** be captured by cudagraph.
+# they can be arbitrarily large.
+# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
+# the actual sizes to capture will be determined by the model,
+# depending on the model's max_num_seqs.
 # NOTE: _get_graph_batch_size needs to be updated if this list is changed.
 _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
-    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
+    _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
 ]
 _NUM_WARMUP_ITERS = 2
 
@@ -660,7 +664,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
     def _use_captured_graph(self, batch_size: int,
                             max_decode_seq_len: int) -> bool:
         return (self.decode_only and not self.runner.model_config.enforce_eager
-                and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
+                and batch_size <= self.runner.max_batchsize_to_capture
                 and max_decode_seq_len <= self.runner.max_seq_len_to_capture)
 
     def build(self) -> ModelInputForGPU:
@@ -842,6 +846,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         self.sliding_window = model_config.get_sliding_window()
         self.block_size = cache_config.block_size
         self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
+        self.max_batchsize_to_capture = _get_max_graph_batch_size(
+            self.scheduler_config.max_num_seqs)
 
         self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
             {} for _ in range(self.parallel_config.pipeline_parallel_size)
@@ -858,7 +864,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         # The shape of the cached block table will be
         # (max batch size to capture, max context len to capture / block size).
         self.graph_block_tables = np.zeros(
-            (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
+            (self.max_batchsize_to_capture, self.get_max_block_per_batch()),
             dtype=np.int32)
         self.attn_backend = get_attn_backend(
             self.model_config.get_head_size(),
@@ -1271,7 +1277,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         start_time = time.perf_counter()
 
         # Prepare dummy inputs. These will be reused for all batch sizes.
-        max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
+        max_batch_size = self.max_batchsize_to_capture
         input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
         input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
         # Prepare dummy previous_hidden_states only if needed by the model.
@@ -1297,8 +1303,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
             None
         ] * self.parallel_config.pipeline_parallel_size
 
-        graph_batch_size = _get_graph_batch_size(
-            self.scheduler_config.max_num_seqs)
+
+        graph_batch_size = self.max_batchsize_to_capture
         batch_size_capture_list = [
             bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
         ]
@@ -1685,3 +1691,20 @@ def _get_graph_batch_size(batch_size: int) -> int:
     else:
         return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
                 _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)
+
+
+def _get_max_graph_batch_size(max_num_seqs: int) -> int:
+    """
+    max_num_seqs: Maximum number of sequences in a batch.
+    _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
+    pad the max_num_seqs if necessary by calling _get_graph_batch_size,
+    which will deal with some edge cases like 1, 2, 4.
+    if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size.
+    if not, it means the padded size is larger than the largest size in
+    _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
+    """
+    padded_size = _get_graph_batch_size(max_num_seqs)
+    if padded_size in _BATCH_SIZES_TO_CAPTURE:
+        return padded_size
+    assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
+    return _BATCH_SIZES_TO_CAPTURE[-1]