xpu_worker.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """A XPU worker class."""
  2. import gc
  3. import os
  4. from typing import List, Optional, Tuple
  5. import intel_extension_for_pytorch # noqa: F401
  6. import oneccl_bindings_for_pytorch # noqa: F401
  7. import torch
  8. import torch.distributed
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. LoRAConfig, ModelConfig, ParallelConfig,
  11. PromptAdapterConfig, SchedulerConfig,
  12. SpeculativeConfig)
  13. from aphrodite.common.utils import is_xpu
  14. from aphrodite.distributed import (ensure_model_parallel_initialized,
  15. init_distributed_environment)
  16. from aphrodite.distributed.parallel_state import get_pp_group
  17. from aphrodite.modeling import set_random_seed
  18. from aphrodite.worker.cache_engine import CacheEngine
  19. from aphrodite.worker.worker import Worker
  20. from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase
  21. from aphrodite.worker.xpu_model_runner import XPUModelRunner
  22. class XPUWorker(LoraNotSupportedWorkerBase, Worker):
  23. """A worker class that executes (a partition of) the model on a GPU.
  24. Each worker is associated with a single XPU device. The worker is
  25. responsible for maintaining the KV cache and executing the model on the
  26. XPU. In case of distributed inference, each worker is assigned a partition
  27. of the model.
  28. """
  29. def __init__(
  30. self,
  31. model_config: ModelConfig,
  32. parallel_config: ParallelConfig,
  33. scheduler_config: SchedulerConfig,
  34. device_config: DeviceConfig,
  35. cache_config: CacheConfig,
  36. load_config: LoadConfig,
  37. local_rank: int,
  38. rank: int,
  39. distributed_init_method: str,
  40. lora_config: Optional[LoRAConfig] = None,
  41. speculative_config: Optional[SpeculativeConfig] = None,
  42. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  43. is_driver_worker: bool = False,
  44. ) -> None:
  45. assert device_config.device_type == "xpu"
  46. assert is_xpu()
  47. self.model_config = model_config
  48. self.parallel_config = parallel_config
  49. self.scheduler_config = scheduler_config
  50. self.device_config = device_config
  51. self.cache_config = cache_config
  52. self.load_config = load_config
  53. self.local_rank = local_rank
  54. self.rank = rank
  55. self.distributed_init_method = distributed_init_method
  56. self.lora_config = lora_config
  57. self.prompt_adapter_config = prompt_adapter_config
  58. self.is_driver_worker = is_driver_worker
  59. if parallel_config and is_driver_worker:
  60. assert rank % parallel_config.tensor_parallel_size == 0, \
  61. "Driver worker should be rank 0 of tensor parallel group."
  62. self.model_runner = XPUModelRunner( # type: ignore
  63. model_config,
  64. parallel_config,
  65. scheduler_config,
  66. device_config,
  67. cache_config,
  68. load_config=self.load_config,
  69. lora_config=self.lora_config,
  70. kv_cache_dtype=self.cache_config.cache_dtype,
  71. is_driver_worker=is_driver_worker,
  72. )
  73. # Uninitialized cache engine. Will be initialized by
  74. # initialize_cache.
  75. self.cache_engine: List[CacheEngine]
  76. self.gpu_cache: Optional[List[List[torch.Tensor]]]
  77. def init_device(self) -> None:
  78. if self.device_config.device.type == "xpu" and is_xpu():
  79. self.device = torch.device(f"xpu:{self.local_rank}")
  80. torch.xpu.set_device(self.device)
  81. torch.xpu.empty_cache()
  82. self.init_gpu_memory = torch.xpu.get_device_properties(
  83. self.local_rank).total_memory
  84. else:
  85. raise RuntimeError(
  86. f"Not support device type: {self.device_config.device}")
  87. # Initialize the distributed environment.
  88. self.init_worker_distributed_environment()
  89. # Initialize the model.
  90. set_random_seed(self.model_config.seed)
  91. # keep this method for `empty_cache` and `synchronize` api
  92. @torch.inference_mode()
  93. def determine_num_available_blocks(self) -> Tuple[int, int]:
  94. """Profiles the peak memory usage of the model to determine how many
  95. KV blocks may be allocated without OOMs.
  96. The engine will first conduct a profiling of the existing memory usage.
  97. Then, it calculate the maximum possible number of GPU and CPU blocks
  98. that can be allocated with the remaining free memory.
  99. .. tip::
  100. You may limit the usage of GPU memory
  101. by adjusting the `gpu_memory_utilization` parameter.
  102. """
  103. # Profile the memory usage of the model and get the maximum number of
  104. # cache blocks that can be allocated with the remaining free memory.
  105. torch.xpu.empty_cache()
  106. # Execute a forward pass with dummy inputs to profile the memory usage
  107. # of the model.
  108. self.model_runner.profile_run()
  109. # Calculate the number of blocks that can be allocated with the
  110. # profiled peak memory.
  111. torch.xpu.synchronize()
  112. used_memory = torch.xpu.memory_allocated()
  113. total_gpu_memory = torch.xpu.get_device_properties(
  114. self.local_rank).total_memory
  115. free_gpu_memory = total_gpu_memory - used_memory
  116. # NOTE: Here we assume that the other processes using the same
  117. # GPU did not change their memory usage during the profiling.
  118. peak_memory = self.init_gpu_memory - free_gpu_memory
  119. assert peak_memory > 0, (
  120. "Error in memory profiling. This happens when the GPU memory was "
  121. "not properly cleaned up before initializing the Aphrodite.")
  122. cache_block_size = self.get_cache_block_size_bytes()
  123. num_gpu_blocks = int(
  124. (total_gpu_memory * self.cache_config.gpu_memory_utilization -
  125. peak_memory) // cache_block_size)
  126. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  127. cache_block_size)
  128. num_gpu_blocks = max(num_gpu_blocks, 0)
  129. num_cpu_blocks = max(num_cpu_blocks, 0)
  130. gc.collect()
  131. torch.xpu.empty_cache()
  132. return num_gpu_blocks, num_cpu_blocks
  133. def _warm_up_model(self) -> None:
  134. # IPEX don't support capture graph yet
  135. pass
  136. def init_worker_distributed_environment(self) -> None:
  137. """Initialize the distributed environment."""
  138. parallel_config = self.parallel_config
  139. rank = self.rank
  140. distributed_init_method = self.distributed_init_method
  141. if torch.distributed.is_initialized():
  142. torch_world_size = torch.distributed.get_world_size()
  143. if torch_world_size != parallel_config.world_size:
  144. raise RuntimeError(
  145. "torch.distributed is already initialized but the torch "
  146. "world size does not match parallel_config.world_size "
  147. f"({torch_world_size} vs. {parallel_config.world_size}).")
  148. elif not distributed_init_method:
  149. raise ValueError(
  150. "distributed_init_method must be set if torch.distributed "
  151. "is not already initialized")
  152. else:
  153. # use sockets as default Level zero IPC exchange backend. By
  154. # default oneccl will use `drmfd` as mechanism which need extra
  155. # dependency (libdrm and drm headers) on your system.
  156. ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
  157. "sockets")
  158. ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
  159. str(parallel_config.world_size))
  160. os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
  161. os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
  162. os.environ["LOCAL_RANK"] = str(self.local_rank)
  163. init_distributed_environment(
  164. world_size=parallel_config.world_size,
  165. rank=rank,
  166. distributed_init_method=distributed_init_method,
  167. local_rank=self.local_rank,
  168. backend="ccl")
  169. ensure_model_parallel_initialized(
  170. parallel_config.tensor_parallel_size,
  171. parallel_config.pipeline_parallel_size)
  172. if parallel_config.pipeline_parallel_size > 1:
  173. # torch-ccl xpu need a collective API warm up
  174. # before calling send/recv API
  175. get_pp_group().all_reduce(torch.zeros(1).xpu())