smaller_tp_proposer_worker.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from typing import List, Optional, Set, Tuple
  2. import torch
  3. from loguru import logger
  4. from aphrodite.common.sequence import ExecuteModelRequest
  5. from aphrodite.distributed.parallel_state import (get_tp_group,
  6. init_model_parallel_group,
  7. patch_tensor_parallel_group)
  8. from aphrodite.modeling.layers.sampler import SamplerOutput
  9. from aphrodite.spec_decode.interfaces import SpeculativeProposals
  10. from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
  11. from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
  12. class SmallerTpProposerWorker(ProposerWorkerBase):
  13. """Class which allows a speculative draft model to run with smaller tensor
  14. parallel degree than target model.
  15. This reduces the communication overhead of small draft models.
  16. To implement this feature, this class differs behavior based on is_dummy
  17. flag, where dummy means worker that does not participate draft generation.
  18. Participating workers use a smaller tp group by patching Aphrodite's tensor
  19. parallel group temporarily during forward passes of draft models.
  20. """
  21. @classmethod
  22. def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int,
  23. target_tensor_parallel_size: int):
  24. """Wrap the worker in a SmallerTpProposerWorker if necessary.
  25. """
  26. if draft_tensor_parallel_size == target_tensor_parallel_size:
  27. return worker
  28. # gpu ranks that will generate draft tokens together
  29. draft_ranks = list(range(draft_tensor_parallel_size))
  30. logger.info(f"Wrapping {type(worker)} in {cls}")
  31. return cls(worker, draft_ranks)
  32. def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]):
  33. """Create a SmallerTpProposerWorker.
  34. Args:
  35. worker (MultiStepWorker): an actual worker wrapped with this class
  36. draft_ranks (List[int]): if this value is given, only the GPU ranks
  37. written in this value participate in draft generation
  38. """
  39. self._worker = worker
  40. self._draft_ranks = draft_ranks
  41. # init during init_device
  42. self._is_dummy = False
  43. self._tp_group = None
  44. def _patch_tensor_parallel_group(self):
  45. """Temporarily patch the global tp group state with its own tp group
  46. state.
  47. """
  48. return patch_tensor_parallel_group(self._tp_group)
  49. def init_device(self) -> None:
  50. self._is_dummy = get_tp_group().rank not in self._draft_ranks
  51. # dummy workers do nothing
  52. if self._is_dummy:
  53. return
  54. # creates tp process group containing only a subset of gpu ranks
  55. local_rank = get_tp_group().local_rank
  56. tp_backend = torch.distributed.get_backend(get_tp_group().device_group)
  57. self._tp_group = init_model_parallel_group([self._draft_ranks],
  58. local_rank, tp_backend)
  59. with self._patch_tensor_parallel_group():
  60. self._worker.init_device()
  61. def set_include_gpu_probs_tensor(self) -> None:
  62. if self._is_dummy:
  63. return
  64. # Need include_gpu_probs_tensor for multi_step_worker
  65. self._worker.set_include_gpu_probs_tensor()
  66. def set_should_modify_greedy_probs_inplace(self) -> None:
  67. if self._is_dummy:
  68. return
  69. self._worker.set_should_modify_greedy_probs_inplace()
  70. def load_model(self) -> None:
  71. if self._is_dummy:
  72. return
  73. with self._patch_tensor_parallel_group():
  74. self._worker.load_model()
  75. def determine_num_available_blocks(self) -> Tuple[int, int]:
  76. if self._is_dummy:
  77. # this case is not used now
  78. return -1, -1
  79. with self._patch_tensor_parallel_group():
  80. return self._worker.determine_num_available_blocks()
  81. def initialize_cache(self, num_gpu_blocks: int,
  82. num_cpu_blocks: int) -> None:
  83. if self._is_dummy:
  84. return
  85. with self._patch_tensor_parallel_group():
  86. self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
  87. def sampler_output(
  88. self,
  89. execute_model_req: ExecuteModelRequest,
  90. sample_len: int,
  91. seq_ids_with_bonus_token_in_last_step: Set[int],
  92. ) -> Tuple[List[SamplerOutput], bool]:
  93. # Do not check _is_dummy, as it's always called by get_spec_proposals
  94. return self._worker.sampler_output(
  95. execute_model_req, sample_len,
  96. seq_ids_with_bonus_token_in_last_step)
  97. def get_spec_proposals(
  98. self,
  99. execute_model_req: ExecuteModelRequest,
  100. seq_ids_with_bonus_token_in_last_step: Set[int],
  101. ) -> SpeculativeProposals:
  102. """Produce speculations given an input batch of sequences. The number of
  103. speculative tokens per sequence is determined by max_proposal_len.
  104. """
  105. if self._is_dummy:
  106. return SpeculativeProposals(None, None, None)
  107. with self._patch_tensor_parallel_group():
  108. return self._worker.get_spec_proposals(
  109. execute_model_req, seq_ids_with_bonus_token_in_last_step)
  110. def execute_model(
  111. self,
  112. execute_model_req: Optional[ExecuteModelRequest] = None
  113. ) -> List[SamplerOutput]:
  114. if self._is_dummy:
  115. return []
  116. with self._patch_tensor_parallel_group():
  117. return self._worker.execute_model(execute_model_req)
  118. def get_cache_block_size_bytes(self) -> int:
  119. if self._is_dummy:
  120. # by returning zero, target worker can use the entire kv cache space
  121. return 0
  122. return self._worker.get_cache_block_size_bytes()
  123. @property
  124. def vocab_size(self) -> int:
  125. return self._worker.vocab_size