浏览代码

suppress tpu import warning (#696)

AlpinDale 6 月之前
父节点
当前提交
5d37ec1016

+ 6 - 4
aphrodite/_custom_ops.py

@@ -6,11 +6,13 @@ import torch
 from loguru import logger
 
 from aphrodite._core_ext import ScalarType
+from aphrodite.platforms import current_platform
 
-try:
-    import aphrodite._C
-except ImportError as e:
-    logger.warning(f"Failed to import from aphrodite._C with {e}")
+if not current_platform.is_tpu():
+    try:
+        import aphrodite._C
+    except ImportError as e:
+        logger.warning(f"Failed to import from aphrodite._C with {e}")
 
 with contextlib.suppress(ImportError):
     # ruff: noqa: F401

+ 2 - 2
aphrodite/attention/selector.py

@@ -9,7 +9,7 @@ from loguru import logger
 
 from aphrodite.attention.backends.abstract import AttentionBackend
 from aphrodite.common.utils import (STR_BACKEND_ENV_VAR, is_cpu, is_hip,
-                                    is_openvino, is_tpu, is_xpu)
+                                    is_openvino, is_xpu)
 from aphrodite.platforms import current_platform
 
 APHRODITE_ATTENTION_BACKEND = os.getenv("APHRODITE_ATTENTION_BACKEND", None)
@@ -201,7 +201,7 @@ def which_attn_to_use(
             logger.info(f"Cannot use {selected_backend} backend on XPU.")
         return _Backend.IPEX
 
-    if is_tpu():
+    if current_platform.is_tpu():
         if selected_backend != _Backend.PALLAS:
             logger.info(f"Cannot use {selected_backend} backend on TPU.")
         return _Backend.PALLAS

+ 3 - 4
aphrodite/common/config.py

@@ -12,8 +12,7 @@ from transformers import PretrainedConfig
 from aphrodite.common.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes,
                                     cuda_device_count_stateless,
                                     get_cpu_memory, is_cpu, is_hip, is_neuron,
-                                    is_openvino, is_tpu, is_xpu,
-                                    print_warning_once)
+                                    is_openvino, is_xpu, print_warning_once)
 from aphrodite.distributed import get_current_tp_rank_partition_size
 from aphrodite.modeling.models import ModelRegistry
 from aphrodite.platforms import current_platform
@@ -307,7 +306,7 @@ class ModelConfig:
                 raise ValueError(
                     f"{self.quantization} quantization is currently not "
                     "supported in ROCm.")
-            if is_tpu(
+            if current_platform.is_tpu(
             ) and self.quantization not in tpu_supported_quantization:
                 raise ValueError(
                     f"{self.quantization} quantization is currently not "
@@ -988,7 +987,7 @@ class DeviceConfig:
                 self.device_type = "neuron"
             elif is_openvino():
                 self.device_type = "openvino"
-            elif is_tpu():
+            elif current_platform.is_tpu():
                 self.device_type = "tpu"
             elif is_cpu():
                 self.device_type = "cpu"

+ 1 - 10
aphrodite/common/utils.py

@@ -31,7 +31,6 @@ from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
                            SpinnerColumn, TextColumn, TimeElapsedColumn)
 from typing_extensions import ParamSpec, TypeIs, assert_never
 
-from aphrodite import _custom_ops as ops
 from aphrodite.common.logger import enable_trace_function_call
 from aphrodite.distributed import get_tensor_model_parallel_rank
 
@@ -334,15 +333,6 @@ def is_neuron() -> bool:
     return transformers_neuronx is not None
 
 
-@lru_cache(maxsize=None)
-def is_tpu() -> bool:
-    try:
-        import libtpu
-    except ImportError:
-        libtpu = None
-    return libtpu is not None
-
-
 @lru_cache(maxsize=None)
 def is_xpu() -> bool:
     from importlib.metadata import version
@@ -366,6 +356,7 @@ def is_xpu() -> bool:
 @lru_cache(maxsize=None)
 def get_max_shared_memory_bytes(gpu: int = 0) -> int:
     """Returns the maximum shared memory per thread block in bytes."""
+    from aphrodite import _custom_ops as ops
     max_shared_mem = (
         ops.get_max_shared_memory_per_block_device_attribute(gpu))
     # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py

+ 3 - 2
aphrodite/executor/ray_utils.py

@@ -2,7 +2,8 @@ from typing import List, Optional, Tuple, Union
 
 from aphrodite.common.config import ParallelConfig
 from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
-from aphrodite.common.utils import get_ip, is_hip, is_tpu, is_xpu
+from aphrodite.common.utils import get_ip, is_hip, is_xpu
+from aphrodite.platforms import current_platform
 from aphrodite.task_handler.worker_base import WorkerWrapperBase
 
 try:
@@ -107,7 +108,7 @@ def initialize_ray_cluster(
         # Placement group is already set.
         return
 
-    device_str = "GPU" if not is_tpu() else "TPU"
+    device_str = "GPU" if not current_platform.is_tpu() else "TPU"
     # Create placement group for worker processes
     current_placement_group = ray.util.get_current_placement_group()
     if current_placement_group:

+ 3 - 2
aphrodite/modeling/_custom_op.py

@@ -1,6 +1,7 @@
 import torch.nn as nn
 
-from aphrodite.common.utils import is_cpu, is_hip, is_tpu, is_xpu
+from aphrodite.common.utils import is_cpu, is_hip, is_xpu
+from aphrodite.platforms import current_platform
 
 
 class CustomOp(nn.Module):
@@ -53,7 +54,7 @@ class CustomOp(nn.Module):
             return self.forward_hip
         elif is_cpu():
             return self.forward_cpu
-        elif is_tpu():
+        elif current_platform.is_tpu():
             return self.forward_tpu
         elif is_xpu():
             return self.forward_xpu

+ 2 - 2
aphrodite/modeling/layers/rotary_embedding.py

@@ -28,8 +28,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
 import torch
 import torch.nn as nn
 
-from aphrodite.common.utils import is_tpu
 from aphrodite.modeling._custom_op import CustomOp
+from aphrodite.platforms import current_platform
 
 
 def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -80,7 +80,7 @@ class RotaryEmbedding(CustomOp):
         self.dtype = dtype
 
         cache = self._compute_cos_sin_cache()
-        self.use_native2 = is_tpu() and is_neox_style
+        self.use_native2 = current_platform.is_tpu() and is_neox_style
         if not self.use_native2:
             cache = cache.to(dtype)
             self.register_buffer("cos_sin_cache", cache, persistent=False)

+ 3 - 3
aphrodite/modeling/model_loader/loader.py

@@ -23,7 +23,7 @@ from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
                                      DeviceConfig, LoadConfig, LoadFormat,
                                      LoRAConfig, ModelConfig, MultiModalConfig,
                                      ParallelConfig, SchedulerConfig)
-from aphrodite.common.utils import is_pin_memory_available, is_tpu
+from aphrodite.common.utils import is_pin_memory_available
 from aphrodite.modeling.model_loader.tensorizer import (
     TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
     serialize_aphrodite_model, tensorizer_weights_iterator)
@@ -90,7 +90,7 @@ def _get_quantization_config(
     """Get the quantization config."""
     if model_config.quantization is not None:
         quant_config = get_quant_config(model_config, load_config)
-        if not is_tpu():
+        if not current_platform.is_tpu():
             capability = current_platform.get_device_capability()
             capability = capability[0] * 10 + capability[1]
             if capability < quant_config.get_min_capability():
@@ -316,7 +316,7 @@ class DefaultModelLoader(BaseModelLoader):
         else:
             weights_iterator = pt_weights_iterator(hf_weights_files)
 
-        if is_tpu():
+        if current_platform.is_tpu():
             # In PyTorch XLA, we should call `xm.mark_step` frequently so that
             # not too many ops are accumulated in the XLA program.
             import torch_xla.core.xla_model as xm

+ 12 - 9
aphrodite/platforms/__init__.py

@@ -1,22 +1,25 @@
-from typing import Optional
-
 import torch
 
-from aphrodite.common.utils import is_tpu
-
 from .interface import Platform, PlatformEnum, UnspecifiedPlatform
 
-current_platform: Optional[Platform]
+current_platform: Platform
 
-if torch.version.cuda is not None:
+try:
+    import libtpu
+except ImportError:
+    libtpu = None
+
+if libtpu is not None:
+    # people might install pytorch built with cuda but run on tpu
+    # so we need to check tpu first
+    from .tpu import TpuPlatform
+    current_platform = TpuPlatform()
+elif torch.version.cuda is not None:
     from .cuda import CudaPlatform
     current_platform = CudaPlatform()
 elif torch.version.hip is not None:
     from .rocm import RocmPlatform
     current_platform = RocmPlatform()
-elif is_tpu():
-    from .tpu import TpuPlatform
-    current_platform = TpuPlatform()
 else:
     current_platform = UnspecifiedPlatform()