|
@@ -59,6 +59,8 @@ class Worker:
|
|
|
raise ValueError("Invalid or unspecified rank.")
|
|
|
torch.cuda.set_device(self.device)
|
|
|
|
|
|
+ _check_if_gpu_supports_dtype(self.model_config.dtype)
|
|
|
+
|
|
|
# Initialize the distributed environment.
|
|
|
_init_distributed_environment(self.parallel_config, self.rank,
|
|
|
self.distributed_init_method)
|
|
@@ -387,4 +389,15 @@ def _check_if_can_support_max_seq_len(max_seq_len: int,
|
|
|
f"capability {torch.cuda.get_device_capability()} "
|
|
|
f"(required shared memory {required_shared_mem} > "
|
|
|
f"available shared memory {max_shared_mem}). "
|
|
|
- "This will be fixed in a future release.")
|
|
|
+ "This will be fixed in a future release.")
|
|
|
+
|
|
|
+def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
|
|
+ if torch_dtype == torch.bfloat16:
|
|
|
+ compute_capability = torch.cuda.get_device_capability()
|
|
|
+ if compute_capability[0] < 8:
|
|
|
+ gpu_name = torch.cuda.get_device_name()
|
|
|
+ raise ValueError(
|
|
|
+ "Bfloat16 is only supported on GPUs with compute capability "
|
|
|
+ f"of at least 8.0. You {gpu_name} GPU has compute capability "
|
|
|
+ f"{compute_capability[0]}.{compute_capability[1]}. Please "
|
|
|
+ "use the `--dtype float16` argument when launching the engine.")
|