metrics.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import time
  2. from typing import Callable, Optional
  3. import msgspec
  4. import torch
  5. from aphrodite.common.utils import is_pin_memory_available
  6. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  7. SpecDecodeBaseSampler)
  8. class SpecDecodeWorkerMetrics(
  9. msgspec.Struct,
  10. omit_defaults=True, # type: ignore[call-arg]
  11. array_like=True): # type: ignore[call-arg]
  12. """Dataclass holding metrics emitted from the spec decode worker.
  13. """
  14. # The empirical acceptance rate of the proposal method on a per-token basis.
  15. # This is useful for evaluating how well the proposal method aligns with the
  16. # scoring method.
  17. draft_acceptance_rate: float
  18. # The empirical efficiency, measured as the number of tokens emitted by the
  19. # system divided by the number of tokens that could be emitted by the system
  20. # if the proposal method were perfect.
  21. system_efficiency: float
  22. # The number of speculative tokens produced by the proposal method.
  23. draft_tokens: int
  24. # The number of tokens emitted by the entire system.
  25. emitted_tokens: int
  26. # The number of tokens accepted by the scoring model and verification
  27. # routine, e.g. Llama2-70B and lossless rejection sampling.
  28. #
  29. # NOTE: Any token accepted by the verification routine is considered
  30. # accepted (regardless of if the speculative prefix is also accepted). The
  31. # user will usually see less accepted tokens. This metric is helpful when
  32. # evaluating alignment of the proposal method with the scoring model.
  33. accepted_tokens: int
  34. # The number of speculative tokens per sequence.
  35. num_spec_tokens: int
  36. Timer = Callable[[], float]
  37. class AsyncMetricsCollector:
  38. """Class which copies rejection/typical-acceptance sampler metrics
  39. from the device to CPU on a non-default Torch stream.
  40. """
  41. def __init__(self,
  42. spec_decode_sampler: SpecDecodeBaseSampler,
  43. timer: Optional[Timer] = None,
  44. collect_interval_s: float = 5.0):
  45. self.spec_decode_sampler = spec_decode_sampler
  46. self._timer = time.time if timer is None else timer
  47. self._rank: Optional[int] = None
  48. # We don't have a device set yet.
  49. self._copy_stream: Optional[torch.cuda.Stream] = None
  50. self._in_flight_copy: Optional[torch.cuda.Event] = None
  51. pin_memory = is_pin_memory_available()
  52. self._aggregate_num_accepted_tokens = torch.tensor(
  53. 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
  54. self._aggregate_num_emitted_tokens = torch.tensor(
  55. 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
  56. self._aggregate_num_draft_tokens = 0
  57. self._rejsample_metrics_collect_interval_s = collect_interval_s
  58. self._last_metrics_collect_time = self._timer()
  59. def init_gpu_tensors(self, rank: int) -> None:
  60. self._rank = rank
  61. self._copy_stream = torch.cuda.Stream()
  62. def maybe_collect_rejsample_metrics(
  63. self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
  64. # If a copy was initiated in the previous call, collect and return.
  65. if self._in_flight_copy is not None:
  66. ready_event = self._in_flight_copy
  67. self._in_flight_copy = None
  68. return self._collect_rejsample_metrics(k, ready_event)
  69. # Otherwise, check if we should start a new copy.
  70. if self._should_collect_rejsample_metrics(self._timer()):
  71. assert self._in_flight_copy is None
  72. self._in_flight_copy = self._copy_rejsample_metrics_async()
  73. return None
  74. def _should_collect_rejsample_metrics(self, now: float) -> bool:
  75. """Return whether or not this iteration should print sampling
  76. metrics.
  77. """
  78. if self._rank != 0:
  79. return False
  80. if (now - self._last_metrics_collect_time <
  81. self._rejsample_metrics_collect_interval_s):
  82. return False
  83. return True
  84. def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
  85. """Copy rejection/typical-acceptance sampling metrics
  86. (number of accepted tokens, etc) to CPU asynchronously.
  87. Returns a CUDA event recording when the copy is complete.
  88. """
  89. assert self._copy_stream is not None
  90. self._copy_stream.wait_stream(torch.cuda.current_stream())
  91. with torch.cuda.stream(self._copy_stream):
  92. self._aggregate_num_accepted_tokens.copy_(
  93. self.spec_decode_sampler.num_accepted_tokens,
  94. non_blocking=True)
  95. self._aggregate_num_emitted_tokens.copy_(
  96. self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
  97. # Number of draft tokens is calculated on CPU, so no copy is
  98. # required.
  99. self._aggregate_num_draft_tokens = (
  100. self.spec_decode_sampler.num_draft_tokens)
  101. aggregate_metrics_ready = torch.cuda.Event()
  102. aggregate_metrics_ready.record(self._copy_stream)
  103. return aggregate_metrics_ready
  104. def _collect_rejsample_metrics(
  105. self, k: int,
  106. ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
  107. """Create metrics object from statistics copied asynchronously.
  108. Args:
  109. k: int. The number of speculative tokens; used to determine system
  110. efficiency.
  111. ready_event: torch.cuda.Event. The CUDA event recording when the
  112. async GPU->CPU copy is complete.
  113. """
  114. ready_event.synchronize()
  115. # update time of last collection
  116. self._last_metrics_collect_time = self._timer()
  117. accepted_tokens = self._aggregate_num_accepted_tokens.item()
  118. emitted_tokens = self._aggregate_num_emitted_tokens.item()
  119. draft_tokens = self._aggregate_num_draft_tokens
  120. max_num_emitted_tokens = self.get_max_num_emitted_tokens(
  121. draft_tokens, k)
  122. if draft_tokens > 0:
  123. draft_acceptance_rate = accepted_tokens / draft_tokens
  124. else:
  125. draft_acceptance_rate = float("nan")
  126. if max_num_emitted_tokens > 0:
  127. system_efficiency = emitted_tokens / max_num_emitted_tokens
  128. else:
  129. system_efficiency = float("nan")
  130. return SpecDecodeWorkerMetrics(
  131. num_spec_tokens=k,
  132. draft_acceptance_rate=draft_acceptance_rate,
  133. system_efficiency=system_efficiency,
  134. accepted_tokens=accepted_tokens,
  135. draft_tokens=draft_tokens,
  136. emitted_tokens=emitted_tokens,
  137. )
  138. @staticmethod
  139. def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
  140. """Calculate the number of emitted tokens, assuming all tokens are
  141. accepted.
  142. This is equal to the number of sequences that have been speculated on,
  143. times (speculation len + 1). The +1 comes from the bonus token.
  144. """
  145. # Determine the number of sequences that have been speculated on. Since
  146. # the batch size can be variable, we divide by k.
  147. assert draft_tokens % k == 0
  148. total_num_spec_seqs = draft_tokens // k
  149. # A single sequence may emit k accepted tokens and one bonus token in
  150. # the best case.
  151. num_emitted_per_seq_if_all_accepted = k + 1
  152. # The max num of emitted tokens is the number of speculated sequences
  153. # times the max emitted per seq.
  154. return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted