utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912
  1. import argparse
  2. import asyncio
  3. import contextlib
  4. import datetime
  5. import enum
  6. import gc
  7. import os
  8. import socket
  9. import subprocess
  10. import sys
  11. import tempfile
  12. import threading
  13. import uuid
  14. import warnings
  15. from collections import defaultdict
  16. from functools import lru_cache, partial, wraps
  17. from platform import uname
  18. from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
  19. Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
  20. Union)
  21. import numpy as np
  22. import psutil
  23. import torch
  24. from loguru import logger
  25. from aphrodite.common.logger import enable_trace_function_call
  26. T = TypeVar("T")
  27. class _Sentinel:
  28. ...
  29. ALL_PINNED_SENTINEL = _Sentinel()
  30. STR_DTYPE_TO_TORCH_DTYPE = {
  31. "half": torch.half,
  32. "bfloat16": torch.bfloat16,
  33. "float": torch.float,
  34. "fp8": torch.uint8,
  35. "fp8_e4m3": torch.uint8,
  36. "fp8_e5m2": torch.uint8,
  37. }
  38. class Device(enum.Enum):
  39. GPU = enum.auto()
  40. CPU = enum.auto()
  41. class Counter:
  42. def __init__(self, start: int = 0) -> None:
  43. self.counter = start
  44. def __next__(self) -> int:
  45. i = self.counter
  46. self.counter += 1
  47. return i
  48. def reset(self) -> None:
  49. self.counter = 0
  50. class LRUCache(Generic[T]):
  51. def __init__(self, capacity: int):
  52. self.cache: OrderedDict[Hashable, T] = OrderedDict()
  53. self.pinned_items: Set[Hashable] = set()
  54. self.capacity = capacity
  55. def __contains__(self, key: Hashable) -> bool:
  56. return key in self.cache
  57. def __len__(self) -> int:
  58. return len(self.cache)
  59. def __getitem__(self, key: Hashable) -> Optional[T]:
  60. return self.get(key)
  61. def __setitem__(self, key: Hashable, value: T) -> None:
  62. self.put(key, value)
  63. def __delitem__(self, key: Hashable) -> None:
  64. self.pop(key)
  65. def touch(self, key: Hashable) -> None:
  66. self.cache.move_to_end(key)
  67. def get(self,
  68. key: Hashable,
  69. default_value: Optional[T] = None) -> Optional[T]:
  70. if key in self.cache:
  71. value: Optional[T] = self.cache[key]
  72. self.cache.move_to_end(key)
  73. else:
  74. value = default_value
  75. return value
  76. def put(self, key: Hashable, value: T) -> None:
  77. self.cache[key] = value
  78. self.cache.move_to_end(key)
  79. self._remove_old_if_needed()
  80. def pin(self, key: Hashable) -> None:
  81. """
  82. Pins a key in the cache preventing it from being
  83. evicted in the LRU order.
  84. """
  85. if key not in self.cache:
  86. raise ValueError(f"Cannot pin key: {key} not in cache.")
  87. self.pinned_items.add(key)
  88. def _unpin(self, key: Hashable) -> None:
  89. self.pinned_items.remove(key)
  90. def _on_remove(self, key: Hashable, value: Optional[T]):
  91. pass
  92. def remove_oldest(self, remove_pinned=False):
  93. if not self.cache:
  94. return
  95. if not remove_pinned:
  96. # pop the oldest item in the cache that is not pinned
  97. lru_key = next(
  98. (key for key in self.cache if key not in self.pinned_items),
  99. ALL_PINNED_SENTINEL)
  100. if lru_key is ALL_PINNED_SENTINEL:
  101. raise RuntimeError("All items are pinned, "
  102. "cannot remove oldest from the cache.")
  103. else:
  104. lru_key = next(iter(self.cache))
  105. self.pop(lru_key)
  106. def _remove_old_if_needed(self) -> None:
  107. while len(self.cache) > self.capacity:
  108. self.remove_oldest()
  109. def pop(self,
  110. key: Hashable,
  111. default_value: Optional[T] = None) -> Optional[T]:
  112. run_on_remove = key in self.cache
  113. value: Optional[T] = self.cache.pop(key, default_value)
  114. # remove from pinned items
  115. if key in self.pinned_items:
  116. self._unpin(key)
  117. if run_on_remove:
  118. self._on_remove(key, value)
  119. return value
  120. def clear(self):
  121. while len(self.cache) > 0:
  122. self.remove_oldest(remove_pinned=True)
  123. self.cache.clear()
  124. def is_hip() -> bool:
  125. return torch.version.hip is not None
  126. @lru_cache(maxsize=None)
  127. def is_cpu() -> bool:
  128. from importlib.metadata import PackageNotFoundError, version
  129. try:
  130. return "cpu" in version("aphrodite-engine")
  131. except PackageNotFoundError:
  132. return False
  133. @lru_cache(maxsize=None)
  134. def is_openvino() -> bool:
  135. from importlib.metadata import PackageNotFoundError, version
  136. try:
  137. return "openvino" in version("aphrodite-engine")
  138. except PackageNotFoundError:
  139. return False
  140. @lru_cache(maxsize=None)
  141. def is_neuron() -> bool:
  142. try:
  143. import transformers_neuronx
  144. except ImportError:
  145. transformers_neuronx = None
  146. return transformers_neuronx is not None
  147. @lru_cache(maxsize=None)
  148. def is_tpu() -> bool:
  149. try:
  150. import libtpu
  151. except ImportError:
  152. libtpu = None
  153. return libtpu is not None
  154. @lru_cache(maxsize=None)
  155. def is_xpu() -> bool:
  156. from importlib.metadata import version
  157. is_xpu_flag = "xpu" in version("aphrodite-engine")
  158. # aphrodite is not build with xpu
  159. if not is_xpu_flag:
  160. return False
  161. try:
  162. import intel_extension_for_pytorch as ipex # noqa: F401
  163. _import_ipex = True
  164. except ImportError as e:
  165. logger.warning(f"Import Error for IPEX: {e.msg}")
  166. _import_ipex = False
  167. # ipex dependency is not ready
  168. if not _import_ipex:
  169. logger.warning("not found ipex lib")
  170. return False
  171. return hasattr(torch, "xpu") and torch.xpu.is_available()
  172. @lru_cache(maxsize=None)
  173. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  174. """Returns the maximum shared memory per thread block in bytes."""
  175. # NOTE: This import statement should be executed lazily since
  176. # the Neuron-X backend does not have the `cuda_utils` module.
  177. from aphrodite import _custom_ops as ops
  178. max_shared_mem = (
  179. ops.get_max_shared_memory_per_block_device_attribute(gpu))
  180. # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
  181. # will fail
  182. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  183. return int(max_shared_mem)
  184. def get_cpu_memory() -> int:
  185. """Returns the total CPU memory of the node in bytes."""
  186. return psutil.virtual_memory().total
  187. def random_uuid() -> str:
  188. return str(uuid.uuid4().hex)
  189. @lru_cache(maxsize=None)
  190. def get_aphrodite_instance_id():
  191. """
  192. If the environment variable APHRODITE_INSTANCE_ID is set, return it.
  193. Otherwise, return a random UUID.
  194. Instance id represents an instance of the Aphrodite. All processes in the
  195. same instance should have the same instance id.
  196. """
  197. return os.environ.get("APHRODITE_INSTANCE_ID",
  198. f"aphrodite-instance-{random_uuid()}")
  199. @lru_cache(maxsize=None)
  200. def in_wsl() -> bool:
  201. # Reference: https://github.com/microsoft/WSL/issues/4071
  202. return "microsoft" in " ".join(uname()).lower()
  203. def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
  204. """Take a blocking function, and run it on in an executor thread.
  205. This function prevents the blocking function from blocking the
  206. asyncio event loop.
  207. The code in this function needs to be thread safe.
  208. """
  209. def _async_wrapper(*args, **kwargs) -> asyncio.Future:
  210. loop = asyncio.get_event_loop()
  211. p_func = partial(func, *args, **kwargs)
  212. return loop.run_in_executor(executor=None, func=p_func)
  213. return _async_wrapper
  214. def merge_async_iterators(
  215. *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
  216. """Merge multiple asynchronous iterators into a single iterator.
  217. This method handle the case where some iterators finish before others.
  218. When it yields, it yields a tuple (i, item) where i is the index of the
  219. iterator that yields the item.
  220. """
  221. queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
  222. finished = [False] * len(iterators)
  223. async def producer(i: int, iterator: AsyncIterator[T]):
  224. try:
  225. async for item in iterator:
  226. await queue.put((i, item))
  227. except Exception as e:
  228. await queue.put(e)
  229. finished[i] = True
  230. _tasks = [
  231. asyncio.create_task(producer(i, iterator))
  232. for i, iterator in enumerate(iterators)
  233. ]
  234. async def consumer():
  235. try:
  236. while not all(finished) or not queue.empty():
  237. item = await queue.get()
  238. if isinstance(item, Exception):
  239. raise item
  240. yield item
  241. except (Exception, asyncio.CancelledError) as e:
  242. for task in _tasks:
  243. if sys.version_info >= (3, 9):
  244. # msg parameter only supported in Python 3.9+
  245. task.cancel(e)
  246. else:
  247. task.cancel()
  248. raise e
  249. await asyncio.gather(*_tasks)
  250. return consumer()
  251. def get_ip() -> str:
  252. host_ip = os.environ.get("HOST_IP")
  253. if host_ip:
  254. return host_ip
  255. # IP is not set, try to get it from the network interface
  256. # try ipv4
  257. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  258. try:
  259. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  260. return s.getsockname()[0]
  261. except Exception:
  262. pass
  263. # try ipv6
  264. try:
  265. s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  266. # Google's public DNS server, see
  267. # https://developers.google.com/speed/public-dns/docs/using#addresses
  268. s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
  269. return s.getsockname()[0]
  270. except Exception:
  271. pass
  272. warnings.warn(
  273. "Failed to get the IP address, using 0.0.0.0 by default."
  274. "The value can be set by the environment variable HOST_IP.",
  275. stacklevel=2)
  276. return "0.0.0.0"
  277. def get_distributed_init_method(ip: str, port: int) -> str:
  278. # Brackets are not permitted in ipv4 addresses,
  279. # see https://github.com/python/cpython/issues/103848
  280. return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
  281. def get_open_port() -> int:
  282. # try ipv4
  283. try:
  284. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  285. s.bind(("", 0))
  286. return s.getsockname()[1]
  287. except OSError:
  288. # try ipv6
  289. with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
  290. s.bind(("", 0))
  291. return s.getsockname()[1]
  292. def update_environment_variables(envs: Dict[str, str]):
  293. for k, v in envs.items():
  294. if k in os.environ and os.environ[k] != v:
  295. logger.warning(f"Overwriting environment variable {k} "
  296. f"from '{os.environ[k]}' to '{v}'")
  297. os.environ[k] = v
  298. def init_kmp_env():
  299. if not is_cpu():
  300. return
  301. ld_prealod_str = os.getenv("LD_PRELOAD", "")
  302. if "libiomp5.so" not in ld_prealod_str:
  303. return
  304. # The time(milliseconds) that a thread should wait after completing the
  305. # execution of a parallel region, before sleeping.
  306. os.environ['KMP_BLOCKTIME'] = "1"
  307. # dump settings on start up
  308. os.environ['KMP_SETTINGS'] = "1"
  309. # Prevents the CPU to run into low performance state
  310. os.environ['KMP_TPAUSE'] = "0"
  311. # Provides fine granularity parallelism
  312. os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist"
  313. os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist"
  314. os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist"
  315. def chunk_list(lst: List[T], chunk_size: int):
  316. """Yield successive chunk_size chunks from lst."""
  317. for i in range(0, len(lst), chunk_size):
  318. yield lst[i:i + chunk_size]
  319. def cdiv(a: int, b: int) -> int:
  320. """Ceiling division."""
  321. return -(a // -b)
  322. def _generate_random_fp8(
  323. tensor: torch.tensor,
  324. low: float,
  325. high: float,
  326. ) -> None:
  327. # NOTE: Due to NaN and Inf representation for fp8 data type,
  328. # it may occur Inf or NaN if we directly use torch.randint
  329. # to generate random data for fp8 data.
  330. # For example, s.11111.00 in fp8e5m2 format represents Inf.
  331. # | E4M3 | E5M2
  332. #-----|-------------|-------------------
  333. # Inf | N/A | s.11111.00
  334. # NaN | s.1111.111 | s.11111.{01,10,11}
  335. from aphrodite import _custom_ops as ops
  336. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  337. tensor_tmp.uniform_(low, high)
  338. ops.convert_fp8(tensor, tensor_tmp)
  339. del tensor_tmp
  340. def get_kv_cache_torch_dtype(
  341. cache_dtype: Optional[Union[str, torch.dtype]],
  342. model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
  343. if isinstance(cache_dtype, str):
  344. if cache_dtype == "auto":
  345. if isinstance(model_dtype, str):
  346. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  347. elif isinstance(model_dtype, torch.dtype):
  348. torch_dtype = model_dtype
  349. else:
  350. raise ValueError(f"Invalid model dtype: {model_dtype}")
  351. elif cache_dtype in ["half", "bfloat16", "float"]:
  352. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  353. elif cache_dtype == "fp8":
  354. torch_dtype = torch.uint8
  355. else:
  356. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  357. elif isinstance(cache_dtype, torch.dtype):
  358. torch_dtype = cache_dtype
  359. else:
  360. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  361. return torch_dtype
  362. def create_kv_caches_with_random_flash(
  363. num_blocks: int,
  364. block_size: int,
  365. num_layers: int,
  366. num_heads: int,
  367. head_size: int,
  368. cache_dtype: Optional[Union[str, torch.dtype]],
  369. model_dtype: Optional[Union[str, torch.dtype]] = None,
  370. seed: int = 0,
  371. device: Optional[str] = "cuda",
  372. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  373. assert cache_dtype != "fp8"
  374. torch.random.manual_seed(seed)
  375. if torch.cuda.is_available():
  376. torch.cuda.manual_seed(seed)
  377. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  378. key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
  379. scale = head_size**-0.5
  380. key_caches, value_caches = [], []
  381. for _ in range(num_layers):
  382. key_value_cache = torch.empty(size=key_value_cache_shape,
  383. dtype=torch_dtype,
  384. device=device)
  385. key_value_cache.uniform_(-scale, scale)
  386. key_caches.append(key_value_cache[:, 0])
  387. value_caches.append(key_value_cache[:, 1])
  388. return key_caches, value_caches
  389. def create_kv_caches_with_random(
  390. num_blocks: int,
  391. block_size: int,
  392. num_layers: int,
  393. num_heads: int,
  394. head_size: int,
  395. cache_dtype: Optional[Union[str, torch.dtype]],
  396. model_dtype: Optional[Union[str, torch.dtype]] = None,
  397. seed: int = 0,
  398. device: Optional[str] = "cuda",
  399. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  400. torch.random.manual_seed(seed)
  401. if torch.cuda.is_available():
  402. torch.cuda.manual_seed(seed)
  403. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  404. scale = head_size**-0.5
  405. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  406. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  407. key_caches = []
  408. for _ in range(num_layers):
  409. key_cache = torch.empty(size=key_cache_shape,
  410. dtype=torch_dtype,
  411. device=device)
  412. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  413. key_cache.uniform_(-scale, scale)
  414. elif cache_dtype == 'fp8':
  415. _generate_random_fp8(key_cache, -scale, scale)
  416. else:
  417. raise ValueError(
  418. f"Does not support key cache of type {cache_dtype}")
  419. key_caches.append(key_cache)
  420. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  421. value_caches = []
  422. for _ in range(num_layers):
  423. value_cache = torch.empty(size=value_cache_shape,
  424. dtype=torch_dtype,
  425. device=device)
  426. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  427. value_cache.uniform_(-scale, scale)
  428. elif cache_dtype == 'fp8':
  429. _generate_random_fp8(value_cache, -scale, scale)
  430. else:
  431. raise ValueError(
  432. f"Does not support value cache of type {cache_dtype}")
  433. value_caches.append(value_cache)
  434. return key_caches, value_caches
  435. @lru_cache
  436. def print_warning_once(msg: str) -> None:
  437. logger.warning(msg)
  438. @lru_cache(maxsize=None)
  439. def is_pin_memory_available() -> bool:
  440. if in_wsl():
  441. # Pinning memory in WSL is not supported.
  442. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
  443. print_warning_once("Using 'pin_memory=False' as WSL is detected. "
  444. "This may slow down the performance.")
  445. return False
  446. elif is_xpu():
  447. print_warning_once("Pin memory is not supported on XPU.")
  448. return False
  449. elif is_neuron():
  450. print_warning_once("Pin memory is not supported on Neuron.")
  451. return False
  452. elif is_cpu() or is_openvino():
  453. return False
  454. return True
  455. class CudaMemoryProfiler:
  456. def __init__(self, device=None):
  457. self.device = device
  458. def current_memory_usage(self) -> float:
  459. # Return the memory usage in bytes.
  460. if torch.cuda.is_available():
  461. torch.cuda.reset_peak_memory_stats(self.device)
  462. mem = torch.cuda.max_memory_allocated(self.device)
  463. elif is_xpu():
  464. torch.xpu.reset_peak_memory_stats(self.device)
  465. mem = torch.xpu.max_memory_allocated(self.device)
  466. return mem
  467. def __enter__(self):
  468. self.initial_memory = self.current_memory_usage()
  469. # This allows us to call methods of the context manager if needed
  470. return self
  471. def __exit__(self, exc_type, exc_val, exc_tb):
  472. self.final_memory = self.current_memory_usage()
  473. self.consumed_memory = self.final_memory - self.initial_memory
  474. # Force garbage collection
  475. gc.collect()
  476. def str_to_int_tuple(s: str) -> Tuple[int, ...]:
  477. """Convert a string to a tuple of integers."""
  478. try:
  479. return tuple(map(int, s.split(",")))
  480. except ValueError as e:
  481. raise ValueError(
  482. "String must be a series of integers separated by commas "
  483. f"(e.g., 1, 2, 3). Given input: {s}") from e
  484. def make_tensor_with_pad(
  485. x: List[List[int]],
  486. max_len: int,
  487. pad: int,
  488. dtype: torch.dtype,
  489. device: Optional[Union[str, torch.device]],
  490. ) -> torch.Tensor:
  491. """Make a padded tensor of a 2D inputs.
  492. The padding is applied to the end of each inner list until it reaches
  493. `max_len`.
  494. """
  495. padded_x = np.zeros([len(x), max_len], dtype=np.int32) + pad
  496. for ind, blocktb in enumerate(x):
  497. assert len(blocktb) <= max_len
  498. padded_x[ind, :len(blocktb)] = blocktb
  499. return torch.tensor(padded_x, dtype=dtype, device=device)
  500. def async_tensor_h2d(
  501. data: list,
  502. dtype: torch.dtype,
  503. target_device: Union[str, torch.device],
  504. pin_memory: bool,
  505. ) -> torch.Tensor:
  506. """Asynchronously create a tensor and copy it from host to device."""
  507. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  508. return t.to(device=target_device, non_blocking=True)
  509. def maybe_expand_dim(tensor: torch.Tensor,
  510. target_dims: int,
  511. size: int = 1) -> torch.Tensor:
  512. """Expand the tensor to the target_dims."""
  513. if tensor.ndim < target_dims:
  514. tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
  515. return tensor
  516. def get_dtype_size(dtype: torch.dtype) -> int:
  517. """Get the size of the data type in bytes."""
  518. return torch.tensor([], dtype=dtype).element_size()
  519. def merge_dicts(dict1: Dict[Any, List[Any]],
  520. dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
  521. """Merge 2 dicts that have key -> List of items.
  522. When a key conflicts, the values in dict1 is prioritized.
  523. """
  524. merged_dict = defaultdict(list)
  525. for key, value in dict1.items():
  526. merged_dict[key].extend(value)
  527. for key, value in dict2.items():
  528. merged_dict[key].extend(value)
  529. return dict(merged_dict)
  530. def init_cached_hf_modules():
  531. """
  532. Lazy initialization of the Hugging Face modules.
  533. """
  534. from transformers.dynamic_module_utils import init_hf_modules
  535. init_hf_modules()
  536. @lru_cache(maxsize=None)
  537. def find_library(lib_name: str) -> str:
  538. """
  539. Find the library file in the system.
  540. `lib_name` is full filename, with both prefix and suffix.
  541. This function resolves `lib_name` to the full path of the library.
  542. """
  543. # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
  544. # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
  545. # `/sbin/ldconfig` should exist in all Linux systems.
  546. # `/sbin/ldconfig` searches the library in the system
  547. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
  548. # each line looks like the following:
  549. # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
  550. locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
  551. # `LD_LIBRARY_PATH` searches the library in the user-defined paths
  552. env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
  553. if not locs and env_ld_library_path:
  554. locs = [
  555. os.path.join(dir, lib_name)
  556. for dir in env_ld_library_path.split(":")
  557. if os.path.exists(os.path.join(dir, lib_name))
  558. ]
  559. if not locs:
  560. raise ValueError(f"Cannot find {lib_name} in the system.")
  561. return locs[0]
  562. def find_nccl_library():
  563. """
  564. We either use the library file specified by the `APHRODITE_NCCL_SO_PATH`
  565. environment variable, or we find the library file brought by PyTorch.
  566. After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
  567. found by `ctypes` automatically.
  568. """
  569. so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
  570. # manually load the nccl library
  571. if so_file:
  572. logger.info("Found nccl from environment variable "
  573. f"APHRODITE_NCCL_SO_PATH={so_file}")
  574. else:
  575. if torch.version.cuda is not None:
  576. so_file = "libnccl.so.2"
  577. elif torch.version.hip is not None:
  578. so_file = "librccl.so.1"
  579. else:
  580. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  581. logger.info(f"Found nccl from library {so_file}")
  582. return so_file
  583. def enable_trace_function_call_for_thread() -> None:
  584. if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
  585. tmp_dir = tempfile.gettempdir()
  586. filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
  587. f"_thread_{threading.get_ident()}_"
  588. f"at_{datetime.datetime.now()}.log").replace(" ", "_")
  589. log_path = os.path.join(tmp_dir, "aphrodite",
  590. get_aphrodite_instance_id(), filename)
  591. os.makedirs(os.path.dirname(log_path), exist_ok=True)
  592. enable_trace_function_call(log_path)
  593. def identity(value: T) -> T:
  594. return value
  595. F = TypeVar('F', bound=Callable[..., Any])
  596. def deprecate_kwargs(
  597. *kws: str,
  598. is_deprecated: Union[bool, Callable[[], bool]] = True,
  599. additional_message: Optional[str] = None) -> Callable[[F], F]:
  600. deprecated_kws = set(kws)
  601. if not callable(is_deprecated):
  602. is_deprecated = partial(identity, is_deprecated)
  603. def wrapper(fn: F) -> F:
  604. @wraps(fn)
  605. def inner(*args, **kwargs):
  606. if is_deprecated():
  607. deprecated_kwargs = kwargs.keys() & deprecated_kws
  608. if deprecated_kwargs:
  609. msg = (
  610. f"The keyword arguments {deprecated_kwargs} are "
  611. "deprecated and will be removed in a future update.")
  612. if additional_message is not None:
  613. msg += f" {additional_message}"
  614. warnings.warn(
  615. DeprecationWarning(msg),
  616. stacklevel=3, # The inner function takes up one level
  617. )
  618. return fn(*args, **kwargs)
  619. return inner # type: ignore
  620. return wrapper
  621. @lru_cache(maxsize=8)
  622. def _cuda_device_count_stateless(
  623. cuda_visible_devices: Optional[str] = None) -> int:
  624. # Note: cuda_visible_devices is not used, but we keep it as an argument for
  625. # LRU Cache purposes.
  626. # Code below is based on
  627. # https://github.com/pytorch/pytorch/blob/
  628. # c1cd946818442aca8c7f812b16d187ce1586c3bc/
  629. # torch/cuda/__init__.py#L831C1-L831C17
  630. import torch.cuda
  631. import torch.version
  632. if not torch.cuda._is_compiled():
  633. return 0
  634. if is_hip():
  635. # ROCm uses amdsmi instead of nvml for stateless device count
  636. # This requires a sufficiently modern version of Torch 2.4.0
  637. raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
  638. torch.cuda, "_device_count_amdsmi")) else -1
  639. else:
  640. raw_count = torch.cuda._device_count_nvml()
  641. r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
  642. return r
  643. def cuda_device_count_stateless() -> int:
  644. """Get number of CUDA devices, caching based on the value of
  645. CUDA_VISIBLE_DEVICES at the time of call.
  646. This should be used instead of torch.cuda.device_count()
  647. unless CUDA_VISIBLE_DEVICES has already been set to the desired
  648. value."""
  649. # This can be removed and simply replaced with torch.cuda.get_device_count
  650. # after https://github.com/pytorch/pytorch/pull/122815 is released.
  651. return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES"))
  652. def error_on_invalid_device_count_status():
  653. cache_entries = 0
  654. with contextlib.suppress(Exception):
  655. # future pytorch will fix the issue, device_count will not be cached
  656. # at that time, `.cache_info().currsize` will error out
  657. cache_entries = torch.cuda.device_count.cache_info().currsize
  658. if cache_entries != 0:
  659. # the function is already called, and the result is cached
  660. remembered = torch.cuda.device_count()
  661. current = cuda_device_count_stateless()
  662. if remembered > current:
  663. raise RuntimeError(
  664. "The number of CUDA devices has changed since the first "
  665. "call to torch.cuda.device_count(). This is not allowed "
  666. "and may result in undefined behavior. Please set "
  667. "CUDA_VISIBLE_DEVICES to the GPUs you want to use.")
  668. # NVML utils
  669. # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
  670. # all the related functions work on real physical device ids.
  671. # the major benefit of using NVML is that it will not initialize CUDA
  672. try:
  673. import pynvml
  674. except ImportError:
  675. # For non-NV devices
  676. pynvml = None
  677. def with_nvml_context(fn):
  678. @wraps(fn)
  679. def wrapper(*args, **kwargs):
  680. if pynvml is not None:
  681. pynvml.nvmlInit()
  682. try:
  683. return fn(*args, **kwargs)
  684. finally:
  685. if pynvml is not None:
  686. pynvml.nvmlShutdown()
  687. return wrapper
  688. @with_nvml_context
  689. def is_full_nvlink(device_ids: List[int]) -> bool:
  690. """
  691. query if the set of gpus are fully connected by nvlink (1 hop)
  692. """
  693. handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
  694. for i, handle in enumerate(handles):
  695. for j, peer_handle in enumerate(handles):
  696. if i < j:
  697. try:
  698. p2p_status = pynvml.nvmlDeviceGetP2PStatus(
  699. handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
  700. if p2p_status != pynvml.NVML_P2P_STATUS_OK:
  701. return False
  702. except pynvml.NVMLError as error:
  703. logger.error(
  704. "NVLink detection failed. This is normal if your"
  705. " machine has no NVLink equipped.",
  706. exc_info=error)
  707. return False
  708. return True
  709. class FlexibleArgumentParser(argparse.ArgumentParser):
  710. """ArgumentParser that allows both underscore and dash in names."""
  711. def parse_args(self, args=None, namespace=None):
  712. if args is None:
  713. args = sys.argv[1:]
  714. # Convert underscores to dashes and vice versa in argument names
  715. processed_args = []
  716. for arg in args:
  717. if arg.startswith('--'):
  718. if '=' in arg:
  719. key, value = arg.split('=', 1)
  720. key = '--' + key[len('--'):].replace('_', '-')
  721. processed_args.append(f'{key}={value}')
  722. else:
  723. processed_args.append('--' +
  724. arg[len('--'):].replace('_', '-'))
  725. else:
  726. processed_args.append(arg)
  727. return super().parse_args(processed_args, namespace)
  728. async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
  729. **kwargs):
  730. """Utility function to run async task in a lock"""
  731. async with lock:
  732. return await task(*args, **kwargs)