utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715
  1. import asyncio
  2. import datetime
  3. import enum
  4. import gc
  5. import os
  6. import socket
  7. import subprocess
  8. import sys
  9. import tempfile
  10. import threading
  11. import uuid
  12. import warnings
  13. from collections import defaultdict
  14. from functools import lru_cache, partial, wraps
  15. from platform import uname
  16. from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
  17. Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
  18. Union)
  19. import numpy as np
  20. import psutil
  21. import torch
  22. from loguru import logger
  23. from aphrodite.common.logger import enable_trace_function_call
  24. T = TypeVar("T")
  25. STR_DTYPE_TO_TORCH_DTYPE = {
  26. "half": torch.half,
  27. "bfloat16": torch.bfloat16,
  28. "float": torch.float,
  29. "fp8": torch.uint8,
  30. "fp8_e4m3": torch.uint8,
  31. "fp8_e5m2": torch.uint8,
  32. }
  33. class Device(enum.Enum):
  34. GPU = enum.auto()
  35. CPU = enum.auto()
  36. class Counter:
  37. def __init__(self, start: int = 0) -> None:
  38. self.counter = start
  39. def __next__(self) -> int:
  40. i = self.counter
  41. self.counter += 1
  42. return i
  43. def reset(self) -> None:
  44. self.counter = 0
  45. class LRUCache(Generic[T]):
  46. def __init__(self, capacity: int):
  47. self.cache: OrderedDict[Hashable, T] = OrderedDict()
  48. self.capacity = capacity
  49. def __contains__(self, key: Hashable) -> bool:
  50. return key in self.cache
  51. def __len__(self) -> int:
  52. return len(self.cache)
  53. def __getitem__(self, key: Hashable) -> Optional[T]:
  54. return self.get(key)
  55. def __setitem__(self, key: Hashable, value: T) -> None:
  56. self.put(key, value)
  57. def __delitem__(self, key: Hashable) -> None:
  58. self.pop(key)
  59. def touch(self, key: Hashable) -> None:
  60. self.cache.move_to_end(key)
  61. def get(self,
  62. key: Hashable,
  63. default_value: Optional[T] = None) -> Optional[T]:
  64. if key in self.cache:
  65. value: Optional[T] = self.cache[key]
  66. self.cache.move_to_end(key)
  67. else:
  68. value = default_value
  69. return value
  70. def put(self, key: Hashable, value: T) -> None:
  71. self.cache[key] = value
  72. self.cache.move_to_end(key)
  73. self._remove_old_if_needed()
  74. def _on_remove(self, key: Hashable, value: Optional[T]):
  75. pass
  76. def remove_oldest(self):
  77. if not self.cache:
  78. return
  79. key, value = self.cache.popitem(last=False)
  80. self._on_remove(key, value)
  81. def _remove_old_if_needed(self) -> None:
  82. while len(self.cache) > self.capacity:
  83. self.remove_oldest()
  84. def pop(self,
  85. key: Hashable,
  86. default_value: Optional[T] = None) -> Optional[T]:
  87. run_on_remove = key in self.cache
  88. value: Optional[T] = self.cache.pop(key, default_value)
  89. if run_on_remove:
  90. self._on_remove(key, value)
  91. return value
  92. def clear(self):
  93. while len(self.cache) > 0:
  94. self.remove_oldest()
  95. self.cache.clear()
  96. def is_hip() -> bool:
  97. return torch.version.hip is not None
  98. @lru_cache(maxsize=None)
  99. def is_cpu() -> bool:
  100. from importlib.metadata import PackageNotFoundError, version
  101. try:
  102. return "cpu" in version("aphrodite-engine")
  103. except PackageNotFoundError:
  104. return False
  105. @lru_cache(maxsize=None)
  106. def is_neuron() -> bool:
  107. try:
  108. import transformers_neuronx
  109. except ImportError:
  110. transformers_neuronx = None
  111. return transformers_neuronx is not None
  112. @lru_cache(maxsize=None)
  113. def is_tpu() -> bool:
  114. try:
  115. import libtpu
  116. except ImportError:
  117. libtpu = None
  118. return libtpu is not None
  119. @lru_cache(maxsize=None)
  120. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  121. """Returns the maximum shared memory per thread block in bytes."""
  122. # NOTE: This import statement should be executed lazily since
  123. # the Neuron-X backend does not have the `cuda_utils` module.
  124. from aphrodite import _custom_ops as ops
  125. max_shared_mem = (
  126. ops.get_max_shared_memory_per_block_device_attribute(gpu))
  127. # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
  128. # will fail
  129. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  130. return int(max_shared_mem)
  131. def get_cpu_memory() -> int:
  132. """Returns the total CPU memory of the node in bytes."""
  133. return psutil.virtual_memory().total
  134. def random_uuid() -> str:
  135. return str(uuid.uuid4().hex)
  136. @lru_cache(maxsize=None)
  137. def get_aphrodite_instance_id():
  138. """
  139. If the environment variable APHRODITE_INSTANCE_ID is set, return it.
  140. Otherwise, return a random UUID.
  141. Instance id represents an instance of the Aphrodite. All processes in the
  142. same instance should have the same instance id.
  143. """
  144. return os.environ.get("APHRODITE_INSTANCE_ID",
  145. f"aphrodite-instance-{random_uuid()}")
  146. @lru_cache(maxsize=None)
  147. def in_wsl() -> bool:
  148. # Reference: https://github.com/microsoft/WSL/issues/4071
  149. return "microsoft" in " ".join(uname()).lower()
  150. def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
  151. """Take a blocking function, and run it on in an executor thread.
  152. This function prevents the blocking function from blocking the
  153. asyncio event loop.
  154. The code in this function needs to be thread safe.
  155. """
  156. def _async_wrapper(*args, **kwargs) -> asyncio.Future:
  157. loop = asyncio.get_event_loop()
  158. p_func = partial(func, *args, **kwargs)
  159. return loop.run_in_executor(executor=None, func=p_func)
  160. return _async_wrapper
  161. def merge_async_iterators(
  162. *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
  163. """Merge multiple asynchronous iterators into a single iterator.
  164. This method handle the case where some iterators finish before others.
  165. When it yields, it yields a tuple (i, item) where i is the index of the
  166. iterator that yields the item.
  167. """
  168. queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
  169. finished = [False] * len(iterators)
  170. async def producer(i: int, iterator: AsyncIterator[T]):
  171. try:
  172. async for item in iterator:
  173. await queue.put((i, item))
  174. except Exception as e:
  175. await queue.put(e)
  176. finished[i] = True
  177. _tasks = [
  178. asyncio.create_task(producer(i, iterator))
  179. for i, iterator in enumerate(iterators)
  180. ]
  181. async def consumer():
  182. try:
  183. while not all(finished) or not queue.empty():
  184. item = await queue.get()
  185. if isinstance(item, Exception):
  186. raise item
  187. yield item
  188. except (Exception, asyncio.CancelledError) as e:
  189. for task in _tasks:
  190. if sys.version_info >= (3, 9):
  191. # msg parameter only supported in Python 3.9+
  192. task.cancel(e)
  193. else:
  194. task.cancel()
  195. raise e
  196. await asyncio.gather(*_tasks)
  197. return consumer()
  198. def get_ip() -> str:
  199. host_ip = os.environ.get("HOST_IP")
  200. if host_ip:
  201. return host_ip
  202. # IP is not set, try to get it from the network interface
  203. # try ipv4
  204. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  205. try:
  206. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  207. return s.getsockname()[0]
  208. except Exception:
  209. pass
  210. # try ipv6
  211. try:
  212. s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  213. # Google's public DNS server, see
  214. # https://developers.google.com/speed/public-dns/docs/using#addresses
  215. s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
  216. return s.getsockname()[0]
  217. except Exception:
  218. pass
  219. warnings.warn(
  220. "Failed to get the IP address, using 0.0.0.0 by default."
  221. "The value can be set by the environment variable HOST_IP.",
  222. stacklevel=2)
  223. return "0.0.0.0"
  224. def get_distributed_init_method(ip: str, port: int) -> str:
  225. # Brackets are not permitted in ipv4 addresses,
  226. # see https://github.com/python/cpython/issues/103848
  227. return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
  228. def get_open_port() -> int:
  229. # try ipv4
  230. try:
  231. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  232. s.bind(("", 0))
  233. return s.getsockname()[1]
  234. except OSError:
  235. # try ipv6
  236. with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
  237. s.bind(("", 0))
  238. return s.getsockname()[1]
  239. def update_environment_variables(envs: Dict[str, str]):
  240. for k, v in envs.items():
  241. if k in os.environ and os.environ[k] != v:
  242. logger.warning(f"Overwriting environment variable {k} "
  243. f"from '{os.environ[k]}' to '{v}'")
  244. os.environ[k] = v
  245. def chunk_list(lst, chunk_size):
  246. """Yield successive chunk_size chunks from lst."""
  247. return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
  248. def cdiv(a: int, b: int) -> int:
  249. """Ceiling division."""
  250. return -(a // -b)
  251. def _generate_random_fp8(
  252. tensor: torch.tensor,
  253. low: float,
  254. high: float,
  255. ) -> None:
  256. # NOTE: Due to NaN and Inf representation for fp8 data type,
  257. # it may occur Inf or NaN if we directly use torch.randint
  258. # to generate random data for fp8 data.
  259. # For example, s.11111.00 in fp8e5m2 format represents Inf.
  260. # | E4M3 | E5M2
  261. #-----|-------------|-------------------
  262. # Inf | N/A | s.11111.00
  263. # NaN | s.1111.111 | s.11111.{01,10,11}
  264. from aphrodite import _custom_ops as ops
  265. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  266. tensor_tmp.uniform_(low, high)
  267. ops.convert_fp8(tensor, tensor_tmp)
  268. del tensor_tmp
  269. def get_kv_cache_torch_dtype(
  270. cache_dtype: Optional[Union[str, torch.dtype]],
  271. model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
  272. if isinstance(cache_dtype, str):
  273. if cache_dtype == "auto":
  274. if isinstance(model_dtype, str):
  275. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  276. elif isinstance(model_dtype, torch.dtype):
  277. torch_dtype = model_dtype
  278. else:
  279. raise ValueError(f"Invalid model dtype: {model_dtype}")
  280. elif cache_dtype in ["half", "bfloat16", "float"]:
  281. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  282. elif cache_dtype == "fp8":
  283. torch_dtype = torch.uint8
  284. else:
  285. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  286. elif isinstance(cache_dtype, torch.dtype):
  287. torch_dtype = cache_dtype
  288. else:
  289. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  290. return torch_dtype
  291. def create_kv_caches_with_random_flash(
  292. num_blocks: int,
  293. block_size: int,
  294. num_layers: int,
  295. num_heads: int,
  296. head_size: int,
  297. cache_dtype: Optional[Union[str, torch.dtype]],
  298. model_dtype: Optional[Union[str, torch.dtype]] = None,
  299. seed: int = 0,
  300. device: Optional[str] = "cuda",
  301. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  302. assert cache_dtype != "fp8"
  303. torch.random.manual_seed(seed)
  304. if torch.cuda.is_available():
  305. torch.cuda.manual_seed(seed)
  306. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  307. key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
  308. scale = head_size**-0.5
  309. key_caches, value_caches = [], []
  310. for _ in range(num_layers):
  311. key_value_cache = torch.empty(size=key_value_cache_shape,
  312. dtype=torch_dtype,
  313. device=device)
  314. key_value_cache.uniform_(-scale, scale)
  315. key_caches.append(key_value_cache[:, 0])
  316. value_caches.append(key_value_cache[:, 1])
  317. return key_caches, value_caches
  318. def create_kv_caches_with_random(
  319. num_blocks: int,
  320. block_size: int,
  321. num_layers: int,
  322. num_heads: int,
  323. head_size: int,
  324. cache_dtype: Optional[Union[str, torch.dtype]],
  325. model_dtype: Optional[Union[str, torch.dtype]] = None,
  326. seed: int = 0,
  327. device: Optional[str] = "cuda",
  328. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  329. torch.random.manual_seed(seed)
  330. if torch.cuda.is_available():
  331. torch.cuda.manual_seed(seed)
  332. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  333. scale = head_size**-0.5
  334. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  335. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  336. key_caches = []
  337. for _ in range(num_layers):
  338. key_cache = torch.empty(size=key_cache_shape,
  339. dtype=torch_dtype,
  340. device=device)
  341. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  342. key_cache.uniform_(-scale, scale)
  343. elif cache_dtype == 'fp8':
  344. _generate_random_fp8(key_cache, -scale, scale)
  345. else:
  346. raise ValueError(
  347. f"Does not support key cache of type {cache_dtype}")
  348. key_caches.append(key_cache)
  349. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  350. value_caches = []
  351. for _ in range(num_layers):
  352. value_cache = torch.empty(size=value_cache_shape,
  353. dtype=torch_dtype,
  354. device=device)
  355. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  356. value_cache.uniform_(-scale, scale)
  357. elif cache_dtype == 'fp8':
  358. _generate_random_fp8(value_cache, -scale, scale)
  359. else:
  360. raise ValueError(
  361. f"Does not support value cache of type {cache_dtype}")
  362. value_caches.append(value_cache)
  363. return key_caches, value_caches
  364. @lru_cache
  365. def print_warning_once(msg: str) -> None:
  366. logger.warning(msg)
  367. @lru_cache(maxsize=None)
  368. def is_pin_memory_available() -> bool:
  369. if in_wsl():
  370. # Pinning memory in WSL is not supported.
  371. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
  372. print_warning_once("Using 'pin_memory=False' as WSL is detected. "
  373. "This may slow down the performance.")
  374. return False
  375. elif is_neuron():
  376. print_warning_once("Pin memory is not supported on Neuron.")
  377. return False
  378. elif is_cpu():
  379. return False
  380. return True
  381. class CudaMemoryProfiler:
  382. def __init__(self, device=None):
  383. self.device = device
  384. def current_memory_usage(self) -> float:
  385. # Return the memory usage in bytes.
  386. torch.cuda.reset_peak_memory_stats(self.device)
  387. mem = torch.cuda.max_memory_allocated(self.device)
  388. return mem
  389. def __enter__(self):
  390. self.initial_memory = self.current_memory_usage()
  391. # This allows us to call methods of the context manager if needed
  392. return self
  393. def __exit__(self, exc_type, exc_val, exc_tb):
  394. self.final_memory = self.current_memory_usage()
  395. self.consumed_memory = self.final_memory - self.initial_memory
  396. # Force garbage collection
  397. gc.collect()
  398. def str_to_int_tuple(s: str) -> Tuple[int, ...]:
  399. """Convert a string to a tuple of integers."""
  400. try:
  401. return tuple(map(int, s.split(",")))
  402. except ValueError as e:
  403. raise ValueError(
  404. "String must be a series of integers separated by commas "
  405. f"(e.g., 1, 2, 3). Given input: {s}") from e
  406. def make_tensor_with_pad(
  407. x: List[List[int]],
  408. max_len: int,
  409. pad: int,
  410. dtype: torch.dtype,
  411. device: Optional[Union[str, torch.device]],
  412. ) -> torch.Tensor:
  413. """Make a padded tensor of a 2D inputs.
  414. The padding is applied to the end of each inner list until it reaches
  415. `max_len`.
  416. """
  417. padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad
  418. for ind, blocktb in enumerate(x):
  419. assert len(blocktb) <= max_len
  420. padded_x[ind, :len(blocktb)] = blocktb
  421. return torch.tensor(padded_x, dtype=dtype, device=device)
  422. def async_tensor_h2d(
  423. data: list,
  424. dtype: torch.dtype,
  425. target_device: Union[str, torch.device],
  426. pin_memory: bool,
  427. ) -> torch.Tensor:
  428. """Asynchronously create a tensor and copy it from host to device."""
  429. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  430. return t.to(device=target_device, non_blocking=True)
  431. def maybe_expand_dim(tensor: torch.Tensor,
  432. target_dims: int,
  433. size: int = 1) -> torch.Tensor:
  434. """Expand the tensor to the target_dims."""
  435. if tensor.ndim < target_dims:
  436. tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
  437. return tensor
  438. def get_dtype_size(dtype: torch.dtype) -> int:
  439. """Get the size of the data type in bytes."""
  440. return torch.tensor([], dtype=dtype).element_size()
  441. def merge_dicts(dict1: Dict[Any, List[Any]],
  442. dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
  443. """Merge 2 dicts that have key -> List of items.
  444. When a key conflicts, the values in dict1 is prioritized.
  445. """
  446. merged_dict = defaultdict(list)
  447. for key, value in dict1.items():
  448. merged_dict[key].extend(value)
  449. for key, value in dict2.items():
  450. merged_dict[key].extend(value)
  451. return dict(merged_dict)
  452. def init_cached_hf_modules():
  453. """
  454. Lazy initialization of the Hugging Face modules.
  455. """
  456. from transformers.dynamic_module_utils import init_hf_modules
  457. init_hf_modules()
  458. @lru_cache(maxsize=None)
  459. def find_library(lib_name: str) -> str:
  460. """
  461. Find the library file in the system.
  462. `lib_name` is full filename, with both prefix and suffix.
  463. This function resolves `lib_name` to the full path of the library.
  464. """
  465. # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
  466. # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
  467. # `/sbin/ldconfig` should exist in all Linux systems.
  468. # `/sbin/ldconfig` searches the library in the system
  469. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
  470. # each line looks like the following:
  471. # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
  472. locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
  473. # `LD_LIBRARY_PATH` searches the library in the user-defined paths
  474. env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
  475. if not locs and env_ld_library_path:
  476. locs = [
  477. os.path.join(dir, lib_name)
  478. for dir in env_ld_library_path.split(":")
  479. if os.path.exists(os.path.join(dir, lib_name))
  480. ]
  481. if not locs:
  482. raise ValueError(f"Cannot find {lib_name} in the system.")
  483. return locs[0]
  484. def find_nccl_library():
  485. """
  486. We either use the library file specified by the `APHRODITE_NCCL_SO_PATH`
  487. environment variable, or we find the library file brought by PyTorch.
  488. After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
  489. found by `ctypes` automatically.
  490. """
  491. so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
  492. # manually load the nccl library
  493. if so_file:
  494. logger.info("Found nccl from environment variable "
  495. f"APHRODITE_NCCL_SO_PATH={so_file}")
  496. else:
  497. if torch.version.cuda is not None:
  498. so_file = "libnccl.so.2"
  499. elif torch.version.hip is not None:
  500. so_file = "librccl.so.1"
  501. else:
  502. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  503. logger.info(f"Found nccl from library {so_file}")
  504. return so_file
  505. def enable_trace_function_call_for_thread() -> None:
  506. if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
  507. tmp_dir = tempfile.gettempdir()
  508. filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
  509. f"_thread_{threading.get_ident()}_"
  510. f"at_{datetime.datetime.now()}.log").replace(" ", "_")
  511. log_path = os.path.join(tmp_dir, "aphrodite",
  512. get_aphrodite_instance_id(), filename)
  513. os.makedirs(os.path.dirname(log_path), exist_ok=True)
  514. enable_trace_function_call(log_path)
  515. def identity(value: T) -> T:
  516. return value
  517. F = TypeVar('F', bound=Callable[..., Any])
  518. def deprecate_kwargs(
  519. *kws: str,
  520. is_deprecated: Union[bool, Callable[[], bool]] = True,
  521. additional_message: Optional[str] = None) -> Callable[[F], F]:
  522. deprecated_kws = set(kws)
  523. if not callable(is_deprecated):
  524. is_deprecated = partial(identity, is_deprecated)
  525. def wrapper(fn: F) -> F:
  526. @wraps(fn)
  527. def inner(*args, **kwargs):
  528. if is_deprecated():
  529. deprecated_kwargs = kwargs.keys() & deprecated_kws
  530. if deprecated_kwargs:
  531. msg = (
  532. f"The keyword arguments {deprecated_kwargs} are "
  533. "deprecated and will be removed in a future update.")
  534. if additional_message is not None:
  535. msg += f" {additional_message}"
  536. warnings.warn(
  537. DeprecationWarning(msg),
  538. stacklevel=3, # The inner function takes up one level
  539. )
  540. return fn(*args, **kwargs)
  541. return inner # type: ignore
  542. return wrapper
  543. @lru_cache(maxsize=8)
  544. def _cuda_device_count_stateless(
  545. cuda_visible_devices: Optional[str] = None) -> int:
  546. # Note: cuda_visible_devices is not used, but we keep it as an argument for
  547. # LRU Cache purposes.
  548. # Code below is based on
  549. # https://github.com/pytorch/pytorch/blob/
  550. # c1cd946818442aca8c7f812b16d187ce1586c3bc/
  551. # torch/cuda/__init__.py#L831C1-L831C17
  552. import torch.cuda
  553. import torch.version
  554. if not torch.cuda._is_compiled():
  555. return 0
  556. # bypass _device_count_nvml() if rocm (not supported)
  557. nvml_count = -1 if torch.version.hip else torch.cuda._device_count_nvml()
  558. r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count
  559. return r
  560. def cuda_device_count_stateless() -> int:
  561. """Get number of CUDA devices, caching based on the value of
  562. CUDA_VISIBLE_DEVICES at the time of call.
  563. This should be used instead of torch.cuda.device_count()
  564. unless CUDA_VISIBLE_DEVICES has already been set to the desired
  565. value."""
  566. # This can be removed and simply replaced with torch.cuda.get_device_count
  567. # after https://github.com/pytorch/pytorch/pull/122815 is released.
  568. return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES"))