utils.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. """Utils."""
  2. from os import path
  3. import enum
  4. import socket
  5. from platform import uname
  6. import uuid
  7. import psutil
  8. import torch
  9. from aphrodite._C import cuda_utils
  10. class Device(enum.Enum):
  11. GPU = enum.auto()
  12. CPU = enum.auto()
  13. class Counter:
  14. def __init__(self, start: int = 0) -> None:
  15. self.counter = start
  16. def __next__(self) -> int:
  17. i = self.counter
  18. self.counter += 1
  19. return i
  20. def reset(self) -> None:
  21. self.counter = 0
  22. def is_hip() -> bool:
  23. return torch.version.hip is not None
  24. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  25. """Returns the maximum shared memory per thread block in bytes."""
  26. # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
  27. # pylint: disable=invalid-name
  28. cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
  29. max_shared_mem = cuda_utils.get_device_attribute(
  30. cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
  31. return int(max_shared_mem)
  32. def get_cpu_memory() -> int:
  33. """Returns the total CPU memory of the node or container in bytes."""
  34. memory_limit = psutil.virtual_memory().total
  35. for limit_file in [
  36. "/sys/fs/cgroup/memory/memory.limit_in_bytes", # v1
  37. "/sys/fs/cgroup/memory.max" # v2
  38. ]:
  39. if path.exists(limit_file):
  40. with open(limit_file) as f:
  41. content = f.read().strip()
  42. if content.isnumeric(): # v2 can have "max" as limit
  43. memory_limit = min(memory_limit, int(content))
  44. return memory_limit
  45. def random_uuid() -> str:
  46. return str(uuid.uuid4().hex)
  47. def in_wsl() -> bool:
  48. # Reference: https://github.com/microsoft/WSL/issues/4071
  49. return "microsoft" in " ".join(uname()).lower()
  50. def get_open_port():
  51. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  52. s.bind(("", 0))
  53. return s.getsockname()[1]