Sfoglia il codice sorgente

chore: fix datatype check (#65)

AlpinDale 1 anno fa
parent
commit
c1fa7e8567
2 ha cambiato i file con 14 aggiunte e 10 eliminazioni
  1. 0 9
      aphrodite/common/config.py
  2. 14 1
      aphrodite/task_handler/worker.py

+ 0 - 9
aphrodite/common/config.py

@@ -331,15 +331,6 @@ def _get_and_verify_dtype(
             # Casting between float16 and bfloat16 is allowed with a warning.
             logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
 
-    # Check if the GPU supports the dtype.
-    if torch_dtype == torch.bfloat16:
-        compute_capability = torch.cuda.get_device_capability()
-        if compute_capability[0] < 8:
-            gpu_name = torch.cuda.get_device_name()
-            raise ValueError(
-                "Bfloat16 is only supported on GPUs with compute capability "
-                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
-                f"{compute_capability[0]}.{compute_capability[1]}.")
     return torch_dtype
 
 

+ 14 - 1
aphrodite/task_handler/worker.py

@@ -59,6 +59,8 @@ class Worker:
             raise ValueError("Invalid or unspecified rank.")
         torch.cuda.set_device(self.device)
 
+        _check_if_gpu_supports_dtype(self.model_config.dtype)
+
         # Initialize the distributed environment.
         _init_distributed_environment(self.parallel_config, self.rank,
                                       self.distributed_init_method)
@@ -387,4 +389,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
             f"capability {torch.cuda.get_device_capability()} "
             f"(required shared memory {required_shared_mem} > "
             f"available shared memory {max_shared_mem}). "
-            "This will be fixed in a future release.")
+            "This will be fixed in a future release.")
+
+def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
+    if torch_dtype == torch.bfloat16:
+        compute_capability = torch.cuda.get_device_capability()
+        if compute_capability[0] < 8:
+            gpu_name = torch.cuda.get_device_name()
+            raise ValueError(
+                "Bfloat16 is only supported on GPUs with compute capability "
+                f"of at least 8.0. You {gpu_name} GPU has compute capability "
+                f"{compute_capability[0]}.{compute_capability[1]}. Please "
+                "use the `--dtype float16` argument when launching the engine.")