|
@@ -1,4 +1,3 @@
|
|
|
-import ray
|
|
|
import torch
|
|
|
import torch.distributed as dist
|
|
|
from torch.distributed import ProcessGroup
|
|
@@ -6,6 +5,7 @@ from torch.distributed import ProcessGroup
|
|
|
from aphrodite.platforms import current_platform
|
|
|
|
|
|
if current_platform.is_tpu():
|
|
|
+ import ray
|
|
|
import torch_xla.core.xla_model as xm
|
|
|
import torch_xla.runtime as xr
|
|
|
from torch_xla._internal import pjrt
|