cfg_worker.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import copy
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. from aphrodite.cfg.cfg_model_runner import CFGModelRunner
  5. from aphrodite.cfg.separated_worker import SeparatedWorker
  6. from aphrodite.common.config import CFGConfig, ParallelConfig
  7. from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
  8. SequenceData, SequenceGroupMetadata)
  9. from aphrodite.distributed import get_pp_group, get_tp_group
  10. from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
  11. WorkerBase)
  12. def create_cfg_worker(*args, **kwargs) -> "CFGWorker":
  13. assert "cfg_config" in kwargs
  14. cfg_config: CFGConfig = kwargs.get("cfg_config")
  15. assert cfg_config is not None
  16. kwargs.pop("cfg_config")
  17. kwargs["model_runner_cls"] = CFGModelRunner
  18. root_worker = SeparatedWorker(*args, **kwargs)
  19. guidance_model_config = cfg_config.guidance_model_config
  20. guidance_parallel_config = cfg_config.guidance_parallel_config
  21. kwargs.update(
  22. model_config=guidance_model_config,
  23. parallel_config=guidance_parallel_config,
  24. )
  25. guidance_worker = SeparatedWorker(*args, **kwargs)
  26. return CFGWorker(
  27. root_worker=root_worker,
  28. guidance_worker=guidance_worker,
  29. is_driver_worker=kwargs["is_driver_worker"],
  30. parallel_config=kwargs["parallel_config"],
  31. )
  32. class CFGWorker(LoraNotSupportedWorkerBase):
  33. def __init__(
  34. self,
  35. root_worker: WorkerBase,
  36. guidance_worker: WorkerBase,
  37. is_driver_worker: bool,
  38. parallel_config: ParallelConfig,
  39. ):
  40. self.root_worker = root_worker
  41. self.guidance_worker = guidance_worker
  42. self.is_driver_worker = is_driver_worker
  43. self.parallel_config = parallel_config
  44. assert self.parallel_config.pipeline_parallel_size == 1
  45. def init_device(self):
  46. self.root_worker.init_device()
  47. self.guidance_worker.init_device()
  48. def load_model(self):
  49. self.root_worker.load_model()
  50. self.guidance_worker.share_model(self.root_worker)
  51. def determine_num_available_blocks(self) -> Tuple[int, int]:
  52. (
  53. num_gpu_blocks,
  54. num_cpu_blocks,
  55. ) = self.root_worker.determine_num_available_blocks()
  56. root_cache_block_size_bytes = (
  57. self.root_worker.get_cache_block_size_bytes()
  58. )
  59. guidance_cache_block_size_bytes = (
  60. self.guidance_worker.get_cache_block_size_bytes()
  61. )
  62. new_num_gpu_blocks = int(
  63. num_gpu_blocks
  64. * root_cache_block_size_bytes
  65. / (guidance_cache_block_size_bytes + root_cache_block_size_bytes)
  66. )
  67. return new_num_gpu_blocks, num_cpu_blocks
  68. def initialize_cache(
  69. self, num_gpu_blocks: int, num_cpu_blocks: int
  70. ) -> None:
  71. self.root_worker.initialize_cache(
  72. num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks
  73. )
  74. self.guidance_worker.initialize_cache(
  75. num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks
  76. )
  77. @property
  78. def do_metadata_broadcast(self) -> bool:
  79. return self.parallel_config.tensor_parallel_size > 1
  80. @torch.inference_mode()
  81. def execute_model(
  82. self, execute_model_req: Optional[ExecuteModelRequest] = None
  83. ) -> List[SamplerOutput]:
  84. # prepare negative request with shallow copy
  85. if execute_model_req is not None:
  86. negative_seq_group_metadata_list: List[SequenceGroupMetadata] = []
  87. negative_excute_model_req = execute_model_req.clone(
  88. negative_seq_group_metadata_list
  89. )
  90. for seq_group_metadata in execute_model_req.seq_group_metadata_list:
  91. negative_seq_group_metadata = copy.copy(seq_group_metadata)
  92. negative_seq_data: Dict[int, SequenceData] = {}
  93. negative_block_tables: Dict[int, List[int]] = {}
  94. assert len(seq_group_metadata.seq_data) == 1
  95. for seq_id in seq_group_metadata.seq_data:
  96. negative_seq_data[
  97. seq_id
  98. ] = seq_group_metadata.negative_seq_data
  99. negative_block_tables[
  100. seq_id
  101. ] = seq_group_metadata.negative_block_table
  102. if negative_seq_group_metadata.is_prompt:
  103. negative_seq_group_metadata.token_chunk_size = list(
  104. negative_seq_data.values()
  105. )[0].get_len()
  106. negative_seq_group_metadata.seq_data = negative_seq_data
  107. negative_seq_group_metadata.block_tables = negative_block_tables
  108. negative_seq_group_metadata.negative_seq_data = None
  109. negative_seq_group_metadata.negative_block_table = None
  110. negative_seq_group_metadata_list.append(
  111. negative_seq_group_metadata
  112. )
  113. negative_excute_model_req.seq_group_metadata_list = (
  114. negative_seq_group_metadata_list
  115. )
  116. else:
  117. negative_excute_model_req = None
  118. inputs = self.root_worker.prepare_input(execute_model_req)
  119. negative_inputs = self.guidance_worker.prepare_input(
  120. negative_excute_model_req
  121. )
  122. if inputs is None:
  123. assert negative_inputs is None
  124. return None
  125. # get root models's logits
  126. condition_logits = self.root_worker.execute_model_part(inputs)
  127. # get unconditional logits
  128. unconditional_logits = self.guidance_worker.execute_model_part(
  129. negative_inputs
  130. )
  131. # do classifier free guidance logist process
  132. model_input, _ = inputs
  133. if condition_logits is not None:
  134. for seq_group in model_input.sampling_metadata.seq_groups:
  135. seq_ids = seq_group.seq_ids
  136. guidance_scale = seq_group.sampling_params.guidance_scale
  137. if guidance_scale == 1.0:
  138. break
  139. for seq_id, logits_row_idx in zip(
  140. seq_ids, seq_group.sample_indices
  141. ):
  142. logits_row = torch.nn.functional.log_softmax(
  143. condition_logits[logits_row_idx], dim=-1
  144. )
  145. unconditional_logits_row = torch.nn.functional.log_softmax(
  146. unconditional_logits[logits_row_idx], dim=-1
  147. )
  148. condition_logits[logits_row_idx] = (
  149. guidance_scale * (logits_row - unconditional_logits_row)
  150. + unconditional_logits_row
  151. )
  152. # do logist_processor
  153. scores = self.root_worker.compute_logits(condition_logits, model_input)
  154. if not self.is_driver_worker:
  155. return []
  156. # do sample
  157. output = self.root_worker.do_sample(scores, model_input)
  158. if not get_pp_group().is_last_rank:
  159. # output is IntermediateTensors
  160. get_pp_group().send_tensor_dict(
  161. output.tensors, all_gather_group=get_tp_group()
  162. )
  163. return [None]
  164. # output is List[SamplerOutput]
  165. return output
  166. def get_cache_block_size_bytes(self):
  167. raise NotImplementedError