tpu_worker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import os
  2. from typing import List, Optional, Tuple, Union
  3. import torch
  4. import torch_xla.core.xla_model as xm
  5. import torch_xla.runtime as xr
  6. import aphrodite.common.envs as envs
  7. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  8. ModelConfig, ParallelConfig,
  9. SchedulerConfig)
  10. from aphrodite.common.sequence import ExecuteModelRequest
  11. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
  12. from aphrodite.distributed import (ensure_model_parallel_initialized,
  13. init_distributed_environment)
  14. from aphrodite.modeling import set_random_seed
  15. from aphrodite.task_handler.tpu_model_runner import TPUModelRunner
  16. from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
  17. LoraNotSupportedWorkerBase,
  18. WorkerInput)
  19. class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
  20. def __init__(
  21. self,
  22. model_config: ModelConfig,
  23. parallel_config: ParallelConfig,
  24. scheduler_config: SchedulerConfig,
  25. device_config: DeviceConfig,
  26. cache_config: CacheConfig,
  27. load_config: LoadConfig,
  28. local_rank: int,
  29. rank: int,
  30. distributed_init_method: str,
  31. is_driver_worker: bool,
  32. ) -> None:
  33. self.model_config = model_config
  34. self.parallel_config = parallel_config
  35. self.parallel_config.rank = rank
  36. self.scheduler_config = scheduler_config
  37. self.device_config = device_config
  38. self.cache_config = cache_config
  39. self.load_config = load_config
  40. self.local_rank = local_rank
  41. self.rank = rank
  42. self.distributed_init_method = distributed_init_method
  43. self.is_driver_worker = is_driver_worker
  44. assert self.device_config.device_type == "tpu"
  45. if self.cache_config.cache_dtype == "auto":
  46. self.cache_dtype = self.model_config.dtype
  47. else:
  48. self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
  49. self.cache_config.cache_dtype]
  50. self.model_runner: TPUModelRunner = TPUModelRunner(
  51. model_config,
  52. parallel_config,
  53. scheduler_config,
  54. device_config,
  55. cache_config,
  56. load_config,
  57. is_driver_worker=is_driver_worker)
  58. def init_device(self) -> None:
  59. os.environ["PJRT_DEVICE"] = "TPU"
  60. torch.set_grad_enabled(False)
  61. torch.set_default_dtype(self.model_config.dtype)
  62. # NOTE: This is just to initialize the TP group and broadcast
  63. # the input objects on CPU. The all-reduce and all-gather ops on TPU
  64. # are invoked by `xm.all_reduce` and `xm.all_gather` which use their
  65. # own context.
  66. init_distributed_environment(
  67. world_size=self.parallel_config.world_size,
  68. rank=self.rank,
  69. local_rank=self.local_rank,
  70. distributed_init_method=self.distributed_init_method,
  71. backend="gloo",
  72. )
  73. ensure_model_parallel_initialized(
  74. self.parallel_config.tensor_parallel_size,
  75. self.parallel_config.pipeline_parallel_size)
  76. # Device initialization should happen after initializing the distributed
  77. # runtime.
  78. self.device = xm.xla_device()
  79. self.device_config.device = self.device
  80. # Set random seed.
  81. set_random_seed(self.model_config.seed)
  82. xm.set_rng_state(self.model_config.seed, self.device)
  83. # Increase the cache size limit, which is the maximum number of
  84. # dynamo graphs that can be compiled.
  85. # NOTE: Usually, we compile 10-15 graphs for prefill and
  86. # 30-40 graphs for decode. 128 is an arbitrary safe number.
  87. torch._dynamo.config.cache_size_limit = 128
  88. # Use persistent cache to avoid XLA recompilation.
  89. # NOTE: Set per-rank cache path since different ranks
  90. # can have slightly different XLA graphs.
  91. APHRODITE_XLA_CACHE_PATH = envs.APHRODITE_XLA_CACHE_PATH
  92. world_size = self.parallel_config.world_size
  93. per_rank_path = os.path.join(APHRODITE_XLA_CACHE_PATH,
  94. f"tp{world_size}_rank{self.rank}")
  95. xr.initialize_cache(per_rank_path, readonly=False)
  96. def load_model(self):
  97. self.model_runner.load_model()
  98. def determine_num_available_blocks(self) -> Tuple[int, int]:
  99. num_layers = self.model_config.get_num_layers(self.parallel_config)
  100. head_size = self.model_config.get_head_size()
  101. num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  102. kv_caches = [(None, None) for _ in range(num_layers)]
  103. self.model_runner._dummy_run(
  104. batch_size=1,
  105. seq_len=self.scheduler_config.max_num_batched_tokens,
  106. kv_caches=kv_caches,
  107. is_prompt=True,
  108. )
  109. # Synchronize before measuring the memory usage.
  110. xm.wait_device_ops()
  111. dtype_btyes = get_dtype_size(self.cache_dtype)
  112. block_size = self.cache_config.block_size
  113. block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
  114. head_size * num_kv_heads)
  115. # Calculate the TPU KV cache size based on profiling.
  116. m = xm.get_memory_info(self.device)
  117. total_memory_size = m["bytes_limit"]
  118. usable_memory_size = int(total_memory_size *
  119. self.cache_config.gpu_memory_utilization)
  120. profiled = m["bytes_used"] # Weights + intermediate activations.
  121. tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
  122. num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
  123. num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
  124. # Calculate the CPU KV cache size based on the config.
  125. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  126. block_size_bytes)
  127. num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
  128. # reset and discard the guard and compiled bytecode for profiling runs
  129. torch._dynamo.reset()
  130. return num_tpu_blocks, num_cpu_blocks
  131. def initialize_cache(
  132. self,
  133. num_gpu_blocks: int,
  134. num_cpu_blocks: int,
  135. ) -> None:
  136. self.cache_config.num_gpu_blocks = num_gpu_blocks
  137. self.cache_config.num_cpu_blocks = num_cpu_blocks
  138. self.block_size = self.cache_config.block_size
  139. dtype = self.cache_dtype
  140. num_layers = self.model_config.get_num_layers(self.parallel_config)
  141. num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  142. head_size = self.model_config.get_head_size()
  143. self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
  144. self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
  145. tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
  146. num_gpu_blocks, self.block_size, num_kv_heads, head_size)
  147. cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
  148. num_cpu_blocks, self.block_size, num_kv_heads, head_size)
  149. for _ in range(num_layers):
  150. tpu_k_cache = torch.zeros(tpu_cache_shape,
  151. dtype=dtype,
  152. device=self.device)
  153. tpu_v_cache = torch.zeros_like(tpu_k_cache)
  154. self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
  155. cpu_k_cache = torch.zeros(cpu_cache_shape,
  156. dtype=dtype,
  157. device="cpu")
  158. cpu_v_cache = torch.zeros_like(cpu_k_cache)
  159. self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
  160. self._warmup_model()
  161. def _warmup_model(self) -> None:
  162. # FIXME: Here we are abusing `enforce_eager` which is defined
  163. # for CUDA graphs. We should refactor this part.
  164. if not self.model_config.enforce_eager:
  165. # Warm up the model with all possible input shapes so that
  166. # compilation never happens during the actual execution.
  167. # This may take ~30 mins for the first run and ~20 mins for the
  168. # subsequent runs.
  169. # If `enforce_eager` is True, the ahead-of-time compilation is
  170. # skipped and the compilation happens during the actual execution,
  171. # which is bad for performance but useful for development.
  172. self.model_runner.warmup_model(self.tpu_cache)
  173. def get_cache_block_size_bytes(self) -> int:
  174. head_size = self.model_config.get_head_size()
  175. num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  176. num_layers = self.model_config.get_num_layers(self.parallel_config)
  177. key_cache_block = self.cache_config.block_size * num_heads * head_size
  178. value_cache_block = key_cache_block
  179. total = num_layers * (key_cache_block + value_cache_block)
  180. dtype_size = get_dtype_size(self.cache_dtype)
  181. return dtype_size * total
  182. @property
  183. def do_metadata_broadcast(self) -> bool:
  184. return self.parallel_config.tensor_parallel_size > 1
  185. @property
  186. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  187. # NOTE: This assumes virtual_engine == 0, i.e., no pipeline
  188. # parallelism.
  189. return [self.tpu_cache]
  190. def prepare_worker_input(
  191. self,
  192. execute_model_req: ExecuteModelRequest,
  193. ) -> WorkerInput:
  194. virtual_engine = execute_model_req.virtual_engine
  195. num_seq_groups = len(execute_model_req.seq_group_metadata_list)
  196. blocks_to_swap_in = _make_src_to_dst(
  197. execute_model_req.blocks_to_swap_in, "cpu", self.device)
  198. blocks_to_swap_out = _make_src_to_dst(
  199. execute_model_req.blocks_to_swap_out, self.device, "cpu")
  200. blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
  201. self.device, self.device)
  202. return WorkerInput(
  203. num_seq_groups=num_seq_groups,
  204. blocks_to_swap_in=blocks_to_swap_in,
  205. blocks_to_swap_out=blocks_to_swap_out,
  206. blocks_to_copy=blocks_to_copy,
  207. virtual_engine=virtual_engine,
  208. )
  209. def execute_worker(self, worker_input: WorkerInput) -> None:
  210. virtual_engine = worker_input.virtual_engine
  211. assert virtual_engine == 0
  212. attn_backend = self.model_runner.attn_backend
  213. num_layers = self.model_config.get_num_layers(self.parallel_config)
  214. # Issue cache operations.
  215. if worker_input.blocks_to_swap_in is not None:
  216. src_indices, dst_indices = worker_input.blocks_to_swap_in
  217. if src_indices.numel() > 0:
  218. # Swap from CPU to TPU.
  219. for i in range(num_layers):
  220. tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
  221. cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
  222. k = cpu_k_cache[:, src_indices].to(self.device)
  223. v = cpu_v_cache[:, src_indices].to(self.device)
  224. _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
  225. if worker_input.blocks_to_swap_out is not None:
  226. src_indices, dst_indices = worker_input.blocks_to_swap_out
  227. if src_indices.numel() > 0:
  228. # Swap from TPU to CPU.
  229. for i in range(num_layers):
  230. tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
  231. cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
  232. cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
  233. cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
  234. if worker_input.blocks_to_copy is not None:
  235. src_indices, dst_indices = worker_input.blocks_to_copy
  236. if src_indices.numel() > 0:
  237. attn_backend.copy_blocks(self.tpu_cache,
  238. (src_indices, dst_indices))
  239. def _make_src_to_dst(
  240. mapping: List[Tuple[int, int]],
  241. src_device: Union[torch.device, str],
  242. dst_device: Union[torch.device, str],
  243. ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
  244. if not mapping:
  245. return None
  246. src_indices = [i for i, _ in mapping]
  247. dst_indices = [i for _, i in mapping]
  248. src_indices = torch.tensor(src_indices,
  249. device=src_device,
  250. dtype=torch.int64)
  251. dst_indices = torch.tensor(dst_indices,
  252. device=dst_device,
  253. dtype=torch.int64)
  254. return src_indices, dst_indices
  255. @torch.compile(backend="openxla")
  256. def _insert_kv(
  257. k: torch.Tensor,
  258. v: torch.Tensor,
  259. indices: torch.Tensor,
  260. tpu_k_cache: torch.Tensor,
  261. tpu_v_cache: torch.Tensor,
  262. ) -> None:
  263. torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
  264. torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
  265. tpu_k_cache[:, indices] = k
  266. tpu_v_cache[:, indices] = v