utils.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168
  1. import argparse
  2. import asyncio
  3. import contextlib
  4. import datetime
  5. import enum
  6. import gc
  7. import math
  8. import os
  9. import socket
  10. import subprocess
  11. import sys
  12. import tempfile
  13. import threading
  14. import uuid
  15. import warnings
  16. from asyncio import FIRST_COMPLETED, ensure_future
  17. from functools import lru_cache, partial, wraps
  18. from platform import uname
  19. from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
  20. Hashable, Iterable, List, Literal, Optional, OrderedDict,
  21. Set, Tuple, Type, TypeVar, Union, overload)
  22. from uuid import uuid4
  23. import numpy as np
  24. import numpy.typing as npt
  25. import psutil
  26. import torch
  27. import torch.types
  28. from loguru import logger
  29. from packaging.version import Version
  30. from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
  31. SpinnerColumn, TextColumn, TimeElapsedColumn)
  32. from typing_extensions import ParamSpec, TypeIs, assert_never
  33. import aphrodite.common.envs as envs
  34. from aphrodite.common.logger import enable_trace_function_call
  35. from aphrodite.distributed import get_tensor_model_parallel_rank
  36. # Exception strings for non-implemented encoder/decoder scenarios
  37. STR_NOT_IMPL_ENC_DEC_SWA = \
  38. "Sliding window attention for encoder/decoder models " + \
  39. "is not currently supported."
  40. STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
  41. "Prefix caching for encoder/decoder models " + \
  42. "is not currently supported."
  43. STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
  44. "Chunked prefill for encoder/decoder models " + \
  45. "is not currently supported."
  46. STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
  47. "Models with logits_soft_cap "
  48. "require FlashInfer backend, which is "
  49. "currently not supported for encoder/decoder "
  50. "models.")
  51. STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
  52. "supported with encoder/decoder "
  53. "models.")
  54. STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
  55. "currently supported with "
  56. "encoder/decoder models.")
  57. STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
  58. "supported with encoder/decoder "
  59. "models.")
  60. STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
  61. "currently supported with encoder/"
  62. "decoder models.")
  63. STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
  64. "currently supported with encoder/"
  65. "decoder models.")
  66. STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
  67. "currently supported with encoder/"
  68. "decoder models.")
  69. STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
  70. "currently supported with encoder/"
  71. "decoder models.")
  72. # Efficiently import all enc/dec error strings
  73. # rather than having to import all of the above
  74. STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
  75. "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
  76. "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
  77. "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
  78. STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
  79. "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
  80. "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
  81. "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
  82. "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
  83. "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
  84. "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
  85. "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
  86. "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
  87. }
  88. # Constants related to forcing the attention backend selection
  89. # String name of register which may be set in order to
  90. # force auto-selection of attention backend by Attention
  91. # wrapper
  92. STR_BACKEND_ENV_VAR: str = "APHRODITE_ATTENTION_BACKEND"
  93. # Possible string values of STR_BACKEND_ENV_VAR
  94. # register, corresponding to possible backends
  95. STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
  96. STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
  97. STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
  98. STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
  99. STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
  100. STR_INVALID_VAL: str = "INVALID"
  101. GiB_bytes = 1 << 30
  102. """The number of bytes in one gibibyte (GiB)."""
  103. STR_DTYPE_TO_TORCH_DTYPE = {
  104. "half": torch.half,
  105. "bfloat16": torch.bfloat16,
  106. "float": torch.float,
  107. "fp8": torch.uint8,
  108. "fp8_e4m3": torch.uint8,
  109. "fp8_e5m2": torch.uint8,
  110. }
  111. TORCH_DTYPE_TO_NUMPY_DTYPE = {
  112. torch.float16: np.float16,
  113. torch.float32: np.float32,
  114. torch.float64: np.float64,
  115. torch.uint8: np.uint8,
  116. torch.int32: np.int32,
  117. torch.int64: np.int64,
  118. }
  119. P = ParamSpec('P')
  120. K = TypeVar("K")
  121. T = TypeVar("T")
  122. U = TypeVar("U")
  123. class _Sentinel:
  124. ...
  125. ALL_PINNED_SENTINEL = _Sentinel()
  126. class Device(enum.Enum):
  127. GPU = enum.auto()
  128. CPU = enum.auto()
  129. class Counter:
  130. def __init__(self, start: int = 0) -> None:
  131. self.counter = start
  132. def __next__(self) -> int:
  133. i = self.counter
  134. self.counter += 1
  135. return i
  136. def reset(self) -> None:
  137. self.counter = 0
  138. class LRUCache(Generic[T]):
  139. def __init__(self, capacity: int):
  140. self.cache: OrderedDict[Hashable, T] = OrderedDict()
  141. self.pinned_items: Set[Hashable] = set()
  142. self.capacity = capacity
  143. def __contains__(self, key: Hashable) -> bool:
  144. return key in self.cache
  145. def __len__(self) -> int:
  146. return len(self.cache)
  147. def __getitem__(self, key: Hashable) -> T:
  148. value = self.cache[key] # Raise KeyError if not exists
  149. self.cache.move_to_end(key)
  150. return value
  151. def __setitem__(self, key: Hashable, value: T) -> None:
  152. self.put(key, value)
  153. def __delitem__(self, key: Hashable) -> None:
  154. self.pop(key)
  155. def touch(self, key: Hashable) -> None:
  156. self.cache.move_to_end(key)
  157. def get(self,
  158. key: Hashable,
  159. default_value: Optional[T] = None) -> Optional[T]:
  160. value: Optional[T]
  161. if key in self.cache:
  162. value = self.cache[key]
  163. self.cache.move_to_end(key)
  164. else:
  165. value = default_value
  166. return value
  167. def put(self, key: Hashable, value: T) -> None:
  168. self.cache[key] = value
  169. self.cache.move_to_end(key)
  170. self._remove_old_if_needed()
  171. def pin(self, key: Hashable) -> None:
  172. """
  173. Pins a key in the cache preventing it from being
  174. evicted in the LRU order.
  175. """
  176. if key not in self.cache:
  177. raise ValueError(f"Cannot pin key: {key} not in cache.")
  178. self.pinned_items.add(key)
  179. def _unpin(self, key: Hashable) -> None:
  180. self.pinned_items.remove(key)
  181. def _on_remove(self, key: Hashable, value: Optional[T]):
  182. pass
  183. def remove_oldest(self, remove_pinned=False):
  184. if not self.cache:
  185. return
  186. if not remove_pinned:
  187. # pop the oldest item in the cache that is not pinned
  188. lru_key = next(
  189. (key for key in self.cache if key not in self.pinned_items),
  190. ALL_PINNED_SENTINEL)
  191. if lru_key is ALL_PINNED_SENTINEL:
  192. raise RuntimeError("All items are pinned, "
  193. "cannot remove oldest from the cache.")
  194. else:
  195. lru_key = next(iter(self.cache))
  196. self.pop(lru_key)
  197. def _remove_old_if_needed(self) -> None:
  198. while len(self.cache) > self.capacity:
  199. self.remove_oldest()
  200. def pop(self,
  201. key: Hashable,
  202. default_value: Optional[T] = None) -> Optional[T]:
  203. run_on_remove = key in self.cache
  204. value: Optional[T] = self.cache.pop(key, default_value)
  205. # remove from pinned items
  206. if key in self.pinned_items:
  207. self._unpin(key)
  208. if run_on_remove:
  209. self._on_remove(key, value)
  210. return value
  211. def clear(self):
  212. while len(self.cache) > 0:
  213. self.remove_oldest(remove_pinned=True)
  214. self.cache.clear()
  215. class PyObjectCache:
  216. """Used to cache python objects to avoid object allocations
  217. across scheduler iterations.
  218. """
  219. def __init__(self, obj_builder):
  220. self._obj_builder = obj_builder
  221. self._index = 0
  222. self._obj_cache = []
  223. for _ in range(128):
  224. self._obj_cache.append(self._obj_builder())
  225. def _grow_cache(self):
  226. # Double the size of the cache
  227. num_objs = len(self._obj_cache)
  228. for _ in range(num_objs):
  229. self._obj_cache.append(self._obj_builder())
  230. def get_object(self):
  231. """Returns a pre-allocated cached object. If there is not enough
  232. objects, then the cache size will double.
  233. """
  234. if self._index >= len(self._obj_cache):
  235. self._grow_cache()
  236. assert self._index < len(self._obj_cache)
  237. obj = self._obj_cache[self._index]
  238. self._index += 1
  239. return obj
  240. def reset(self):
  241. """Makes all cached-objects available for the next scheduler iteration.
  242. """
  243. self._index = 0
  244. def is_hip() -> bool:
  245. return torch.version.hip is not None
  246. @lru_cache(maxsize=None)
  247. def is_cpu() -> bool:
  248. from importlib.metadata import PackageNotFoundError, version
  249. try:
  250. return "cpu" in version("aphrodite-engine")
  251. except PackageNotFoundError:
  252. return False
  253. @lru_cache(maxsize=None)
  254. def is_openvino() -> bool:
  255. from importlib.metadata import PackageNotFoundError, version
  256. try:
  257. return "openvino" in version("aphrodite-engine")
  258. except PackageNotFoundError:
  259. return False
  260. @lru_cache(maxsize=None)
  261. def is_neuron() -> bool:
  262. try:
  263. import transformers_neuronx
  264. except ImportError:
  265. transformers_neuronx = None
  266. return transformers_neuronx is not None
  267. @lru_cache(maxsize=None)
  268. def is_xpu() -> bool:
  269. from importlib.metadata import version
  270. is_xpu_flag = "xpu" in version("aphrodite-engine")
  271. # aphrodite is not build with xpu
  272. if not is_xpu_flag:
  273. return False
  274. try:
  275. import intel_extension_for_pytorch as ipex # noqa: F401
  276. _import_ipex = True
  277. except ImportError as e:
  278. logger.warning(f"Import Error for IPEX: {e.msg}")
  279. _import_ipex = False
  280. # ipex dependency is not ready
  281. if not _import_ipex:
  282. logger.warning("not found ipex lib")
  283. return False
  284. return hasattr(torch, "xpu") and torch.xpu.is_available()
  285. @lru_cache(maxsize=None)
  286. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  287. """Returns the maximum shared memory per thread block in bytes."""
  288. from aphrodite import _custom_ops as ops
  289. max_shared_mem = (
  290. ops.get_max_shared_memory_per_block_device_attribute(gpu))
  291. # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
  292. # will fail
  293. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  294. return int(max_shared_mem)
  295. def get_cpu_memory() -> int:
  296. """Returns the total CPU memory of the node in bytes."""
  297. return psutil.virtual_memory().total
  298. def random_uuid() -> str:
  299. return str(uuid.uuid4().hex)
  300. @lru_cache(maxsize=None)
  301. def get_aphrodite_instance_id():
  302. """
  303. If the environment variable APHRODITE_INSTANCE_ID is set, return it.
  304. Otherwise, return a random UUID.
  305. Instance id represents an instance of the Aphrodite. All processes in the
  306. same instance should have the same instance id.
  307. """
  308. return envs.APHRODITE_INSTANCE_ID or f"aphrodite-instance-{random_uuid()}"
  309. @lru_cache(maxsize=None)
  310. def in_wsl() -> bool:
  311. # Reference: https://github.com/microsoft/WSL/issues/4071
  312. return "microsoft" in " ".join(uname()).lower()
  313. @lru_cache(maxsize=None)
  314. def in_windows() -> bool:
  315. return sys.platform.startswith("win32")
  316. def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
  317. """Take a blocking function, and run it on in an executor thread.
  318. This function prevents the blocking function from blocking the
  319. asyncio event loop.
  320. The code in this function needs to be thread safe.
  321. """
  322. def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
  323. loop = asyncio.get_event_loop()
  324. p_func = partial(func, *args, **kwargs)
  325. return loop.run_in_executor(executor=None, func=p_func)
  326. return _async_wrapper
  327. async def iterate_with_cancellation(
  328. iterator: AsyncGenerator[T, None],
  329. is_cancelled: Callable[[], Awaitable[bool]],
  330. ) -> AsyncGenerator[T, None]:
  331. """Convert async iterator into one that polls the provided function
  332. at least once per second to check for client cancellation.
  333. """
  334. # Can use anext() in python >= 3.10
  335. awaits = [ensure_future(iterator.__anext__())]
  336. while True:
  337. done, pending = await asyncio.wait(awaits, timeout=1)
  338. if await is_cancelled():
  339. with contextlib.suppress(BaseException):
  340. awaits[0].cancel()
  341. await iterator.aclose()
  342. raise asyncio.CancelledError("client cancelled")
  343. if done:
  344. try:
  345. item = await awaits[0]
  346. awaits[0] = ensure_future(iterator.__anext__())
  347. yield item
  348. except StopAsyncIteration:
  349. # we are done
  350. return
  351. async def merge_async_iterators(
  352. *iterators: AsyncGenerator[T, None],
  353. is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
  354. ) -> AsyncGenerator[Tuple[int, T], None]:
  355. """Merge multiple asynchronous iterators into a single iterator.
  356. This method handle the case where some iterators finish before others.
  357. When it yields, it yields a tuple (i, item) where i is the index of the
  358. iterator that yields the item.
  359. It also optionally polls a provided function at least once per second
  360. to check for client cancellation.
  361. """
  362. # Can use anext() in python >= 3.10
  363. awaits = {
  364. ensure_future(pair[1].__anext__()): pair
  365. for pair in enumerate(iterators)
  366. }
  367. timeout = None if is_cancelled is None else 1
  368. try:
  369. while awaits:
  370. done, pending = await asyncio.wait(awaits.keys(),
  371. return_when=FIRST_COMPLETED,
  372. timeout=timeout)
  373. if is_cancelled is not None and await is_cancelled():
  374. raise asyncio.CancelledError("client cancelled")
  375. for d in done:
  376. pair = awaits.pop(d)
  377. try:
  378. item = await d
  379. i, it = pair
  380. awaits[ensure_future(it.__anext__())] = pair
  381. yield i, item
  382. except StopAsyncIteration:
  383. pass
  384. finally:
  385. # Cancel any remaining iterators
  386. for f, (_, it) in awaits.items():
  387. with contextlib.suppress(BaseException):
  388. f.cancel()
  389. await it.aclose()
  390. def get_ip() -> str:
  391. host_ip = os.environ.get("HOST_IP")
  392. if host_ip:
  393. return host_ip
  394. # IP is not set, try to get it from the network interface
  395. # try ipv4
  396. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  397. try:
  398. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  399. return s.getsockname()[0]
  400. except Exception:
  401. pass
  402. # try ipv6
  403. try:
  404. s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  405. # Google's public DNS server, see
  406. # https://developers.google.com/speed/public-dns/docs/using#addresses
  407. s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
  408. return s.getsockname()[0]
  409. except Exception:
  410. pass
  411. warnings.warn(
  412. "Failed to get the IP address, using 0.0.0.0 by default."
  413. "The value can be set by the environment variable HOST_IP.",
  414. stacklevel=2)
  415. return "0.0.0.0"
  416. def get_distributed_init_method(ip: str, port: int) -> str:
  417. # Brackets are not permitted in ipv4 addresses,
  418. # see https://github.com/python/cpython/issues/103848
  419. return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
  420. def get_open_zmq_ipc_path() -> str:
  421. if not in_windows():
  422. base_rpc_path = envs.APHRODITE_RPC_BASE_PATH
  423. return f"ipc://{base_rpc_path}/{uuid4()}"
  424. else:
  425. # windows doesn't support ipc://
  426. # use tcp:// instead
  427. return f"tcp://127.0.0.1:{get_open_port()}"
  428. def get_open_port(port: Optional[int] = None) -> int:
  429. port = envs.APHRODITE_PORT
  430. if port is not None:
  431. while True:
  432. try:
  433. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  434. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  435. s.bind(("", port))
  436. return port
  437. except OSError:
  438. port += 1 # Increment port number if already in use
  439. logger.info(f"Port {port - 1} is already in use, trying port "
  440. f"{port}")
  441. # try ipv4
  442. try:
  443. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  444. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  445. s.bind(("", 0))
  446. return s.getsockname()[1]
  447. except OSError:
  448. # try ipv6
  449. with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
  450. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  451. s.bind(("", 0))
  452. return s.getsockname()[1]
  453. def find_process_using_port(port: int) -> Optional[psutil.Process]:
  454. for conn in psutil.net_connections():
  455. if conn.laddr.port == port:
  456. try:
  457. return psutil.Process(conn.pid)
  458. except psutil.NoSuchProcess:
  459. return None
  460. return None
  461. def update_environment_variables(envs: Dict[str, str]):
  462. for k, v in envs.items():
  463. if k in os.environ and os.environ[k] != v:
  464. logger.warning(f"Overwriting environment variable {k} "
  465. f"from '{os.environ[k]}' to '{v}'")
  466. os.environ[k] = v
  467. def chunk_list(lst: List[T], chunk_size: int):
  468. """Yield successive chunk_size chunks from lst."""
  469. for i in range(0, len(lst), chunk_size):
  470. yield lst[i:i + chunk_size]
  471. def cdiv(a: int, b: int) -> int:
  472. """Ceiling division."""
  473. return -(a // -b)
  474. def _generate_random_fp8(
  475. tensor: torch.Tensor,
  476. low: float,
  477. high: float,
  478. ) -> None:
  479. # NOTE: Due to NaN and Inf representation for fp8 data type,
  480. # it may occur Inf or NaN if we directly use torch.randint
  481. # to generate random data for fp8 data.
  482. # For example, s.11111.00 in fp8e5m2 format represents Inf.
  483. # | E4M3 | E5M2
  484. #-----|-------------|-------------------
  485. # Inf | N/A | s.11111.00
  486. # NaN | s.1111.111 | s.11111.{01,10,11}
  487. from aphrodite import _custom_ops as ops
  488. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  489. tensor_tmp.uniform_(low, high)
  490. ops.convert_fp8(tensor, tensor_tmp)
  491. del tensor_tmp
  492. def get_kv_cache_torch_dtype(
  493. cache_dtype: Optional[Union[str, torch.dtype]],
  494. model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
  495. if isinstance(cache_dtype, str):
  496. if cache_dtype == "auto":
  497. if isinstance(model_dtype, str):
  498. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  499. elif isinstance(model_dtype, torch.dtype):
  500. torch_dtype = model_dtype
  501. else:
  502. raise ValueError(f"Invalid model dtype: {model_dtype}")
  503. elif cache_dtype in ["half", "bfloat16", "float"]:
  504. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  505. elif cache_dtype == "fp8":
  506. torch_dtype = torch.uint8
  507. else:
  508. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  509. elif isinstance(cache_dtype, torch.dtype):
  510. torch_dtype = cache_dtype
  511. else:
  512. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  513. return torch_dtype
  514. def create_kv_caches_with_random_flash(
  515. num_blocks: int,
  516. block_size: int,
  517. num_layers: int,
  518. num_heads: int,
  519. head_size: int,
  520. cache_dtype: Optional[Union[str, torch.dtype]],
  521. model_dtype: Optional[Union[str, torch.dtype]] = None,
  522. seed: int = 0,
  523. device: Optional[str] = "cuda",
  524. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  525. torch.random.manual_seed(seed)
  526. if torch.cuda.is_available():
  527. torch.cuda.manual_seed(seed)
  528. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  529. key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
  530. scale = head_size**-0.5
  531. key_caches: List[torch.Tensor] = []
  532. value_caches: List[torch.Tensor] = []
  533. for _ in range(num_layers):
  534. key_value_cache = torch.empty(size=key_value_cache_shape,
  535. dtype=torch_dtype,
  536. device=device)
  537. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  538. key_value_cache.uniform_(-scale, scale)
  539. elif cache_dtype == 'fp8':
  540. _generate_random_fp8(key_value_cache, -scale, scale)
  541. else:
  542. raise ValueError(
  543. f"Does not support key cache of type {cache_dtype}")
  544. key_caches.append(key_value_cache[:, 0])
  545. value_caches.append(key_value_cache[:, 1])
  546. return key_caches, value_caches
  547. def create_kv_caches_with_random(
  548. num_blocks: int,
  549. block_size: int,
  550. num_layers: int,
  551. num_heads: int,
  552. head_size: int,
  553. cache_dtype: Optional[Union[str, torch.dtype]],
  554. model_dtype: Optional[Union[str, torch.dtype]] = None,
  555. seed: int = 0,
  556. device: Optional[str] = "cuda",
  557. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  558. if cache_dtype == "fp8" and head_size % 16:
  559. raise ValueError(
  560. f"Does not support key cache of type fp8 with head_size "
  561. f"{head_size}")
  562. torch.random.manual_seed(seed)
  563. if torch.cuda.is_available():
  564. torch.cuda.manual_seed(seed)
  565. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  566. scale = head_size**-0.5
  567. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  568. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  569. key_caches: List[torch.Tensor] = []
  570. for _ in range(num_layers):
  571. key_cache = torch.empty(size=key_cache_shape,
  572. dtype=torch_dtype,
  573. device=device)
  574. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  575. key_cache.uniform_(-scale, scale)
  576. elif cache_dtype == 'fp8':
  577. _generate_random_fp8(key_cache, -scale, scale)
  578. else:
  579. raise ValueError(
  580. f"Does not support key cache of type {cache_dtype}")
  581. key_caches.append(key_cache)
  582. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  583. value_caches: List[torch.Tensor] = []
  584. for _ in range(num_layers):
  585. value_cache = torch.empty(size=value_cache_shape,
  586. dtype=torch_dtype,
  587. device=device)
  588. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  589. value_cache.uniform_(-scale, scale)
  590. elif cache_dtype == 'fp8':
  591. _generate_random_fp8(value_cache, -scale, scale)
  592. else:
  593. raise ValueError(
  594. f"Does not support value cache of type {cache_dtype}")
  595. value_caches.append(value_cache)
  596. return key_caches, value_caches
  597. @lru_cache
  598. def print_warning_once(msg: str) -> None:
  599. logger.warning(msg)
  600. @lru_cache(maxsize=None)
  601. def is_pin_memory_available() -> bool:
  602. if in_wsl():
  603. # Pinning memory in WSL is not supported.
  604. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
  605. print_warning_once("Using 'pin_memory=False' as WSL is detected. "
  606. "This may slow down the performance.")
  607. return False
  608. elif is_xpu():
  609. print_warning_once("Pin memory is not supported on XPU.")
  610. return False
  611. elif is_neuron():
  612. print_warning_once("Pin memory is not supported on Neuron.")
  613. return False
  614. elif is_cpu() or is_openvino():
  615. return False
  616. return True
  617. class CudaMemoryProfiler:
  618. def __init__(self, device: Optional[torch.types.Device] = None):
  619. self.device = device
  620. def current_memory_usage(self) -> float:
  621. # Return the memory usage in bytes.
  622. if torch.cuda.is_available():
  623. torch.cuda.reset_peak_memory_stats(self.device)
  624. mem = torch.cuda.max_memory_allocated(self.device)
  625. elif is_xpu():
  626. torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
  627. mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
  628. return mem
  629. def __enter__(self):
  630. self.initial_memory = self.current_memory_usage()
  631. # This allows us to call methods of the context manager if needed
  632. return self
  633. def __exit__(self, exc_type, exc_val, exc_tb):
  634. self.final_memory = self.current_memory_usage()
  635. self.consumed_memory = self.final_memory - self.initial_memory
  636. # Force garbage collection
  637. gc.collect()
  638. def make_ndarray_with_pad(
  639. x: List[List[T]],
  640. pad: T,
  641. dtype: npt.DTypeLike,
  642. *,
  643. max_len: Optional[int] = None,
  644. ) -> npt.NDArray:
  645. """
  646. Make a padded array from 2D inputs.
  647. The padding is applied to the end of each inner list until it reaches
  648. `max_len`.
  649. """
  650. if max_len is None:
  651. # Unlike for most functions, map is faster than a genexpr over `len`
  652. max_len = max(map(len, x), default=0)
  653. padded_x = np.full((len(x), max_len), pad, dtype=dtype)
  654. for ind, blocktb in enumerate(x):
  655. assert len(blocktb) <= max_len
  656. padded_x[ind, :len(blocktb)] = blocktb
  657. return padded_x
  658. def make_tensor_with_pad(
  659. x: List[List[T]],
  660. pad: T,
  661. dtype: torch.dtype,
  662. *,
  663. max_len: Optional[int] = None,
  664. device: Optional[Union[str, torch.device]] = None,
  665. pin_memory: bool = False,
  666. ) -> torch.Tensor:
  667. """
  668. Make a padded tensor from 2D inputs.
  669. The padding is applied to the end of each inner list until it reaches
  670. `max_len`.
  671. """
  672. np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
  673. padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
  674. tensor = torch.from_numpy(padded_x).to(device)
  675. if pin_memory:
  676. tensor = tensor.pin_memory()
  677. return tensor
  678. def async_tensor_h2d(
  679. data: list,
  680. dtype: torch.dtype,
  681. target_device: Union[str, torch.device],
  682. pin_memory: bool,
  683. ) -> torch.Tensor:
  684. """Asynchronously create a tensor and copy it from host to device."""
  685. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  686. return t.to(device=target_device, non_blocking=True)
  687. def maybe_expand_dim(tensor: torch.Tensor,
  688. target_dims: int,
  689. size: int = 1) -> torch.Tensor:
  690. """Expand the tensor to the target_dims."""
  691. if tensor.ndim < target_dims:
  692. tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
  693. return tensor
  694. def get_dtype_size(dtype: torch.dtype) -> int:
  695. """Get the size of the data type in bytes."""
  696. return torch.tensor([], dtype=dtype).element_size()
  697. # `collections` helpers
  698. def is_list_of(
  699. value: object,
  700. typ: Type[T],
  701. *,
  702. check: Literal["first", "all"] = "first",
  703. ) -> TypeIs[List[T]]:
  704. if not isinstance(value, list):
  705. return False
  706. if check == "first":
  707. return len(value) == 0 or isinstance(value[0], typ)
  708. elif check == "all":
  709. return all(isinstance(v, typ) for v in value)
  710. assert_never(check)
  711. JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
  712. Tuple["JSONTree[T]", ...], T]
  713. """A nested JSON structure where the leaves need not be JSON-serializable."""
  714. @overload
  715. def json_map_leaves(
  716. func: Callable[[T], U],
  717. value: Dict[str, JSONTree[T]],
  718. ) -> Dict[str, JSONTree[U]]:
  719. ...
  720. @overload
  721. def json_map_leaves(
  722. func: Callable[[T], U],
  723. value: List[JSONTree[T]],
  724. ) -> List[JSONTree[U]]:
  725. ...
  726. @overload
  727. def json_map_leaves(
  728. func: Callable[[T], U],
  729. value: Tuple[JSONTree[T], ...],
  730. ) -> Tuple[JSONTree[U], ...]:
  731. ...
  732. @overload
  733. def json_map_leaves(
  734. func: Callable[[T], U],
  735. value: JSONTree[T],
  736. ) -> JSONTree[U]:
  737. ...
  738. def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
  739. if isinstance(value, dict):
  740. return {k: json_map_leaves(func, v) for k, v in value.items()}
  741. elif isinstance(value, list):
  742. return [json_map_leaves(func, v) for v in value]
  743. elif isinstance(value, tuple):
  744. return tuple(json_map_leaves(func, v) for v in value)
  745. else:
  746. return func(value)
  747. def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
  748. """Flatten a list of lists to a single list."""
  749. return [item for sublist in lists for item in sublist]
  750. def init_cached_hf_modules() -> None:
  751. """
  752. Lazy initialization of the Hugging Face modules.
  753. """
  754. from transformers.dynamic_module_utils import init_hf_modules
  755. init_hf_modules()
  756. @lru_cache(maxsize=None)
  757. def find_library(lib_name: str) -> str:
  758. """
  759. Find the library file in the system.
  760. `lib_name` is full filename, with both prefix and suffix.
  761. This function resolves `lib_name` to the full path of the library.
  762. """
  763. # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
  764. # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
  765. # `/sbin/ldconfig` should exist in all Linux systems.
  766. # `/sbin/ldconfig` searches the library in the system
  767. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
  768. # each line looks like the following:
  769. # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
  770. locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
  771. # `LD_LIBRARY_PATH` searches the library in the user-defined paths
  772. env_ld_library_path = envs.LD_LIBRARY_PATH
  773. if not locs and env_ld_library_path:
  774. locs = [
  775. os.path.join(dir, lib_name)
  776. for dir in env_ld_library_path.split(":")
  777. if os.path.exists(os.path.join(dir, lib_name))
  778. ]
  779. if not locs:
  780. raise ValueError(f"Cannot find {lib_name} in the system.")
  781. return locs[0]
  782. def find_nccl_library() -> str:
  783. """
  784. We either use the library file specified by the `APHRODITE_NCCL_SO_PATH`
  785. environment variable, or we find the library file brought by PyTorch.
  786. After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
  787. found by `ctypes` automatically.
  788. """
  789. so_file = envs.APHRODITE_NCCL_SO_PATH
  790. # manually load the nccl library
  791. if so_file:
  792. logger.debug("Found nccl from environment variable "
  793. f"APHRODITE_NCCL_SO_PATH={so_file}")
  794. else:
  795. if torch.version.cuda is not None:
  796. so_file = "libnccl.so.2"
  797. elif torch.version.hip is not None:
  798. so_file = "librccl.so.1"
  799. else:
  800. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  801. logger.debug(f"Found nccl from library {so_file}")
  802. return so_file
  803. def enable_trace_function_call_for_thread() -> None:
  804. if envs.APHRODITE_TRACE_FUNCTION:
  805. tmp_dir = tempfile.gettempdir()
  806. filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
  807. f"_thread_{threading.get_ident()}_"
  808. f"at_{datetime.datetime.now()}.log").replace(" ", "_")
  809. log_path = os.path.join(tmp_dir, "aphrodite",
  810. get_aphrodite_instance_id(), filename)
  811. os.makedirs(os.path.dirname(log_path), exist_ok=True)
  812. enable_trace_function_call(log_path)
  813. def identity(value: T) -> T:
  814. return value
  815. F = TypeVar('F', bound=Callable[..., Any])
  816. def deprecate_kwargs(
  817. *kws: str,
  818. is_deprecated: Union[bool, Callable[[], bool]] = True,
  819. additional_message: Optional[str] = None) -> Callable[[F], F]:
  820. deprecated_kws = set(kws)
  821. if not callable(is_deprecated):
  822. is_deprecated = partial(identity, is_deprecated)
  823. def wrapper(fn: F) -> F:
  824. @wraps(fn)
  825. def inner(*args, **kwargs):
  826. if is_deprecated():
  827. deprecated_kwargs = kwargs.keys() & deprecated_kws
  828. if deprecated_kwargs:
  829. msg = (
  830. f"The keyword arguments {deprecated_kwargs} are "
  831. "deprecated and will be removed in a future update.")
  832. if additional_message is not None:
  833. msg += f" {additional_message}"
  834. warnings.warn(
  835. DeprecationWarning(msg),
  836. stacklevel=3, # The inner function takes up one level
  837. )
  838. return fn(*args, **kwargs)
  839. return inner # type: ignore
  840. return wrapper
  841. @lru_cache(maxsize=8)
  842. def _cuda_device_count_stateless(
  843. cuda_visible_devices: Optional[str] = None) -> int:
  844. # Note: cuda_visible_devices is not used, but we keep it as an argument for
  845. # LRU Cache purposes.
  846. # Code below is based on
  847. # https://github.com/pytorch/pytorch/blob/
  848. # c1cd946818442aca8c7f812b16d187ce1586c3bc/
  849. # torch/cuda/__init__.py#L831C1-L831C17
  850. import torch.cuda
  851. import torch.version
  852. if not torch.cuda._is_compiled():
  853. return 0
  854. if is_hip():
  855. # ROCm uses amdsmi instead of nvml for stateless device count
  856. # This requires a sufficiently modern version of Torch 2.4.0
  857. raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
  858. torch.cuda, "_device_count_amdsmi")) else -1
  859. else:
  860. raw_count = torch.cuda._device_count_nvml()
  861. r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
  862. return r
  863. def cuda_device_count_stateless() -> int:
  864. """Get number of CUDA devices, caching based on the value of
  865. CUDA_VISIBLE_DEVICES at the time of call.
  866. This should be used instead of torch.cuda.device_count()
  867. unless CUDA_VISIBLE_DEVICES has already been set to the desired
  868. value."""
  869. # This can be removed and simply replaced with torch.cuda.get_device_count
  870. # after https://github.com/pytorch/pytorch/pull/122815 is released.
  871. return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
  872. #From: https://stackoverflow.com/a/4104188/2749989
  873. def run_once(f):
  874. def wrapper(*args, **kwargs) -> Any:
  875. if not wrapper.has_run: # type: ignore[attr-defined]
  876. wrapper.has_run = True # type: ignore[attr-defined]
  877. return f(*args, **kwargs)
  878. wrapper.has_run = False # type: ignore[attr-defined]
  879. return wrapper
  880. class FlexibleArgumentParser(argparse.ArgumentParser):
  881. """ArgumentParser that allows both underscore and dash in names."""
  882. def parse_args(self, args=None, namespace=None):
  883. if args is None:
  884. args = sys.argv[1:]
  885. # Convert underscores to dashes and vice versa in argument names
  886. processed_args = []
  887. for arg in args:
  888. if arg.startswith('--'):
  889. if '=' in arg:
  890. key, value = arg.split('=', 1)
  891. key = '--' + key[len('--'):].replace('_', '-')
  892. processed_args.append(f'{key}={value}')
  893. else:
  894. processed_args.append('--' +
  895. arg[len('--'):].replace('_', '-'))
  896. else:
  897. processed_args.append(arg)
  898. return super().parse_args(processed_args, namespace)
  899. async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
  900. **kwargs):
  901. """Utility function to run async task in a lock"""
  902. async with lock:
  903. return await task(*args, **kwargs)
  904. def progress_bar(iterable, desc="Processing"):
  905. show_progress = get_tensor_model_parallel_rank() == 0
  906. if show_progress:
  907. with Progress(
  908. SpinnerColumn(),
  909. TextColumn("[progress.description]{task.description}"),
  910. BarColumn(),
  911. MofNCompleteColumn(),
  912. TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
  913. TimeElapsedColumn(),
  914. ) as progress:
  915. task = progress.add_task(f"[cyan]{desc}", total=len(iterable))
  916. for item in iterable:
  917. yield item
  918. progress.update(task, advance=1)
  919. else:
  920. yield from iterable
  921. def tensor_progress_bar(iterable:Iterable[Tuple[str, torch.Tensor]],
  922. final_bytes:int, desc="Processing"):
  923. show_progress = get_tensor_model_parallel_rank() == 0
  924. units = 1024 ** (int(math.log2(final_bytes)) // 10)
  925. if show_progress:
  926. with Progress(
  927. SpinnerColumn(),
  928. TextColumn("[progress.description]{task.description}"),
  929. BarColumn(),
  930. # MofNCompleteColumn(),
  931. TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
  932. TextColumn("{task.completed:.2f}/{task.total:.2f} GiB"),
  933. TimeElapsedColumn(),
  934. ) as progress:
  935. task = progress.add_task(f"[cyan]{desc}", total=final_bytes/units)
  936. for item in iterable:
  937. steps = item[1].element_size() * item[1].nelement() / units
  938. yield item
  939. progress.update(task, advance=steps)
  940. else:
  941. yield from iterable
  942. # Using dynamo with Aphrodite doesn't really work well with PyTorch
  943. # versions < 2.4.0.
  944. # In particular, the FakeScalarType is not supported for earlier versions of
  945. # PyTorch which breaks dynamo for any ops registered using ScalarType.
  946. def supports_dynamo() -> bool:
  947. base_torch_version = Version(Version(torch.__version__).base_version)
  948. return base_torch_version >= Version("2.4.0")