utils.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119
  1. import argparse
  2. import asyncio
  3. import contextlib
  4. import datetime
  5. import enum
  6. import gc
  7. import os
  8. import socket
  9. import subprocess
  10. import sys
  11. import tempfile
  12. import threading
  13. import uuid
  14. import warnings
  15. from asyncio import FIRST_COMPLETED, ensure_future
  16. from functools import lru_cache, partial, wraps
  17. from platform import uname
  18. from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
  19. Hashable, List, Literal, Optional, OrderedDict, Set, Tuple,
  20. Type, TypeVar, Union, overload)
  21. from uuid import uuid4
  22. import numpy as np
  23. import numpy.typing as npt
  24. import psutil
  25. import torch
  26. import torch.types
  27. from loguru import logger
  28. from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
  29. SpinnerColumn, TextColumn, TimeElapsedColumn)
  30. from typing_extensions import ParamSpec, TypeIs, assert_never
  31. from aphrodite.common.logger import enable_trace_function_call
  32. from aphrodite.distributed import get_tensor_model_parallel_rank
  33. # Exception strings for non-implemented encoder/decoder scenarios
  34. STR_NOT_IMPL_ENC_DEC_SWA = \
  35. "Sliding window attention for encoder/decoder models " + \
  36. "is not currently supported."
  37. STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
  38. "Prefix caching for encoder/decoder models " + \
  39. "is not currently supported."
  40. STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
  41. "Chunked prefill for encoder/decoder models " + \
  42. "is not currently supported."
  43. STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
  44. "Models with logits_soft_cap "
  45. "require FlashInfer backend, which is "
  46. "currently not supported for encoder/decoder "
  47. "models.")
  48. STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
  49. "supported with encoder/decoder "
  50. "models.")
  51. STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
  52. "currently supported with "
  53. "encoder/decoder models.")
  54. STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
  55. "supported with encoder/decoder "
  56. "models.")
  57. STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
  58. "currently supported with encoder/"
  59. "decoder models.")
  60. STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
  61. "currently supported with encoder/"
  62. "decoder models.")
  63. STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
  64. "currently supported with encoder/"
  65. "decoder models.")
  66. STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
  67. "currently supported with encoder/"
  68. "decoder models.")
  69. # Efficiently import all enc/dec error strings
  70. # rather than having to import all of the above
  71. STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
  72. "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
  73. "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
  74. "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
  75. STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
  76. "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
  77. "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
  78. "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
  79. "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
  80. "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
  81. "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
  82. "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
  83. "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
  84. }
  85. # Constants related to forcing the attention backend selection
  86. # String name of register which may be set in order to
  87. # force auto-selection of attention backend by Attention
  88. # wrapper
  89. STR_BACKEND_ENV_VAR: str = "APHRODITE_ATTENTION_BACKEND"
  90. # Possible string values of STR_BACKEND_ENV_VAR
  91. # register, corresponding to possible backends
  92. STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
  93. STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
  94. STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
  95. STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
  96. STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
  97. STR_INVALID_VAL: str = "INVALID"
  98. GiB_bytes = 1 << 30
  99. """The number of bytes in one gibibyte (GiB)."""
  100. STR_DTYPE_TO_TORCH_DTYPE = {
  101. "half": torch.half,
  102. "bfloat16": torch.bfloat16,
  103. "float": torch.float,
  104. "fp8": torch.uint8,
  105. "fp8_e4m3": torch.uint8,
  106. "fp8_e5m2": torch.uint8,
  107. }
  108. TORCH_DTYPE_TO_NUMPY_DTYPE = {
  109. torch.float16: np.float16,
  110. torch.float32: np.float32,
  111. torch.float64: np.float64,
  112. torch.uint8: np.uint8,
  113. torch.int32: np.int32,
  114. torch.int64: np.int64,
  115. }
  116. P = ParamSpec('P')
  117. K = TypeVar("K")
  118. T = TypeVar("T")
  119. U = TypeVar("U")
  120. class _Sentinel:
  121. ...
  122. ALL_PINNED_SENTINEL = _Sentinel()
  123. class Device(enum.Enum):
  124. GPU = enum.auto()
  125. CPU = enum.auto()
  126. class Counter:
  127. def __init__(self, start: int = 0) -> None:
  128. self.counter = start
  129. def __next__(self) -> int:
  130. i = self.counter
  131. self.counter += 1
  132. return i
  133. def reset(self) -> None:
  134. self.counter = 0
  135. class LRUCache(Generic[T]):
  136. def __init__(self, capacity: int):
  137. self.cache: OrderedDict[Hashable, T] = OrderedDict()
  138. self.pinned_items: Set[Hashable] = set()
  139. self.capacity = capacity
  140. def __contains__(self, key: Hashable) -> bool:
  141. return key in self.cache
  142. def __len__(self) -> int:
  143. return len(self.cache)
  144. def __getitem__(self, key: Hashable) -> T:
  145. value = self.cache[key] # Raise KeyError if not exists
  146. self.cache.move_to_end(key)
  147. return value
  148. def __setitem__(self, key: Hashable, value: T) -> None:
  149. self.put(key, value)
  150. def __delitem__(self, key: Hashable) -> None:
  151. self.pop(key)
  152. def touch(self, key: Hashable) -> None:
  153. self.cache.move_to_end(key)
  154. def get(self,
  155. key: Hashable,
  156. default_value: Optional[T] = None) -> Optional[T]:
  157. value: Optional[T]
  158. if key in self.cache:
  159. value = self.cache[key]
  160. self.cache.move_to_end(key)
  161. else:
  162. value = default_value
  163. return value
  164. def put(self, key: Hashable, value: T) -> None:
  165. self.cache[key] = value
  166. self.cache.move_to_end(key)
  167. self._remove_old_if_needed()
  168. def pin(self, key: Hashable) -> None:
  169. """
  170. Pins a key in the cache preventing it from being
  171. evicted in the LRU order.
  172. """
  173. if key not in self.cache:
  174. raise ValueError(f"Cannot pin key: {key} not in cache.")
  175. self.pinned_items.add(key)
  176. def _unpin(self, key: Hashable) -> None:
  177. self.pinned_items.remove(key)
  178. def _on_remove(self, key: Hashable, value: Optional[T]):
  179. pass
  180. def remove_oldest(self, remove_pinned=False):
  181. if not self.cache:
  182. return
  183. if not remove_pinned:
  184. # pop the oldest item in the cache that is not pinned
  185. lru_key = next(
  186. (key for key in self.cache if key not in self.pinned_items),
  187. ALL_PINNED_SENTINEL)
  188. if lru_key is ALL_PINNED_SENTINEL:
  189. raise RuntimeError("All items are pinned, "
  190. "cannot remove oldest from the cache.")
  191. else:
  192. lru_key = next(iter(self.cache))
  193. self.pop(lru_key)
  194. def _remove_old_if_needed(self) -> None:
  195. while len(self.cache) > self.capacity:
  196. self.remove_oldest()
  197. def pop(self,
  198. key: Hashable,
  199. default_value: Optional[T] = None) -> Optional[T]:
  200. run_on_remove = key in self.cache
  201. value: Optional[T] = self.cache.pop(key, default_value)
  202. # remove from pinned items
  203. if key in self.pinned_items:
  204. self._unpin(key)
  205. if run_on_remove:
  206. self._on_remove(key, value)
  207. return value
  208. def clear(self):
  209. while len(self.cache) > 0:
  210. self.remove_oldest(remove_pinned=True)
  211. self.cache.clear()
  212. class PyObjectCache:
  213. """Used to cache python objects to avoid object allocations
  214. across scheduler iterations.
  215. """
  216. def __init__(self, obj_builder):
  217. self._obj_builder = obj_builder
  218. self._index = 0
  219. self._obj_cache = []
  220. for _ in range(128):
  221. self._obj_cache.append(self._obj_builder())
  222. def _grow_cache(self):
  223. # Double the size of the cache
  224. num_objs = len(self._obj_cache)
  225. for _ in range(num_objs):
  226. self._obj_cache.append(self._obj_builder())
  227. def get_object(self):
  228. """Returns a pre-allocated cached object. If there is not enough
  229. objects, then the cache size will double.
  230. """
  231. if self._index >= len(self._obj_cache):
  232. self._grow_cache()
  233. assert self._index < len(self._obj_cache)
  234. obj = self._obj_cache[self._index]
  235. self._index += 1
  236. return obj
  237. def reset(self):
  238. """Makes all cached-objects available for the next scheduler iteration.
  239. """
  240. self._index = 0
  241. def is_hip() -> bool:
  242. return torch.version.hip is not None
  243. @lru_cache(maxsize=None)
  244. def is_cpu() -> bool:
  245. from importlib.metadata import PackageNotFoundError, version
  246. try:
  247. return "cpu" in version("aphrodite-engine")
  248. except PackageNotFoundError:
  249. return False
  250. @lru_cache(maxsize=None)
  251. def is_openvino() -> bool:
  252. from importlib.metadata import PackageNotFoundError, version
  253. try:
  254. return "openvino" in version("aphrodite-engine")
  255. except PackageNotFoundError:
  256. return False
  257. @lru_cache(maxsize=None)
  258. def is_neuron() -> bool:
  259. try:
  260. import transformers_neuronx
  261. except ImportError:
  262. transformers_neuronx = None
  263. return transformers_neuronx is not None
  264. @lru_cache(maxsize=None)
  265. def is_xpu() -> bool:
  266. from importlib.metadata import version
  267. is_xpu_flag = "xpu" in version("aphrodite-engine")
  268. # aphrodite is not build with xpu
  269. if not is_xpu_flag:
  270. return False
  271. try:
  272. import intel_extension_for_pytorch as ipex # noqa: F401
  273. _import_ipex = True
  274. except ImportError as e:
  275. logger.warning(f"Import Error for IPEX: {e.msg}")
  276. _import_ipex = False
  277. # ipex dependency is not ready
  278. if not _import_ipex:
  279. logger.warning("not found ipex lib")
  280. return False
  281. return hasattr(torch, "xpu") and torch.xpu.is_available()
  282. @lru_cache(maxsize=None)
  283. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  284. """Returns the maximum shared memory per thread block in bytes."""
  285. from aphrodite import _custom_ops as ops
  286. max_shared_mem = (
  287. ops.get_max_shared_memory_per_block_device_attribute(gpu))
  288. # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
  289. # will fail
  290. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  291. return int(max_shared_mem)
  292. def get_cpu_memory() -> int:
  293. """Returns the total CPU memory of the node in bytes."""
  294. return psutil.virtual_memory().total
  295. def random_uuid() -> str:
  296. return str(uuid.uuid4().hex)
  297. @lru_cache(maxsize=None)
  298. def get_aphrodite_instance_id():
  299. """
  300. If the environment variable APHRODITE_INSTANCE_ID is set, return it.
  301. Otherwise, return a random UUID.
  302. Instance id represents an instance of the Aphrodite. All processes in the
  303. same instance should have the same instance id.
  304. """
  305. return os.environ.get("APHRODITE_INSTANCE_ID",
  306. f"aphrodite-instance-{random_uuid()}")
  307. @lru_cache(maxsize=None)
  308. def in_wsl() -> bool:
  309. # Reference: https://github.com/microsoft/WSL/issues/4071
  310. return "microsoft" in " ".join(uname()).lower()
  311. def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
  312. """Take a blocking function, and run it on in an executor thread.
  313. This function prevents the blocking function from blocking the
  314. asyncio event loop.
  315. The code in this function needs to be thread safe.
  316. """
  317. def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
  318. loop = asyncio.get_event_loop()
  319. p_func = partial(func, *args, **kwargs)
  320. return loop.run_in_executor(executor=None, func=p_func)
  321. return _async_wrapper
  322. async def iterate_with_cancellation(
  323. iterator: AsyncGenerator[T, None],
  324. is_cancelled: Callable[[], Awaitable[bool]],
  325. ) -> AsyncGenerator[T, None]:
  326. """Convert async iterator into one that polls the provided function
  327. at least once per second to check for client cancellation.
  328. """
  329. # Can use anext() in python >= 3.10
  330. awaits = [ensure_future(iterator.__anext__())]
  331. while True:
  332. done, pending = await asyncio.wait(awaits, timeout=1)
  333. if await is_cancelled():
  334. with contextlib.suppress(BaseException):
  335. awaits[0].cancel()
  336. await iterator.aclose()
  337. raise asyncio.CancelledError("client cancelled")
  338. if done:
  339. try:
  340. item = await awaits[0]
  341. awaits[0] = ensure_future(iterator.__anext__())
  342. yield item
  343. except StopAsyncIteration:
  344. # we are done
  345. return
  346. async def merge_async_iterators(
  347. *iterators: AsyncGenerator[T, None],
  348. is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
  349. ) -> AsyncGenerator[Tuple[int, T], None]:
  350. """Merge multiple asynchronous iterators into a single iterator.
  351. This method handle the case where some iterators finish before others.
  352. When it yields, it yields a tuple (i, item) where i is the index of the
  353. iterator that yields the item.
  354. It also optionally polls a provided function at least once per second
  355. to check for client cancellation.
  356. """
  357. # Can use anext() in python >= 3.10
  358. awaits = {
  359. ensure_future(pair[1].__anext__()): pair
  360. for pair in enumerate(iterators)
  361. }
  362. timeout = None if is_cancelled is None else 1
  363. try:
  364. while awaits:
  365. done, pending = await asyncio.wait(awaits.keys(),
  366. return_when=FIRST_COMPLETED,
  367. timeout=timeout)
  368. if is_cancelled is not None and await is_cancelled():
  369. raise asyncio.CancelledError("client cancelled")
  370. for d in done:
  371. pair = awaits.pop(d)
  372. try:
  373. item = await d
  374. i, it = pair
  375. awaits[ensure_future(it.__anext__())] = pair
  376. yield i, item
  377. except StopAsyncIteration:
  378. pass
  379. finally:
  380. # Cancel any remaining iterators
  381. for f, (_, it) in awaits.items():
  382. with contextlib.suppress(BaseException):
  383. f.cancel()
  384. await it.aclose()
  385. def get_ip() -> str:
  386. host_ip = os.environ.get("HOST_IP")
  387. if host_ip:
  388. return host_ip
  389. # IP is not set, try to get it from the network interface
  390. # try ipv4
  391. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  392. try:
  393. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  394. return s.getsockname()[0]
  395. except Exception:
  396. pass
  397. # try ipv6
  398. try:
  399. s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  400. # Google's public DNS server, see
  401. # https://developers.google.com/speed/public-dns/docs/using#addresses
  402. s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
  403. return s.getsockname()[0]
  404. except Exception:
  405. pass
  406. warnings.warn(
  407. "Failed to get the IP address, using 0.0.0.0 by default."
  408. "The value can be set by the environment variable HOST_IP.",
  409. stacklevel=2)
  410. return "0.0.0.0"
  411. def get_distributed_init_method(ip: str, port: int) -> str:
  412. # Brackets are not permitted in ipv4 addresses,
  413. # see https://github.com/python/cpython/issues/103848
  414. return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
  415. def get_open_zmq_ipc_path() -> str:
  416. APHRODITE_RPC_BASE_PATH = os.getenv("APHRODITE_RPC_BASE_PATH",
  417. tempfile.gettempdir())
  418. base_rpc_path = APHRODITE_RPC_BASE_PATH
  419. return f"ipc://{base_rpc_path}/{uuid4()}"
  420. def get_open_port(port: Optional[int] = None) -> int:
  421. port = int(os.getenv("APHRODITE_PORT", 0)
  422. ) if "APHRODITE_PORT" in os.environ else None
  423. if port is not None:
  424. while True:
  425. try:
  426. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  427. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  428. s.bind(("", port))
  429. return port
  430. except OSError:
  431. port += 1 # Increment port number if already in use
  432. logger.info(f"Port {port - 1} is already in use, trying port "
  433. f"{port}")
  434. # try ipv4
  435. try:
  436. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  437. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  438. s.bind(("", 0))
  439. return s.getsockname()[1]
  440. except OSError:
  441. # try ipv6
  442. with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
  443. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  444. s.bind(("", 0))
  445. return s.getsockname()[1]
  446. def update_environment_variables(envs: Dict[str, str]):
  447. for k, v in envs.items():
  448. if k in os.environ and os.environ[k] != v:
  449. logger.warning(f"Overwriting environment variable {k} "
  450. f"from '{os.environ[k]}' to '{v}'")
  451. os.environ[k] = v
  452. def chunk_list(lst: List[T], chunk_size: int):
  453. """Yield successive chunk_size chunks from lst."""
  454. for i in range(0, len(lst), chunk_size):
  455. yield lst[i:i + chunk_size]
  456. def cdiv(a: int, b: int) -> int:
  457. """Ceiling division."""
  458. return -(a // -b)
  459. def _generate_random_fp8(
  460. tensor: torch.Tensor,
  461. low: float,
  462. high: float,
  463. ) -> None:
  464. # NOTE: Due to NaN and Inf representation for fp8 data type,
  465. # it may occur Inf or NaN if we directly use torch.randint
  466. # to generate random data for fp8 data.
  467. # For example, s.11111.00 in fp8e5m2 format represents Inf.
  468. # | E4M3 | E5M2
  469. #-----|-------------|-------------------
  470. # Inf | N/A | s.11111.00
  471. # NaN | s.1111.111 | s.11111.{01,10,11}
  472. from aphrodite import _custom_ops as ops
  473. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  474. tensor_tmp.uniform_(low, high)
  475. ops.convert_fp8(tensor, tensor_tmp)
  476. del tensor_tmp
  477. def get_kv_cache_torch_dtype(
  478. cache_dtype: Optional[Union[str, torch.dtype]],
  479. model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
  480. if isinstance(cache_dtype, str):
  481. if cache_dtype == "auto":
  482. if isinstance(model_dtype, str):
  483. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  484. elif isinstance(model_dtype, torch.dtype):
  485. torch_dtype = model_dtype
  486. else:
  487. raise ValueError(f"Invalid model dtype: {model_dtype}")
  488. elif cache_dtype in ["half", "bfloat16", "float"]:
  489. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  490. elif cache_dtype == "fp8":
  491. torch_dtype = torch.uint8
  492. else:
  493. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  494. elif isinstance(cache_dtype, torch.dtype):
  495. torch_dtype = cache_dtype
  496. else:
  497. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  498. return torch_dtype
  499. def create_kv_caches_with_random_flash(
  500. num_blocks: int,
  501. block_size: int,
  502. num_layers: int,
  503. num_heads: int,
  504. head_size: int,
  505. cache_dtype: Optional[Union[str, torch.dtype]],
  506. model_dtype: Optional[Union[str, torch.dtype]] = None,
  507. seed: int = 0,
  508. device: Optional[str] = "cuda",
  509. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  510. torch.random.manual_seed(seed)
  511. if torch.cuda.is_available():
  512. torch.cuda.manual_seed(seed)
  513. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  514. key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
  515. scale = head_size**-0.5
  516. key_caches: List[torch.Tensor] = []
  517. value_caches: List[torch.Tensor] = []
  518. for _ in range(num_layers):
  519. key_value_cache = torch.empty(size=key_value_cache_shape,
  520. dtype=torch_dtype,
  521. device=device)
  522. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  523. key_value_cache.uniform_(-scale, scale)
  524. elif cache_dtype == 'fp8':
  525. _generate_random_fp8(key_value_cache, -scale, scale)
  526. else:
  527. raise ValueError(
  528. f"Does not support key cache of type {cache_dtype}")
  529. key_caches.append(key_value_cache[:, 0])
  530. value_caches.append(key_value_cache[:, 1])
  531. return key_caches, value_caches
  532. def create_kv_caches_with_random(
  533. num_blocks: int,
  534. block_size: int,
  535. num_layers: int,
  536. num_heads: int,
  537. head_size: int,
  538. cache_dtype: Optional[Union[str, torch.dtype]],
  539. model_dtype: Optional[Union[str, torch.dtype]] = None,
  540. seed: int = 0,
  541. device: Optional[str] = "cuda",
  542. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  543. if cache_dtype == "fp8" and head_size % 16:
  544. raise ValueError(
  545. f"Does not support key cache of type fp8 with head_size "
  546. f"{head_size}")
  547. torch.random.manual_seed(seed)
  548. if torch.cuda.is_available():
  549. torch.cuda.manual_seed(seed)
  550. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  551. scale = head_size**-0.5
  552. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  553. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  554. key_caches: List[torch.Tensor] = []
  555. for _ in range(num_layers):
  556. key_cache = torch.empty(size=key_cache_shape,
  557. dtype=torch_dtype,
  558. device=device)
  559. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  560. key_cache.uniform_(-scale, scale)
  561. elif cache_dtype == 'fp8':
  562. _generate_random_fp8(key_cache, -scale, scale)
  563. else:
  564. raise ValueError(
  565. f"Does not support key cache of type {cache_dtype}")
  566. key_caches.append(key_cache)
  567. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  568. value_caches: List[torch.Tensor] = []
  569. for _ in range(num_layers):
  570. value_cache = torch.empty(size=value_cache_shape,
  571. dtype=torch_dtype,
  572. device=device)
  573. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  574. value_cache.uniform_(-scale, scale)
  575. elif cache_dtype == 'fp8':
  576. _generate_random_fp8(value_cache, -scale, scale)
  577. else:
  578. raise ValueError(
  579. f"Does not support value cache of type {cache_dtype}")
  580. value_caches.append(value_cache)
  581. return key_caches, value_caches
  582. @lru_cache
  583. def print_warning_once(msg: str) -> None:
  584. logger.warning(msg)
  585. @lru_cache(maxsize=None)
  586. def is_pin_memory_available() -> bool:
  587. if in_wsl():
  588. # Pinning memory in WSL is not supported.
  589. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
  590. print_warning_once("Using 'pin_memory=False' as WSL is detected. "
  591. "This may slow down the performance.")
  592. return False
  593. elif is_xpu():
  594. print_warning_once("Pin memory is not supported on XPU.")
  595. return False
  596. elif is_neuron():
  597. print_warning_once("Pin memory is not supported on Neuron.")
  598. return False
  599. elif is_cpu() or is_openvino():
  600. return False
  601. return True
  602. class CudaMemoryProfiler:
  603. def __init__(self, device: Optional[torch.types.Device] = None):
  604. self.device = device
  605. def current_memory_usage(self) -> float:
  606. # Return the memory usage in bytes.
  607. if torch.cuda.is_available():
  608. torch.cuda.reset_peak_memory_stats(self.device)
  609. mem = torch.cuda.max_memory_allocated(self.device)
  610. elif is_xpu():
  611. torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
  612. mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
  613. return mem
  614. def __enter__(self):
  615. self.initial_memory = self.current_memory_usage()
  616. # This allows us to call methods of the context manager if needed
  617. return self
  618. def __exit__(self, exc_type, exc_val, exc_tb):
  619. self.final_memory = self.current_memory_usage()
  620. self.consumed_memory = self.final_memory - self.initial_memory
  621. # Force garbage collection
  622. gc.collect()
  623. def make_ndarray_with_pad(
  624. x: List[List[T]],
  625. pad: T,
  626. dtype: npt.DTypeLike,
  627. *,
  628. max_len: Optional[int] = None,
  629. ) -> npt.NDArray:
  630. """
  631. Make a padded array from 2D inputs.
  632. The padding is applied to the end of each inner list until it reaches
  633. `max_len`.
  634. """
  635. if max_len is None:
  636. # Unlike for most functions, map is faster than a genexpr over `len`
  637. max_len = max(map(len, x), default=0)
  638. padded_x = np.full((len(x), max_len), pad, dtype=dtype)
  639. for ind, blocktb in enumerate(x):
  640. assert len(blocktb) <= max_len
  641. padded_x[ind, :len(blocktb)] = blocktb
  642. return padded_x
  643. def make_tensor_with_pad(
  644. x: List[List[T]],
  645. pad: T,
  646. dtype: torch.dtype,
  647. *,
  648. max_len: Optional[int] = None,
  649. device: Optional[Union[str, torch.device]] = None,
  650. pin_memory: bool = False,
  651. ) -> torch.Tensor:
  652. """
  653. Make a padded tensor from 2D inputs.
  654. The padding is applied to the end of each inner list until it reaches
  655. `max_len`.
  656. """
  657. np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
  658. padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
  659. tensor = torch.from_numpy(padded_x).to(device)
  660. if pin_memory:
  661. tensor = tensor.pin_memory()
  662. return tensor
  663. def async_tensor_h2d(
  664. data: list,
  665. dtype: torch.dtype,
  666. target_device: Union[str, torch.device],
  667. pin_memory: bool,
  668. ) -> torch.Tensor:
  669. """Asynchronously create a tensor and copy it from host to device."""
  670. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  671. return t.to(device=target_device, non_blocking=True)
  672. def maybe_expand_dim(tensor: torch.Tensor,
  673. target_dims: int,
  674. size: int = 1) -> torch.Tensor:
  675. """Expand the tensor to the target_dims."""
  676. if tensor.ndim < target_dims:
  677. tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
  678. return tensor
  679. def get_dtype_size(dtype: torch.dtype) -> int:
  680. """Get the size of the data type in bytes."""
  681. return torch.tensor([], dtype=dtype).element_size()
  682. # `collections` helpers
  683. def is_list_of(
  684. value: object,
  685. typ: Type[T],
  686. *,
  687. check: Literal["first", "all"] = "first",
  688. ) -> TypeIs[List[T]]:
  689. if not isinstance(value, list):
  690. return False
  691. if check == "first":
  692. return len(value) == 0 or isinstance(value[0], typ)
  693. elif check == "all":
  694. return all(isinstance(v, typ) for v in value)
  695. assert_never(check)
  696. JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
  697. Tuple["JSONTree[T]", ...], T]
  698. """A nested JSON structure where the leaves need not be JSON-serializable."""
  699. @overload
  700. def json_map_leaves(
  701. func: Callable[[T], U],
  702. value: Dict[str, JSONTree[T]],
  703. ) -> Dict[str, JSONTree[U]]:
  704. ...
  705. @overload
  706. def json_map_leaves(
  707. func: Callable[[T], U],
  708. value: List[JSONTree[T]],
  709. ) -> List[JSONTree[U]]:
  710. ...
  711. @overload
  712. def json_map_leaves(
  713. func: Callable[[T], U],
  714. value: Tuple[JSONTree[T], ...],
  715. ) -> Tuple[JSONTree[U], ...]:
  716. ...
  717. @overload
  718. def json_map_leaves(
  719. func: Callable[[T], U],
  720. value: JSONTree[T],
  721. ) -> JSONTree[U]:
  722. ...
  723. def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
  724. if isinstance(value, dict):
  725. return {k: json_map_leaves(func, v) for k, v in value.items()}
  726. elif isinstance(value, list):
  727. return [json_map_leaves(func, v) for v in value]
  728. elif isinstance(value, tuple):
  729. return tuple(json_map_leaves(func, v) for v in value)
  730. else:
  731. return func(value)
  732. def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
  733. """Flatten a list of lists to a single list."""
  734. return [item for sublist in lists for item in sublist]
  735. def init_cached_hf_modules() -> None:
  736. """
  737. Lazy initialization of the Hugging Face modules.
  738. """
  739. from transformers.dynamic_module_utils import init_hf_modules
  740. init_hf_modules()
  741. @lru_cache(maxsize=None)
  742. def find_library(lib_name: str) -> str:
  743. """
  744. Find the library file in the system.
  745. `lib_name` is full filename, with both prefix and suffix.
  746. This function resolves `lib_name` to the full path of the library.
  747. """
  748. # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
  749. # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
  750. # `/sbin/ldconfig` should exist in all Linux systems.
  751. # `/sbin/ldconfig` searches the library in the system
  752. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
  753. # each line looks like the following:
  754. # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
  755. locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
  756. # `LD_LIBRARY_PATH` searches the library in the user-defined paths
  757. env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
  758. if not locs and env_ld_library_path:
  759. locs = [
  760. os.path.join(dir, lib_name)
  761. for dir in env_ld_library_path.split(":")
  762. if os.path.exists(os.path.join(dir, lib_name))
  763. ]
  764. if not locs:
  765. raise ValueError(f"Cannot find {lib_name} in the system.")
  766. return locs[0]
  767. def find_nccl_library() -> str:
  768. """
  769. We either use the library file specified by the `APHRODITE_NCCL_SO_PATH`
  770. environment variable, or we find the library file brought by PyTorch.
  771. After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
  772. found by `ctypes` automatically.
  773. """
  774. so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
  775. # manually load the nccl library
  776. if so_file:
  777. logger.debug("Found nccl from environment variable "
  778. f"APHRODITE_NCCL_SO_PATH={so_file}")
  779. else:
  780. if torch.version.cuda is not None:
  781. so_file = "libnccl.so.2"
  782. elif torch.version.hip is not None:
  783. so_file = "librccl.so.1"
  784. else:
  785. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  786. logger.debug(f"Found nccl from library {so_file}")
  787. return so_file
  788. def enable_trace_function_call_for_thread() -> None:
  789. if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
  790. tmp_dir = tempfile.gettempdir()
  791. filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
  792. f"_thread_{threading.get_ident()}_"
  793. f"at_{datetime.datetime.now()}.log").replace(" ", "_")
  794. log_path = os.path.join(tmp_dir, "aphrodite",
  795. get_aphrodite_instance_id(), filename)
  796. os.makedirs(os.path.dirname(log_path), exist_ok=True)
  797. enable_trace_function_call(log_path)
  798. def identity(value: T) -> T:
  799. return value
  800. F = TypeVar('F', bound=Callable[..., Any])
  801. def deprecate_kwargs(
  802. *kws: str,
  803. is_deprecated: Union[bool, Callable[[], bool]] = True,
  804. additional_message: Optional[str] = None) -> Callable[[F], F]:
  805. deprecated_kws = set(kws)
  806. if not callable(is_deprecated):
  807. is_deprecated = partial(identity, is_deprecated)
  808. def wrapper(fn: F) -> F:
  809. @wraps(fn)
  810. def inner(*args, **kwargs):
  811. if is_deprecated():
  812. deprecated_kwargs = kwargs.keys() & deprecated_kws
  813. if deprecated_kwargs:
  814. msg = (
  815. f"The keyword arguments {deprecated_kwargs} are "
  816. "deprecated and will be removed in a future update.")
  817. if additional_message is not None:
  818. msg += f" {additional_message}"
  819. warnings.warn(
  820. DeprecationWarning(msg),
  821. stacklevel=3, # The inner function takes up one level
  822. )
  823. return fn(*args, **kwargs)
  824. return inner # type: ignore
  825. return wrapper
  826. @lru_cache(maxsize=8)
  827. def _cuda_device_count_stateless(
  828. cuda_visible_devices: Optional[str] = None) -> int:
  829. # Note: cuda_visible_devices is not used, but we keep it as an argument for
  830. # LRU Cache purposes.
  831. # Code below is based on
  832. # https://github.com/pytorch/pytorch/blob/
  833. # c1cd946818442aca8c7f812b16d187ce1586c3bc/
  834. # torch/cuda/__init__.py#L831C1-L831C17
  835. import torch.cuda
  836. import torch.version
  837. if not torch.cuda._is_compiled():
  838. return 0
  839. if is_hip():
  840. # ROCm uses amdsmi instead of nvml for stateless device count
  841. # This requires a sufficiently modern version of Torch 2.4.0
  842. raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
  843. torch.cuda, "_device_count_amdsmi")) else -1
  844. else:
  845. raw_count = torch.cuda._device_count_nvml()
  846. r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
  847. return r
  848. def cuda_device_count_stateless() -> int:
  849. """Get number of CUDA devices, caching based on the value of
  850. CUDA_VISIBLE_DEVICES at the time of call.
  851. This should be used instead of torch.cuda.device_count()
  852. unless CUDA_VISIBLE_DEVICES has already been set to the desired
  853. value."""
  854. # This can be removed and simply replaced with torch.cuda.get_device_count
  855. # after https://github.com/pytorch/pytorch/pull/122815 is released.
  856. return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES"))
  857. #From: https://stackoverflow.com/a/4104188/2749989
  858. def run_once(f):
  859. def wrapper(*args, **kwargs) -> Any:
  860. if not wrapper.has_run: # type: ignore[attr-defined]
  861. wrapper.has_run = True # type: ignore[attr-defined]
  862. return f(*args, **kwargs)
  863. wrapper.has_run = False # type: ignore[attr-defined]
  864. return wrapper
  865. class FlexibleArgumentParser(argparse.ArgumentParser):
  866. """ArgumentParser that allows both underscore and dash in names."""
  867. def parse_args(self, args=None, namespace=None):
  868. if args is None:
  869. args = sys.argv[1:]
  870. # Convert underscores to dashes and vice versa in argument names
  871. processed_args = []
  872. for arg in args:
  873. if arg.startswith('--'):
  874. if '=' in arg:
  875. key, value = arg.split('=', 1)
  876. key = '--' + key[len('--'):].replace('_', '-')
  877. processed_args.append(f'{key}={value}')
  878. else:
  879. processed_args.append('--' +
  880. arg[len('--'):].replace('_', '-'))
  881. else:
  882. processed_args.append(arg)
  883. return super().parse_args(processed_args, namespace)
  884. async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
  885. **kwargs):
  886. """Utility function to run async task in a lock"""
  887. async with lock:
  888. return await task(*args, **kwargs)
  889. def progress_bar(iterable, desc="Processing"):
  890. show_progress = get_tensor_model_parallel_rank() == 0
  891. if show_progress:
  892. with Progress(
  893. SpinnerColumn(),
  894. TextColumn("[progress.description]{task.description}"),
  895. BarColumn(),
  896. MofNCompleteColumn(),
  897. TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
  898. TimeElapsedColumn(),
  899. ) as progress:
  900. task = progress.add_task(f"[cyan]{desc}", total=len(iterable))
  901. for item in iterable:
  902. yield item
  903. progress.update(task, advance=1)
  904. else:
  905. yield from iterable