xpu_worker.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """A XPU worker class."""
  2. import gc
  3. import os
  4. from typing import List, Optional, Tuple
  5. import intel_extension_for_pytorch # noqa: F401
  6. import oneccl_bindings_for_pytorch # noqa: F401
  7. import torch
  8. import torch.distributed
  9. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  10. LoRAConfig, ModelConfig, ParallelConfig,
  11. PromptAdapterConfig, SchedulerConfig,
  12. SpeculativeConfig)
  13. from aphrodite.common.utils import is_xpu
  14. from aphrodite.distributed import (ensure_model_parallel_initialized,
  15. init_distributed_environment)
  16. from aphrodite.modeling import set_random_seed
  17. from aphrodite.task_handler.cache_engine import CacheEngine
  18. from aphrodite.task_handler.worker import Worker
  19. from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
  20. from aphrodite.task_handler.xpu_model_runner import XPUModelRunner
  21. class XPUWorker(LoraNotSupportedWorkerBase, Worker):
  22. """A worker class that executes (a partition of) the model on a GPU.
  23. Each worker is associated with a single XPU device. The worker is
  24. responsible for maintaining the KV cache and executing the model on the
  25. XPU. In case of distributed inference, each worker is assigned a partition
  26. of the model.
  27. """
  28. def __init__(
  29. self,
  30. model_config: ModelConfig,
  31. parallel_config: ParallelConfig,
  32. scheduler_config: SchedulerConfig,
  33. device_config: DeviceConfig,
  34. cache_config: CacheConfig,
  35. load_config: LoadConfig,
  36. local_rank: int,
  37. rank: int,
  38. distributed_init_method: str,
  39. lora_config: Optional[LoRAConfig] = None,
  40. speculative_config: Optional[SpeculativeConfig] = None,
  41. prompt_adapter_config: Optional[PromptAdapterConfig] = None,
  42. is_driver_worker: bool = False,
  43. ) -> None:
  44. assert device_config.device_type == "xpu"
  45. assert is_xpu()
  46. self.model_config = model_config
  47. self.parallel_config = parallel_config
  48. self.scheduler_config = scheduler_config
  49. self.device_config = device_config
  50. self.cache_config = cache_config
  51. self.load_config = load_config
  52. self.local_rank = local_rank
  53. self.rank = rank
  54. self.distributed_init_method = distributed_init_method
  55. self.lora_config = lora_config
  56. self.prompt_adapter_config = prompt_adapter_config
  57. self.is_driver_worker = is_driver_worker
  58. if parallel_config and is_driver_worker:
  59. assert rank % parallel_config.tensor_parallel_size == 0, \
  60. "Driver worker should be rank 0 of tensor parallel group."
  61. self.model_runner = XPUModelRunner( # type: ignore
  62. model_config,
  63. parallel_config,
  64. scheduler_config,
  65. device_config,
  66. cache_config,
  67. load_config=self.load_config,
  68. lora_config=self.lora_config,
  69. kv_cache_dtype=self.cache_config.cache_dtype,
  70. is_driver_worker=is_driver_worker,
  71. )
  72. # Uninitialized cache engine. Will be initialized by
  73. # initialize_cache.
  74. self.cache_engine: List[CacheEngine]
  75. self.gpu_cache: Optional[List[List[torch.Tensor]]]
  76. def init_device(self) -> None:
  77. if self.device_config.device.type == "xpu" and is_xpu():
  78. self.device = torch.device(f"xpu:{self.local_rank}")
  79. torch.xpu.set_device(self.device)
  80. torch.xpu.empty_cache()
  81. self.init_gpu_memory = torch.xpu.get_device_properties(
  82. self.local_rank).total_memory
  83. else:
  84. raise RuntimeError(
  85. f"Not support device type: {self.device_config.device}")
  86. # Initialize the distributed environment.
  87. self.init_worker_distributed_environment()
  88. # Initialize the model.
  89. set_random_seed(self.model_config.seed)
  90. # keep this method for `empty_cache` and `synchronize` api
  91. @torch.inference_mode()
  92. def determine_num_available_blocks(self) -> Tuple[int, int]:
  93. """Profiles the peak memory usage of the model to determine how many
  94. KV blocks may be allocated without OOMs.
  95. The engine will first conduct a profiling of the existing memory usage.
  96. Then, it calculate the maximum possible number of GPU and CPU blocks
  97. that can be allocated with the remaining free memory.
  98. .. tip::
  99. You may limit the usage of GPU memory
  100. by adjusting the `gpu_memory_utilization` parameter.
  101. """
  102. # Profile the memory usage of the model and get the maximum number of
  103. # cache blocks that can be allocated with the remaining free memory.
  104. torch.xpu.empty_cache()
  105. # Execute a forward pass with dummy inputs to profile the memory usage
  106. # of the model.
  107. self.model_runner.profile_run()
  108. # Calculate the number of blocks that can be allocated with the
  109. # profiled peak memory.
  110. torch.xpu.synchronize()
  111. used_memory = torch.xpu.memory_allocated()
  112. total_gpu_memory = torch.xpu.get_device_properties(
  113. self.local_rank).total_memory
  114. free_gpu_memory = total_gpu_memory - used_memory
  115. # NOTE: Here we assume that the other processes using the same
  116. # GPU did not change their memory usage during the profiling.
  117. peak_memory = self.init_gpu_memory - free_gpu_memory
  118. assert peak_memory > 0, (
  119. "Error in memory profiling. This happens when the GPU memory was "
  120. "not properly cleaned up before initializing the Aphrodite.")
  121. cache_block_size = self.get_cache_block_size_bytes()
  122. num_gpu_blocks = int(
  123. (total_gpu_memory * self.cache_config.gpu_memory_utilization -
  124. peak_memory) // cache_block_size)
  125. num_cpu_blocks = int(self.cache_config.swap_space_bytes //
  126. cache_block_size)
  127. num_gpu_blocks = max(num_gpu_blocks, 0)
  128. num_cpu_blocks = max(num_cpu_blocks, 0)
  129. gc.collect()
  130. torch.xpu.empty_cache()
  131. return num_gpu_blocks, num_cpu_blocks
  132. def _warm_up_model(self) -> None:
  133. # IPEX don't support capture graph yet
  134. pass
  135. def init_worker_distributed_environment(self) -> None:
  136. """Initialize the distributed environment."""
  137. parallel_config = self.parallel_config
  138. rank = self.rank
  139. distributed_init_method = self.distributed_init_method
  140. if torch.distributed.is_initialized():
  141. torch_world_size = torch.distributed.get_world_size()
  142. if torch_world_size != parallel_config.world_size:
  143. raise RuntimeError(
  144. "torch.distributed is already initialized but the torch "
  145. "world size does not match parallel_config.world_size "
  146. f"({torch_world_size} vs. {parallel_config.world_size}).")
  147. elif not distributed_init_method:
  148. raise ValueError(
  149. "distributed_init_method must be set if torch.distributed "
  150. "is not already initialized")
  151. else:
  152. # use sockets as default Level zero IPC exchange backend. By
  153. # default oneccl will use `drmfd` as mechanism which need extra
  154. # dependency (libdrm and drm headers) on your system.
  155. ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
  156. "sockets")
  157. ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
  158. str(parallel_config.world_size))
  159. os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
  160. os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
  161. os.environ["LOCAL_RANK"] = str(self.local_rank)
  162. init_distributed_environment(
  163. world_size=parallel_config.world_size,
  164. rank=rank,
  165. distributed_init_method=distributed_init_method,
  166. local_rank=self.local_rank,
  167. backend="ccl")
  168. ensure_model_parallel_initialized(
  169. parallel_config.tensor_parallel_size,
  170. parallel_config.pipeline_parallel_size)