utils.py 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027
  1. import argparse
  2. import asyncio
  3. import datetime
  4. import enum
  5. import gc
  6. import os
  7. import socket
  8. import subprocess
  9. import sys
  10. import tempfile
  11. import threading
  12. import uuid
  13. import warnings
  14. from collections import defaultdict
  15. from functools import lru_cache, partial, wraps
  16. from platform import uname
  17. from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
  18. Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
  19. Union, overload)
  20. import numpy as np
  21. import numpy.typing as npt
  22. import psutil
  23. import torch
  24. import torch.types
  25. from loguru import logger
  26. from typing_extensions import ParamSpec
  27. from aphrodite import _custom_ops as ops
  28. from aphrodite.common.logger import enable_trace_function_call
  29. STR_DTYPE_TO_TORCH_DTYPE = {
  30. "half": torch.half,
  31. "bfloat16": torch.bfloat16,
  32. "float": torch.float,
  33. "fp8": torch.uint8,
  34. "fp8_e4m3": torch.uint8,
  35. "fp8_e5m2": torch.uint8,
  36. }
  37. TORCH_DTYPE_TO_NUMPY_DTYPE = {
  38. torch.float16: np.float16,
  39. torch.float32: np.float32,
  40. torch.float64: np.float64,
  41. torch.uint8: np.uint8,
  42. torch.int32: np.int32,
  43. torch.int64: np.int64,
  44. }
  45. P = ParamSpec('P')
  46. K = TypeVar("K")
  47. T = TypeVar("T")
  48. U = TypeVar("U")
  49. class _Sentinel:
  50. ...
  51. ALL_PINNED_SENTINEL = _Sentinel()
  52. class Device(enum.Enum):
  53. GPU = enum.auto()
  54. CPU = enum.auto()
  55. class Counter:
  56. def __init__(self, start: int = 0) -> None:
  57. self.counter = start
  58. def __next__(self) -> int:
  59. i = self.counter
  60. self.counter += 1
  61. return i
  62. def reset(self) -> None:
  63. self.counter = 0
  64. class LRUCache(Generic[T]):
  65. def __init__(self, capacity: int):
  66. self.cache: OrderedDict[Hashable, T] = OrderedDict()
  67. self.pinned_items: Set[Hashable] = set()
  68. self.capacity = capacity
  69. def __contains__(self, key: Hashable) -> bool:
  70. return key in self.cache
  71. def __len__(self) -> int:
  72. return len(self.cache)
  73. def __getitem__(self, key: Hashable) -> T:
  74. value = self.cache[key] # Raise KeyError if not exists
  75. self.cache.move_to_end(key)
  76. return value
  77. def __setitem__(self, key: Hashable, value: T) -> None:
  78. self.put(key, value)
  79. def __delitem__(self, key: Hashable) -> None:
  80. self.pop(key)
  81. def touch(self, key: Hashable) -> None:
  82. self.cache.move_to_end(key)
  83. def get(self,
  84. key: Hashable,
  85. default_value: Optional[T] = None) -> Optional[T]:
  86. value: Optional[T]
  87. if key in self.cache:
  88. value = self.cache[key]
  89. self.cache.move_to_end(key)
  90. else:
  91. value = default_value
  92. return value
  93. def put(self, key: Hashable, value: T) -> None:
  94. self.cache[key] = value
  95. self.cache.move_to_end(key)
  96. self._remove_old_if_needed()
  97. def pin(self, key: Hashable) -> None:
  98. """
  99. Pins a key in the cache preventing it from being
  100. evicted in the LRU order.
  101. """
  102. if key not in self.cache:
  103. raise ValueError(f"Cannot pin key: {key} not in cache.")
  104. self.pinned_items.add(key)
  105. def _unpin(self, key: Hashable) -> None:
  106. self.pinned_items.remove(key)
  107. def _on_remove(self, key: Hashable, value: Optional[T]):
  108. pass
  109. def remove_oldest(self, remove_pinned=False):
  110. if not self.cache:
  111. return
  112. if not remove_pinned:
  113. # pop the oldest item in the cache that is not pinned
  114. lru_key = next(
  115. (key for key in self.cache if key not in self.pinned_items),
  116. ALL_PINNED_SENTINEL)
  117. if lru_key is ALL_PINNED_SENTINEL:
  118. raise RuntimeError("All items are pinned, "
  119. "cannot remove oldest from the cache.")
  120. else:
  121. lru_key = next(iter(self.cache))
  122. self.pop(lru_key)
  123. def _remove_old_if_needed(self) -> None:
  124. while len(self.cache) > self.capacity:
  125. self.remove_oldest()
  126. def pop(self,
  127. key: Hashable,
  128. default_value: Optional[T] = None) -> Optional[T]:
  129. run_on_remove = key in self.cache
  130. value: Optional[T] = self.cache.pop(key, default_value)
  131. # remove from pinned items
  132. if key in self.pinned_items:
  133. self._unpin(key)
  134. if run_on_remove:
  135. self._on_remove(key, value)
  136. return value
  137. def clear(self):
  138. while len(self.cache) > 0:
  139. self.remove_oldest(remove_pinned=True)
  140. self.cache.clear()
  141. def is_hip() -> bool:
  142. return torch.version.hip is not None
  143. @lru_cache(maxsize=None)
  144. def is_cpu() -> bool:
  145. from importlib.metadata import PackageNotFoundError, version
  146. try:
  147. return "cpu" in version("aphrodite-engine")
  148. except PackageNotFoundError:
  149. return False
  150. @lru_cache(maxsize=None)
  151. def is_openvino() -> bool:
  152. from importlib.metadata import PackageNotFoundError, version
  153. try:
  154. return "openvino" in version("aphrodite-engine")
  155. except PackageNotFoundError:
  156. return False
  157. @lru_cache(maxsize=None)
  158. def is_neuron() -> bool:
  159. try:
  160. import transformers_neuronx
  161. except ImportError:
  162. transformers_neuronx = None
  163. return transformers_neuronx is not None
  164. @lru_cache(maxsize=None)
  165. def is_tpu() -> bool:
  166. try:
  167. import libtpu
  168. except ImportError:
  169. libtpu = None
  170. return libtpu is not None
  171. @lru_cache(maxsize=None)
  172. def is_xpu() -> bool:
  173. from importlib.metadata import version
  174. is_xpu_flag = "xpu" in version("aphrodite-engine")
  175. # aphrodite is not build with xpu
  176. if not is_xpu_flag:
  177. return False
  178. try:
  179. import intel_extension_for_pytorch as ipex # noqa: F401
  180. _import_ipex = True
  181. except ImportError as e:
  182. logger.warning(f"Import Error for IPEX: {e.msg}")
  183. _import_ipex = False
  184. # ipex dependency is not ready
  185. if not _import_ipex:
  186. logger.warning("not found ipex lib")
  187. return False
  188. return hasattr(torch, "xpu") and torch.xpu.is_available()
  189. @lru_cache(maxsize=None)
  190. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  191. """Returns the maximum shared memory per thread block in bytes."""
  192. max_shared_mem = (
  193. ops.get_max_shared_memory_per_block_device_attribute(gpu))
  194. # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
  195. # will fail
  196. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  197. return int(max_shared_mem)
  198. def get_cpu_memory() -> int:
  199. """Returns the total CPU memory of the node in bytes."""
  200. return psutil.virtual_memory().total
  201. def random_uuid() -> str:
  202. return str(uuid.uuid4().hex)
  203. @lru_cache(maxsize=None)
  204. def get_aphrodite_instance_id():
  205. """
  206. If the environment variable APHRODITE_INSTANCE_ID is set, return it.
  207. Otherwise, return a random UUID.
  208. Instance id represents an instance of the Aphrodite. All processes in the
  209. same instance should have the same instance id.
  210. """
  211. return os.environ.get("APHRODITE_INSTANCE_ID",
  212. f"aphrodite-instance-{random_uuid()}")
  213. @lru_cache(maxsize=None)
  214. def in_wsl() -> bool:
  215. # Reference: https://github.com/microsoft/WSL/issues/4071
  216. return "microsoft" in " ".join(uname()).lower()
  217. def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
  218. """Take a blocking function, and run it on in an executor thread.
  219. This function prevents the blocking function from blocking the
  220. asyncio event loop.
  221. The code in this function needs to be thread safe.
  222. """
  223. def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
  224. loop = asyncio.get_event_loop()
  225. p_func = partial(func, *args, **kwargs)
  226. return loop.run_in_executor(executor=None, func=p_func)
  227. return _async_wrapper
  228. class ProducerFinished:
  229. pass
  230. def merge_async_iterators(
  231. *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
  232. """Merge multiple asynchronous iterators into a single iterator.
  233. This method handle the case where some iterators finish before others.
  234. When it yields, it yields a tuple (i, item) where i is the index of the
  235. iterator that yields the item.
  236. """
  237. queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
  238. Exception]] = asyncio.Queue()
  239. producers = len(iterators)
  240. async def producer(i: int, iterator: AsyncIterator[T]):
  241. try:
  242. async for item in iterator:
  243. await queue.put((i, item))
  244. except Exception as e:
  245. await queue.put(e)
  246. # Signal to the consumer that we've finished
  247. await queue.put(ProducerFinished())
  248. _tasks = [
  249. asyncio.create_task(producer(i, iterator))
  250. for i, iterator in enumerate(iterators)
  251. ]
  252. async def consumer():
  253. remaining = producers
  254. try:
  255. while remaining or not queue.empty():
  256. # we think there is a race condition here
  257. item = await queue.get()
  258. if isinstance(item, ProducerFinished):
  259. # Signal that a producer finished- not a real item
  260. remaining -= 1
  261. continue
  262. if isinstance(item, Exception):
  263. raise item
  264. yield item
  265. except (Exception, asyncio.CancelledError) as e:
  266. for task in _tasks:
  267. if sys.version_info >= (3, 9):
  268. # msg parameter only supported in Python 3.9+
  269. task.cancel(e)
  270. else:
  271. task.cancel()
  272. raise e
  273. await asyncio.gather(*_tasks)
  274. return consumer()
  275. def get_ip() -> str:
  276. host_ip = os.environ.get("HOST_IP")
  277. if host_ip:
  278. return host_ip
  279. # IP is not set, try to get it from the network interface
  280. # try ipv4
  281. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  282. try:
  283. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  284. return s.getsockname()[0]
  285. except Exception:
  286. pass
  287. # try ipv6
  288. try:
  289. s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  290. # Google's public DNS server, see
  291. # https://developers.google.com/speed/public-dns/docs/using#addresses
  292. s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
  293. return s.getsockname()[0]
  294. except Exception:
  295. pass
  296. warnings.warn(
  297. "Failed to get the IP address, using 0.0.0.0 by default."
  298. "The value can be set by the environment variable HOST_IP.",
  299. stacklevel=2)
  300. return "0.0.0.0"
  301. def get_distributed_init_method(ip: str, port: int) -> str:
  302. # Brackets are not permitted in ipv4 addresses,
  303. # see https://github.com/python/cpython/issues/103848
  304. return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
  305. def get_open_port(port: Optional[int] = None) -> int:
  306. if port is None:
  307. # Default behavior here is to return a port for multi-gpu communication
  308. port = int(os.getenv("APHRODITE_PORT", 2242))
  309. if port is not None:
  310. while True:
  311. try:
  312. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  313. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  314. s.bind(("", port))
  315. return port
  316. except OSError:
  317. port += 1 # Increment port number if already in use
  318. logger.info(f"Port {port - 1} is already in use, trying port "
  319. f"{port}")
  320. # try ipv4
  321. try:
  322. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  323. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  324. s.bind(("", 0))
  325. return s.getsockname()[1]
  326. except OSError:
  327. # try ipv6
  328. with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
  329. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  330. s.bind(("", 0))
  331. return s.getsockname()[1]
  332. def update_environment_variables(envs: Dict[str, str]):
  333. for k, v in envs.items():
  334. if k in os.environ and os.environ[k] != v:
  335. logger.warning(f"Overwriting environment variable {k} "
  336. f"from '{os.environ[k]}' to '{v}'")
  337. os.environ[k] = v
  338. def chunk_list(lst: List[T], chunk_size: int):
  339. """Yield successive chunk_size chunks from lst."""
  340. for i in range(0, len(lst), chunk_size):
  341. yield lst[i:i + chunk_size]
  342. def cdiv(a: int, b: int) -> int:
  343. """Ceiling division."""
  344. return -(a // -b)
  345. def _generate_random_fp8(
  346. tensor: torch.Tensor,
  347. low: float,
  348. high: float,
  349. ) -> None:
  350. # NOTE: Due to NaN and Inf representation for fp8 data type,
  351. # it may occur Inf or NaN if we directly use torch.randint
  352. # to generate random data for fp8 data.
  353. # For example, s.11111.00 in fp8e5m2 format represents Inf.
  354. # | E4M3 | E5M2
  355. #-----|-------------|-------------------
  356. # Inf | N/A | s.11111.00
  357. # NaN | s.1111.111 | s.11111.{01,10,11}
  358. from aphrodite import _custom_ops as ops
  359. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  360. tensor_tmp.uniform_(low, high)
  361. ops.convert_fp8(tensor, tensor_tmp)
  362. del tensor_tmp
  363. def get_kv_cache_torch_dtype(
  364. cache_dtype: Optional[Union[str, torch.dtype]],
  365. model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
  366. if isinstance(cache_dtype, str):
  367. if cache_dtype == "auto":
  368. if isinstance(model_dtype, str):
  369. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  370. elif isinstance(model_dtype, torch.dtype):
  371. torch_dtype = model_dtype
  372. else:
  373. raise ValueError(f"Invalid model dtype: {model_dtype}")
  374. elif cache_dtype in ["half", "bfloat16", "float"]:
  375. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  376. elif cache_dtype == "fp8":
  377. torch_dtype = torch.uint8
  378. else:
  379. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  380. elif isinstance(cache_dtype, torch.dtype):
  381. torch_dtype = cache_dtype
  382. else:
  383. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  384. return torch_dtype
  385. def create_kv_caches_with_random_flash(
  386. num_blocks: int,
  387. block_size: int,
  388. num_layers: int,
  389. num_heads: int,
  390. head_size: int,
  391. cache_dtype: Optional[Union[str, torch.dtype]],
  392. model_dtype: Optional[Union[str, torch.dtype]] = None,
  393. seed: int = 0,
  394. device: Optional[str] = "cuda",
  395. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  396. torch.random.manual_seed(seed)
  397. if torch.cuda.is_available():
  398. torch.cuda.manual_seed(seed)
  399. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  400. key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
  401. scale = head_size**-0.5
  402. key_caches: List[torch.Tensor] = []
  403. value_caches: List[torch.Tensor] = []
  404. for _ in range(num_layers):
  405. key_value_cache = torch.empty(size=key_value_cache_shape,
  406. dtype=torch_dtype,
  407. device=device)
  408. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  409. key_value_cache.uniform_(-scale, scale)
  410. elif cache_dtype == 'fp8':
  411. _generate_random_fp8(key_value_cache, -scale, scale)
  412. else:
  413. raise ValueError(
  414. f"Does not support key cache of type {cache_dtype}")
  415. key_caches.append(key_value_cache[:, 0])
  416. value_caches.append(key_value_cache[:, 1])
  417. return key_caches, value_caches
  418. def create_kv_caches_with_random(
  419. num_blocks: int,
  420. block_size: int,
  421. num_layers: int,
  422. num_heads: int,
  423. head_size: int,
  424. cache_dtype: Optional[Union[str, torch.dtype]],
  425. model_dtype: Optional[Union[str, torch.dtype]] = None,
  426. seed: int = 0,
  427. device: Optional[str] = "cuda",
  428. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  429. if cache_dtype == "fp8" and head_size % 16:
  430. raise ValueError(
  431. f"Does not support key cache of type fp8 with head_size "
  432. f"{head_size}")
  433. torch.random.manual_seed(seed)
  434. if torch.cuda.is_available():
  435. torch.cuda.manual_seed(seed)
  436. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  437. scale = head_size**-0.5
  438. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  439. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  440. key_caches: List[torch.Tensor] = []
  441. for _ in range(num_layers):
  442. key_cache = torch.empty(size=key_cache_shape,
  443. dtype=torch_dtype,
  444. device=device)
  445. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  446. key_cache.uniform_(-scale, scale)
  447. elif cache_dtype == 'fp8':
  448. _generate_random_fp8(key_cache, -scale, scale)
  449. else:
  450. raise ValueError(
  451. f"Does not support key cache of type {cache_dtype}")
  452. key_caches.append(key_cache)
  453. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  454. value_caches: List[torch.Tensor] = []
  455. for _ in range(num_layers):
  456. value_cache = torch.empty(size=value_cache_shape,
  457. dtype=torch_dtype,
  458. device=device)
  459. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  460. value_cache.uniform_(-scale, scale)
  461. elif cache_dtype == 'fp8':
  462. _generate_random_fp8(value_cache, -scale, scale)
  463. else:
  464. raise ValueError(
  465. f"Does not support value cache of type {cache_dtype}")
  466. value_caches.append(value_cache)
  467. return key_caches, value_caches
  468. @lru_cache
  469. def print_warning_once(msg: str) -> None:
  470. logger.warning(msg)
  471. @lru_cache(maxsize=None)
  472. def is_pin_memory_available() -> bool:
  473. if in_wsl():
  474. # Pinning memory in WSL is not supported.
  475. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
  476. print_warning_once("Using 'pin_memory=False' as WSL is detected. "
  477. "This may slow down the performance.")
  478. return False
  479. elif is_xpu():
  480. print_warning_once("Pin memory is not supported on XPU.")
  481. return False
  482. elif is_neuron():
  483. print_warning_once("Pin memory is not supported on Neuron.")
  484. return False
  485. elif is_cpu() or is_openvino():
  486. return False
  487. return True
  488. class CudaMemoryProfiler:
  489. def __init__(self, device: Optional[torch.types.Device] = None):
  490. self.device = device
  491. def current_memory_usage(self) -> float:
  492. # Return the memory usage in bytes.
  493. if torch.cuda.is_available():
  494. torch.cuda.reset_peak_memory_stats(self.device)
  495. mem = torch.cuda.max_memory_allocated(self.device)
  496. elif is_xpu():
  497. torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
  498. mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
  499. return mem
  500. def __enter__(self):
  501. self.initial_memory = self.current_memory_usage()
  502. # This allows us to call methods of the context manager if needed
  503. return self
  504. def __exit__(self, exc_type, exc_val, exc_tb):
  505. self.final_memory = self.current_memory_usage()
  506. self.consumed_memory = self.final_memory - self.initial_memory
  507. # Force garbage collection
  508. gc.collect()
  509. def str_to_int_tuple(s: str) -> Tuple[int, ...]:
  510. """Convert a string to a tuple of integers."""
  511. try:
  512. return tuple(map(int, s.split(",")))
  513. except ValueError as e:
  514. raise ValueError(
  515. "String must be a series of integers separated by commas "
  516. f"(e.g., 1, 2, 3). Given input: {s}") from e
  517. def make_ndarray_with_pad(
  518. x: List[List[T]],
  519. pad: T,
  520. dtype: npt.DTypeLike,
  521. *,
  522. max_len: Optional[int] = None,
  523. ) -> npt.NDArray:
  524. """
  525. Make a padded array from 2D inputs.
  526. The padding is applied to the end of each inner list until it reaches
  527. `max_len`.
  528. """
  529. if max_len is None:
  530. # Unlike for most functions, map is faster than a genexpr over `len`
  531. max_len = max(map(len, x), default=0)
  532. padded_x = np.full((len(x), max_len), pad, dtype=dtype)
  533. for ind, blocktb in enumerate(x):
  534. assert len(blocktb) <= max_len
  535. padded_x[ind, :len(blocktb)] = blocktb
  536. return padded_x
  537. def make_tensor_with_pad(
  538. x: List[List[T]],
  539. pad: T,
  540. dtype: torch.dtype,
  541. *,
  542. max_len: Optional[int] = None,
  543. device: Optional[Union[str, torch.device]] = None,
  544. pin_memory: bool = False,
  545. ) -> torch.Tensor:
  546. """
  547. Make a padded tensor from 2D inputs.
  548. The padding is applied to the end of each inner list until it reaches
  549. `max_len`.
  550. """
  551. np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
  552. padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
  553. tensor = torch.from_numpy(padded_x).to(device)
  554. if pin_memory:
  555. tensor = tensor.pin_memory()
  556. return tensor
  557. def async_tensor_h2d(
  558. data: list,
  559. dtype: torch.dtype,
  560. target_device: Union[str, torch.device],
  561. pin_memory: bool,
  562. ) -> torch.Tensor:
  563. """Asynchronously create a tensor and copy it from host to device."""
  564. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  565. return t.to(device=target_device, non_blocking=True)
  566. def maybe_expand_dim(tensor: torch.Tensor,
  567. target_dims: int,
  568. size: int = 1) -> torch.Tensor:
  569. """Expand the tensor to the target_dims."""
  570. if tensor.ndim < target_dims:
  571. tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
  572. return tensor
  573. def get_dtype_size(dtype: torch.dtype) -> int:
  574. """Get the size of the data type in bytes."""
  575. return torch.tensor([], dtype=dtype).element_size()
  576. def merge_dicts(dict1: Dict[K, List[T]],
  577. dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
  578. """Merge 2 dicts that have key -> List of items.
  579. When a key conflicts, the values in dict1 is prioritized.
  580. """
  581. merged_dict: Dict[K, List[T]] = defaultdict(list)
  582. for key, value in dict1.items():
  583. merged_dict[key].extend(value)
  584. for key, value in dict2.items():
  585. merged_dict[key].extend(value)
  586. return dict(merged_dict)
  587. JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
  588. Tuple["JSONTree[T]", ...], T]
  589. """A nested JSON structure where the leaves need not be JSON-serializable."""
  590. @overload
  591. def json_map_leaves(
  592. func: Callable[[T], U],
  593. value: Dict[str, JSONTree[T]],
  594. ) -> Dict[str, JSONTree[U]]:
  595. ...
  596. @overload
  597. def json_map_leaves(
  598. func: Callable[[T], U],
  599. value: List[JSONTree[T]],
  600. ) -> List[JSONTree[U]]:
  601. ...
  602. @overload
  603. def json_map_leaves(
  604. func: Callable[[T], U],
  605. value: Tuple[JSONTree[T], ...],
  606. ) -> Tuple[JSONTree[U], ...]:
  607. ...
  608. @overload
  609. def json_map_leaves(
  610. func: Callable[[T], U],
  611. value: JSONTree[T],
  612. ) -> JSONTree[U]:
  613. ...
  614. def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
  615. if isinstance(value, dict):
  616. return {k: json_map_leaves(func, v) for k, v in value.items()}
  617. elif isinstance(value, list):
  618. return [json_map_leaves(func, v) for v in value]
  619. elif isinstance(value, tuple):
  620. return tuple(json_map_leaves(func, v) for v in value)
  621. else:
  622. return func(value)
  623. def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
  624. """Flatten a list of lists to a single list."""
  625. return [item for sublist in lists for item in sublist]
  626. def init_cached_hf_modules() -> None:
  627. """
  628. Lazy initialization of the Hugging Face modules.
  629. """
  630. from transformers.dynamic_module_utils import init_hf_modules
  631. init_hf_modules()
  632. @lru_cache(maxsize=None)
  633. def find_library(lib_name: str) -> str:
  634. """
  635. Find the library file in the system.
  636. `lib_name` is full filename, with both prefix and suffix.
  637. This function resolves `lib_name` to the full path of the library.
  638. """
  639. # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
  640. # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
  641. # `/sbin/ldconfig` should exist in all Linux systems.
  642. # `/sbin/ldconfig` searches the library in the system
  643. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
  644. # each line looks like the following:
  645. # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
  646. locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
  647. # `LD_LIBRARY_PATH` searches the library in the user-defined paths
  648. env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
  649. if not locs and env_ld_library_path:
  650. locs = [
  651. os.path.join(dir, lib_name)
  652. for dir in env_ld_library_path.split(":")
  653. if os.path.exists(os.path.join(dir, lib_name))
  654. ]
  655. if not locs:
  656. raise ValueError(f"Cannot find {lib_name} in the system.")
  657. return locs[0]
  658. def find_nccl_library() -> str:
  659. """
  660. We either use the library file specified by the `APHRODITE_NCCL_SO_PATH`
  661. environment variable, or we find the library file brought by PyTorch.
  662. After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
  663. found by `ctypes` automatically.
  664. """
  665. so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
  666. # manually load the nccl library
  667. if so_file:
  668. logger.debug("Found nccl from environment variable "
  669. f"APHRODITE_NCCL_SO_PATH={so_file}")
  670. else:
  671. if torch.version.cuda is not None:
  672. so_file = "libnccl.so.2"
  673. elif torch.version.hip is not None:
  674. so_file = "librccl.so.1"
  675. else:
  676. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  677. logger.debug(f"Found nccl from library {so_file}")
  678. return so_file
  679. def enable_trace_function_call_for_thread() -> None:
  680. if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
  681. tmp_dir = tempfile.gettempdir()
  682. filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
  683. f"_thread_{threading.get_ident()}_"
  684. f"at_{datetime.datetime.now()}.log").replace(" ", "_")
  685. log_path = os.path.join(tmp_dir, "aphrodite",
  686. get_aphrodite_instance_id(), filename)
  687. os.makedirs(os.path.dirname(log_path), exist_ok=True)
  688. enable_trace_function_call(log_path)
  689. def identity(value: T) -> T:
  690. return value
  691. F = TypeVar('F', bound=Callable[..., Any])
  692. def deprecate_kwargs(
  693. *kws: str,
  694. is_deprecated: Union[bool, Callable[[], bool]] = True,
  695. additional_message: Optional[str] = None) -> Callable[[F], F]:
  696. deprecated_kws = set(kws)
  697. if not callable(is_deprecated):
  698. is_deprecated = partial(identity, is_deprecated)
  699. def wrapper(fn: F) -> F:
  700. @wraps(fn)
  701. def inner(*args, **kwargs):
  702. if is_deprecated():
  703. deprecated_kwargs = kwargs.keys() & deprecated_kws
  704. if deprecated_kwargs:
  705. msg = (
  706. f"The keyword arguments {deprecated_kwargs} are "
  707. "deprecated and will be removed in a future update.")
  708. if additional_message is not None:
  709. msg += f" {additional_message}"
  710. warnings.warn(
  711. DeprecationWarning(msg),
  712. stacklevel=3, # The inner function takes up one level
  713. )
  714. return fn(*args, **kwargs)
  715. return inner # type: ignore
  716. return wrapper
  717. @lru_cache(maxsize=8)
  718. def _cuda_device_count_stateless(
  719. cuda_visible_devices: Optional[str] = None) -> int:
  720. # Note: cuda_visible_devices is not used, but we keep it as an argument for
  721. # LRU Cache purposes.
  722. # Code below is based on
  723. # https://github.com/pytorch/pytorch/blob/
  724. # c1cd946818442aca8c7f812b16d187ce1586c3bc/
  725. # torch/cuda/__init__.py#L831C1-L831C17
  726. import torch.cuda
  727. import torch.version
  728. if not torch.cuda._is_compiled():
  729. return 0
  730. if is_hip():
  731. # ROCm uses amdsmi instead of nvml for stateless device count
  732. # This requires a sufficiently modern version of Torch 2.4.0
  733. raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
  734. torch.cuda, "_device_count_amdsmi")) else -1
  735. else:
  736. raw_count = torch.cuda._device_count_nvml()
  737. r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
  738. return r
  739. def cuda_device_count_stateless() -> int:
  740. """Get number of CUDA devices, caching based on the value of
  741. CUDA_VISIBLE_DEVICES at the time of call.
  742. This should be used instead of torch.cuda.device_count()
  743. unless CUDA_VISIBLE_DEVICES has already been set to the desired
  744. value."""
  745. # This can be removed and simply replaced with torch.cuda.get_device_count
  746. # after https://github.com/pytorch/pytorch/pull/122815 is released.
  747. return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES"))
  748. # NVML utils
  749. # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
  750. # all the related functions work on real physical device ids.
  751. # the major benefit of using NVML is that it will not initialize CUDA
  752. try:
  753. import pynvml
  754. except ImportError:
  755. # For non-NV devices
  756. pynvml = None
  757. def with_nvml_context(fn):
  758. @wraps(fn)
  759. def wrapper(*args, **kwargs):
  760. if pynvml is not None:
  761. pynvml.nvmlInit()
  762. try:
  763. return fn(*args, **kwargs)
  764. finally:
  765. if pynvml is not None:
  766. pynvml.nvmlShutdown()
  767. return wrapper
  768. @with_nvml_context
  769. def is_full_nvlink(device_ids: List[int]) -> bool:
  770. """
  771. query if the set of gpus are fully connected by nvlink (1 hop)
  772. """
  773. handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids]
  774. for i, handle in enumerate(handles):
  775. for j, peer_handle in enumerate(handles):
  776. if i < j:
  777. try:
  778. p2p_status = pynvml.nvmlDeviceGetP2PStatus(
  779. handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK)
  780. if p2p_status != pynvml.NVML_P2P_STATUS_OK:
  781. return False
  782. except pynvml.NVMLError as error:
  783. logger.error(
  784. "NVLink detection failed. This is normal if your"
  785. " machine has no NVLink equipped.",
  786. exc_info=error)
  787. return False
  788. return True
  789. #From: https://stackoverflow.com/a/4104188/2749989
  790. def run_once(f):
  791. def wrapper(*args, **kwargs) -> Any:
  792. if not wrapper.has_run: # type: ignore[attr-defined]
  793. wrapper.has_run = True # type: ignore[attr-defined]
  794. return f(*args, **kwargs)
  795. wrapper.has_run = False # type: ignore[attr-defined]
  796. return wrapper
  797. class FlexibleArgumentParser(argparse.ArgumentParser):
  798. """ArgumentParser that allows both underscore and dash in names."""
  799. def parse_args(self, args=None, namespace=None):
  800. if args is None:
  801. args = sys.argv[1:]
  802. # Convert underscores to dashes and vice versa in argument names
  803. processed_args = []
  804. for arg in args:
  805. if arg.startswith('--'):
  806. if '=' in arg:
  807. key, value = arg.split('=', 1)
  808. key = '--' + key[len('--'):].replace('_', '-')
  809. processed_args.append(f'{key}={value}')
  810. else:
  811. processed_args.append('--' +
  812. arg[len('--'):].replace('_', '-'))
  813. else:
  814. processed_args.append(arg)
  815. return super().parse_args(processed_args, namespace)
  816. async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
  817. **kwargs):
  818. """Utility function to run async task in a lock"""
  819. async with lock:
  820. return await task(*args, **kwargs)