metrics.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import torch
  2. from dataclasses import dataclass
  3. from typing import Optional
  4. import time
  5. from typing import Callable
  6. from aphrodite.modeling.layers.rejection import RejectionSampler
  7. from aphrodite.common.utils import is_pin_memory_available
  8. @dataclass
  9. class SpecDecodeWorkerMetrics:
  10. """Dataclass holding metrics emitted from the spec decode worker."""
  11. # The empirical acceptance rate of the proposal method on a per-token basis.
  12. # This is useful for evaluating how well the proposal method aligns with the
  13. # scoring method.
  14. draft_acceptance_rate: float
  15. # The empirical efficiency, measured as the number of tokens emitted by the
  16. # system divided by the number of tokens that could be emitted by the system
  17. # if the proposal method were perfect.
  18. system_efficiency: float
  19. # The number of speculative tokens produced by the proposal method.
  20. draft_tokens: int
  21. # The number of tokens emitted by the entire system.
  22. emitted_tokens: int
  23. # The number of tokens accepted by the scoring model and verification
  24. # routine, e.g. Llama2-70B and lossless rejection sampling.
  25. #
  26. # NOTE: Any token accepted by the verification routine is considered
  27. # accepted (regardless of if the speculative prefix is also accepted). The
  28. # user will usually see less accepted tokens. This metric is helpful when
  29. # evaluating alignment of the proposal method with the scoring model.
  30. accepted_tokens: int
  31. # The number of speculative tokens per sequence.
  32. num_spec_tokens: int
  33. Timer = Callable[[], float]
  34. class AsyncMetricsCollector:
  35. """Class which copies rejection sampler metrics from the device to CPU on a
  36. non-default Torch stream.
  37. """
  38. def __init__(
  39. self,
  40. rejection_sampler: RejectionSampler,
  41. timer: Optional[Timer] = None,
  42. collect_interval_s: float = 5.0,
  43. ):
  44. self._rejection_sampler = rejection_sampler
  45. self._timer = time.time if timer is None else timer
  46. self._rank: Optional[int] = None
  47. # We don't have a device set yet.
  48. self._copy_stream: Optional[torch.cuda.Stream] = None
  49. self._in_flight_copy: Optional[torch.cuda.Event] = None
  50. pin_memory = is_pin_memory_available()
  51. self._aggregate_num_accepted_tokens = torch.tensor(
  52. 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
  53. self._aggregate_num_emitted_tokens = torch.tensor(
  54. 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
  55. self._aggregate_num_draft_tokens = 0
  56. self._rejsample_metrics_collect_interval_s = collect_interval_s
  57. self._last_metrics_collect_time = self._timer()
  58. def init_gpu_tensors(self, rank: int) -> None:
  59. self._rank = rank
  60. self._copy_stream = torch.cuda.Stream()
  61. def maybe_collect_rejsample_metrics(
  62. self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
  63. # If a copy was initiated in the previous call, collect and return.
  64. if self._in_flight_copy is not None:
  65. ready_event = self._in_flight_copy
  66. self._in_flight_copy = None
  67. return self._collect_rejsample_metrics(k, ready_event)
  68. # Otherwise, check if we should start a new copy.
  69. if self._should_collect_rejsample_metrics(self._timer()):
  70. assert self._in_flight_copy is None
  71. self._in_flight_copy = self._copy_rejsample_metrics_async()
  72. return None
  73. def _should_collect_rejsample_metrics(self, now: float) -> bool:
  74. """Return whether or not this iteration should print rejection sampling
  75. metrics.
  76. """
  77. if self._rank != 0:
  78. return False
  79. if (now - self._last_metrics_collect_time <
  80. self._rejsample_metrics_collect_interval_s):
  81. return False
  82. return True
  83. def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
  84. """Copy rejection sampling metrics (number of accepted tokens, etc) to
  85. CPU asynchronously.
  86. Returns a CUDA event recording when the copy is complete.
  87. """
  88. self._copy_stream.wait_stream(torch.cuda.current_stream())
  89. with torch.cuda.stream(self._copy_stream):
  90. self._aggregate_num_accepted_tokens.copy_(
  91. self._rejection_sampler.num_accepted_tokens, non_blocking=True)
  92. self._aggregate_num_emitted_tokens.copy_(
  93. self._rejection_sampler.num_emitted_tokens, non_blocking=True)
  94. # Number of draft tokens is calculated on CPU, so no copy is
  95. # required.
  96. self._aggregate_num_draft_tokens = (
  97. self._rejection_sampler.num_draft_tokens)
  98. aggregate_metrics_ready = torch.cuda.Event()
  99. aggregate_metrics_ready.record(self._copy_stream)
  100. return aggregate_metrics_ready
  101. def _collect_rejsample_metrics(
  102. self, k: int,
  103. ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
  104. """Create metrics object from statistics copied asynchronously.
  105. Args:
  106. k: int. The number of speculative tokens; used to determine system
  107. efficiency.
  108. ready_event: torch.cuda.Event. The CUDA event recording when the
  109. async GPU->CPU copy is complete.
  110. """
  111. ready_event.synchronize()
  112. accepted_tokens = self._aggregate_num_accepted_tokens.item()
  113. emitted_tokens = self._aggregate_num_emitted_tokens.item()
  114. draft_tokens = self._aggregate_num_draft_tokens
  115. num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
  116. if draft_tokens > 0:
  117. draft_acceptance_rate = accepted_tokens / draft_tokens
  118. else:
  119. draft_acceptance_rate = float("nan")
  120. if num_possible_tokens > 0:
  121. system_efficiency = emitted_tokens / num_possible_tokens
  122. else:
  123. system_efficiency = float("nan")
  124. return SpecDecodeWorkerMetrics(
  125. num_spec_tokens=k,
  126. draft_acceptance_rate=draft_acceptance_rate,
  127. system_efficiency=system_efficiency,
  128. accepted_tokens=accepted_tokens,
  129. draft_tokens=draft_tokens,
  130. emitted_tokens=emitted_tokens,
  131. )
  132. @staticmethod
  133. def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
  134. # Divide by k since batch size can be variable.
  135. total_num_spec_seqs = draft_tokens / k
  136. num_accepted_per_seq_if_all_accepted = k + 1
  137. return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)