utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. import asyncio
  2. import datetime
  3. import enum
  4. import gc
  5. import glob
  6. import os
  7. import socket
  8. import subprocess
  9. import tempfile
  10. import threading
  11. import uuid
  12. import warnings
  13. from collections import defaultdict
  14. from functools import lru_cache, partial
  15. from platform import uname
  16. from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
  17. Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
  18. Union)
  19. import psutil
  20. import torch
  21. from loguru import logger
  22. from packaging.version import Version, parse
  23. from aphrodite.common.logger import enable_trace_function_call
  24. T = TypeVar("T")
  25. STR_DTYPE_TO_TORCH_DTYPE = {
  26. "half": torch.half,
  27. "bfloat16": torch.bfloat16,
  28. "float": torch.float,
  29. "fp8": torch.uint8,
  30. }
  31. class Device(enum.Enum):
  32. GPU = enum.auto()
  33. CPU = enum.auto()
  34. class Counter:
  35. def __init__(self, start: int = 0) -> None:
  36. self.counter = start
  37. def __next__(self) -> int:
  38. i = self.counter
  39. self.counter += 1
  40. return i
  41. def reset(self) -> None:
  42. self.counter = 0
  43. class LRUCache(Generic[T]):
  44. def __init__(self, capacity: int):
  45. self.cache: OrderedDict[Hashable, T] = OrderedDict()
  46. self.capacity = capacity
  47. def __contains__(self, key: Hashable) -> bool:
  48. return key in self.cache
  49. def __len__(self) -> int:
  50. return len(self.cache)
  51. def __getitem__(self, key: Hashable) -> Optional[T]:
  52. return self.get(key)
  53. def __setitem__(self, key: Hashable, value: T) -> None:
  54. self.put(key, value)
  55. def __delitem__(self, key: Hashable) -> None:
  56. self.pop(key)
  57. def touch(self, key: Hashable) -> None:
  58. self.cache.move_to_end(key)
  59. def get(self,
  60. key: Hashable,
  61. default_value: Optional[T] = None) -> Optional[T]:
  62. if key in self.cache:
  63. value: Optional[T] = self.cache[key]
  64. self.cache.move_to_end(key)
  65. else:
  66. value = default_value
  67. return value
  68. def put(self, key: Hashable, value: T) -> None:
  69. self.cache[key] = value
  70. self.cache.move_to_end(key)
  71. self._remove_old_if_needed()
  72. def _on_remove(self, key: Hashable, value: Optional[T]):
  73. pass
  74. def remove_oldest(self):
  75. if not self.cache:
  76. return
  77. key, value = self.cache.popitem(last=False)
  78. self._on_remove(key, value)
  79. def _remove_old_if_needed(self) -> None:
  80. while len(self.cache) > self.capacity:
  81. self.remove_oldest()
  82. def pop(self,
  83. key: Hashable,
  84. default_value: Optional[T] = None) -> Optional[T]:
  85. run_on_remove = key in self.cache
  86. value: Optional[T] = self.cache.pop(key, default_value)
  87. if run_on_remove:
  88. self._on_remove(key, value)
  89. return value
  90. def clear(self):
  91. while len(self.cache) > 0:
  92. self.remove_oldest()
  93. self.cache.clear()
  94. def is_hip() -> bool:
  95. return torch.version.hip is not None
  96. @lru_cache(maxsize=None)
  97. def is_cpu() -> bool:
  98. from importlib.metadata import PackageNotFoundError, version
  99. try:
  100. return "cpu" in version("aphrodite-engine")
  101. except PackageNotFoundError:
  102. return False
  103. @lru_cache(maxsize=None)
  104. def is_neuron() -> bool:
  105. try:
  106. import transformers_neuronx
  107. except ImportError:
  108. transformers_neuronx = None
  109. return transformers_neuronx is not None
  110. @lru_cache(maxsize=None)
  111. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  112. """Returns the maximum shared memory per thread block in bytes."""
  113. # NOTE: This import statement should be executed lazily since
  114. # the Neuron-X backend does not have the `cuda_utils` module.
  115. from aphrodite._C import cuda_utils
  116. max_shared_mem = (
  117. cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
  118. # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
  119. # will fail
  120. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  121. return int(max_shared_mem)
  122. def get_cpu_memory() -> int:
  123. """Returns the total CPU memory of the node in bytes."""
  124. return psutil.virtual_memory().total
  125. def random_uuid() -> str:
  126. return str(uuid.uuid4().hex)
  127. @lru_cache(maxsize=None)
  128. def get_aphrodite_instance_id():
  129. """
  130. If the environment variable APHRODITE_INSTANCE_ID is set, return it.
  131. Otherwise, return a random UUID.
  132. Instance id represents an instance of the Aphrodite. All processes in the
  133. same instance should have the same instance id.
  134. """
  135. return os.environ.get("APHRODITE_INSTANCE_ID",
  136. f"aphrodite-instance-{random_uuid()}")
  137. @lru_cache(maxsize=None)
  138. def in_wsl() -> bool:
  139. # Reference: https://github.com/microsoft/WSL/issues/4071
  140. return "microsoft" in " ".join(uname()).lower()
  141. def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
  142. """Take a blocking function, and run it on in an executor thread.
  143. This function prevents the blocking function from blocking the
  144. asyncio event loop.
  145. The code in this function needs to be thread safe.
  146. """
  147. def _async_wrapper(*args, **kwargs) -> asyncio.Future:
  148. loop = asyncio.get_event_loop()
  149. p_func = partial(func, *args, **kwargs)
  150. return loop.run_in_executor(executor=None, func=p_func)
  151. return _async_wrapper
  152. def merge_async_iterators(
  153. *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
  154. """Merge multiple asynchronous iterators into a single iterator.
  155. This method handle the case where some iterators finish before others.
  156. When it yields, it yields a tuple (i, item) where i is the index of the
  157. iterator that yields the item.
  158. """
  159. queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
  160. finished = [False] * len(iterators)
  161. async def producer(i: int, iterator: AsyncIterator[T]):
  162. try:
  163. async for item in iterator:
  164. await queue.put((i, item))
  165. except Exception as e:
  166. await queue.put(e)
  167. finished[i] = True
  168. _tasks = [
  169. asyncio.create_task(producer(i, iterator))
  170. for i, iterator in enumerate(iterators)
  171. ]
  172. async def consumer():
  173. try:
  174. while not all(finished) or not queue.empty():
  175. item = await queue.get()
  176. if isinstance(item, Exception):
  177. raise item
  178. yield item
  179. except (Exception, asyncio.CancelledError) as e:
  180. for task in _tasks:
  181. # NOTE: Pass the error msg in cancel()
  182. # when only Python 3.9+ is supported.
  183. task.cancel()
  184. raise e
  185. await asyncio.gather(*_tasks)
  186. return consumer()
  187. def get_ip() -> str:
  188. host_ip = os.environ.get("HOST_IP")
  189. if host_ip:
  190. return host_ip
  191. # IP is not set, try to get it from the network interface
  192. # try ipv4
  193. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  194. try:
  195. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  196. return s.getsockname()[0]
  197. except Exception:
  198. pass
  199. # try ipv6
  200. try:
  201. s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
  202. # Google's public DNS server, see
  203. # https://developers.google.com/speed/public-dns/docs/using#addresses
  204. s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
  205. return s.getsockname()[0]
  206. except Exception:
  207. pass
  208. warnings.warn(
  209. "Failed to get the IP address, using 0.0.0.0 by default."
  210. "The value can be set by the environment variable HOST_IP.",
  211. stacklevel=2)
  212. return "0.0.0.0"
  213. def get_distributed_init_method(ip: str, port: int) -> str:
  214. # Brackets are not permitted in ipv4 addresses,
  215. # see https://github.com/python/cpython/issues/103848
  216. return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
  217. def get_open_port() -> int:
  218. # try ipv4
  219. try:
  220. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  221. s.bind(("", 0))
  222. return s.getsockname()[1]
  223. except OSError:
  224. # try ipv6
  225. with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
  226. s.bind(("", 0))
  227. return s.getsockname()[1]
  228. def update_environment_variables(envs: Dict[str, str]):
  229. for k, v in envs.items():
  230. if k in os.environ and os.environ[k] != v:
  231. logger.warning(f"Overwriting environment variable {k} "
  232. f"from '{os.environ[k]}' to '{v}'")
  233. os.environ[k] = v
  234. def chunk_list(lst, chunk_size):
  235. """Yield successive chunk_size chunks from lst."""
  236. return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
  237. def cdiv(a: int, b: int) -> int:
  238. """Ceiling division."""
  239. return -(a // -b)
  240. @lru_cache(maxsize=None)
  241. def get_nvcc_cuda_version() -> Optional[Version]:
  242. cuda_home = os.environ.get('CUDA_HOME')
  243. if not cuda_home:
  244. cuda_home = '/usr/local/cuda'
  245. if os.path.isfile(cuda_home + '/bin/nvcc'):
  246. logger.info(f'CUDA_HOME is not found in the environment. '
  247. f'Using {cuda_home} as CUDA_HOME.')
  248. else:
  249. logger.warning(
  250. f'Not found nvcc in {cuda_home}. Skip cuda version check!')
  251. return None
  252. nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
  253. universal_newlines=True)
  254. output = nvcc_output.split()
  255. release_idx = output.index("release") + 1
  256. nvcc_cuda_version = parse(output[release_idx].split(",")[0])
  257. return nvcc_cuda_version
  258. def _generate_random_fp8(
  259. tensor: torch.tensor,
  260. low: float,
  261. high: float,
  262. ) -> None:
  263. # NOTE: Due to NaN and Inf representation for fp8 data type,
  264. # it may occur Inf or NaN if we directly use torch.randint
  265. # to generate random data for fp8 data.
  266. # For example, s.11111.00 in fp8e5m2 format represents Inf.
  267. # | E4M3 | E5M2
  268. #-----|-------------|-------------------
  269. # Inf | N/A | s.11111.00
  270. # NaN | s.1111.111 | s.11111.{01,10,11}
  271. from aphrodite._C import cache_ops
  272. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  273. tensor_tmp.uniform_(low, high)
  274. cache_ops.convert_fp8(tensor_tmp, tensor)
  275. del tensor_tmp
  276. def create_kv_caches_with_random(
  277. num_blocks: int,
  278. block_size: int,
  279. num_layers: int,
  280. num_heads: int,
  281. head_size: int,
  282. cache_dtype: Optional[Union[str, torch.dtype]],
  283. model_dtype: Optional[Union[str, torch.dtype]] = None,
  284. seed: int = 0,
  285. device: Optional[str] = "cuda",
  286. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  287. torch.random.manual_seed(seed)
  288. if torch.cuda.is_available():
  289. torch.cuda.manual_seed(seed)
  290. if isinstance(cache_dtype, str):
  291. if cache_dtype == "auto":
  292. if isinstance(model_dtype, str):
  293. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  294. elif isinstance(model_dtype, torch.dtype):
  295. torch_dtype = model_dtype
  296. else:
  297. raise ValueError(f"Invalid model dtype: {model_dtype}")
  298. elif cache_dtype in ["half", "bfloat16", "float"]:
  299. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  300. elif cache_dtype == "fp8":
  301. torch_dtype = torch.uint8
  302. else:
  303. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  304. elif isinstance(cache_dtype, torch.dtype):
  305. torch_dtype = cache_dtype
  306. else:
  307. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  308. scale = head_size**-0.5
  309. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  310. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  311. key_caches = []
  312. for _ in range(num_layers):
  313. key_cache = torch.empty(size=key_cache_shape,
  314. dtype=torch_dtype,
  315. device=device)
  316. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  317. key_cache.uniform_(-scale, scale)
  318. elif cache_dtype == 'fp8':
  319. _generate_random_fp8(key_cache, -scale, scale)
  320. else:
  321. raise ValueError(
  322. f"Does not support key cache of type {cache_dtype}")
  323. key_caches.append(key_cache)
  324. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  325. value_caches = []
  326. for _ in range(num_layers):
  327. value_cache = torch.empty(size=value_cache_shape,
  328. dtype=torch_dtype,
  329. device=device)
  330. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  331. value_cache.uniform_(-scale, scale)
  332. elif cache_dtype == 'fp8':
  333. _generate_random_fp8(value_cache, -scale, scale)
  334. else:
  335. raise ValueError(
  336. f"Does not support value cache of type {cache_dtype}")
  337. value_caches.append(value_cache)
  338. return key_caches, value_caches
  339. @lru_cache
  340. def print_warning_once(msg: str) -> None:
  341. logger.warning(msg)
  342. @lru_cache(maxsize=None)
  343. def is_pin_memory_available() -> bool:
  344. if in_wsl():
  345. # Pinning memory in WSL is not supported.
  346. # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications
  347. print_warning_once("Using 'pin_memory=False' as WSL is detected. "
  348. "This may slow down the performance.")
  349. return False
  350. elif is_neuron():
  351. print_warning_once("Pin memory is not supported on Neuron.")
  352. return False
  353. elif is_cpu():
  354. return False
  355. return True
  356. class CudaMemoryProfiler:
  357. def __init__(self, device=None):
  358. self.device = device
  359. def current_memory_usage(self) -> float:
  360. # Return the memory usage in bytes.
  361. torch.cuda.reset_peak_memory_stats(self.device)
  362. mem = torch.cuda.max_memory_allocated(self.device)
  363. return mem
  364. def __enter__(self):
  365. self.initial_memory = self.current_memory_usage()
  366. # This allows us to call methods of the context manager if needed
  367. return self
  368. def __exit__(self, exc_type, exc_val, exc_tb):
  369. self.final_memory = self.current_memory_usage()
  370. self.consumed_memory = self.final_memory - self.initial_memory
  371. # Force garbage collection
  372. gc.collect()
  373. def str_to_int_tuple(s: str) -> Tuple[int, ...]:
  374. """Convert a string to a tuple of integers."""
  375. try:
  376. return tuple(map(int, s.split(",")))
  377. except ValueError as e:
  378. raise ValueError(
  379. "String must be a series of integers separated by commas "
  380. f"(e.g., 1, 2, 3). Given input: {s}") from e
  381. def pad_to_max_length(x: List[int], max_len: int, pad: int) -> List[int]:
  382. assert len(x) <= max_len
  383. return x + [pad] * (max_len - len(x))
  384. def make_tensor_with_pad(
  385. x: List[List[int]],
  386. max_len: int,
  387. pad: int,
  388. dtype: torch.dtype,
  389. device: Optional[Union[str, torch.device]],
  390. ) -> torch.Tensor:
  391. """Make a padded tensor of a 2D inputs.
  392. The padding is applied to the end of each inner list until it reaches
  393. `max_len`.
  394. """
  395. padded_x = [pad_to_max_length(x_i, max_len, pad) for x_i in x]
  396. return torch.tensor(padded_x, dtype=dtype, device=device)
  397. def async_tensor_h2d(
  398. data: list,
  399. dtype: torch.dtype,
  400. target_device: Union[str, torch.device],
  401. pin_memory: bool,
  402. ) -> torch.Tensor:
  403. """Asynchronously create a tensor and copy it from host to device."""
  404. t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
  405. return t.to(device=target_device, non_blocking=True)
  406. def maybe_expand_dim(tensor: torch.Tensor,
  407. target_dims: int,
  408. size: int = 1) -> torch.Tensor:
  409. """Expand the tensor to the target_dims."""
  410. if tensor.ndim < target_dims:
  411. tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
  412. return tensor
  413. def merge_dicts(dict1: Dict[Any, List[Any]],
  414. dict2: Dict[Any, List[Any]]) -> Dict[Any, List[Any]]:
  415. """Merge 2 dicts that have key -> List of items.
  416. When a key conflicts, the values in dict1 is prioritized.
  417. """
  418. merged_dict = defaultdict(list)
  419. for key, value in dict1.items():
  420. merged_dict[key].extend(value)
  421. for key, value in dict2.items():
  422. merged_dict[key].extend(value)
  423. return dict(merged_dict)
  424. def init_cached_hf_modules():
  425. """
  426. Lazy initialization of the Hugging Face modules.
  427. """
  428. from transformers.dynamic_module_utils import init_hf_modules
  429. init_hf_modules()
  430. def nccl_integrity_check(filepath):
  431. """
  432. when the library is corrupted, we cannot catch
  433. the exception in python. it will crash the process.
  434. instead, we use the exit code of `ldd` to check
  435. if the library is corrupted. if not, we will return
  436. the version of the library.
  437. """
  438. exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
  439. if exit_code != 0:
  440. raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
  441. import ctypes
  442. nccl = ctypes.CDLL(filepath)
  443. version = ctypes.c_int()
  444. nccl.ncclGetVersion.restype = ctypes.c_int
  445. nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
  446. result = nccl.ncclGetVersion(ctypes.byref(version))
  447. assert result == 0
  448. return version.value
  449. @lru_cache(maxsize=None)
  450. def find_library(lib_name: str) -> str:
  451. """
  452. Find the library file in the system.
  453. `lib_name` is full filename, with both prefix and suffix.
  454. This function resolves `lib_name` to the full path of the library.
  455. """
  456. # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa
  457. # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard
  458. # `/sbin/ldconfig` should exist in all Linux systems.
  459. # `/sbin/ldconfig` searches the library in the system
  460. libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
  461. # each line looks like the following:
  462. # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
  463. locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line]
  464. # `LD_LIBRARY_PATH` searches the library in the user-defined paths
  465. env_ld_library_path = os.getenv("LD_LIBRARY_PATH")
  466. if not locs and env_ld_library_path:
  467. locs = [
  468. os.path.join(dir, lib_name)
  469. for dir in env_ld_library_path.split(":")
  470. if os.path.exists(os.path.join(dir, lib_name))
  471. ]
  472. if not locs:
  473. raise ValueError(f"Cannot find {lib_name} in the system.")
  474. return locs[0]
  475. def find_nccl_library():
  476. so_file = os.environ.get("APHRODITE_NCCL_SO_PATH", "")
  477. aphrodite_nccl_path = None
  478. if torch.version.cuda is not None:
  479. cuda_major = torch.version.cuda.split(".")[0]
  480. path = os.path.expanduser(
  481. f"~/.config/aphrodite/nccl/cu{cuda_major}/libnccl.so.*")
  482. files = glob.glob(path)
  483. aphrodite_nccl_path = files[0] if files else None
  484. # manually load the nccl library
  485. if so_file:
  486. logger.info("Found nccl from environment variable "
  487. f"APHRODITE_NCCL_SO_PATH={so_file}")
  488. else:
  489. if torch.version.cuda is not None:
  490. so_file = aphrodite_nccl_path or find_library("libnccl.so.2")
  491. elif torch.version.hip is not None:
  492. so_file = find_library("librccl.so.1")
  493. else:
  494. raise ValueError("NCCL only supports CUDA and ROCm backends.")
  495. logger.info(f"Found nccl from library {so_file}")
  496. return so_file
  497. def enable_trace_function_call_for_thread() -> None:
  498. if int(os.getenv("APHRODITE_TRACE_FUNCTION", "0")):
  499. tmp_dir = tempfile.gettempdir()
  500. filename = (f"APHRODITE_TRACE_FUNCTION_for_process_{os.getpid()}"
  501. f"_thread_{threading.get_ident()}_"
  502. f"at_{datetime.datetime.now()}.log").replace(" ", "_")
  503. log_path = os.path.join(tmp_dir, "aphrodite",
  504. get_aphrodite_instance_id(), filename)
  505. os.makedirs(os.path.dirname(log_path), exist_ok=True)
  506. enable_trace_function_call(log_path)