Browse Source

imports in sampler

AlpinDale 1 year ago
parent
commit
70dbf7de03
2 changed files with 2 additions and 2 deletions
  1. 1 1
      aphrodite/modeling/layers/sampler.py
  2. 1 1
      aphrodite/task_handler/worker.py

+ 1 - 1
aphrodite/modeling/layers/sampler.py

@@ -8,7 +8,7 @@ import torch.nn as nn
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.megatron.tensor_parallel import (
     gather_from_tensor_model_parallel_region)
-from aphrodite.common.sampling_params import SamplingParams
+from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import SamplerOutput, SequenceOutputs, SequenceData
 
 __SAMPLING_EPS = 1e-5

+ 1 - 1
aphrodite/task_handler/worker.py

@@ -365,7 +365,7 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
     required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
     if padded_max_seq_len * float32_bytes > max_shared_mem:
         raise RuntimeError(
-            f"vLLM cannot currently support max_model_len={max_seq_len} "
+            f"Aphrodite cannot currently support max_model_len={max_seq_len} "
             f"with block_size={block_size} on GPU with compute "
             f"capability {torch.cuda.get_device_capability()} "
             f"(required shared memory {required_shared_mem} > "