neuron_worker.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. """A Neuron worker class."""
  2. from typing import List, Optional, Tuple
  3. import torch
  4. import torch.distributed
  5. from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
  6. ParallelConfig, SchedulerConfig)
  7. from aphrodite.common.sequence import ExecuteModelRequest
  8. from aphrodite.distributed import (ensure_model_parallel_initialized,
  9. init_distributed_environment)
  10. from aphrodite.modeling import set_random_seed
  11. from aphrodite.task_handler.neuron_model_runner import NeuronModelRunner
  12. from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
  13. LoraNotSupportedWorkerBase,
  14. WorkerInput)
  15. class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
  16. """A worker class that executes the model on a group of neuron cores.
  17. """
  18. def __init__(
  19. self,
  20. model_config: ModelConfig,
  21. parallel_config: ParallelConfig,
  22. scheduler_config: SchedulerConfig,
  23. device_config: DeviceConfig,
  24. cache_config: CacheConfig,
  25. local_rank: int,
  26. rank: int,
  27. distributed_init_method: str,
  28. ) -> None:
  29. self.model_config = model_config
  30. self.parallel_config = parallel_config
  31. self.scheduler_config = scheduler_config
  32. self.device_config = device_config
  33. self.cache_config = cache_config
  34. self.local_rank = local_rank
  35. self.rank = rank
  36. self.distributed_init_method = distributed_init_method
  37. if self.model_config.trust_remote_code:
  38. # note: lazy import to avoid importing torch before initializing
  39. from aphrodite.common.utils import init_cached_hf_modules
  40. init_cached_hf_modules()
  41. self.model_runner: NeuronModelRunner = NeuronModelRunner(
  42. model_config, parallel_config, scheduler_config, device_config)
  43. self.is_driver_worker = True
  44. def init_device(self) -> None:
  45. self.init_distributed_environment()
  46. # Set random seed.
  47. set_random_seed(self.model_config.seed)
  48. def load_model(self):
  49. self.model_runner.load_model()
  50. def determine_num_available_blocks(self) -> Tuple[int, int]:
  51. """Determine the number of available KV blocks.
  52. Swapping is not yet supported, so always return num_cpu_blocks=0.
  53. We configure num_gpu_blocks to be equal to max_num_seqs.
  54. """
  55. # Set the number of GPU blocks to be the same as the maximum number of
  56. # sequences that can be processed in a single batch. This is equivalent
  57. # to schedule without PagedAttention.
  58. num_gpu_blocks = self.scheduler_config.max_num_seqs
  59. # Swap not yet supported with Neuron backend.
  60. num_cpu_blocks = 0
  61. return num_gpu_blocks, num_cpu_blocks
  62. def initialize_cache(self, num_gpu_blocks: int,
  63. num_cpu_blocks: int) -> None:
  64. """Initialize the KV cache.
  65. """
  66. # Different values are not tested.
  67. assert num_cpu_blocks == 0
  68. assert num_gpu_blocks == self.scheduler_config.max_num_seqs
  69. self.cache_config.num_gpu_blocks = num_gpu_blocks
  70. self.cache_config.num_cpu_blocks = num_cpu_blocks
  71. @property
  72. def do_metadata_broadcast(self) -> bool:
  73. return False
  74. @property
  75. def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
  76. return None
  77. @torch.inference_mode()
  78. def prepare_worker_input(
  79. self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
  80. return WorkerInput(num_seq_groups=len(
  81. execute_model_req.seq_group_metadata_list), )
  82. def execute_worker(self, worker_input: WorkerInput) -> None:
  83. pass
  84. def get_cache_block_size_bytes(self) -> int:
  85. """Determine the size in bytes of a cache block.
  86. This is required for speculative decoding; it is not yet implemented.
  87. """
  88. raise NotImplementedError
  89. def init_distributed_environment(self):
  90. """Neuron uses transformers-neuronx for tensor parallelism.
  91. Aphrodite still needs the environment inited when TP/PP > 1
  92. """
  93. init_distributed_environment(
  94. world_size=1,
  95. rank=self.rank,
  96. local_rank=self.local_rank,
  97. distributed_init_method=self.distributed_init_method,
  98. backend="gloo",
  99. )
  100. ensure_model_parallel_initialized(
  101. 1,
  102. 1,
  103. )