utils.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247
  1. import argparse
  2. import asyncio
  3. import contextlib
  4. import datetime
  5. import enum
  6. import gc
  7. import inspect
  8. import ipaddress
  9. import math
  10. import os
  11. import random
  12. import socket
  13. import subprocess
  14. import sys
  15. import tempfile
  16. import threading
  17. import uuid
  18. import warnings
  19. import weakref
  20. from asyncio import FIRST_COMPLETED, ensure_future
  21. from functools import lru_cache, partial, wraps
  22. from platform import uname
  23. from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic,
  24. Hashable, Iterable, List, Literal, Optional, OrderedDict,
  25. Set, Tuple, Type, TypeVar, Union, overload)
  26. from uuid import uuid4
  27. import numpy as np
  28. import numpy.typing as npt
  29. import psutil
  30. import torch
  31. import torch.types
  32. from loguru import logger
  33. from packaging.version import Version
  34. from rich.progress import (BarColumn, MofNCompleteColumn, Progress,
  35. SpinnerColumn, TextColumn, TimeElapsedColumn)
  36. from typing_extensions import ParamSpec, TypeIs, assert_never
  37. import aphrodite.common.envs as envs
  38. from aphrodite.common.logger import enable_trace_function_call
  39. from aphrodite.distributed import get_tensor_model_parallel_rank
  40. from aphrodite.platforms import current_platform
  41. # Exception strings for non-implemented encoder/decoder scenarios
  42. STR_NOT_IMPL_ENC_DEC_SWA = \
  43. "Sliding window attention for encoder/decoder models " + \
  44. "is not currently supported."
  45. STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \
  46. "Prefix caching for encoder/decoder models " + \
  47. "is not currently supported."
  48. STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \
  49. "Chunked prefill for encoder/decoder models " + \
  50. "is not currently supported."
  51. STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = (
  52. "Models with logits_soft_cap "
  53. "require FlashInfer backend, which is "
  54. "currently not supported for encoder/decoder "
  55. "models.")
  56. STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently "
  57. "supported with encoder/decoder "
  58. "models.")
  59. STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not "
  60. "currently supported with "
  61. "encoder/decoder models.")
  62. STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently "
  63. "supported with encoder/decoder "
  64. "models.")
  65. STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not "
  66. "currently supported with encoder/"
  67. "decoder models.")
  68. STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend "
  69. "currently supported with encoder/"
  70. "decoder models.")
  71. STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
  72. "currently supported with encoder/"
  73. "decoder models.")
  74. STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with "
  75. "encoder/decoder models.")
  76. # Efficiently import all enc/dec error strings
  77. # rather than having to import all of the above
  78. STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
  79. "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA,
  80. "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE,
  81. "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL":
  82. STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL,
  83. "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP,
  84. "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA,
  85. "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP,
  86. "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM,
  87. "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
  88. "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
  89. "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
  90. "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU
  91. }
  92. # Constants related to forcing the attention backend selection
  93. # String name of register which may be set in order to
  94. # force auto-selection of attention backend by Attention
  95. # wrapper
  96. STR_BACKEND_ENV_VAR: str = "APHRODITE_ATTENTION_BACKEND"
  97. # Possible string values of STR_BACKEND_ENV_VAR
  98. # register, corresponding to possible backends
  99. STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER"
  100. STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
  101. STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
  102. STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
  103. STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
  104. STR_INVALID_VAL: str = "INVALID"
  105. GiB_bytes = 1 << 30
  106. """The number of bytes in one gibibyte (GiB)."""
  107. STR_DTYPE_TO_TORCH_DTYPE = {
  108. "half": torch.half,
  109. "bfloat16": torch.bfloat16,
  110. "float": torch.float,
  111. "fp8": torch.uint8,
  112. "fp8_e4m3": torch.uint8,
  113. "fp8_e5m2": torch.uint8,
  114. }
  115. TORCH_DTYPE_TO_NUMPY_DTYPE = {
  116. torch.float16: np.float16,
  117. torch.float32: np.float32,
  118. torch.float64: np.float64,
  119. torch.uint8: np.uint8,
  120. torch.int32: np.int32,
  121. torch.int64: np.int64,
  122. }
  123. P = ParamSpec('P')
  124. K = TypeVar("K")
  125. T = TypeVar("T")
  126. U = TypeVar("U")
  127. class _Sentinel:
  128. ...
  129. ALL_PINNED_SENTINEL = _Sentinel()
  130. class Device(enum.Enum):
  131. GPU = enum.auto()
  132. CPU = enum.auto()
  133. class Counter:
  134. def __init__(self, start: int = 0) -> None:
  135. self.counter = start
  136. def __next__(self) -> int:
  137. i = self.counter
  138. self.counter += 1
  139. return i
  140. def reset(self) -> None:
  141. self.counter = 0
  142. class LRUCache(Generic[T]):
  143. def __init__(self, capacity: int):
  144. self.cache: OrderedDict[Hashable, T] = OrderedDict()
  145. self.pinned_items: Set[Hashable] = set()
  146. self.capacity = capacity
  147. def __contains__(self, key: Hashable) -> bool:
  148. return key in self.cache
  149. def __len__(self) -> int:
  150. return len(self.cache)
  151. def __getitem__(self, key: Hashable) -> T:
  152. value = self.cache[key] # Raise KeyError if not exists
  153. self.cache.move_to_end(key)
  154. return value
  155. def __setitem__(self, key: Hashable, value: T) -> None:
  156. self.put(key, value)
  157. def __delitem__(self, key: Hashable) -> None:
  158. self.pop(key)
  159. def touch(self, key: Hashable) -> None:
  160. self.cache.move_to_end(key)
  161. def get(self,
  162. key: Hashable,
  163. default_value: Optional[T] = None) -> Optional[T]:
  164. value: Optional[T]
  165. if key in self.cache:
  166. value = self.cache[key]
  167. self.cache.move_to_end(key)
  168. else:
  169. value = default_value
  170. return value
  171. def put(self, key: Hashable, value: T) -> None:
  172. self.cache[key] = value
  173. self.cache.move_to_end(key)
  174. self._remove_old_if_needed()
  175. def pin(self, key: Hashable) -> None:
  176. """
  177. Pins a key in the cache preventing it from being
  178. evicted in the LRU order.
  179. """
  180. if key not in self.cache:
  181. raise ValueError(f"Cannot pin key: {key} not in cache.")
  182. self.pinned_items.add(key)
  183. def _unpin(self, key: Hashable) -> None:
  184. self.pinned_items.remove(key)
  185. def _on_remove(self, key: Hashable, value: Optional[T]):
  186. pass
  187. def remove_oldest(self, remove_pinned=False):
  188. if not self.cache:
  189. return
  190. if not remove_pinned:
  191. # pop the oldest item in the cache that is not pinned
  192. lru_key = next(
  193. (key for key in self.cache if key not in self.pinned_items),
  194. ALL_PINNED_SENTINEL)
  195. if lru_key is ALL_PINNED_SENTINEL:
  196. raise RuntimeError("All items are pinned, "
  197. "cannot remove oldest from the cache.")
  198. else:
  199. lru_key = next(iter(self.cache))
  200. self.pop(lru_key)
  201. def _remove_old_if_needed(self) -> None:
  202. while len(self.cache) > self.capacity:
  203. self.remove_oldest()
  204. def pop(self,
  205. key: Hashable,
  206. default_value: Optional[T] = None) -> Optional[T]:
  207. run_on_remove = key in self.cache
  208. value: Optional[T] = self.cache.pop(key, default_value)
  209. # remove from pinned items
  210. if key in self.pinned_items:
  211. self._unpin(key)
  212. if run_on_remove:
  213. self._on_remove(key, value)
  214. return value
  215. def clear(self):
  216. while len(self.cache) > 0:
  217. self.remove_oldest(remove_pinned=True)
  218. self.cache.clear()
  219. class PyObjectCache:
  220. """Used to cache python objects to avoid object allocations
  221. across scheduler iterations.
  222. """
  223. def __init__(self, obj_builder):
  224. self._obj_builder = obj_builder
  225. self._index = 0
  226. self._obj_cache = []
  227. for _ in range(128):
  228. self._obj_cache.append(self._obj_builder())
  229. def _grow_cache(self):
  230. # Double the size of the cache
  231. num_objs = len(self._obj_cache)
  232. for _ in range(num_objs):
  233. self._obj_cache.append(self._obj_builder())
  234. def get_object(self):
  235. """Returns a pre-allocated cached object. If there is not enough
  236. objects, then the cache size will double.
  237. """
  238. if self._index >= len(self._obj_cache):
  239. self._grow_cache()
  240. assert self._index < len(self._obj_cache)
  241. obj = self._obj_cache[self._index]
  242. self._index += 1
  243. return obj
  244. def reset(self):
  245. """Makes all cached-objects available for the next scheduler iteration.
  246. """
  247. self._index = 0
  248. def is_hip() -> bool:
  249. return torch.version.hip is not None
  250. @lru_cache(maxsize=None)
  251. def is_cpu() -> bool:
  252. from importlib.metadata import PackageNotFoundError, version
  253. try:
  254. return "cpu" in version("aphrodite-engine")
  255. except PackageNotFoundError:
  256. return False
  257. @lru_cache(maxsize=None)
  258. def is_openvino() -> bool:
  259. from importlib.metadata import PackageNotFoundError, version
  260. try:
  261. return "openvino" in version("aphrodite-engine")
  262. except PackageNotFoundError:
  263. return False
  264. @lru_cache(maxsize=None)
  265. def is_neuron() -> bool:
  266. try:
  267. import transformers_neuronx
  268. except ImportError:
  269. transformers_neuronx = None
  270. return transformers_neuronx is not None
  271. @lru_cache(maxsize=None)
  272. def is_xpu() -> bool:
  273. from importlib.metadata import version
  274. is_xpu_flag = "xpu" in version("aphrodite-engine")
  275. # aphrodite is not build with xpu
  276. if not is_xpu_flag:
  277. return False
  278. try:
  279. import intel_extension_for_pytorch as ipex # noqa: F401
  280. _import_ipex = True
  281. except ImportError as e:
  282. logger.warning(f"Import Error for IPEX: {e.msg}")
  283. _import_ipex = False
  284. # ipex dependency is not ready
  285. if not _import_ipex:
  286. logger.warning("not found ipex lib")
  287. return False
  288. return hasattr(torch, "xpu") and torch.xpu.is_available()
  289. @lru_cache(maxsize=None)
  290. def is_triton() -> bool:
  291. if envs.APHRODITE_USE_TRITON_BACKEND:
  292. try:
  293. import triton # noqa: F401
  294. except ImportError:
  295. return False
  296. return True
  297. return False
  298. @lru_cache(maxsize=None)
  299. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  300. """Returns the maximum shared memory per thread block in bytes."""
  301. from aphrodite import _custom_ops as ops
  302. max_shared_mem = (
  303. ops.get_max_shared_memory_per_block_device_attribute(gpu))
  304. # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
  305. # will fail
  306. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  307. return int(max_shared_mem)
  308. def get_cpu_memory() -> int:
  309. """Returns the total CPU memory of the node in bytes."""
  310. return psutil.virtual_memory().total
  311. def seed_everything(seed: int) -> None:
  312. """
  313. Set the seed of each random module.
  314. Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20
  315. """
  316. random.seed(seed)
  317. np.random.seed(seed)
  318. if current_platform.is_cuda():
  319. torch.cuda.manual_seed_all(seed)
  320. if is_xpu():
  321. torch.xpu.manual_seed_all(seed)
  322. def random_uuid() -> str:
  323. return str(uuid.uuid4().hex)
  324. @lru_cache(maxsize=None)
  325. def get_aphrodite_instance_id():
  326. """
  327. If the environment variable APHRODITE_INSTANCE_ID is set, return it.
  328. Otherwise, return a random UUID.
  329. Instance id represents an instance of the Aphrodite. All processes in the
  330. same instance should have the same instance id.
  331. """
  332. return envs.APHRODITE_INSTANCE_ID or f"aphrodite-instance-{random_uuid()}"
  333. @lru_cache(maxsize=None)
  334. def in_wsl() -> bool:
  335. # Reference: https://github.com/microsoft/WSL/issues/4071
  336. return "microsoft" in " ".join(uname()).lower()
  337. @lru_cache(maxsize=None)
  338. def in_windows() -> bool:
  339. return sys.platform.startswith("win32")
  340. def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
  341. """Take a blocking function, and run it on in an executor thread.
  342. This function prevents the blocking function from blocking the
  343. asyncio event loop.
  344. The code in this function needs to be thread safe.
  345. """
  346. def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
  347. loop = asyncio.get_event_loop()
  348. p_func = partial(func, *args, **kwargs)
  349. return loop.run_in_executor(executor=None, func=p_func)
  350. return _async_wrapper
  351. async def iterate_with_cancellation(
  352. iterator: AsyncGenerator[T, None],
  353. is_cancelled: Callable[[], Awaitable[bool]],
  354. ) -> AsyncGenerator[T, None]:
  355. """Convert async iterator into one that polls the provided function
  356. at least once per second to check for client cancellation.
  357. """
  358. # Can use anext() in python >= 3.10
  359. awaits = [ensure_future(iterator.__anext__())]
  360. while True:
  361. done, pending = await asyncio.wait(awaits, timeout=1)
  362. if await is_cancelled():
  363. with contextlib.suppress(BaseException):
  364. awaits[0].cancel()
  365. await iterator.aclose()
  366. raise asyncio.CancelledError("client cancelled")
  367. if done:
  368. try:
  369. item = await awaits[0]
  370. awaits[0] = ensure_future(iterator.__anext__())
  371. yield item
  372. except StopAsyncIteration:
  373. # we are done
  374. return
  375. async def merge_async_iterators(
  376. *iterators: AsyncGenerator[T, None],
  377. is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None,
  378. ) -> AsyncGenerator[Tuple[int, T], None]:
  379. """Merge multiple asynchronous iterators into a single iterator.
  380. This method handle the case where some iterators finish before others.
  381. When it yields, it yields a tuple (i, item) where i is the index of the
  382. iterator that yields the item.
  383. It also optionally polls a provided function at least once per second
  384. to check for client cancellation.
  385. """
  386. # Can use anext() in python >= 3.10
  387. awaits = {
  388. ensure_future(pair[1].__anext__()): pair
  389. for pair in enumerate(iterators)
  390. }
  391. timeout = None if is_cancelled is None else 1
  392. try:
  393. while awaits:
  394. done, pending = await asyncio.wait(awaits.keys(),
  395. return_when=FIRST_COMPLETED,
  396. timeout=timeout)
  397. if is_cancelled is not None and await is_cancelled():
  398. raise asyncio.CancelledError("client cancelled")
  399. for d in done:
  400. pair = awaits.pop(d)
  401. try:
  402. item = await d
  403. i, it = pair
  404. awaits[ensure_future(it.__anext__())] = pair
  405. yield i, item
  406. except StopAsyncIteration:
  407. pass
  408. finally:
  409. # Cancel any remaining iterators
  410. for f, (_, it) in awaits.items():
  411. with contextlib.suppress(BaseException):
  412. f.cancel()
  413. await it.aclose()
  414. def get_ip() -> str:
  415. host_ip = os.environ.get("HOST_IP")
  416. if host_ip:
  417. return host_ip
  418. # IP is not set, try to get it from the network interface
  419. # try ipv4
  420. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  421. try:
  422. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  423. return s.getsockname()[0]
  424. except Exception:
  425. pass
  426. # try ipv6
  427. try:
  428. s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  429. # Google's public DNS server, see
  430. # https://developers.google.com/speed/public-dns/docs/using#addresses
  431. s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
  432. return s.getsockname()[0]
  433. except Exception:
  434. pass
  435. warnings.warn(
  436. "Failed to get the IP address, using 0.0.0.0 by default."
  437. "The value can be set by the environment variable HOST_IP.",
  438. stacklevel=2)
  439. return "0.0.0.0"
  440. def is_valid_ipv6_address(address: str) -> bool:
  441. try:
  442. ipaddress.IPv6Address(address)
  443. return True
  444. except ValueError:
  445. return False
  446. def get_distributed_init_method(ip: str, port: int) -> str:
  447. # Brackets are not permitted in ipv4 addresses,
  448. # see https://github.com/python/cpython/issues/103848
  449. return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
  450. def get_open_zmq_ipc_path() -> str:
  451. if not in_windows():
  452. base_rpc_path = envs.APHRODITE_RPC_BASE_PATH
  453. return f"ipc://{base_rpc_path}/{uuid4()}"
  454. else:
  455. # windows doesn't support ipc://
  456. # use tcp:// instead
  457. return f"tcp://127.0.0.1:{get_open_port()}"
  458. def get_open_port(port: Optional[int] = None) -> int:
  459. port = envs.APHRODITE_PORT
  460. if port is not None:
  461. while True:
  462. try:
  463. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  464. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  465. s.bind(("", port))
  466. return port
  467. except OSError:
  468. port += 1 # Increment port number if already in use
  469. logger.info(f"Port {port - 1} is already in use, trying port "
  470. f"{port}")
  471. # try ipv4
  472. try:
  473. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  474. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  475. s.bind(("", 0))
  476. return s.getsockname()[1]
  477. except OSError:
  478. # try ipv6
  479. with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
  480. s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  481. s.bind(("", 0))
  482. return s.getsockname()[1]
  483. def find_process_using_port(port: int) -> Optional[psutil.Process]:
  484. for conn in psutil.net_connections():
  485. if conn.laddr.port == port:
  486. try:
  487. return psutil.Process(conn.pid)
  488. except psutil.NoSuchProcess:
  489. return None
  490. return None
  491. def update_environment_variables(envs: Dict[str, str]):
  492. for k, v in envs.items():
  493. if k in os.environ and os.environ[k] != v:
  494. logger.warning(f"Overwriting environment variable {k} "
  495. f"from '{os.environ[k]}' to '{v}'")
  496. os.environ[k] = v
  497. def chunk_list(lst: List[T], chunk_size: int):
  498. """Yield successive chunk_size chunks from lst."""
  499. for i in range(0, len(lst), chunk_size):
  500. yield lst[i:i + chunk_size]
  501. def cdiv(a: int, b: int) -> int:
  502. """Ceiling division."""
  503. return -(a // -b)
  504. def _generate_random_fp8(
  505. tensor: torch.Tensor,
  506. low: float,
  507. high: float,
  508. ) -> None:
  509. # NOTE: Due to NaN and Inf representation for fp8 data type,
  510. # it may occur Inf or NaN if we directly use torch.randint
  511. # to generate random data for fp8 data.
  512. # For example, s.11111.00 in fp8e5m2 format represents Inf.
  513. # | E4M3 | E5M2
  514. #-----|-------------|-------------------
  515. # Inf | N/A | s.11111.00
  516. # NaN | s.1111.111 | s.11111.{01,10,11}
  517. from aphrodite import _custom_ops as ops
  518. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  519. tensor_tmp.uniform_(low, high)
  520. ops.convert_fp8(tensor, tensor_tmp)
  521. del tensor_tmp
  522. def get_kv_cache_torch_dtype(
  523. cache_dtype: Optional[Union[str, torch.dtype]],
  524. model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
  525. if isinstance(cache_dtype, str):
  526. if cache_dtype == "auto":
  527. if isinstance(model_dtype, str):
  528. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  529. elif isinstance(model_dtype, torch.dtype):
  530. torch_dtype = model_dtype
  531. else:
  532. raise ValueError(f"Invalid model dtype: {model_dtype}")
  533. elif cache_dtype in ["half", "bfloat16", "float"]:
  534. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  535. elif cache_dtype == "fp8":
  536. torch_dtype = torch.uint8
  537. else:
  538. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  539. elif isinstance(cache_dtype, torch.dtype):
  540. torch_dtype = cache_dtype
  541. else:
  542. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  543. return torch_dtype
  544. def create_kv_caches_with_random_flash(
  545. num_blocks: int,
  546. block_size: int,
  547. num_layers: int,
  548. num_heads: int,
  549. head_size: int,
  550. cache_dtype: Optional[Union[str, torch.dtype]],
  551. model_dtype: Optional[Union[str, torch.dtype]] = None,
  552. seed: int = 0,
  553. device: Optional[str] = "cuda",
  554. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  555. torch.random.manual_seed(seed)
  556. if torch.cuda.is_available():
  557. torch.cuda.manual_seed(seed)
  558. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  559. key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
  560. scale = head_size**-0.5
  561. key_caches: List[torch.Tensor] = []
  562. value_caches: List[torch.Tensor] = []
  563. for _ in range(num_layers):
  564. key_value_cache = torch.empty(size=key_value_cache_shape,
  565. dtype=torch_dtype,
  566. device=device)
  567. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  568. key_value_cache.uniform_(-scale, scale)
  569. elif cache_dtype == 'fp8':
  570. _generate_random_fp8(key_value_cache, -scale, scale)
  571. else:
  572. raise ValueError(
  573. f"Does not support key cache of type {cache_dtype}")
  574. key_caches.append(key_value_cache[:, 0])
  575. value_caches.append(key_value_cache[:, 1])
  576. return key_caches, value_caches
  577. def create_kv_caches_with_random(
  578. num_blocks: int,
  579. block_size: int,
  580. num_layers: int,
  581. num_heads: int,
  582. head_size: int,
  583. cache_dtype: Optional[Union[str, torch.dtype]],
  584. model_dtype: Optional[Union[str, torch.dtype]] = None,
  585. seed: int = 0,
  586. device: Optional[str] = "cuda",
  587. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  588. if cache_dtype == "fp8" and head_size % 16:
  589. raise ValueError(
  590. f"Does not support key cache of type fp8 with head_size "
  591. f"{head_size}")
  592. torch.random.manual_seed(seed)
  593. if torch.cuda.is_available():
  594. torch.cuda.manual_seed(seed)
  595. torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
  596. scale = head_size**-0.5
  597. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  598. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  599. key_caches: List[torch.Tensor] = []
  600. for _ in range(num_layers):
  601. key_cache = torch.empty(size=key_cache_shape,
  602. dtype=torch_dtype,
  603. device=device)
  604. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  605. key_cache.uniform_(-scale, scale)
  606. elif cache_dtype == 'fp8':
  607. _generate_random_fp8(key_cache, -scale, scale)
  608. else:
  609. raise ValueError(
  610. f"Does not support key cache of type {cache_dtype}")
  611. key_caches.append(key_cache)
  612. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  613. value_caches: List[torch.Tensor] = []
  614. for _ in range(num_layers):
  615. value_cache = torch.empty(size=value_cache_shape,
  616. dtype=torch_dtype,
  617. device=device)
  618. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  619. value_cache.uniform_(-scale, scale)
  620. elif cache_dtype == 'fp8':
  621. _generate_random_fp8(value_cache, -scale, scale)
  622. else:
  623. raise ValueError(
  624. f"Does not support value cache of type {cache_dtype}")
  625. value_caches.append(value_cache)
  626. return key_caches, value_caches
  627. @lru_cache
  628. def print_warning_once(msg: str) -> None:
  629. logger.warning(msg)
  630. @lru_cache(maxsize=None)
  631. def is_pin_memory_available() -> bool:
  632. if in_wsl():
  633. # Pinning memory in WSL is not supported.
  634. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
  635. print_warning_once("Using 'pin_memory=False' as WSL is detected. "
  636. "This may slow down the performance.")
  637. return False
  638. elif is_xpu():
  639. print_warning_once("Pin memory is not supported on XPU.")
  640. return False
  641. elif is_neuron():
  642. print_warning_once("Pin memory is not supported on Neuron.")
  643. return False
  644. elif is_cpu() or is_openvino():
  645. return False
  646. return True
  647. class DeviceMemoryProfiler:
  648. def __init__(self, device: Optional[torch.types.Device] = None):
  649. self.device = device
  650. def current_memory_usage(self) -> float:
  651. # Return the memory usage in bytes.
  652. if torch.cuda.is_available():
  653. torch.cuda.reset_peak_memory_stats(self.device)
  654. mem = torch.cuda.max_memory_allocated(self.device)
  655. elif is_xpu():
  656. torch.xpu.reset_peak_memory_stats(self.device) # type: ignore
  657. mem = torch.xpu.max_memory_allocated(self.device) # type: ignore
  658. return mem
  659. def __enter__(self):
  660. self.initial_memory = self.current_memory_usage()
  661. # This allows us to call methods of the context manager if needed
  662. return self
  663. def __exit__(self, exc_type, exc_val, exc_tb):
  664. self.final_memory = self.current_memory_usage()
  665. self.consumed_memory = self.final_memory - self.initial_memory
  666. # Force garbage collection
  667. gc.collect()
  668. def make_ndarray_with_pad(
  669. x: List[List[T]],
  670. pad: T,
  671. dtype: npt.DTypeLike,
  672. *,
  673. max_len: Optional[int] = None,
  674. ) -> npt.NDArray:
  675. """
  676. Make a padded array from 2D inputs.
  677. The padding is applied to the end of each inner list until it reaches
  678. `max_len`.
  679. """
  680. if max_len is None:
  681. # Unlike for most functions, map is faster than a genexpr over `len`
  682. max_len = max(map(len, x), default=0)
  683. padded_x = np.full((len(x), max_len), pad, dtype=dtype)
  684. for ind, blocktb in enumerate(x):
  685. assert len(blocktb) <= max_len
  686. padded_x[ind, :len(blocktb)] = blocktb
  687. return padded_x
  688. def make_tensor_with_pad(
  689. x: List[List[T]],
  690. pad: T,
  691. dtype: torch.dtype,
  692. *,
  693. max_len: Optional[int] = None,
  694. device: Optional[Union[str, torch.device]] = None,
  695. pin_memory: bool = False,
  696. ) -> torch.Tensor:
  697. """
  698. Make a padded tensor from 2D inputs.
  699. The padding is applied to the end of each inner list until it reaches
  700. `max_len`.
  701. """
  702. np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
  703. padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len)
  704. tensor = torch.from_numpy(padded_x).to(device)
  705. if pin_memory:
  706. tensor = tensor.pin_memory()
  707. return tensor
  708. def async_tensor_h2d(
  709. data: list,
  710. dtype: torch.dtype,
  711. target_device: Union[str, torch.device],
  712. pin_memory: bool,
  713. ) -> torch.Tensor:
  714. """Asynchronously create a tensor and copy it from host to device."""
  715. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  716. return t.to(device=target_device, non_blocking=True)
  717. def get_dtype_size(dtype: torch.dtype) -> int:
  718. """Get the size of the data type in bytes."""
  719. return torch.tensor([], dtype=dtype).element_size()
  720. # `collections` helpers
  721. def is_list_of(
  722. value: object,
  723. typ: Type[T],
  724. *,
  725. check: Literal["first", "all"] = "first",
  726. ) -> TypeIs[List[T]]:
  727. if not isinstance(value, list):
  728. return False
  729. if check == "first":
  730. return len(value) == 0 or isinstance(value[0], typ)
  731. elif check == "all":
  732. return all(isinstance(v, typ) for v in value)
  733. assert_never(check)
  734. JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
  735. Tuple["JSONTree[T]", ...], T]
  736. """A nested JSON structure where the leaves need not be JSON-serializable."""
  737. @overload
  738. def json_map_leaves(
  739. func: Callable[[T], U],
  740. value: Dict[str, JSONTree[T]],
  741. ) -> Dict[str, JSONTree[U]]:
  742. ...
  743. @overload
  744. def json_map_leaves(
  745. func: Callable[[T], U],
  746. value: List[JSONTree[T]],
  747. ) -> List[JSONTree[U]]:
  748. ...
  749. @overload
  750. def json_map_leaves(
  751. func: Callable[[T], U],
  752. value: Tuple[JSONTree[T], ...],
  753. ) -> Tuple[JSONTree[U], ...]:
  754. ...
  755. @overload
  756. def json_map_leaves(
  757. func: Callable[[T], U],
  758. value: JSONTree[T],
  759. ) -> JSONTree[U]:
  760. ...
  761. def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
  762. if isinstance(value, dict):
  763. return {k: json_map_leaves(func, v) for k, v in value.items()}
  764. elif isinstance(value, list):
  765. return [json_map_leaves(func, v) for v in value]
  766. elif isinstance(value, tuple):
  767. return tuple(json_map_leaves(func, v) for v in value)
  768. else:
  769. return func(value)
  770. def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
  771. """Flatten a list of lists to a single list."""
  772. return [item for sublist in lists for item in sublist]
  773. def init_cached_hf_modules() -> None:
  774. """
  775. Lazy initialization of the Hugging Face modules.
  776. """
  777. from transformers.dynamic_module_utils import init_hf_modules
  778. init_hf_modules()
  779. @lru_cache(maxsize=None)
  780. def find_library(lib_name: str) -> str:
  781. """
  782. Find the library file in the system.
  783. `lib_name` is full filename, with both prefix and suffix.
  784. This function resolves `lib_name` to the full path of the library.
  785. """
  786. # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
  787. # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
  788. # `/sbin/ldconfig` should exist in all Linux systems.
  789. # `/sbin/ldconfig` searches the library in the system
  790. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
  791. # each line looks like the following:
  792. # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
  793. locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
  794. # `LD_LIBRARY_PATH` searches the library in the user-defined paths
  795. env_ld_library_path = envs.LD_LIBRARY_PATH
  796. if not locs and env_ld_library_path:
  797. locs = [
  798. os.path.join(dir, lib_name)
  799. for dir in env_ld_library_path.split(":")
  800. if os.path.exists(os.path.join(dir, lib_name))
  801. ]
  802. if not locs:
  803. raise ValueError(f"Cannot find {lib_name} in the system.")
  804. return locs[0]
  805. def find_nccl_library() -> str:
  806. """
  807. We either use the library file specified by the `APHRODITE_NCCL_SO_PATH`
  808. environment variable, or we find the library file brought by PyTorch.
  809. After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
  810. found by `ctypes` automatically.
  811. """
  812. so_file = envs.APHRODITE_NCCL_SO_PATH
  813. # manually load the nccl library
  814. if so_file:
  815. logger.debug("Found nccl from environment variable "
  816. f"APHRODITE_NCCL_SO_PATH={so_file}")
  817. else:
  818. if torch.version.cuda is not None:
  819. so_file = "libnccl.so.2"
  820. elif torch.version.hip is not None:
  821. so_file = "librccl.so.1"
  822. else:
  823. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  824. logger.debug(f"Found nccl from library {so_file}")
  825. return so_file
  826. def enable_trace_function_call_for_thread() -> None:
  827. if envs.APHRODITE_TRACE_FUNCTION:
  828. tmp_dir = tempfile.gettempdir()
  829. filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
  830. f"_thread_{threading.get_ident()}_"
  831. f"at_{datetime.datetime.now()}.log").replace(" ", "_")
  832. log_path = os.path.join(tmp_dir, "aphrodite",
  833. get_aphrodite_instance_id(), filename)
  834. os.makedirs(os.path.dirname(log_path), exist_ok=True)
  835. enable_trace_function_call(log_path)
  836. def identity(value: T) -> T:
  837. return value
  838. F = TypeVar('F', bound=Callable[..., Any])
  839. def deprecate_kwargs(
  840. *kws: str,
  841. is_deprecated: Union[bool, Callable[[], bool]] = True,
  842. additional_message: Optional[str] = None) -> Callable[[F], F]:
  843. deprecated_kws = set(kws)
  844. if not callable(is_deprecated):
  845. is_deprecated = partial(identity, is_deprecated)
  846. def wrapper(fn: F) -> F:
  847. @wraps(fn)
  848. def inner(*args, **kwargs):
  849. if is_deprecated():
  850. deprecated_kwargs = kwargs.keys() & deprecated_kws
  851. if deprecated_kwargs:
  852. msg = (
  853. f"The keyword arguments {deprecated_kwargs} are "
  854. "deprecated and will be removed in a future update.")
  855. if additional_message is not None:
  856. msg += f" {additional_message}"
  857. warnings.warn(
  858. DeprecationWarning(msg),
  859. stacklevel=3, # The inner function takes up one level
  860. )
  861. return fn(*args, **kwargs)
  862. return inner # type: ignore
  863. return wrapper
  864. @lru_cache(maxsize=8)
  865. def _cuda_device_count_stateless(
  866. cuda_visible_devices: Optional[str] = None) -> int:
  867. # Note: cuda_visible_devices is not used, but we keep it as an argument for
  868. # LRU Cache purposes.
  869. # Code below is based on
  870. # https://github.com/pytorch/pytorch/blob/
  871. # c1cd946818442aca8c7f812b16d187ce1586c3bc/
  872. # torch/cuda/__init__.py#L831C1-L831C17
  873. import torch.cuda
  874. import torch.version
  875. if not torch.cuda._is_compiled():
  876. return 0
  877. if is_hip():
  878. # ROCm uses amdsmi instead of nvml for stateless device count
  879. # This requires a sufficiently modern version of Torch 2.4.0
  880. raw_count = torch.cuda._device_count_amdsmi() if (hasattr(
  881. torch.cuda, "_device_count_amdsmi")) else -1
  882. else:
  883. raw_count = torch.cuda._device_count_nvml()
  884. r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
  885. return r
  886. def cuda_device_count_stateless() -> int:
  887. """Get number of CUDA devices, caching based on the value of
  888. CUDA_VISIBLE_DEVICES at the time of call.
  889. This should be used instead of torch.cuda.device_count()
  890. unless CUDA_VISIBLE_DEVICES has already been set to the desired
  891. value."""
  892. # This can be removed and simply replaced with torch.cuda.get_device_count
  893. # after https://github.com/pytorch/pytorch/pull/122815 is released.
  894. return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
  895. def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
  896. """Make an instance method that weakly references
  897. its associated instance and no-ops once that
  898. instance is collected."""
  899. ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined]
  900. unbound = bound_method.__func__ # type: ignore[attr-defined]
  901. def weak_bound(*args, **kwargs) -> None:
  902. if inst := ref():
  903. unbound(inst, *args, **kwargs)
  904. return weak_bound
  905. #From: https://stackoverflow.com/a/4104188/2749989
  906. def run_once(f):
  907. def wrapper(*args, **kwargs) -> Any:
  908. if not wrapper.has_run: # type: ignore[attr-defined]
  909. wrapper.has_run = True # type: ignore[attr-defined]
  910. return f(*args, **kwargs)
  911. wrapper.has_run = False # type: ignore[attr-defined]
  912. return wrapper
  913. class FlexibleArgumentParser(argparse.ArgumentParser):
  914. """ArgumentParser that allows both underscore and dash in names."""
  915. def parse_args(self, args=None, namespace=None):
  916. if args is None:
  917. args = sys.argv[1:]
  918. # Convert underscores to dashes and vice versa in argument names
  919. processed_args = []
  920. for arg in args:
  921. if arg.startswith('--'):
  922. if '=' in arg:
  923. key, value = arg.split('=', 1)
  924. key = '--' + key[len('--'):].replace('_', '-')
  925. processed_args.append(f'{key}={value}')
  926. else:
  927. processed_args.append('--' +
  928. arg[len('--'):].replace('_', '-'))
  929. else:
  930. processed_args.append(arg)
  931. return super().parse_args(processed_args, namespace)
  932. async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args,
  933. **kwargs):
  934. """Utility function to run async task in a lock"""
  935. async with lock:
  936. return await task(*args, **kwargs)
  937. def progress_bar(iterable, desc="Processing"):
  938. show_progress = get_tensor_model_parallel_rank() == 0
  939. if show_progress:
  940. with Progress(
  941. SpinnerColumn(),
  942. TextColumn("[progress.description]{task.description}"),
  943. BarColumn(),
  944. MofNCompleteColumn(),
  945. TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
  946. TimeElapsedColumn(),
  947. ) as progress:
  948. task = progress.add_task(f"[cyan]{desc}", total=len(iterable))
  949. for item in iterable:
  950. yield item
  951. progress.update(task, advance=1)
  952. else:
  953. yield from iterable
  954. def tensor_progress_bar(iterable:Iterable[Tuple[str, torch.Tensor]],
  955. final_bytes:int, desc="Processing"):
  956. show_progress = get_tensor_model_parallel_rank() == 0
  957. units = 1024 ** (int(math.log2(final_bytes)) // 10)
  958. if show_progress:
  959. with Progress(
  960. SpinnerColumn(),
  961. TextColumn("[progress.description]{task.description}"),
  962. BarColumn(),
  963. # MofNCompleteColumn(),
  964. TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
  965. TextColumn("{task.completed:.2f}/{task.total:.2f} GiB"),
  966. TimeElapsedColumn(),
  967. ) as progress:
  968. task = progress.add_task(f"[cyan]{desc}", total=final_bytes/units)
  969. for item in iterable:
  970. steps = item[1].element_size() * item[1].nelement() / units
  971. yield item
  972. progress.update(task, advance=steps)
  973. else:
  974. yield from iterable
  975. def get_allowed_kwarg_only_overrides(
  976. callable: Callable[..., object],
  977. overrides: Optional[Dict[str, Any]],
  978. ) -> Dict[str, Any]:
  979. """
  980. Given a callable which has one or more keyword only params and a dict
  981. mapping param names to values, drop values that can be not be kwarg
  982. expanded to overwrite one or more keyword-only args. This is used in a
  983. few places to handle custom processor overrides for multimodal models,
  984. e.g., for profiling when processor options provided by the user
  985. may affect the number of mm tokens per instance.
  986. Args:
  987. callable: Callable which takes 0 or more keyword only arguments.
  988. overrides: Potential overrides to be used when invoking the callable.
  989. Returns:
  990. Dictionary containing the kwargs to be leveraged which may be used
  991. to overwrite one or more keyword only arguments when invoking the
  992. callable.
  993. """
  994. if not overrides:
  995. return {}
  996. allowed_override_names = [
  997. name for name, param in inspect.signature(callable).parameters.items()
  998. if param.kind == inspect.Parameter.KEYWORD_ONLY
  999. ]
  1000. # Drop any mm_processor_kwargs provided by the user that are
  1001. # not kwarg names accepted by the provided input processor.
  1002. filtered_overrides = {
  1003. kwarg_name: val
  1004. for kwarg_name, val in overrides.items()
  1005. if kwarg_name in allowed_override_names
  1006. }
  1007. # If anything is dropped, log a warning
  1008. dropped_keys = overrides.keys() - filtered_overrides.keys()
  1009. if dropped_keys:
  1010. logger.warning(
  1011. "The following intended overrides are not keyword-only args "
  1012. f"and and will be dropped: {dropped_keys}")
  1013. return filtered_overrides
  1014. # Using dynamo with Aphrodite doesn't really work well with PyTorch
  1015. # versions < 2.4.0.
  1016. # In particular, the FakeScalarType is not supported for earlier versions of
  1017. # PyTorch which breaks dynamo for any ops registered using ScalarType.
  1018. def supports_dynamo() -> bool:
  1019. base_torch_version = Version(Version(torch.__version__).base_version)
  1020. return base_torch_version >= Version("2.4.0")