utils.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import enum
  2. import os
  3. import socket
  4. import subprocess
  5. import uuid
  6. from platform import uname
  7. import psutil
  8. import torch
  9. import asyncio
  10. from functools import partial
  11. from typing import (Any, Awaitable, Callable, Hashable, Optional, TypeVar,
  12. List, Tuple, Union)
  13. from collections import OrderedDict
  14. from packaging.version import parse, Version
  15. from aphrodite.common.logger import init_logger
  16. T = TypeVar("T")
  17. logger = init_logger(__name__)
  18. STR_DTYPE_TO_TORCH_DTYPE = {
  19. "half": torch.half,
  20. "bfloat16": torch.bfloat16,
  21. "float": torch.float,
  22. "fp8_e5m2": torch.uint8,
  23. }
  24. class Device(enum.Enum):
  25. GPU = enum.auto()
  26. CPU = enum.auto()
  27. class Counter:
  28. def __init__(self, start: int = 0) -> None:
  29. self.counter = start
  30. def __next__(self) -> int:
  31. i = self.counter
  32. self.counter += 1
  33. return i
  34. def reset(self) -> None:
  35. self.counter = 0
  36. class LRUCache:
  37. def __init__(self, capacity: int):
  38. self.cache = OrderedDict()
  39. self.capacity = capacity
  40. def __contains__(self, key: Hashable) -> bool:
  41. return key in self.cache
  42. def __len__(self) -> int:
  43. return len(self.cache)
  44. def __getitem__(self, key: Hashable) -> Any:
  45. return self.get(key)
  46. def __setitem__(self, key: Hashable, value: Any) -> None:
  47. self.put(key, value)
  48. def __delitem__(self, key: Hashable) -> None:
  49. self.pop(key)
  50. def touch(self, key: Hashable) -> None:
  51. self.cache.move_to_end(key)
  52. def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
  53. if key in self.cache:
  54. value = self.cache[key]
  55. self.cache.move_to_end(key)
  56. else:
  57. value = default_value
  58. return value
  59. def put(self, key: Hashable, value: Any) -> None:
  60. self.cache[key] = value
  61. self.cache.move_to_end(key)
  62. self._remove_old_if_needed()
  63. def _on_remove(self, key: Hashable, value: Any):
  64. pass
  65. def remove_oldest(self):
  66. if not self.cache:
  67. return
  68. key, value = self.cache.popitem(last=False)
  69. self._on_remove(key, value)
  70. def _remove_old_if_needed(self) -> None:
  71. while len(self.cache) > self.capacity:
  72. self.remove_oldest()
  73. def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
  74. run_on_remove = key in self.cache
  75. value = self.cache.pop(key, default_value)
  76. if run_on_remove:
  77. self._on_remove(key, value)
  78. return value
  79. def clear(self):
  80. while len(self.cache) > 0:
  81. self.remove_oldest()
  82. self.cache.clear()
  83. def is_hip() -> bool:
  84. return torch.version.hip is not None
  85. def get_max_shared_memory_bytes(gpu: int = 0) -> int:
  86. """Returns the maximum shared memory per thread block in bytes."""
  87. from aphrodite._C import cuda_utils
  88. # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
  89. max_shared_mem = (
  90. cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
  91. assert max_shared_mem > 0, "max_shared_mem can not be zero"
  92. return int(max_shared_mem)
  93. def get_cpu_memory() -> int:
  94. """Returns the total CPU memory of the node in bytes."""
  95. return psutil.virtual_memory().total
  96. def random_uuid() -> str:
  97. return str(uuid.uuid4().hex)
  98. def in_wsl() -> bool:
  99. # Reference: https://github.com/microsoft/WSL/issues/4071
  100. return "microsoft" in " ".join(uname()).lower()
  101. def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
  102. """Take a blocking function, and run it on in an executor thread.
  103. This function prevents the blocking function from blocking the
  104. asyncio event loop.
  105. The code in this function needs to be thread safe.
  106. """
  107. def _async_wrapper(*args, **kwargs) -> asyncio.Future:
  108. loop = asyncio.get_event_loop()
  109. p_func = partial(func, *args, **kwargs)
  110. return loop.run_in_executor(executor=None, func=p_func)
  111. return _async_wrapper
  112. def get_ip() -> str:
  113. s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
  114. s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
  115. return s.getsockname()[0]
  116. def get_open_port() -> int:
  117. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  118. s.bind(("", 0))
  119. return s.getsockname()[1]
  120. def set_cuda_visible_devices(device_ids: List[int]) -> None:
  121. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
  122. def get_nvcc_cuda_version() -> Version:
  123. cuda_home = os.environ.get("CUDA_HOME")
  124. if not cuda_home:
  125. cuda_home = "/usr/local/cuda"
  126. logger.info(
  127. f"CUDA_HOME is not found in the environment. Using {cuda_home} as "
  128. "CUDA_HOME.")
  129. try:
  130. nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
  131. universal_newlines=True)
  132. except FileNotFoundError:
  133. print("nvcc is not found. Please make sure to export CUDA_HOME.")
  134. return Version("0.0.0") # return a default Version object
  135. except subprocess.CalledProcessError:
  136. print("An error occurred while trying to get nvcc output. Please "
  137. "make sure to export CUDA_HOME.")
  138. return Version("0.0.0")
  139. output = nvcc_output.split()
  140. release_idx = output.index("release") + 1
  141. nvcc_cuda_version = parse(output[release_idx].split(",")[0])
  142. return nvcc_cuda_version
  143. def _generate_random_fp8_e5m2(
  144. tensor: torch.tensor,
  145. low: float,
  146. high: float,
  147. ) -> None:
  148. # NOTE: Due to NaN and Inf representation for fp8 data type,
  149. # we may get Inf or NaN if we directly use torch.randint
  150. # to generate random data for fp8 data.
  151. # For example, s.11111.00 in fp8e5m2 format repesents Inf.
  152. # | E4M3 | E5M2
  153. #-----|-------------|-------------------
  154. # Inf | N/A | s.11111.00
  155. # NaN | s.1111.111 | s.11111.{01,10,11}
  156. from aphrodite._C import cache_ops
  157. tensor_tmp = torch.empty_like(tensor, dtype=torch.float16)
  158. tensor_tmp.uniform_(low, high)
  159. cache_ops.convert_fp8_e5m2(tensor_tmp, tensor)
  160. del tensor_tmp
  161. def create_kv_caches_with_random(
  162. num_blocks: int,
  163. block_size: int,
  164. num_layers: int,
  165. num_heads: int,
  166. head_size: int,
  167. cache_dtype: Optional[Union[str, torch.dtype]],
  168. model_dtype: Optional[Union[str, torch.dtype]] = None,
  169. seed: Optional[int] = 0,
  170. device: Optional[str] = "cuda",
  171. ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
  172. torch.random.manual_seed(seed)
  173. if torch.cuda.is_available():
  174. torch.cuda.manual_seed(seed)
  175. if isinstance(cache_dtype, str):
  176. if cache_dtype == "auto":
  177. if isinstance(model_dtype, str):
  178. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
  179. elif isinstance(model_dtype, torch.dtype):
  180. torch_dtype = model_dtype
  181. else:
  182. raise ValueError(f"Invalid model dtype: {model_dtype}")
  183. elif cache_dtype in ["half", "bfloat16", "float"]:
  184. torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
  185. elif cache_dtype == "fp8_e5m2":
  186. torch_dtype = torch.uint8
  187. else:
  188. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  189. elif isinstance(cache_dtype, torch.dtype):
  190. torch_dtype = cache_dtype
  191. else:
  192. raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
  193. scale = head_size**-0.5
  194. x = 16 // torch.tensor([], dtype=torch_dtype).element_size()
  195. key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x)
  196. key_caches = []
  197. for _ in range(num_layers):
  198. key_cache = torch.empty(size=key_cache_shape,
  199. dtype=torch_dtype,
  200. device=device)
  201. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  202. key_cache.uniform_(-scale, scale)
  203. elif cache_dtype == "fp8_e5m2":
  204. _generate_random_fp8_e5m2(key_cache, -scale, scale)
  205. key_caches.append(key_cache)
  206. value_cache_shape = (num_blocks, num_heads, head_size, block_size)
  207. value_caches = []
  208. for _ in range(num_layers):
  209. value_cache = torch.empty(size=value_cache_shape,
  210. dtype=torch_dtype,
  211. device=device)
  212. if cache_dtype in ["auto", "half", "bfloat16", "float"]:
  213. value_cache.uniform_(-scale, scale)
  214. elif cache_dtype == "fp8_e5m2":
  215. _generate_random_fp8_e5m2(value_cache, -scale, scale)
  216. value_caches.append(value_cache)
  217. return key_caches, value_caches