metrics.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. import time
  2. from dataclasses import dataclass
  3. from typing import Callable, Optional
  4. import torch
  5. from aphrodite.common.utils import is_pin_memory_available
  6. from aphrodite.modeling.layers.spec_decode_base_sampler import (
  7. SpecDecodeBaseSampler)
  8. @dataclass
  9. class SpecDecodeWorkerMetrics:
  10. """Dataclass holding metrics emitted from the spec decode worker.
  11. """
  12. # The empirical acceptance rate of the proposal method on a per-token basis.
  13. # This is useful for evaluating how well the proposal method aligns with the
  14. # scoring method.
  15. draft_acceptance_rate: float
  16. # The empirical efficiency, measured as the number of tokens emitted by the
  17. # system divided by the number of tokens that could be emitted by the system
  18. # if the proposal method were perfect.
  19. system_efficiency: float
  20. # The number of speculative tokens produced by the proposal method.
  21. draft_tokens: int
  22. # The number of tokens emitted by the entire system.
  23. emitted_tokens: int
  24. # The number of tokens accepted by the scoring model and verification
  25. # routine, e.g. Llama2-70B and lossless rejection sampling.
  26. #
  27. # NOTE: Any token accepted by the verification routine is considered
  28. # accepted (regardless of if the speculative prefix is also accepted). The
  29. # user will usually see less accepted tokens. This metric is helpful when
  30. # evaluating alignment of the proposal method with the scoring model.
  31. accepted_tokens: int
  32. # The number of speculative tokens per sequence.
  33. num_spec_tokens: int
  34. Timer = Callable[[], float]
  35. class AsyncMetricsCollector:
  36. """Class which copies rejection/typical-acceptance sampler metrics
  37. from the device to CPU on a non-default Torch stream.
  38. """
  39. def __init__(self,
  40. spec_decode_sampler: SpecDecodeBaseSampler,
  41. timer: Optional[Timer] = None,
  42. collect_interval_s: float = 5.0):
  43. self.spec_decode_sampler = spec_decode_sampler
  44. self._timer = time.time if timer is None else timer
  45. self._rank: Optional[int] = None
  46. # We don't have a device set yet.
  47. self._copy_stream: Optional[torch.cuda.Stream] = None
  48. self._in_flight_copy: Optional[torch.cuda.Event] = None
  49. pin_memory = is_pin_memory_available()
  50. self._aggregate_num_accepted_tokens = torch.tensor(
  51. 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
  52. self._aggregate_num_emitted_tokens = torch.tensor(
  53. 0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
  54. self._aggregate_num_draft_tokens = 0
  55. self._rejsample_metrics_collect_interval_s = collect_interval_s
  56. self._last_metrics_collect_time = self._timer()
  57. def init_gpu_tensors(self, rank: int) -> None:
  58. self._rank = rank
  59. self._copy_stream = torch.cuda.Stream()
  60. def maybe_collect_rejsample_metrics(
  61. self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
  62. # If a copy was initiated in the previous call, collect and return.
  63. if self._in_flight_copy is not None:
  64. ready_event = self._in_flight_copy
  65. self._in_flight_copy = None
  66. return self._collect_rejsample_metrics(k, ready_event)
  67. # Otherwise, check if we should start a new copy.
  68. if self._should_collect_rejsample_metrics(self._timer()):
  69. assert self._in_flight_copy is None
  70. self._in_flight_copy = self._copy_rejsample_metrics_async()
  71. return None
  72. def _should_collect_rejsample_metrics(self, now: float) -> bool:
  73. """Return whether or not this iteration should print sampling
  74. metrics.
  75. """
  76. if self._rank != 0:
  77. return False
  78. if (now - self._last_metrics_collect_time <
  79. self._rejsample_metrics_collect_interval_s):
  80. return False
  81. return True
  82. def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
  83. """Copy rejection/typical-acceptance sampling metrics
  84. (number of accepted tokens, etc) to CPU asynchronously.
  85. Returns a CUDA event recording when the copy is complete.
  86. """
  87. assert self._copy_stream is not None
  88. self._copy_stream.wait_stream(torch.cuda.current_stream())
  89. with torch.cuda.stream(self._copy_stream):
  90. self._aggregate_num_accepted_tokens.copy_(
  91. self.spec_decode_sampler.num_accepted_tokens,
  92. non_blocking=True)
  93. self._aggregate_num_emitted_tokens.copy_(
  94. self.spec_decode_sampler.num_emitted_tokens, non_blocking=True)
  95. # Number of draft tokens is calculated on CPU, so no copy is
  96. # required.
  97. self._aggregate_num_draft_tokens = (
  98. self.spec_decode_sampler.num_draft_tokens)
  99. aggregate_metrics_ready = torch.cuda.Event()
  100. aggregate_metrics_ready.record(self._copy_stream)
  101. return aggregate_metrics_ready
  102. def _collect_rejsample_metrics(
  103. self, k: int,
  104. ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
  105. """Create metrics object from statistics copied asynchronously.
  106. Args:
  107. k: int. The number of speculative tokens; used to determine system
  108. efficiency.
  109. ready_event: torch.cuda.Event. The CUDA event recording when the
  110. async GPU->CPU copy is complete.
  111. """
  112. ready_event.synchronize()
  113. # update time of last collection
  114. self._last_metrics_collect_time = self._timer()
  115. accepted_tokens = self._aggregate_num_accepted_tokens.item()
  116. emitted_tokens = self._aggregate_num_emitted_tokens.item()
  117. draft_tokens = self._aggregate_num_draft_tokens
  118. max_num_emitted_tokens = self.get_max_num_emitted_tokens(
  119. draft_tokens, k)
  120. if draft_tokens > 0:
  121. draft_acceptance_rate = accepted_tokens / draft_tokens
  122. else:
  123. draft_acceptance_rate = float("nan")
  124. if max_num_emitted_tokens > 0:
  125. system_efficiency = emitted_tokens / max_num_emitted_tokens
  126. else:
  127. system_efficiency = float("nan")
  128. return SpecDecodeWorkerMetrics(
  129. num_spec_tokens=k,
  130. draft_acceptance_rate=draft_acceptance_rate,
  131. system_efficiency=system_efficiency,
  132. accepted_tokens=accepted_tokens,
  133. draft_tokens=draft_tokens,
  134. emitted_tokens=emitted_tokens,
  135. )
  136. @staticmethod
  137. def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int:
  138. """Calculate the number of emitted tokens, assuming all tokens are
  139. accepted.
  140. This is equal to the number of sequences that have been speculated on,
  141. times (speculation len + 1). The +1 comes from the bonus token.
  142. """
  143. # Determine the number of sequences that have been speculated on. Since
  144. # the batch size can be variable, we divide by k.
  145. assert draft_tokens % k == 0
  146. total_num_spec_seqs = draft_tokens // k
  147. # A single sequence may emit k accepted tokens and one bonus token in
  148. # the best case.
  149. num_emitted_per_seq_if_all_accepted = k + 1
  150. # The max num of emitted tokens is the number of speculated sequences
  151. # times the max emitted per seq.
  152. return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted