tpu_worker.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import os
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch_xla.core.xla_model as xm
  5. import torch_xla.runtime as xr
  6. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  7. ModelConfig, ParallelConfig,
  8. SchedulerConfig, VisionLanguageConfig)
  9. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  10. from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
  11. from aphrodite.distributed import (ensure_model_parallel_initialized,
  12. init_distributed_environment)
  13. from aphrodite.modeling import set_random_seed
  14. from aphrodite.task_handler.tpu_model_runner import TPUModelRunner
  15. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  16. class TPUWorker(LoraNotSupportedWorkerBase):
  17. def __init__(
  18. self,
  19. model_config: ModelConfig,
  20. parallel_config: ParallelConfig,
  21. scheduler_config: SchedulerConfig,
  22. device_config: DeviceConfig,
  23. cache_config: CacheConfig,
  24. load_config: LoadConfig,
  25. vision_language_config: Optional[VisionLanguageConfig],
  26. local_rank: int,
  27. rank: int,
  28. distributed_init_method: str,
  29. ) -> None:
  30. self.model_config = model_config
  31. self.parallel_config = parallel_config
  32. self.scheduler_config = scheduler_config
  33. self.device_config = device_config
  34. self.cache_config = cache_config
  35. self.load_config = load_config
  36. self.vision_language_config = vision_language_config
  37. self.local_rank = local_rank
  38. self.rank = rank
  39. self.distributed_init_method = distributed_init_method
  40. assert self.device_config.device_type == "tpu"
  41. if self.cache_config.cache_dtype == "auto":
  42. self.cache_dtype = self.model_config.dtype
  43. else:
  44. self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
  45. self.cache_config.cache_dtype]
  46. self.model_runner = TPUModelRunner(model_config, parallel_config,
  47. scheduler_config, device_config,
  48. cache_config, load_config,
  49. vision_language_config)
  50. def init_device(self) -> None:
  51. os.environ["PJRT_DEVICE"] = "TPU"
  52. self.device = xm.xla_device()
  53. self.device_config.device = self.device
  54. torch.set_grad_enabled(False)
  55. torch.set_default_dtype(self.model_config.dtype)
  56. # NOTE: This is just a hack to initialize the TP group.
  57. # This cannot perform the actual communication ops.
  58. init_distributed_environment(
  59. world_size=self.parallel_config.world_size,
  60. rank=self.rank,
  61. local_rank=self.local_rank,
  62. distributed_init_method=self.distributed_init_method,
  63. backend="gloo",
  64. )
  65. ensure_model_parallel_initialized(
  66. self.parallel_config.tensor_parallel_size,
  67. self.parallel_config.pipeline_parallel_size)
  68. # Set random seed.
  69. set_random_seed(self.model_config.seed)
  70. xm.set_rng_state(self.model_config.seed, self.device)
  71. # Increase the cache size limit, which is the maximum number of
  72. # dynamo graphs that can be compiled.
  73. # NOTE: Usually, we compile 10-15 graphs for prefill and
  74. # 30-40 graphs for decode. 128 is an arbitrary safe number.
  75. torch._dynamo.config.cache_size_limit = 128
  76. # Use persistent cache to avoid XLA recompilation.
  77. # NOTE: This does not completely eliminate the recompilation
  78. # overhead because dynamo does not cache the compiled results.
  79. APHRODITE_XLA_CACHE_PATH = os.getenv("APHRODITE_XLA_CACHE_PATH",
  80. "~/.aphrodite/xla_cache/")
  81. xr.initialize_cache(os.path.expanduser(APHRODITE_XLA_CACHE_PATH),
  82. readonly=False)
  83. def load_model(self):
  84. self.model_runner.load_model()
  85. def determine_num_available_blocks(self) -> Tuple[int, int]:
  86. num_layers = self.model_config.get_num_layers(self.parallel_config)
  87. head_size = self.model_config.get_head_size()
  88. num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  89. kv_caches = [(None, None) for _ in range(num_layers)]
  90. self.model_runner._dummy_run(
  91. batch_size=1,
  92. seq_len=self.scheduler_config.max_num_batched_tokens,
  93. kv_caches=kv_caches,
  94. is_prompt=True,
  95. )
  96. # Synchronize before measuring the memory usage.
  97. xm.wait_device_ops()
  98. m = xm.get_memory_info(self.device)
  99. program_size = 1024 * 1024 * 1024 # 1GB
  100. free_bytes = max(m["bytes_limit"] - m["bytes_used"] - program_size, 0)
  101. kv_cache_bytes = int(free_bytes *
  102. self.cache_config.gpu_memory_utilization)
  103. kv_cache_dtype_btyes = get_dtype_size(self.cache_dtype)
  104. block_size = self.cache_config.block_size
  105. num_tpu_blocks = (kv_cache_bytes //
  106. (kv_cache_dtype_btyes * block_size * num_layers * 2 *
  107. head_size * num_kv_heads))
  108. num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
  109. return num_tpu_blocks, 0
  110. def initialize_cache(
  111. self,
  112. num_gpu_blocks: int,
  113. num_cpu_blocks: int,
  114. ) -> None:
  115. self.cache_config.num_gpu_blocks = num_gpu_blocks
  116. self.cache_config.num_cpu_blocks = num_cpu_blocks
  117. self.block_size = self.cache_config.block_size
  118. dtype = self.cache_dtype
  119. num_layers = self.model_config.get_num_layers(self.parallel_config)
  120. num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  121. head_size = self.model_config.get_head_size()
  122. self.tpu_cache = []
  123. tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
  124. num_gpu_blocks, self.block_size, num_kv_heads, head_size)
  125. for _ in range(num_layers):
  126. key_cache = torch.zeros(tpu_cache_shape,
  127. dtype=dtype,
  128. device=self.device)
  129. value_cache = torch.zeros_like(key_cache)
  130. self.tpu_cache.append((key_cache, value_cache))
  131. self._warmup_model()
  132. def _warmup_model(self) -> None:
  133. # FIXME: Here we are abusing `enforce_eager` which is defined
  134. # for CUDA graphs. We should refactor this part.
  135. if not self.model_config.enforce_eager:
  136. # Warm up the model with all possible input shapes so that
  137. # compilation never happens during the actual execution.
  138. # This may take ~30 mins for the first run and ~20 mins for the
  139. # subsequent runs.
  140. # If `enforce_eager` is True, the ahead-of-time compilation is
  141. # skipped and the compilation happens during the actual execution,
  142. # which is bad for performance but useful for development.
  143. self.model_runner.warmup_model(self.tpu_cache)
  144. def get_cache_block_size_bytes(self) -> int:
  145. head_size = self.model_config.get_head_size()
  146. num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
  147. num_layers = self.model_config.get_num_layers(self.parallel_config)
  148. key_cache_block = self.cache_config.block_size * num_heads * head_size
  149. value_cache_block = key_cache_block
  150. total = num_layers * (key_cache_block + value_cache_block)
  151. dtype_size = get_dtype_size(self.cache_dtype)
  152. return dtype_size * total
  153. def execute_model(
  154. self,
  155. execute_model_req: Optional[ExecuteModelRequest] = None
  156. ) -> List[SamplerOutput]:
  157. if execute_model_req is None:
  158. return []
  159. seq_group_metadata_list = execute_model_req.seq_group_metadata_list
  160. num_seq_groups = len(seq_group_metadata_list)
  161. if num_seq_groups == 0:
  162. return []
  163. # Currently, TPUWorker does not support swapping.
  164. # TODO: Support block copying.
  165. assert len(execute_model_req.blocks_to_swap_in) == 0, (
  166. "Swapping is not supported for the TPU backend.")
  167. assert len(execute_model_req.blocks_to_swap_out) == 0, (
  168. "Swapping is not supported for the TPU backend.")
  169. assert len(execute_model_req.blocks_to_copy) == 0
  170. output = self.model_runner.execute_model(seq_group_metadata_list,
  171. self.tpu_cache)
  172. return [output]