tpu_worker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import os
  2. from typing import List, Optional, Tuple, Union
  3. import torch
  4. import torch_xla.core.xla_model as xm
  5. import torch_xla.runtime as xr
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. ModelConfig, MultiModalConfig,
  8. ParallelConfig, SchedulerConfig)
  9. from aphrodite.common.sequence import ExecuteModelRequest
  10. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
  11. from aphrodite.distributed import (ensure_model_parallel_initialized,
  12. init_distributed_environment)
  13. from aphrodite.modeling import set_random_seed
  14. from aphrodite.task_handler.tpu_model_runner import TPUModelRunner
  15. from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
  16. LoraNotSupportedWorkerBase,
  17. WorkerInput)
  18. class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
  19. def __init__(
  20. self,
  21. model_config: ModelConfig,
  22. parallel_config: ParallelConfig,
  23. scheduler_config: SchedulerConfig,
  24. device_config: DeviceConfig,
  25. cache_config: CacheConfig,
  26. load_config: LoadConfig,
  27. multimodal_config: Optional[MultiModalConfig],
  28. local_rank: int,
  29. rank: int,
  30. distributed_init_method: str,
  31. is_driver_worker: bool,
  32. ) -> None:
  33. self.model_config = model_config
  34. self.parallel_config = parallel_config
  35. self.parallel_config.rank = rank
  36. self.scheduler_config = scheduler_config
  37. self.device_config = device_config
  38. self.cache_config = cache_config
  39. self.load_config = load_config
  40. self.multimodal_config = multimodal_config
  41. self.local_rank = local_rank
  42. self.rank = rank
  43. self.distributed_init_method = distributed_init_method
  44. self.is_driver_worker = is_driver_worker
  45. assert self.device_config.device_type == "tpu"
  46. if self.cache_config.cache_dtype == "auto":
  47. self.cache_dtype = self.model_config.dtype
  48. else:
  49. self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
  50. self.cache_config.cache_dtype]
  51. self.model_runner: TPUModelRunner = TPUModelRunner(
  52. model_config,
  53. parallel_config,
  54. scheduler_config,
  55. device_config,
  56. cache_config,
  57. load_config,
  58. multimodal_config,
  59. is_driver_worker=is_driver_worker)
  60. def init_device(self) -> None:
  61. os.environ["PJRT_DEVICE"] = "TPU"
  62. torch.set_grad_enabled(False)
  63. torch.set_default_dtype(self.model_config.dtype)
  64. # NOTE: This is just to initialize the TP group and broadcast
  65. # the input objects on CPU. The all-reduce and all-gather ops on TPU
  66. # are invoked by `xm.all_reduce` and `xm.all_gather` which use their
  67. # own context.
  68. init_distributed_environment(
  69. world_size=self.parallel_config.world_size,
  70. rank=self.rank,
  71. local_rank=self.local_rank,
  72. distributed_init_method=self.distributed_init_method,
  73. backend="gloo",
  74. )
  75. ensure_model_parallel_initialized(
  76. self.parallel_config.tensor_parallel_size,
  77. self.parallel_config.pipeline_parallel_size)
  78. # Device initialization should happen after initializing the distributed
  79. # runtime.
  80. self.device = xm.xla_device()
  81. self.device_config.device = self.device
  82. # Set random seed.
  83. set_random_seed(self.model_config.seed)
  84. xm.set_rng_state(self.model_config.seed, self.device)
  85. # Increase the cache size limit, which is the maximum number of
  86. # dynamo graphs that can be compiled.
  87. # NOTE: Usually, we compile 10-15 graphs for prefill and
  88. # 30-40 graphs for decode. 128 is an arbitrary safe number.
  89. torch._dynamo.config.cache_size_limit = 128
  90. # Use persistent cache to avoid XLA recompilation.
  91. # NOTE: Set per-rank cache path since different ranks
  92. # can have slightly different XLA graphs.
  93. APHRODITE_XLA_CACHE_PATH = os.getenv("APHRODITE_XLA_CACHE_PATH",
  94. "~/.aphrodite/xla_cache/")
  95. world_size = self.parallel_config.world_size
  96. per_rank_path = os.path.join(APHRODITE_XLA_CACHE_PATH,
  97. f"tp{world_size}_rank{self.rank}")
  98. xr.initialize_cache(per_rank_path, readonly=False)
  99. def load_model(self):
  100. self.model_runner.load_model()
  101. def determine_num_available_blocks(self) -> Tuple[int, int]:
  102. num_layers = self.model_config.get_num_layers(self.parallel_config)
  103. head_size = self.model_config.get_head_size()
  104. num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  105. kv_caches = [(None, None) for _ in range(num_layers)]
  106. self.model_runner._dummy_run(
  107. batch_size=1,
  108. seq_len=self.scheduler_config.max_num_batched_tokens,
  109. kv_caches=kv_caches,
  110. is_prompt=True,
  111. )
  112. # Synchronize before measuring the memory usage.
  113. xm.wait_device_ops()
  114. dtype_btyes = get_dtype_size(self.cache_dtype)
  115. block_size = self.cache_config.block_size
  116. block_size_bytes = (dtype_btyes * block_size * num_layers * 2 *
  117. head_size * num_kv_heads)
  118. # Calculate the TPU KV cache size based on profiling.
  119. m = xm.get_memory_info(self.device)
  120. total_memory_size = m["bytes_limit"]
  121. usable_memory_size = int(total_memory_size *
  122. self.cache_config.gpu_memory_utilization)
  123. profiled = m["bytes_used"] # Weights + intermediate activations.
  124. tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
  125. num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
  126. num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
  127. # Calculate the CPU KV cache size based on the config.
  128. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  129. block_size_bytes)
  130. num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
  131. return num_tpu_blocks, num_cpu_blocks
  132. def initialize_cache(
  133. self,
  134. num_gpu_blocks: int,
  135. num_cpu_blocks: int,
  136. ) -> None:
  137. self.cache_config.num_gpu_blocks = num_gpu_blocks
  138. self.cache_config.num_cpu_blocks = num_cpu_blocks
  139. self.block_size = self.cache_config.block_size
  140. dtype = self.cache_dtype
  141. num_layers = self.model_config.get_num_layers(self.parallel_config)
  142. num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  143. head_size = self.model_config.get_head_size()
  144. self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
  145. self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
  146. tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
  147. num_gpu_blocks, self.block_size, num_kv_heads, head_size)
  148. cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
  149. num_cpu_blocks, self.block_size, num_kv_heads, head_size)
  150. for _ in range(num_layers):
  151. tpu_k_cache = torch.zeros(tpu_cache_shape,
  152. dtype=dtype,
  153. device=self.device)
  154. tpu_v_cache = torch.zeros_like(tpu_k_cache)
  155. self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
  156. cpu_k_cache = torch.zeros(cpu_cache_shape,
  157. dtype=dtype,
  158. device="cpu")
  159. cpu_v_cache = torch.zeros_like(cpu_k_cache)
  160. self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
  161. self._warmup_model()
  162. def _warmup_model(self) -> None:
  163. # FIXME: Here we are abusing `enforce_eager` which is defined
  164. # for CUDA graphs. We should refactor this part.
  165. if not self.model_config.enforce_eager:
  166. # Warm up the model with all possible input shapes so that
  167. # compilation never happens during the actual execution.
  168. # This may take ~30 mins for the first run and ~20 mins for the
  169. # subsequent runs.
  170. # If `enforce_eager` is True, the ahead-of-time compilation is
  171. # skipped and the compilation happens during the actual execution,
  172. # which is bad for performance but useful for development.
  173. self.model_runner.warmup_model(self.tpu_cache)
  174. def get_cache_block_size_bytes(self) -> int:
  175. head_size = self.model_config.get_head_size()
  176. num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  177. num_layers = self.model_config.get_num_layers(self.parallel_config)
  178. key_cache_block = self.cache_config.block_size * num_heads * head_size
  179. value_cache_block = key_cache_block
  180. total = num_layers * (key_cache_block + value_cache_block)
  181. dtype_size = get_dtype_size(self.cache_dtype)
  182. return dtype_size * total
  183. @property
  184. def do_metadata_broadcast(self) -> bool:
  185. return self.parallel_config.tensor_parallel_size > 1
  186. @property
  187. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  188. # NOTE: This assumes virtual_engine == 0, i.e., no pipeline
  189. # parallelism.
  190. return [self.tpu_cache]
  191. def prepare_worker_input(
  192. self,
  193. execute_model_req: ExecuteModelRequest,
  194. ) -> WorkerInput:
  195. virtual_engine = execute_model_req.virtual_engine
  196. num_seq_groups = len(execute_model_req.seq_group_metadata_list)
  197. blocks_to_swap_in = _make_src_to_dst(
  198. execute_model_req.blocks_to_swap_in, "cpu", self.device)
  199. blocks_to_swap_out = _make_src_to_dst(
  200. execute_model_req.blocks_to_swap_out, self.device, "cpu")
  201. blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
  202. self.device, self.device)
  203. return WorkerInput(
  204. num_seq_groups=num_seq_groups,
  205. blocks_to_swap_in=blocks_to_swap_in,
  206. blocks_to_swap_out=blocks_to_swap_out,
  207. blocks_to_copy=blocks_to_copy,
  208. virtual_engine=virtual_engine,
  209. )
  210. def execute_worker(self, worker_input: WorkerInput) -> None:
  211. virtual_engine = worker_input.virtual_engine
  212. assert virtual_engine == 0
  213. attn_backend = self.model_runner.attn_backend
  214. num_layers = self.model_config.get_num_layers(self.parallel_config)
  215. # Issue cache operations.
  216. if worker_input.blocks_to_swap_in is not None:
  217. src_indices, dst_indices = worker_input.blocks_to_swap_in
  218. if src_indices.numel() > 0:
  219. # Swap from CPU to TPU.
  220. for i in range(num_layers):
  221. tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
  222. cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
  223. k = cpu_k_cache[:, src_indices].to(self.device)
  224. v = cpu_v_cache[:, src_indices].to(self.device)
  225. _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
  226. if worker_input.blocks_to_swap_out is not None:
  227. src_indices, dst_indices = worker_input.blocks_to_swap_out
  228. if src_indices.numel() > 0:
  229. # Swap from TPU to CPU.
  230. for i in range(num_layers):
  231. tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
  232. cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
  233. cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
  234. cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
  235. if worker_input.blocks_to_copy is not None:
  236. src_indices, dst_indices = worker_input.blocks_to_copy
  237. if src_indices.numel() > 0:
  238. attn_backend.copy_blocks(self.tpu_cache,
  239. (src_indices, dst_indices))
  240. def _make_src_to_dst(
  241. mapping: List[Tuple[int, int]],
  242. src_device: Union[torch.device, str],
  243. dst_device: Union[torch.device, str],
  244. ) -> Tuple[torch.Tensor, torch.Tensor]:
  245. src_indices = [i for i, _ in mapping]
  246. dst_indices = [i for _, i in mapping]
  247. src_indices = torch.tensor(src_indices,
  248. device=src_device,
  249. dtype=torch.int64)
  250. dst_indices = torch.tensor(dst_indices,
  251. device=dst_device,
  252. dtype=torch.int64)
  253. return src_indices, dst_indices
  254. @torch.compile(backend="openxla")
  255. def _insert_kv(
  256. k: torch.Tensor,
  257. v: torch.Tensor,
  258. indices: torch.Tensor,
  259. tpu_k_cache: torch.Tensor,
  260. tpu_v_cache: torch.Tensor,
  261. ) -> None:
  262. torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
  263. torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
  264. tpu_k_cache[:, indices] = k
  265. tpu_v_cache[:, indices] = v