utils.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119
  1. import argparse
  2. import asyncio
  3. import contextlib
  4. import datetime
  5. import enum
  6. import gc
  7. import os
  8. import socket
  9. import subprocess
  10. import sys
  11. import tempfile
  12. import threading
  13. import uuid
  14. import warnings
  15. from asyncio import FIRST_COMPLETED, ensure_future
  16. from collections import defaultdict
  17. from functools import lru_cache, partial, wraps
  18. from platform import uname
  19. from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
  20. Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
  21. Union, overload)
  22. import numpy as np
  23. import numpy.typing as npt
  24. import psutil
  25. import torch
  26. import torch.types
  27. from loguru import logger
  28. from typing_extensions import ParamSpec
  29. from aphrodite import _custom_ops as ops
  30. from aphrodite.common.logger import enable_trace_function_call
  31. from aphrodite.inputs import (ExplicitEncoderDecoderPrompt, PromptInputs,
  32. SingletonPromptInputs)
  33. # Exception strings for non-implemented encoder/decoder scenarios
  34. STR_NOT_IMPL_ENC_DEC_SWA = \
  35. "Sliding window attention for encoder/decoder models " + \
  36. "is not currently supported."
  37. STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
  38. "Prefix caching for encoder/decoder models " + \
  39. "is not currently supported."
  40. STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
  41. "Chunked prefill for encoder/decoder models " + \
  42. "is not currently supported."
  43. STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
  44. "Models with logits_soft_cap "
  45. "require FlashInfer backend, which is "
  46. "currently not supported for encoder/decoder "
  47. "models.")
  48. STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
  49. "supported with encoder/decoder "
  50. "models.")
  51. STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
  52. "currently supported with "
  53. "encoder/decoder models.")
  54. STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
  55. "supported with encoder/decoder "
  56. "models.")
  57. STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
  58. "currently supported with encoder/"
  59. "decoder models.")
  60. STR_NOT_IMPL_ENC_DEC_CUDAGRAPH = ("CUDAGraph is not "
  61. "currently supported with encoder/"
  62. "decoder models.")
  63. STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
  64. "currently supported with encoder/"
  65. "decoder models.")
  66. STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
  67. "currently supported with encoder/"
  68. "decoder models.")
  69. # Efficiently import all enc/dec error strings
  70. # rather than having to import all of the above
  71. STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
  72. "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
  73. "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
  74. "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
  75. STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
  76. "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
  77. "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
  78. "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
  79. "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
  80. "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
  81. "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH,
  82. "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
  83. "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
  84. }
  85. # Constants related to forcing the attention backend selection
  86. # String name of register which may be set in order to
  87. # force auto-selection of attention backend by Attention
  88. # wrapper
  89. STR_BACKEND_ENV_VAR: str = "APHRODITE_ATTENTION_BACKEND"
  90. # Possible string values of STR_BACKEND_ENV_VAR
  91. # register, corresponding to possible backends
  92. STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
  93. STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
  94. STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
  95. STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
  96. STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
  97. STR_INVALID_VAL: str = "INVALID"
  98. STR_DTYPE_TO_TORCH_DTYPE = {
  99. "half": torch.half,
  100. "bfloat16": torch.bfloat16,
  101. "float": torch.float,
  102. "fp8": torch.uint8,
  103. "fp8_e4m3": torch.uint8,
  104. "fp8_e5m2": torch.uint8,
  105. }
  106. TORCH_DTYPE_TO_NUMPY_DTYPE = {
  107. torch.float16: np.float16,
  108. torch.float32: np.float32,
  109. torch.float64: np.float64,
  110. torch.uint8: np.uint8,
  111. torch.int32: np.int32,
  112. torch.int64: np.int64,
  113. }
  114. P = ParamSpec('P')
  115. K = TypeVar("K")
  116. T = TypeVar("T")
  117. U = TypeVar("U")
  118. class _Sentinel:
  119. ...
  120. ALL_PINNED_SENTINEL = _Sentinel()
  121. class Device(enum.Enum):
  122. GPU = enum.auto()
  123. CPU = enum.auto()
  124. class Counter:
  125. def __init__(self, start: int = 0) -> None:
  126. self.counter = start
  127. def __next__(self) -> int:
  128. i = self.counter
  129. self.counter += 1
  130. return i
  131. def reset(self) -> None:
  132. self.counter = 0
  133. class LRUCache(Generic[T]):
  134. def __init__(self, capacity: int):
  135. self.cache: OrderedDict[Hashable, T] = OrderedDict()
  136. self.pinned_items: Set[Hashable] = set()
  137. self.capacity = capacity
  138. def __contains__(self, key: Hashable) -> bool:
  139. return key in self.cache
  140. def __len__(self) -> int:
  141. return len(self.cache)
  142. def __getitem__(self, key: Hashable) -> T:
  143. value = self.cache[key] # Raise KeyError if not exists
  144. self.cache.move_to_end(key)
  145. return value
  146. def __setitem__(self, key: Hashable, value: T) -> None:
  147. self.put(key, value)
  148. def __delitem__(self, key: Hashable) -> None:
  149. self.pop(key)
  150. def touch(self, key: Hashable) -> None:
  151. self.cache.move_to_end(key)
  152. def get(self,
  153. key: Hashable,
  154. default_value: Optional[T] = None) -> Optional[T]:
  155. value: Optional[T]
  156. if key in self.cache:
  157. value = self.cache[key]
  158. self.cache.move_to_end(key)
  159. else:
  160. value = default_value
  161. return value
  162. def put(self, key: Hashable, value: T) -> None:
  163. self.cache[key] = value
  164. self.cache.move_to_end(key)
  165. self._remove_old_if_needed()
  166. def pin(self, key: Hashable) -> None:
  167. """
  168. Pins a key in the cache preventing it from being
  169. evicted in the LRU order.
  170. """
  171. if key not in self.cache:
  172. raise ValueError(f"Cannot pin key: {key} not in cache.")
  173. self.pinned_items.add(key)
  174. def _unpin(self, key: Hashable) -> None:
  175. self.pinned_items.remove(key)
  176. def _on_remove(self, key: Hashable, value: Optional[T]):
  177. pass
  178. def remove_oldest(self, remove_pinned=False):
  179. if not self.cache:
  180. return
  181. if not remove_pinned:
  182. # pop the oldest item in the cache that is not pinned
  183. lru_key = next(
  184. (key for key in self.cache if key not in self.pinned_items),
  185. ALL_PINNED_SENTINEL)
  186. if lru_key is ALL_PINNED_SENTINEL:
  187. raise RuntimeError("All items are pinned, "
  188. "cannot remove oldest from the cache.")
  189. else:
  190. lru_key = next(iter(self.cache))
  191. self.pop(lru_key)
  192. def _remove_old_if_needed(self) -> None:
  193. while len(self.cache) > self.capacity:
  194. self.remove_oldest()
  195. def pop(self,
  196. key: Hashable,
  197. default_value: Optional[T] = None) -> Optional[T]:
  198. run_on_remove = key in self.cache
  199. value: Optional[T] = self.cache.pop(key, default_value)
  200. # remove from pinned items
  201. if key in self.pinned_items:
  202. self._unpin(key)
  203. if run_on_remove:
  204. self._on_remove(key, value)
  205. return value
  206. def clear(self):
  207. while len(self.cache) > 0:
  208. self.remove_oldest(remove_pinned=True)
  209. self.cache.clear()
  210. def is_hip() -> bool:
  211. return torch.version.hip is not None
  212. @lru_cache(maxsize=None)
  213. def is_cpu() -> bool:
  214. from importlib.metadata import PackageNotFoundError, version
  215. try:
  216. return "cpu" in version("aphrodite-engine")
  217. except PackageNotFoundError:
  218. return False
  219. @lru_cache(maxsize=None)
  220. def is_openvino() -> bool:
  221. from importlib.metadata import PackageNotFoundError, version
  222. try:
  223. return "openvino" in version("aphrodite-engine")
  224. except PackageNotFoundError:
  225. return False
  226. @lru_cache(maxsize=None)
  227. def is_neuron() -> bool:
  228. try:
  229. import transformers_neuronx
  230. except ImportError:
  231. transformers_neuronx = None
  232. return transformers_neuronx is not None
  233. @lru_cache(maxsize=None)
  234. def is_tpu() -> bool:
  235. try:
  236. import libtpu
  237. except ImportError:
  238. libtpu = None
  239. return libtpu is not None
  240. @lru_cache(maxsize=None)
  241. def is_xpu() -> bool:
  242. from importlib.metadata import version
  243. is_xpu_flag = "xpu" in version("aphrodite-engine")
  244. # aphrodite is not build with xpu
  245. if not is_xpu_flag:
  246. return False
  247. try:
  248. import intel_extension_for_pytorch as ipex # noqa: F401
  249. _import_ipex = True
  250. except ImportError as e:
  251. logger.warning(f"Import Error for IPEX: {e.msg}")
  252. _import_ipex = False
  253. # ipex dependency is not ready
  254. if not _import_ipex:
  255. logger.warning("not found ipex lib")
  256. return False
  257. return hasattr(torch, "xpu") and torch.xpu.is_available()
  258. @lru_cache(maxsize=None)
  259. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  260. """Returns the maximum shared memory per thread block in bytes."""
  261. max_shared_mem = (
  262. ops.get_max_shared_memory_per_block_device_attribute(gpu))
  263. # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
  264. # will fail
  265. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  266. return int(max_shared_mem)
  267. def get_cpu_memory() -> int:
  268. """Returns the total CPU memory of the node in bytes."""
  269. return psutil.virtual_memory().total
  270. def random_uuid() -> str:
  271. return str(uuid.uuid4().hex)
  272. @lru_cache(maxsize=None)
  273. def get_aphrodite_instance_id():
  274. """
  275. If the environment variable APHRODITE_INSTANCE_ID is set, return it.
  276. Otherwise, return a random UUID.
  277. Instance id represents an instance of the Aphrodite. All processes in the
  278. same instance should have the same instance id.
  279. """
  280. return os.environ.get("APHRODITE_INSTANCE_ID",
  281. f"aphrodite-instance-{random_uuid()}")
  282. @lru_cache(maxsize=None)
  283. def in_wsl() -> bool:
  284. # Reference: https://github.com/microsoft/WSL/issues/4071
  285. return "microsoft" in " ".join(uname()).lower()
  286. def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
  287. """Take a blocking function, and run it on in an executor thread.
  288. This function prevents the blocking function from blocking the
  289. asyncio event loop.
  290. The code in this function needs to be thread safe.
  291. """
  292. def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
  293. loop = asyncio.get_event_loop()
  294. p_func = partial(func, *args, **kwargs)
  295. return loop.run_in_executor(executor=None, func=p_func)
  296. return _async_wrapper
  297. async def iterate_with_cancellation(
  298. iterator: AsyncGenerator[T, None],
  299. is_cancelled: Callable[[], Awaitable[bool]],
  300. ) -> AsyncGenerator[T, None]:
  301. """Convert async iterator into one that polls the provided function
  302. at least once per second to check for client cancellation.
  303. """
  304. # Can use anext() in python >= 3.10
  305. awaits = [ensure_future(iterator.__anext__())]
  306. while True:
  307. done, pending = await asyncio.wait(awaits, timeout=1)
  308. if await is_cancelled():
  309. with contextlib.suppress(BaseException):
  310. awaits[0].cancel()
  311. await iterator.aclose()
  312. raise asyncio.CancelledError("client cancelled")
  313. if done:
  314. try:
  315. item = await awaits[0]
  316. awaits[0] = ensure_future(iterator.__anext__())
  317. yield item
  318. except StopAsyncIteration:
  319. # we are done
  320. return
  321. async def merge_async_iterators(
  322. *iterators: AsyncGenerator[T, None],
  323. is_cancelled: Callable[[], Awaitable[bool]],
  324. ) -> AsyncGenerator[Tuple[int, T], None]:
  325. """Merge multiple asynchronous iterators into a single iterator.
  326. This method handle the case where some iterators finish before others.
  327. When it yields, it yields a tuple (i, item) where i is the index of the
  328. iterator that yields the item.
  329. It also polls the provided function at least once per second to check
  330. for client cancellation.
  331. """
  332. # Can use anext() in python >= 3.10
  333. awaits = {
  334. ensure_future(pair[1].__anext__()): pair
  335. for pair in enumerate(iterators)
  336. }
  337. try:
  338. while awaits:
  339. done, pending = await asyncio.wait(awaits.keys(),
  340. return_when=FIRST_COMPLETED,
  341. timeout=1)
  342. if await is_cancelled():
  343. raise asyncio.CancelledError("client cancelled")
  344. for d in done:
  345. pair = awaits.pop(d)
  346. try:
  347. item = await d
  348. i, it = pair
  349. awaits[ensure_future(it.__anext__())] = pair
  350. yield i, item
  351. except StopAsyncIteration:
  352. pass
  353. finally:
  354. # Cancel any remaining iterators
  355. for f, (_, it) in awaits.items():
  356. with contextlib.suppress(BaseException):
  357. f.cancel()
  358. await it.aclose()
  359. def get_ip() -> str:
  360. host_ip = os.environ.get("HOST_IP")
  361. if host_ip:
  362. return host_ip
  363. # IP is not set, try to get it from the network interface
  364. # try ipv4
  365. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  366. try:
  367. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  368. return s.getsockname()[0]
  369. except Exception:
  370. pass
  371. # try ipv6
  372. try:
  373. s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  374. # Google's public DNS server, see
  375. # https://developers.google.com/speed/public-dns/docs/using#addresses
  376. s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
  377. return s.getsockname()[0]
  378. except Exception:
  379. pass
  380. warnings.warn(
  381. "Failed to get the IP address, using 0.0.0.0 by default."
  382. "The value can be set by the environment variable HOST_IP.",
  383. stacklevel=2)
  384. return "0.0.0.0"
  385. def get_distributed_init_method(ip: str, port: int) -> str:
  386. # Brackets are not permitted in ipv4 addresses,
  387. # see https://github.com/python/cpython/issues/103848
  388. return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
  389. def get_open_port(port: Optional[int] = None) -> int:
  390. if port is None:
  391. # Default behavior here is to return a port for multi-gpu communication
  392. port = int(os.getenv("APHRODITE_PORT", 0)
  393. ) if "APHRODITE_PORT" in os.environ else None
  394. if port is not None:
  395. while True:
  396. try:
  397. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  398. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  399. s.bind(("", port))
  400. return port
  401. except OSError:
  402. port += 1 # Increment port number if already in use
  403. logger.info(f"Port {port - 1} is already in use, trying port "
  404. f"{port}")
  405. # try ipv4
  406. try:
  407. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  408. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  409. s.bind(("", 0))
  410. return s.getsockname()[1]
  411. except OSError:
  412. # try ipv6
  413. with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
  414. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  415. s.bind(("", 0))
  416. return s.getsockname()[1]
  417. def update_environment_variables(envs: Dict[str, str]):
  418. for k, v in envs.items():
  419. if k in os.environ and os.environ[k] != v:
  420. logger.warning(f"Overwriting environment variable {k} "
  421. f"from '{os.environ[k]}' to '{v}'")
  422. os.environ[k] = v
  423. def chunk_list(lst: List[T], chunk_size: int):
  424. """Yield successive chunk_size chunks from lst."""
  425. for i in range(0, len(lst), chunk_size):
  426. yield lst[i:i + chunk_size]
  427. def cdiv(a: int, b: int) -> int:
  428. """Ceiling division."""
  429. return -(a // -b)
  430. def _generate_random_fp8(
  431. tensor: torch.Tensor,
  432. low: float,
  433. high: float,
  434. ) -> None:
  435. # NOTE: Due to NaN and Inf representation for fp8 data type,
  436. # it may occur Inf or NaN if we directly use torch.randint
  437. # to generate random data for fp8 data.
  438. # For example, s.11111.00 in fp8e5m2 format represents Inf.
  439. # | E4M3 | E5M2
  440. #-----|-------------|-------------------
  441. # Inf | N/A | s.11111.00
  442. # NaN | s.1111.111 | s.11111.{01,10,11}
  443. from aphrodite import _custom_ops as ops
  444. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  445. tensor_tmp.uniform_(low, high)
  446. ops.convert_fp8(tensor, tensor_tmp)
  447. del tensor_tmp
  448. def get_kv_cache_torch_dtype(
  449. cache_dtype: Optional[Union[str, torch.dtype]],
  450. model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
  451. if isinstance(cache_dtype, str):
  452. if cache_dtype == "auto":
  453. if isinstance(model_dtype, str):
  454. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  455. elif isinstance(model_dtype, torch.dtype):
  456. torch_dtype = model_dtype
  457. else:
  458. raise ValueError(f"Invalid model dtype: {model_dtype}")
  459. elif cache_dtype in ["half", "bfloat16", "float"]:
  460. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  461. elif cache_dtype == "fp8":
  462. torch_dtype = torch.uint8
  463. else:
  464. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  465. elif isinstance(cache_dtype, torch.dtype):
  466. torch_dtype = cache_dtype
  467. else:
  468. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  469. return torch_dtype
  470. def create_kv_caches_with_random_flash(
  471. num_blocks: int,
  472. block_size: int,
  473. num_layers: int,
  474. num_heads: int,
  475. head_size: int,
  476. cache_dtype: Optional[Union[str, torch.dtype]],
  477. model_dtype: Optional[Union[str, torch.dtype]] = None,
  478. seed: int = 0,
  479. device: Optional[str] = "cuda",
  480. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  481. torch.random.manual_seed(seed)
  482. if torch.cuda.is_available():
  483. torch.cuda.manual_seed(seed)
  484. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  485. key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
  486. scale = head_size**-0.5
  487. key_caches: List[torch.Tensor] = []
  488. value_caches: List[torch.Tensor] = []
  489. for _ in range(num_layers):
  490. key_value_cache = torch.empty(size=key_value_cache_shape,
  491. dtype=torch_dtype,
  492. device=device)
  493. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  494. key_value_cache.uniform_(-scale, scale)
  495. elif cache_dtype == 'fp8':
  496. _generate_random_fp8(key_value_cache, -scale, scale)
  497. else:
  498. raise ValueError(
  499. f"Does not support key cache of type {cache_dtype}")
  500. key_caches.append(key_value_cache[:, 0])
  501. value_caches.append(key_value_cache[:, 1])
  502. return key_caches, value_caches
  503. def create_kv_caches_with_random(
  504. num_blocks: int,
  505. block_size: int,
  506. num_layers: int,
  507. num_heads: int,
  508. head_size: int,
  509. cache_dtype: Optional[Union[str, torch.dtype]],
  510. model_dtype: Optional[Union[str, torch.dtype]] = None,
  511. seed: int = 0,
  512. device: Optional[str] = "cuda",
  513. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  514. if cache_dtype == "fp8" and head_size % 16:
  515. raise ValueError(
  516. f"Does not support key cache of type fp8 with head_size "
  517. f"{head_size}")
  518. torch.random.manual_seed(seed)
  519. if torch.cuda.is_available():
  520. torch.cuda.manual_seed(seed)
  521. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  522. scale = head_size**-0.5
  523. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  524. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  525. key_caches: List[torch.Tensor] = []
  526. for _ in range(num_layers):
  527. key_cache = torch.empty(size=key_cache_shape,
  528. dtype=torch_dtype,
  529. device=device)
  530. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  531. key_cache.uniform_(-scale, scale)
  532. elif cache_dtype == 'fp8':
  533. _generate_random_fp8(key_cache, -scale, scale)
  534. else:
  535. raise ValueError(
  536. f"Does not support key cache of type {cache_dtype}")
  537. key_caches.append(key_cache)
  538. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  539. value_caches: List[torch.Tensor] = []
  540. for _ in range(num_layers):
  541. value_cache = torch.empty(size=value_cache_shape,
  542. dtype=torch_dtype,
  543. device=device)
  544. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  545. value_cache.uniform_(-scale, scale)
  546. elif cache_dtype == 'fp8':
  547. _generate_random_fp8(value_cache, -scale, scale)
  548. else:
  549. raise ValueError(
  550. f"Does not support value cache of type {cache_dtype}")
  551. value_caches.append(value_cache)
  552. return key_caches, value_caches
  553. @lru_cache
  554. def print_warning_once(msg: str) -> None:
  555. logger.warning(msg)
  556. @lru_cache(maxsize=None)
  557. def is_pin_memory_available() -> bool:
  558. if in_wsl():
  559. # Pinning memory in WSL is not supported.
  560. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
  561. print_warning_once("Using 'pin_memory=False' as WSL is detected. "
  562. "This may slow down the performance.")
  563. return False
  564. elif is_xpu():
  565. print_warning_once("Pin memory is not supported on XPU.")
  566. return False
  567. elif is_neuron():
  568. print_warning_once("Pin memory is not supported on Neuron.")
  569. return False
  570. elif is_cpu() or is_openvino():
  571. return False
  572. return True
  573. class CudaMemoryProfiler:
  574. def __init__(self, device: Optional[torch.types.Device] = None):
  575. self.device = device
  576. def current_memory_usage(self) -> float:
  577. # Return the memory usage in bytes.
  578. if torch.cuda.is_available():
  579. torch.cuda.reset_peak_memory_stats(self.device)
  580. mem = torch.cuda.max_memory_allocated(self.device)
  581. elif is_xpu():
  582. torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
  583. mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
  584. return mem
  585. def __enter__(self):
  586. self.initial_memory = self.current_memory_usage()
  587. # This allows us to call methods of the context manager if needed
  588. return self
  589. def __exit__(self, exc_type, exc_val, exc_tb):
  590. self.final_memory = self.current_memory_usage()
  591. self.consumed_memory = self.final_memory - self.initial_memory
  592. # Force garbage collection
  593. gc.collect()
  594. def str_to_int_tuple(s: str) -> Tuple[int, ...]:
  595. """Convert a string to a tuple of integers."""
  596. try:
  597. return tuple(map(int, s.split(",")))
  598. except ValueError as e:
  599. raise ValueError(
  600. "String must be a series of integers separated by commas "
  601. f"(e.g., 1, 2, 3). Given input: {s}") from e
  602. def make_ndarray_with_pad(
  603. x: List[List[T]],
  604. pad: T,
  605. dtype: npt.DTypeLike,
  606. *,
  607. max_len: Optional[int] = None,
  608. ) -> npt.NDArray:
  609. """
  610. Make a padded array from 2D inputs.
  611. The padding is applied to the end of each inner list until it reaches
  612. `max_len`.
  613. """
  614. if max_len is None:
  615. # Unlike for most functions, map is faster than a genexpr over `len`
  616. max_len = max(map(len, x), default=0)
  617. padded_x = np.full((len(x), max_len), pad, dtype=dtype)
  618. for ind, blocktb in enumerate(x):
  619. assert len(blocktb) <= max_len
  620. padded_x[ind, :len(blocktb)] = blocktb
  621. return padded_x
  622. def make_tensor_with_pad(
  623. x: List[List[T]],
  624. pad: T,
  625. dtype: torch.dtype,
  626. *,
  627. max_len: Optional[int] = None,
  628. device: Optional[Union[str, torch.device]] = None,
  629. pin_memory: bool = False,
  630. ) -> torch.Tensor:
  631. """
  632. Make a padded tensor from 2D inputs.
  633. The padding is applied to the end of each inner list until it reaches
  634. `max_len`.
  635. """
  636. np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
  637. padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
  638. tensor = torch.from_numpy(padded_x).to(device)
  639. if pin_memory:
  640. tensor = tensor.pin_memory()
  641. return tensor
  642. def async_tensor_h2d(
  643. data: list,
  644. dtype: torch.dtype,
  645. target_device: Union[str, torch.device],
  646. pin_memory: bool,
  647. ) -> torch.Tensor:
  648. """Asynchronously create a tensor and copy it from host to device."""
  649. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  650. return t.to(device=target_device, non_blocking=True)
  651. def maybe_expand_dim(tensor: torch.Tensor,
  652. target_dims: int,
  653. size: int = 1) -> torch.Tensor:
  654. """Expand the tensor to the target_dims."""
  655. if tensor.ndim < target_dims:
  656. tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
  657. return tensor
  658. def get_dtype_size(dtype: torch.dtype) -> int:
  659. """Get the size of the data type in bytes."""
  660. return torch.tensor([], dtype=dtype).element_size()
  661. def merge_dicts(dict1: Dict[K, List[T]],
  662. dict2: Dict[K, List[T]]) -> Dict[K, List[T]]:
  663. """Merge 2 dicts that have key -> List of items.
  664. When a key conflicts, the values in dict1 is prioritized.
  665. """
  666. merged_dict: Dict[K, List[T]] = defaultdict(list)
  667. for key, value in dict1.items():
  668. merged_dict[key].extend(value)
  669. for key, value in dict2.items():
  670. merged_dict[key].extend(value)
  671. return dict(merged_dict)
  672. JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
  673. Tuple["JSONTree[T]", ...], T]
  674. """A nested JSON structure where the leaves need not be JSON-serializable."""
  675. @overload
  676. def json_map_leaves(
  677. func: Callable[[T], U],
  678. value: Dict[str, JSONTree[T]],
  679. ) -> Dict[str, JSONTree[U]]:
  680. ...
  681. @overload
  682. def json_map_leaves(
  683. func: Callable[[T], U],
  684. value: List[JSONTree[T]],
  685. ) -> List[JSONTree[U]]:
  686. ...
  687. @overload
  688. def json_map_leaves(
  689. func: Callable[[T], U],
  690. value: Tuple[JSONTree[T], ...],
  691. ) -> Tuple[JSONTree[U], ...]:
  692. ...
  693. @overload
  694. def json_map_leaves(
  695. func: Callable[[T], U],
  696. value: JSONTree[T],
  697. ) -> JSONTree[U]:
  698. ...
  699. def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
  700. if isinstance(value, dict):
  701. return {k: json_map_leaves(func, v) for k, v in value.items()}
  702. elif isinstance(value, list):
  703. return [json_map_leaves(func, v) for v in value]
  704. elif isinstance(value, tuple):
  705. return tuple(json_map_leaves(func, v) for v in value)
  706. else:
  707. return func(value)
  708. def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
  709. """Flatten a list of lists to a single list."""
  710. return [item for sublist in lists for item in sublist]
  711. def init_cached_hf_modules() -> None:
  712. """
  713. Lazy initialization of the Hugging Face modules.
  714. """
  715. from transformers.dynamic_module_utils import init_hf_modules
  716. init_hf_modules()
  717. @lru_cache(maxsize=None)
  718. def find_library(lib_name: str) -> str:
  719. """
  720. Find the library file in the system.
  721. `lib_name` is full filename, with both prefix and suffix.
  722. This function resolves `lib_name` to the full path of the library.
  723. """
  724. # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
  725. # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
  726. # `/sbin/ldconfig` should exist in all Linux systems.
  727. # `/sbin/ldconfig` searches the library in the system
  728. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
  729. # each line looks like the following:
  730. # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
  731. locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
  732. # `LD_LIBRARY_PATH` searches the library in the user-defined paths
  733. env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
  734. if not locs and env_ld_library_path:
  735. locs = [
  736. os.path.join(dir, lib_name)
  737. for dir in env_ld_library_path.split(":")
  738. if os.path.exists(os.path.join(dir, lib_name))
  739. ]
  740. if not locs:
  741. raise ValueError(f"Cannot find {lib_name} in the system.")
  742. return locs[0]
  743. def find_nccl_library() -> str:
  744. """
  745. We either use the library file specified by the `APHRODITE_NCCL_SO_PATH`
  746. environment variable, or we find the library file brought by PyTorch.
  747. After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
  748. found by `ctypes` automatically.
  749. """
  750. so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
  751. # manually load the nccl library
  752. if so_file:
  753. logger.debug("Found nccl from environment variable "
  754. f"APHRODITE_NCCL_SO_PATH={so_file}")
  755. else:
  756. if torch.version.cuda is not None:
  757. so_file = "libnccl.so.2"
  758. elif torch.version.hip is not None:
  759. so_file = "librccl.so.1"
  760. else:
  761. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  762. logger.debug(f"Found nccl from library {so_file}")
  763. return so_file
  764. def enable_trace_function_call_for_thread() -> None:
  765. if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
  766. tmp_dir = tempfile.gettempdir()
  767. filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
  768. f"_thread_{threading.get_ident()}_"
  769. f"at_{datetime.datetime.now()}.log").replace(" ", "_")
  770. log_path = os.path.join(tmp_dir, "aphrodite",
  771. get_aphrodite_instance_id(), filename)
  772. os.makedirs(os.path.dirname(log_path), exist_ok=True)
  773. enable_trace_function_call(log_path)
  774. def identity(value: T) -> T:
  775. return value
  776. F = TypeVar('F', bound=Callable[..., Any])
  777. def deprecate_kwargs(
  778. *kws: str,
  779. is_deprecated: Union[bool, Callable[[], bool]] = True,
  780. additional_message: Optional[str] = None) -> Callable[[F], F]:
  781. deprecated_kws = set(kws)
  782. if not callable(is_deprecated):
  783. is_deprecated = partial(identity, is_deprecated)
  784. def wrapper(fn: F) -> F:
  785. @wraps(fn)
  786. def inner(*args, **kwargs):
  787. if is_deprecated():
  788. deprecated_kwargs = kwargs.keys() & deprecated_kws
  789. if deprecated_kwargs:
  790. msg = (
  791. f"The keyword arguments {deprecated_kwargs} are "
  792. "deprecated and will be removed in a future update.")
  793. if additional_message is not None:
  794. msg += f" {additional_message}"
  795. warnings.warn(
  796. DeprecationWarning(msg),
  797. stacklevel=3, # The inner function takes up one level
  798. )
  799. return fn(*args, **kwargs)
  800. return inner # type: ignore
  801. return wrapper
  802. @lru_cache(maxsize=8)
  803. def _cuda_device_count_stateless(
  804. cuda_visible_devices: Optional[str] = None) -> int:
  805. # Note: cuda_visible_devices is not used, but we keep it as an argument for
  806. # LRU Cache purposes.
  807. # Code below is based on
  808. # https://github.com/pytorch/pytorch/blob/
  809. # c1cd946818442aca8c7f812b16d187ce1586c3bc/
  810. # torch/cuda/__init__.py#L831C1-L831C17
  811. import torch.cuda
  812. import torch.version
  813. if not torch.cuda._is_compiled():
  814. return 0
  815. if is_hip():
  816. # ROCm uses amdsmi instead of nvml for stateless device count
  817. # This requires a sufficiently modern version of Torch 2.4.0
  818. raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
  819. torch.cuda, "_device_count_amdsmi")) else -1
  820. else:
  821. raw_count = torch.cuda._device_count_nvml()
  822. r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
  823. return r
  824. def cuda_device_count_stateless() -> int:
  825. """Get number of CUDA devices, caching based on the value of
  826. CUDA_VISIBLE_DEVICES at the time of call.
  827. This should be used instead of torch.cuda.device_count()
  828. unless CUDA_VISIBLE_DEVICES has already been set to the desired
  829. value."""
  830. # This can be removed and simply replaced with torch.cuda.get_device_count
  831. # after https://github.com/pytorch/pytorch/pull/122815 is released.
  832. return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES"))
  833. #From: https://stackoverflow.com/a/4104188/2749989
  834. def run_once(f):
  835. def wrapper(*args, **kwargs) -> Any:
  836. if not wrapper.has_run: # type: ignore[attr-defined]
  837. wrapper.has_run = True # type: ignore[attr-defined]
  838. return f(*args, **kwargs)
  839. wrapper.has_run = False # type: ignore[attr-defined]
  840. return wrapper
  841. class FlexibleArgumentParser(argparse.ArgumentParser):
  842. """ArgumentParser that allows both underscore and dash in names."""
  843. def parse_args(self, args=None, namespace=None):
  844. if args is None:
  845. args = sys.argv[1:]
  846. # Convert underscores to dashes and vice versa in argument names
  847. processed_args = []
  848. for arg in args:
  849. if arg.startswith('--'):
  850. if '=' in arg:
  851. key, value = arg.split('=', 1)
  852. key = '--' + key[len('--'):].replace('_', '-')
  853. processed_args.append(f'{key}={value}')
  854. else:
  855. processed_args.append('--' +
  856. arg[len('--'):].replace('_', '-'))
  857. else:
  858. processed_args.append(arg)
  859. return super().parse_args(processed_args, namespace)
  860. async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
  861. **kwargs):
  862. """Utility function to run async task in a lock"""
  863. async with lock:
  864. return await task(*args, **kwargs)
  865. def is_encoder_decoder_model_config(model_config) -> bool:
  866. '''
  867. Extract the HF encoder/decoder model flag from the ModelConfig instance.
  868. Return False if model_config is None.
  869. '''
  870. return model_config is not None and \
  871. getattr(model_config.hf_config,
  872. "is_encoder_decoder",
  873. False)
  874. def is_embedding_model_config(model_config) -> bool:
  875. '''
  876. Extract the embedding model flag from the ModelConfig instance.
  877. Return False if model_config is None.
  878. '''
  879. return model_config is not None and \
  880. model_config.embedding_mode
  881. def build_explicit_enc_dec_prompt(
  882. encoder_prompt: SingletonPromptInputs,
  883. decoder_prompt: SingletonPromptInputs,
  884. ) -> ExplicitEncoderDecoderPrompt:
  885. return ExplicitEncoderDecoderPrompt(encoder_prompt=encoder_prompt,
  886. decoder_prompt=decoder_prompt)
  887. def zip_enc_dec_prompt_lists(
  888. enc_prompt_list: List[SingletonPromptInputs],
  889. dec_prompt_list: List[SingletonPromptInputs],
  890. ) -> List[ExplicitEncoderDecoderPrompt]:
  891. return [
  892. build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt)
  893. for (encoder_prompt,
  894. decoder_prompt) in zip(enc_prompt_list, dec_prompt_list)
  895. ]
  896. def to_enc_dec_tuple_list(
  897. enc_dec_prompts: List[ExplicitEncoderDecoderPrompt],
  898. ) -> List[Tuple[PromptInputs, PromptInputs]]:
  899. return [(enc_dec_prompt['encoder_prompt'],
  900. enc_dec_prompt['decoder_prompt'])
  901. for enc_dec_prompt in enc_dec_prompts]