tpu_worker.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import os
  2. from typing import List, Optional, Tuple, Union
  3. import torch
  4. import torch_xla.core.xla_model as xm
  5. import torch_xla.experimental.dynamo_set_buffer_donor # noqa: F401
  6. import torch_xla.runtime as xr
  7. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  8. ModelConfig, ParallelConfig,
  9. SchedulerConfig, VisionLanguageConfig)
  10. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  11. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
  12. from aphrodite.distributed import (ensure_model_parallel_initialized,
  13. init_distributed_environment)
  14. from aphrodite.modeling import set_random_seed
  15. from aphrodite.task_handler.tpu_model_runner import TPUModelRunner
  16. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  17. class TPUWorker(LoraNotSupportedWorkerBase):
  18. def __init__(
  19. self,
  20. model_config: ModelConfig,
  21. parallel_config: ParallelConfig,
  22. scheduler_config: SchedulerConfig,
  23. device_config: DeviceConfig,
  24. cache_config: CacheConfig,
  25. load_config: LoadConfig,
  26. vision_language_config: Optional[VisionLanguageConfig],
  27. local_rank: int,
  28. rank: int,
  29. distributed_init_method: str,
  30. is_driver_worker: bool,
  31. ) -> None:
  32. self.model_config = model_config
  33. self.parallel_config = parallel_config
  34. self.scheduler_config = scheduler_config
  35. self.device_config = device_config
  36. self.cache_config = cache_config
  37. self.load_config = load_config
  38. self.vision_language_config = vision_language_config
  39. self.local_rank = local_rank
  40. self.rank = rank
  41. self.distributed_init_method = distributed_init_method
  42. self.is_driver_worker = is_driver_worker
  43. assert self.device_config.device_type == "tpu"
  44. if self.cache_config.cache_dtype == "auto":
  45. self.cache_dtype = self.model_config.dtype
  46. else:
  47. self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
  48. self.cache_config.cache_dtype]
  49. self.model_runner = TPUModelRunner(model_config,
  50. parallel_config,
  51. scheduler_config,
  52. device_config,
  53. cache_config,
  54. load_config,
  55. vision_language_config,
  56. is_driver_worker=is_driver_worker)
  57. def init_device(self) -> None:
  58. os.environ["PJRT_DEVICE"] = "TPU"
  59. self.device = xm.xla_device()
  60. self.device_config.device = self.device
  61. torch.set_grad_enabled(False)
  62. torch.set_default_dtype(self.model_config.dtype)
  63. # NOTE: This is just a hack to initialize the TP group.
  64. # This cannot perform the actual communication ops.
  65. init_distributed_environment(
  66. world_size=self.parallel_config.world_size,
  67. rank=self.rank,
  68. local_rank=self.local_rank,
  69. distributed_init_method=self.distributed_init_method,
  70. backend="gloo",
  71. )
  72. ensure_model_parallel_initialized(
  73. self.parallel_config.tensor_parallel_size,
  74. self.parallel_config.pipeline_parallel_size)
  75. # Set random seed.
  76. set_random_seed(self.model_config.seed)
  77. xm.set_rng_state(self.model_config.seed, self.device)
  78. # Increase the cache size limit, which is the maximum number of
  79. # dynamo graphs that can be compiled.
  80. # NOTE: Usually, we compile 10-15 graphs for prefill and
  81. # 30-40 graphs for decode. 128 is an arbitrary safe number.
  82. torch._dynamo.config.cache_size_limit = 128
  83. # Use persistent cache to avoid XLA recompilation.
  84. # NOTE: This does not completely eliminate the recompilation
  85. # overhead because dynamo does not cache the compiled results.
  86. APHRODITE_XLA_CACHE_PATH = os.getenv("APHRODITE_XLA_CACHE_PATH",
  87. "~/.aphrodite/xla_cache/")
  88. xr.initialize_cache(os.path.expanduser(APHRODITE_XLA_CACHE_PATH),
  89. readonly=False)
  90. def load_model(self):
  91. self.model_runner.load_model()
  92. def determine_num_available_blocks(self) -> Tuple[int, int]:
  93. num_layers = self.model_config.get_num_layers(self.parallel_config)
  94. head_size = self.model_config.get_head_size()
  95. num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  96. kv_caches = [(None, None) for _ in range(num_layers)]
  97. self.model_runner._dummy_run(
  98. batch_size=1,
  99. seq_len=self.scheduler_config.max_num_batched_tokens,
  100. kv_caches=kv_caches,
  101. is_prompt=True,
  102. )
  103. # Synchronize before measuring the memory usage.
  104. xm.wait_device_ops()
  105. dtype_btyes = get_dtype_size(self.cache_dtype)
  106. block_size = self.cache_config.block_size
  107. block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
  108. head_size * num_kv_heads)
  109. # Calculate the TPU KV cache size based on profiling.
  110. m = xm.get_memory_info(self.device)
  111. total_memory_size = m["bytes_limit"]
  112. usable_memory_size = int(total_memory_size *
  113. self.cache_config.gpu_memory_utilization)
  114. profiled = m["bytes_used"] # Weights + intermediate activations.
  115. tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
  116. num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
  117. num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
  118. # Calculate the CPU KV cache size based on the config.
  119. num_cpu_blocks = (self.cache_config.swap_space_bytes //
  120. block_size_bytes)
  121. num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
  122. return num_tpu_blocks, num_cpu_blocks
  123. def initialize_cache(
  124. self,
  125. num_gpu_blocks: int,
  126. num_cpu_blocks: int,
  127. ) -> None:
  128. self.cache_config.num_gpu_blocks = num_gpu_blocks
  129. self.cache_config.num_cpu_blocks = num_cpu_blocks
  130. self.block_size = self.cache_config.block_size
  131. dtype = self.cache_dtype
  132. num_layers = self.model_config.get_num_layers(self.parallel_config)
  133. num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  134. head_size = self.model_config.get_head_size()
  135. self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
  136. self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
  137. tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
  138. num_gpu_blocks, self.block_size, num_kv_heads, head_size)
  139. cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
  140. num_cpu_blocks, self.block_size, num_kv_heads, head_size)
  141. for _ in range(num_layers):
  142. tpu_k_cache = torch.zeros(tpu_cache_shape,
  143. dtype=dtype,
  144. device=self.device)
  145. tpu_v_cache = torch.zeros_like(tpu_k_cache)
  146. self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
  147. cpu_k_cache = torch.zeros(cpu_cache_shape,
  148. dtype=dtype,
  149. device="cpu")
  150. cpu_v_cache = torch.zeros_like(cpu_k_cache)
  151. self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
  152. self._warmup_model()
  153. def _warmup_model(self) -> None:
  154. # FIXME: Here we are abusing `enforce_eager` which is defined
  155. # for CUDA graphs. We should refactor this part.
  156. if not self.model_config.enforce_eager:
  157. # Warm up the model with all possible input shapes so that
  158. # compilation never happens during the actual execution.
  159. # This may take ~30 mins for the first run and ~20 mins for the
  160. # subsequent runs.
  161. # If `enforce_eager` is True, the ahead-of-time compilation is
  162. # skipped and the compilation happens during the actual execution,
  163. # which is bad for performance but useful for development.
  164. self.model_runner.warmup_model(self.tpu_cache)
  165. def get_cache_block_size_bytes(self) -> int:
  166. head_size = self.model_config.get_head_size()
  167. num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  168. num_layers = self.model_config.get_num_layers(self.parallel_config)
  169. key_cache_block = self.cache_config.block_size * num_heads * head_size
  170. value_cache_block = key_cache_block
  171. total = num_layers * (key_cache_block + value_cache_block)
  172. dtype_size = get_dtype_size(self.cache_dtype)
  173. return dtype_size * total
  174. def execute_model(
  175. self,
  176. execute_model_req: Optional[ExecuteModelRequest] = None,
  177. ) -> List[SamplerOutput]:
  178. if not self.is_driver_worker:
  179. self._execute_model_non_driver()
  180. return []
  181. assert execute_model_req is not None
  182. # Issue cache operations.
  183. self.cache_swap(
  184. execute_model_req.blocks_to_swap_in,
  185. execute_model_req.blocks_to_swap_out,
  186. execute_model_req.blocks_to_copy,
  187. )
  188. # Run the model.
  189. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  190. assert len(seq_group_metadata_list) > 0
  191. output = self.model_runner.execute_model(seq_group_metadata_list,
  192. self.tpu_cache)
  193. return output
  194. def cache_swap(
  195. self,
  196. blocks_to_swap_in: List[Tuple[int, int]],
  197. blocks_to_swap_out: List[Tuple[int, int]],
  198. blocks_to_copy: List[Tuple[int, int]],
  199. ) -> None:
  200. attn_backend = self.model_runner.attn_backend
  201. num_layers = self.model_config.get_num_layers(self.parallel_config)
  202. if blocks_to_swap_in:
  203. # Swap from CPU to TPU.
  204. src_indices, dst_indices = _make_src_to_dst(
  205. blocks_to_swap_in, "cpu", self.device)
  206. for i in range(num_layers):
  207. tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
  208. cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
  209. k = cpu_k_cache[:, src_indices].to(self.device)
  210. v = cpu_v_cache[:, src_indices].to(self.device)
  211. _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
  212. if blocks_to_swap_out:
  213. # Swap from TPU to CPU.
  214. src_indices, dst_indices = _make_src_to_dst(
  215. blocks_to_swap_out, self.device, "cpu")
  216. for i in range(num_layers):
  217. tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
  218. cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
  219. cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices].cpu()
  220. cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices].cpu()
  221. if blocks_to_copy:
  222. src_to_dst = _make_src_to_dst(blocks_to_copy, self.device,
  223. self.device)
  224. attn_backend.copy_blocks(self.tpu_cache, src_to_dst)
  225. def start_worker_execution_loop(self) -> None:
  226. while self._execute_model_non_driver():
  227. pass
  228. def _execute_model_non_driver(self) -> bool:
  229. self.model_runner.execute_model(None, self.tpu_cache)
  230. return True
  231. def _make_src_to_dst(
  232. mapping: List[Tuple[int, int]],
  233. src_device: Union[torch.device, str],
  234. dst_device: Union[torch.device, str],
  235. ) -> Tuple[torch.Tensor, torch.Tensor]:
  236. src_indices = [i for i, _ in mapping]
  237. dst_indices = [i for _, i in mapping]
  238. src_indices = torch.tensor(src_indices,
  239. device=src_device,
  240. dtype=torch.int64)
  241. dst_indices = torch.tensor(dst_indices,
  242. device=dst_device,
  243. dtype=torch.int64)
  244. return src_indices, dst_indices
  245. @torch.compile(backend="openxla")
  246. def _insert_kv(
  247. k: torch.Tensor,
  248. v: torch.Tensor,
  249. indices: torch.Tensor,
  250. tpu_k_cache: torch.Tensor,
  251. tpu_v_cache: torch.Tensor,
  252. ) -> None:
  253. torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
  254. torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
  255. tpu_k_cache[:, indices] = k
  256. tpu_v_cache[:, indices] = v