smaller_tp_proposer_worker.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from typing import List, Optional, Set, Tuple
  2. import torch
  3. from loguru import logger
  4. from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput
  5. from aphrodite.distributed.parallel_state import (get_tp_group,
  6. init_model_parallel_group,
  7. patch_tensor_parallel_group)
  8. from aphrodite.spec_decode.interfaces import SpeculativeProposals
  9. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  10. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  11. class SmallerTpProposerWorker(ProposerWorkerBase):
  12. """Class which allows a speculative draft model to run with smaller tensor
  13. parallel degree than target model.
  14. This reduces the communication overhead of small draft models.
  15. To implement this feature, this class differs behavior based on is_dummy
  16. flag, where dummy means worker that does not participate draft generation.
  17. Participating workers use a smaller tp group by patching Aphrodite's tensor
  18. parallel group temporarily during forward passes of draft models.
  19. """
  20. @classmethod
  21. def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int,
  22. target_tensor_parallel_size: int):
  23. """Wrap the worker in a SmallerTpProposerWorker if necessary.
  24. """
  25. if draft_tensor_parallel_size == target_tensor_parallel_size:
  26. return worker
  27. # gpu ranks that will generate draft tokens together
  28. draft_ranks = list(range(draft_tensor_parallel_size))
  29. logger.info(f"Wrapping {type(worker)} in {cls}")
  30. return cls(worker, draft_ranks)
  31. def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
  32. """Create a SmallerTpProposerWorker.
  33. Args:
  34. worker (MultiStepWorker): an actual worker wrapped with this class
  35. draft_ranks (List[int]): if this value is given, only the GPU ranks
  36. written in this value participate in draft generation
  37. """
  38. self._worker = worker
  39. self._draft_ranks = draft_ranks
  40. # init during init_device
  41. self._is_dummy = False
  42. self._tp_group = None
  43. def _patch_tensor_parallel_group(self):
  44. """Temporarily patch the global tp group state with its own tp group
  45. state.
  46. """
  47. return patch_tensor_parallel_group(self._tp_group)
  48. def init_device(self) -> None:
  49. self._is_dummy = get_tp_group().rank not in self._draft_ranks
  50. # dummy workers do nothing
  51. if self._is_dummy:
  52. return
  53. # creates tp process group containing only a subset of gpu ranks
  54. local_rank = get_tp_group().local_rank
  55. tp_backend = torch.distributed.get_backend(get_tp_group().device_group)
  56. self._tp_group = init_model_parallel_group([self._draft_ranks],
  57. local_rank, tp_backend)
  58. with self._patch_tensor_parallel_group():
  59. self._worker.init_device()
  60. def set_include_gpu_probs_tensor(self) -> None:
  61. if self._is_dummy:
  62. return
  63. # Need include_gpu_probs_tensor for multi_step_worker
  64. self._worker.set_include_gpu_probs_tensor()
  65. def set_should_modify_greedy_probs_inplace(self) -> None:
  66. if self._is_dummy:
  67. return
  68. self._worker.set_should_modify_greedy_probs_inplace()
  69. def load_model(self) -> None:
  70. if self._is_dummy:
  71. return
  72. with self._patch_tensor_parallel_group():
  73. self._worker.load_model()
  74. def determine_num_available_blocks(self) -> Tuple[int, int]:
  75. if self._is_dummy:
  76. # this case is not used now
  77. return -1, -1
  78. with self._patch_tensor_parallel_group():
  79. return self._worker.determine_num_available_blocks()
  80. def initialize_cache(self, num_gpu_blocks: int,
  81. num_cpu_blocks: int) -> None:
  82. if self._is_dummy:
  83. return
  84. with self._patch_tensor_parallel_group():
  85. self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  86. def sampler_output(
  87. self,
  88. execute_model_req: ExecuteModelRequest,
  89. sample_len: int,
  90. seq_ids_with_bonus_token_in_last_step: Set[int],
  91. ) -> Tuple[List[SamplerOutput], bool]:
  92. # Do not check _is_dummy, as it's always called by get_spec_proposals
  93. return self._worker.sampler_output(
  94. execute_model_req, sample_len,
  95. seq_ids_with_bonus_token_in_last_step)
  96. def get_spec_proposals(
  97. self,
  98. execute_model_req: ExecuteModelRequest,
  99. seq_ids_with_bonus_token_in_last_step: Set[int],
  100. ) -> SpeculativeProposals:
  101. """Produce speculations given an input batch of sequences. The number of
  102. speculative tokens per sequence is determined by max_proposal_len.
  103. """
  104. if self._is_dummy:
  105. return SpeculativeProposals(None, None, None)
  106. with self._patch_tensor_parallel_group():
  107. return self._worker.get_spec_proposals(
  108. execute_model_req, seq_ids_with_bonus_token_in_last_step)
  109. def execute_model(
  110. self,
  111. execute_model_req: Optional[ExecuteModelRequest] = None
  112. ) -> List[SamplerOutput]:
  113. if self._is_dummy:
  114. return []
  115. with self._patch_tensor_parallel_group():
  116. return self._worker.execute_model(execute_model_req)
  117. def get_cache_block_size_bytes(self) -> int:
  118. if self._is_dummy:
  119. # by returning zero, target worker can use the entire kv cache space
  120. return 0
  121. return self._worker.get_cache_block_size_bytes()
  122. @property
  123. def vocab_size(self) -> int:
  124. return self._worker.vocab_size