|
@@ -23,7 +23,7 @@ from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
|
|
|
DeviceConfig, LoadConfig, LoadFormat,
|
|
|
LoRAConfig, ModelConfig, MultiModalConfig,
|
|
|
ParallelConfig, SchedulerConfig)
|
|
|
-from aphrodite.common.utils import is_pin_memory_available, is_tpu
|
|
|
+from aphrodite.common.utils import is_pin_memory_available
|
|
|
from aphrodite.modeling.model_loader.tensorizer import (
|
|
|
TensorizerConfig, is_aphrodite_tensorized, load_with_tensorizer,
|
|
|
serialize_aphrodite_model, tensorizer_weights_iterator)
|
|
@@ -90,7 +90,7 @@ def _get_quantization_config(
|
|
|
"""Get the quantization config."""
|
|
|
if model_config.quantization is not None:
|
|
|
quant_config = get_quant_config(model_config, load_config)
|
|
|
- if not is_tpu():
|
|
|
+ if not current_platform.is_tpu():
|
|
|
capability = current_platform.get_device_capability()
|
|
|
capability = capability[0] * 10 + capability[1]
|
|
|
if capability < quant_config.get_min_capability():
|
|
@@ -316,7 +316,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
|
else:
|
|
|
weights_iterator = pt_weights_iterator(hf_weights_files)
|
|
|
|
|
|
- if is_tpu():
|
|
|
+ if current_platform.is_tpu():
|
|
|
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
|
|
# not too many ops are accumulated in the XLA program.
|
|
|
import torch_xla.core.xla_model as xm
|