metrics.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import time
  2. from dataclasses import dataclass
  3. from typing import Callable, Optional
  4. import torch
  5. from aphrodite.common.utils import is_pin_memory_available
  6. from aphrodite.modeling.layers.rejection import RejectionSampler
  7. @dataclass
  8. class SpecDecodeWorkerMetrics:
  9. """Dataclass holding metrics emitted from the spec decode worker.
  10. """
  11. # The empirical acceptance rate of the proposal method on a per-token basis.
  12. # This is useful for evaluating how well the proposal method aligns with the
  13. # scoring method.
  14. draft_acceptance_rate: float
  15. # The empirical efficiency, measured as the number of tokens emitted by the
  16. # system divided by the number of tokens that could be emitted by the system
  17. # if the proposal method were perfect.
  18. system_efficiency: float
  19. # The number of speculative tokens produced by the proposal method.
  20. draft_tokens: int
  21. # The number of tokens emitted by the entire system.
  22. emitted_tokens: int
  23. # The number of tokens accepted by the scoring model and verification
  24. # routine, e.g. Llama2-70B and lossless rejection sampling.
  25. #
  26. # NOTE: Any token accepted by the verification routine is considered
  27. # accepted (regardless of if the speculative prefix is also accepted). The
  28. # user will usually see less accepted tokens. This metric is helpful when
  29. # evaluating alignment of the proposal method with the scoring model.
  30. accepted_tokens: int
  31. # The number of speculative tokens per sequence.
  32. num_spec_tokens: int
  33. Timer = Callable[[], float]
  34. class AsyncMetricsCollector:
  35. """Class which copies rejection sampler metrics from the device to CPU on a
  36. non-default Torch stream.
  37. """
  38. def __init__(self,
  39. rejection_sampler: RejectionSampler,
  40. timer: Optional[Timer] = None,
  41. collect_interval_s: float = 5.0):
  42. self._rejection_sampler = rejection_sampler
  43. self._timer = time.time if timer is None else timer
  44. self._rank: Optional[int] = None
  45. # We don't have a device set yet.
  46. self._copy_stream: Optional[torch.cuda.Stream] = None
  47. self._in_flight_copy: Optional[torch.cuda.Event] = None
  48. pin_memory = is_pin_memory_available()
  49. self._aggregate_num_accepted_tokens = torch.tensor(
  50. 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
  51. self._aggregate_num_emitted_tokens = torch.tensor(
  52. 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
  53. self._aggregate_num_draft_tokens = 0
  54. self._rejsample_metrics_collect_interval_s = collect_interval_s
  55. self._last_metrics_collect_time = self._timer()
  56. def init_gpu_tensors(self, rank: int) -> None:
  57. self._rank = rank
  58. self._copy_stream = torch.cuda.Stream()
  59. def maybe_collect_rejsample_metrics(
  60. self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
  61. # If a copy was initiated in the previous call, collect and return.
  62. if self._in_flight_copy is not None:
  63. ready_event = self._in_flight_copy
  64. self._in_flight_copy = None
  65. return self._collect_rejsample_metrics(k, ready_event)
  66. # Otherwise, check if we should start a new copy.
  67. if self._should_collect_rejsample_metrics(self._timer()):
  68. assert self._in_flight_copy is None
  69. self._in_flight_copy = self._copy_rejsample_metrics_async()
  70. return None
  71. def _should_collect_rejsample_metrics(self, now: float) -> bool:
  72. """Return whether or not this iteration should print rejection sampling
  73. metrics.
  74. """
  75. if self._rank != 0:
  76. return False
  77. if (now - self._last_metrics_collect_time <
  78. self._rejsample_metrics_collect_interval_s):
  79. return False
  80. return True
  81. def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
  82. """Copy rejection sampling metrics (number of accepted tokens, etc) to
  83. CPU asynchronously.
  84. Returns a CUDA event recording when the copy is complete.
  85. """
  86. assert self._copy_stream is not None
  87. self._copy_stream.wait_stream(torch.cuda.current_stream())
  88. with torch.cuda.stream(self._copy_stream):
  89. self._aggregate_num_accepted_tokens.copy_(
  90. self._rejection_sampler.num_accepted_tokens, non_blocking=True)
  91. self._aggregate_num_emitted_tokens.copy_(
  92. self._rejection_sampler.num_emitted_tokens, non_blocking=True)
  93. # Number of draft tokens is calculated on CPU, so no copy is
  94. # required.
  95. self._aggregate_num_draft_tokens = (
  96. self._rejection_sampler.num_draft_tokens)
  97. aggregate_metrics_ready = torch.cuda.Event()
  98. aggregate_metrics_ready.record(self._copy_stream)
  99. return aggregate_metrics_ready
  100. def _collect_rejsample_metrics(
  101. self, k: int,
  102. ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
  103. """Create metrics object from statistics copied asynchronously.
  104. Args:
  105. k: int. The number of speculative tokens; used to determine system
  106. efficiency.
  107. ready_event: torch.cuda.Event. The CUDA event recording when the
  108. async GPU->CPU copy is complete.
  109. """
  110. ready_event.synchronize()
  111. accepted_tokens = self._aggregate_num_accepted_tokens.item()
  112. emitted_tokens = self._aggregate_num_emitted_tokens.item()
  113. draft_tokens = self._aggregate_num_draft_tokens
  114. max_num_emitted_tokens = self.get_max_num_emitted_tokens(
  115. draft_tokens, k)
  116. if draft_tokens > 0:
  117. draft_acceptance_rate = accepted_tokens / draft_tokens
  118. else:
  119. draft_acceptance_rate = float("nan")
  120. if max_num_emitted_tokens > 0:
  121. system_efficiency = emitted_tokens / max_num_emitted_tokens
  122. else:
  123. system_efficiency = float("nan")
  124. return SpecDecodeWorkerMetrics(
  125. num_spec_tokens=k,
  126. draft_acceptance_rate=draft_acceptance_rate,
  127. system_efficiency=system_efficiency,
  128. accepted_tokens=accepted_tokens,
  129. draft_tokens=draft_tokens,
  130. emitted_tokens=emitted_tokens,
  131. )
  132. @staticmethod
  133. def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
  134. """Calculate the number of emitted tokens, assuming all tokens are
  135. accepted.
  136. This is equal to the number of sequences that have been speculated on,
  137. times (speculation len + 1). The +1 comes from the bonus token.
  138. """
  139. # Determine the number of sequences that have been speculated on. Since
  140. # the batch size can be variable, we divide by k.
  141. assert draft_tokens % k == 0
  142. total_num_spec_seqs = draft_tokens // k
  143. # A single sequence may emit k accepted tokens and one bonus token in
  144. # the best case.
  145. num_emitted_per_seq_if_all_accepted = k + 1
  146. # The max num of emitted tokens is the number of speculated sequences
  147. # times the max emitted per seq.
  148. return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted