Browse Source

distributed: allow IPv6 in APHRODITE_HOST_IP with ZMQ (#1105)

AlpinDale 1 month ago
parent
commit
76088aa43a

+ 7 - 0
aphrodite/common/utils.py

@@ -5,6 +5,7 @@ import datetime
 import enum
 import gc
 import inspect
+import ipaddress
 import math
 import os
 import random
@@ -531,6 +532,12 @@ def get_ip() -> str:
         stacklevel=2)
     return "0.0.0.0"
 
+def is_valid_ipv6_address(address: str) -> bool:
+    try:
+        ipaddress.IPv6Address(address)
+        return True
+    except ValueError:
+        return False
 
 def get_distributed_init_method(ip: str, port: int) -> str:
     # Brackets are not permitted in ipv4 addresses,

+ 6 - 1
aphrodite/distributed/device_communicators/shm_broadcast.py

@@ -10,10 +10,11 @@ import torch
 import torch.distributed as dist
 from loguru import logger
 from torch.distributed import ProcessGroup
+from zmq import IPV6  # type: ignore
 from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context  # type: ignore
 
 import aphrodite.common.envs as envs
-from aphrodite.common.utils import get_ip, get_open_port
+from aphrodite.common.utils import get_ip, get_open_port, is_valid_ipv6_address
 
 APHRODITE_RINGBUFFER_WARNING_INTERVAL = (
     envs.APHRODITE_RINGBUFFER_WARNING_INTERVAL)
@@ -212,6 +213,8 @@ class MessageQueue:
             self.remote_socket = context.socket(XPUB)
             self.remote_socket.setsockopt(XPUB_VERBOSE, True)
             remote_subscribe_port = get_open_port()
+            if is_valid_ipv6_address(connect_ip):
+                self.remote_socket.setsockopt(IPV6, 1)
             socket_addr = f"tcp://*:{remote_subscribe_port}"
             self.remote_socket.bind(socket_addr)
 
@@ -273,6 +276,8 @@ class MessageQueue:
 
             self.remote_socket = context.socket(SUB)
             self.remote_socket.setsockopt_string(SUBSCRIBE, "")
+            if is_valid_ipv6_address(handle.connect_ip):
+                self.remote_socket.setsockopt(IPV6, 1)
             socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
             logger.debug(f"Connecting to {socket_addr}")
             self.remote_socket.connect(socket_addr)