cfg_model_runner.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. from typing import List, Optional, Union
  2. import torch
  3. from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
  4. from aphrodite.distributed import get_pp_group
  5. from aphrodite.multimodal import MultiModalInputs
  6. from aphrodite.task_handler.model_runner import (
  7. FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper,
  8. BatchPrefillWithPagedKVCacheWrapper, ModelInputForGPUWithSamplingMetadata,
  9. ModelRunner)
  10. class CFGModelRunner(ModelRunner):
  11. def __init__(self, *args, **kwargs):
  12. super().__init__(*args, **kwargs)
  13. @torch.inference_mode()
  14. def model_execute(
  15. self,
  16. model_input: ModelInputForGPUWithSamplingMetadata,
  17. kv_caches: List[torch.Tensor],
  18. intermediate_tensors: Optional[IntermediateTensors] = None,
  19. num_steps: int = 1,
  20. ) -> torch.Tensor:
  21. if num_steps > 1:
  22. raise ValueError("num_steps > 1 is not supported in ModelRunner")
  23. if self.lora_config:
  24. assert model_input.lora_requests is not None
  25. assert model_input.lora_mapping is not None
  26. self.set_active_loras(model_input.lora_requests,
  27. model_input.lora_mapping)
  28. if self.prompt_adapter_config:
  29. assert model_input.prompt_adapter_requests is not None
  30. assert model_input.prompt_adapter_mapping is not None
  31. self.set_active_prompt_adapters(
  32. model_input.prompt_adapter_requests,
  33. model_input.prompt_adapter_mapping)
  34. if self.attn_backend.get_name() == "flashinfer":
  35. assert model_input.attn_metadata is not None
  36. assert model_input.input_tokens is not None
  37. if self.flashinfer_decode_workspace_buffer is None:
  38. self.flashinfer_decode_workspace_buffer = torch.empty(
  39. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  40. dtype=torch.uint8,
  41. device=self.device)
  42. self.flashinfer_decode_wrapper = \
  43. BatchDecodeWithPagedKVCacheWrapper(
  44. self.flashinfer_decode_workspace_buffer, "NHD")
  45. self.flashinfer_prefill_workspace_buffer = torch.empty(
  46. FLASHINFER_WORKSPACE_BUFFER_SIZE,
  47. dtype=torch.uint8,
  48. device=self.device)
  49. self.flashinfer_prefill_wrapper = \
  50. BatchPrefillWithPagedKVCacheWrapper(
  51. self.flashinfer_prefill_workspace_buffer, "NHD")
  52. model_input.attn_metadata.prefill_wrapper = \
  53. self.flashinfer_prefill_wrapper
  54. if model_input.attn_metadata.use_cuda_graph:
  55. batch_size = model_input.input_tokens.shape[0]
  56. model_input.attn_metadata.decode_wrapper = self.graph_runners[
  57. model_input.
  58. virtual_engine][batch_size].flashinfer_decode_wrapper
  59. else:
  60. model_input.attn_metadata.decode_wrapper = \
  61. self.flashinfer_decode_wrapper
  62. model_input.attn_metadata.begin_forward()
  63. # Currently cuda graph is only supported by the decode phase.
  64. assert model_input.attn_metadata is not None
  65. prefill_meta = model_input.attn_metadata.prefill_metadata
  66. decode_meta = model_input.attn_metadata.decode_metadata
  67. # TODO(andoorve): We can remove this once all
  68. # virtual engines share the same kv cache.
  69. virtual_engine = model_input.virtual_engine
  70. if prefill_meta is None and decode_meta.use_cuda_graph:
  71. assert model_input.input_tokens is not None
  72. graph_batch_size = model_input.input_tokens.shape[0]
  73. model_executable = self.graph_runners[virtual_engine][
  74. graph_batch_size]
  75. else:
  76. model_executable = self.model
  77. multi_modal_kwargs = model_input.multi_modal_kwargs or {}
  78. seqlen_agnostic_kwargs = {
  79. "finished_requests_ids": model_input.finished_requests_ids,
  80. "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
  81. } if self.has_seqlen_agnostic else {}
  82. hidden_or_intermediate_states = model_executable(
  83. input_ids=model_input.input_tokens,
  84. positions=model_input.input_positions,
  85. kv_caches=kv_caches,
  86. attn_metadata=model_input.attn_metadata,
  87. intermediate_tensors=intermediate_tensors,
  88. **MultiModalInputs.as_kwargs(multi_modal_kwargs,
  89. device=self.device),
  90. **seqlen_agnostic_kwargs)
  91. return hidden_or_intermediate_states
  92. @torch.inference_mode()
  93. def get_logits(
  94. self,
  95. hidden_or_intermediate_states: torch.Tensor,
  96. model_input: ModelInputForGPUWithSamplingMetadata,
  97. ) -> torch.Tensor:
  98. return self.model._get_logits(hidden_or_intermediate_states,
  99. model_input.sampling_metadata)
  100. @torch.inference_mode()
  101. def compute_logits(
  102. self,
  103. logits: torch.Tensor,
  104. model_input: ModelInputForGPUWithSamplingMetadata,
  105. ) -> torch.Tensor:
  106. return self.model.compute_logits(logits,
  107. model_input.sampling_metadata)
  108. @torch.inference_mode()
  109. def do_sample(
  110. self,
  111. logits: torch.Tensor,
  112. model_input: ModelInputForGPUWithSamplingMetadata,
  113. ):
  114. if not self.is_driver_worker:
  115. return []
  116. # Sample the next token.
  117. output: SamplerOutput = self.model.sample(
  118. logits=logits,
  119. sampling_metadata=model_input.sampling_metadata,
  120. )
  121. if self.return_hidden_states:
  122. raise NotImplementedError("return_hidden_states is not supported "
  123. "in CFGModelRunner")
  124. return [output]
  125. @torch.inference_mode()
  126. def execute_model(
  127. self,
  128. model_input: ModelInputForGPUWithSamplingMetadata,
  129. kv_caches: List[torch.Tensor],
  130. intermediate_tensors: Optional[IntermediateTensors] = None,
  131. num_steps: int = 1,
  132. ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
  133. hidden_or_intermediate_states = self.model_execute(
  134. model_input, kv_caches, intermediate_tensors, num_steps)
  135. if not get_pp_group().is_last_rank:
  136. return hidden_or_intermediate_states
  137. hidden_or_intermediate_states = self.get_logits(
  138. hidden_or_intermediate_states, model_input)
  139. logits = self.compute_logits(hidden_or_intermediate_states, model_input)
  140. return self.do_sample(logits, model_input)