utils.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import enum
  2. import os
  3. import socket
  4. import subprocess
  5. import uuid
  6. import gc
  7. from platform import uname
  8. from loguru import logger
  9. import psutil
  10. import torch
  11. import asyncio
  12. from functools import partial
  13. from typing import (Any, Awaitable, Callable, Hashable, Optional, TypeVar,
  14. List, Tuple, Union)
  15. from collections import OrderedDict
  16. from packaging.version import parse, Version
  17. T = TypeVar("T")
  18. STR_DTYPE_TO_TORCH_DTYPE = {
  19. "half": torch.half,
  20. "bfloat16": torch.bfloat16,
  21. "float": torch.float,
  22. "fp8_e5m2": torch.uint8,
  23. "int8": torch.int8,
  24. }
  25. class Device(enum.Enum):
  26. GPU = enum.auto()
  27. CPU = enum.auto()
  28. class Counter:
  29. def __init__(self, start: int = 0) -> None:
  30. self.counter = start
  31. def __next__(self) -> int:
  32. i = self.counter
  33. self.counter += 1
  34. return i
  35. def reset(self) -> None:
  36. self.counter = 0
  37. class LRUCache:
  38. def __init__(self, capacity: int):
  39. self.cache = OrderedDict()
  40. self.capacity = capacity
  41. def __contains__(self, key: Hashable) -> bool:
  42. return key in self.cache
  43. def __len__(self) -> int:
  44. return len(self.cache)
  45. def __getitem__(self, key: Hashable) -> Any:
  46. return self.get(key)
  47. def __setitem__(self, key: Hashable, value: Any) -> None:
  48. self.put(key, value)
  49. def __delitem__(self, key: Hashable) -> None:
  50. self.pop(key)
  51. def touch(self, key: Hashable) -> None:
  52. self.cache.move_to_end(key)
  53. def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
  54. if key in self.cache:
  55. value = self.cache[key]
  56. self.cache.move_to_end(key)
  57. else:
  58. value = default_value
  59. return value
  60. def put(self, key: Hashable, value: Any) -> None:
  61. self.cache[key] = value
  62. self.cache.move_to_end(key)
  63. self._remove_old_if_needed()
  64. def _on_remove(self, key: Hashable, value: Any):
  65. pass
  66. def remove_oldest(self):
  67. if not self.cache:
  68. return
  69. key, value = self.cache.popitem(last=False)
  70. self._on_remove(key, value)
  71. def _remove_old_if_needed(self) -> None:
  72. while len(self.cache) > self.capacity:
  73. self.remove_oldest()
  74. def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
  75. run_on_remove = key in self.cache
  76. value = self.cache.pop(key, default_value)
  77. if run_on_remove:
  78. self._on_remove(key, value)
  79. return value
  80. def clear(self):
  81. while len(self.cache) > 0:
  82. self.remove_oldest()
  83. self.cache.clear()
  84. def is_hip() -> bool:
  85. return torch.version.hip is not None
  86. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  87. """Returns the maximum shared memory per thread block in bytes."""
  88. from aphrodite._C import cuda_utils
  89. # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
  90. max_shared_mem = (
  91. cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
  92. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  93. return int(max_shared_mem)
  94. def get_cpu_memory() -> int:
  95. """Returns the total CPU memory of the node in bytes."""
  96. return psutil.virtual_memory().total
  97. def random_uuid() -> str:
  98. return str(uuid.uuid4().hex)
  99. def in_wsl() -> bool:
  100. # Reference: https://github.com/microsoft/WSL/issues/4071
  101. return "microsoft" in " ".join(uname()).lower()
  102. def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
  103. """Take a blocking function, and run it on in an executor thread.
  104. This function prevents the blocking function from blocking the
  105. asyncio event loop.
  106. The code in this function needs to be thread safe.
  107. """
  108. def _async_wrapper(*args, **kwargs) -> asyncio.Future:
  109. loop = asyncio.get_event_loop()
  110. p_func = partial(func, *args, **kwargs)
  111. return loop.run_in_executor(executor=None, func=p_func)
  112. return _async_wrapper
  113. def get_ip() -> str:
  114. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  115. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  116. return s.getsockname()[0]
  117. def get_open_port() -> int:
  118. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  119. s.bind(("", 0))
  120. return s.getsockname()[1]
  121. def set_cuda_visible_devices(device_ids: List[int]) -> None:
  122. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
  123. def get_nvcc_cuda_version() -> Optional[Version]:
  124. cuda_home = os.environ.get('CUDA_HOME')
  125. nvcc_path = os.path.join(cuda_home, 'bin', 'nvcc') if cuda_home else 'nvcc'
  126. try:
  127. nvcc_output = subprocess.check_output([nvcc_path, "-V"],
  128. universal_newlines=True)
  129. output = nvcc_output.split()
  130. release_idx = output.index("release") + 1
  131. nvcc_cuda_version = parse(output[release_idx].split(",")[0])
  132. return nvcc_cuda_version
  133. except (FileNotFoundError, subprocess.CalledProcessError):
  134. logger.warning("nvcc not found. Skipping CUDA version check!")
  135. return None
  136. def _generate_random_fp8_e5m2(
  137. tensor: torch.tensor,
  138. low: float,
  139. high: float,
  140. ) -> None:
  141. # NOTE: Due to NaN and Inf representation for fp8 data type,
  142. # we may get Inf or NaN if we directly use torch.randint
  143. # to generate random data for fp8 data.
  144. # For example, s.11111.00 in fp8e5m2 format represents Inf.
  145. # | E4M3 | E5M2
  146. #-----|-------------|-------------------
  147. # Inf | N/A | s.11111.00
  148. # NaN | s.1111.111 | s.11111.{01,10,11}
  149. from aphrodite._C import cache_ops
  150. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  151. tensor_tmp.uniform_(low, high)
  152. cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
  153. del tensor_tmp
  154. def create_kv_caches_with_random(
  155. num_blocks: int,
  156. block_size: int,
  157. num_layers: int,
  158. num_heads: int,
  159. head_size: int,
  160. cache_dtype: Optional[Union[str, torch.dtype]],
  161. model_dtype: Optional[Union[str, torch.dtype]] = None,
  162. seed: Optional[int] = 0,
  163. device: Optional[str] = "cuda",
  164. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  165. torch.random.manual_seed(seed)
  166. if torch.cuda.is_available():
  167. torch.cuda.manual_seed(seed)
  168. if isinstance(cache_dtype, str):
  169. if cache_dtype == "auto":
  170. if isinstance(model_dtype, str):
  171. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  172. elif isinstance(model_dtype, torch.dtype):
  173. torch_dtype = model_dtype
  174. else:
  175. raise ValueError(f"Invalid model dtype: {model_dtype}")
  176. elif cache_dtype in ["half", "bfloat16", "float"]:
  177. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  178. elif cache_dtype == "fp8_e5m2":
  179. torch_dtype = torch.uint8
  180. else:
  181. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  182. elif isinstance(cache_dtype, torch.dtype):
  183. torch_dtype = cache_dtype
  184. else:
  185. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  186. scale = head_size**-0.5
  187. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  188. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  189. key_caches = []
  190. for _ in range(num_layers):
  191. key_cache = torch.empty(size=key_cache_shape,
  192. dtype=torch_dtype,
  193. device=device)
  194. if cache_dtype == 'fp8_e5m2':
  195. _generate_random_fp8_e5m2(key_cache, -scale, scale)
  196. elif cache_dtype == 'int8':
  197. torch.randint(-128, 127, key_cache.size(), out=key_cache)
  198. elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
  199. key_cache.uniform_(-scale, scale)
  200. else:
  201. raise ValueError(
  202. f"Does not support key cache of type {cache_dtype}")
  203. key_caches.append(key_cache)
  204. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  205. value_caches = []
  206. for _ in range(num_layers):
  207. value_cache = torch.empty(size=value_cache_shape,
  208. dtype=torch_dtype,
  209. device=device)
  210. if cache_dtype == 'fp8_e5m2':
  211. _generate_random_fp8_e5m2(value_cache, -scale, scale)
  212. elif cache_dtype == 'int8':
  213. torch.randint(-128, 127, value_cache.size(), out=value_cache)
  214. elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
  215. value_cache.uniform_(-scale, scale)
  216. else:
  217. raise ValueError(
  218. f"Does not support value cache of type {cache_dtype}")
  219. value_caches.append(value_cache)
  220. return key_caches, value_caches
  221. class measure_cuda_memory:
  222. def __init__(self, device=None):
  223. self.device = device
  224. def current_memory_usage(self) -> float:
  225. # Return the memory usage in bytes.
  226. torch.cuda.reset_peak_memory_stats(self.device)
  227. mem = torch.cuda.max_memory_allocated(self.device)
  228. return mem
  229. def __enter__(self):
  230. self.initial_memory = self.current_memory_usage()
  231. # This allows us to call methods of the context manager if needed
  232. return self
  233. def __exit__(self, exc_type, exc_val, exc_tb):
  234. self.final_memory = self.current_memory_usage()
  235. self.consumed_memory = self.final_memory - self.initial_memory
  236. # Force garbage collection
  237. gc.collect()