123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559 |
- from typing import TYPE_CHECKING
- from typing import Counter as CollectionsCounter
- from typing import Dict, List, Optional, Union
- import numpy as np
- import prometheus_client
- from loguru import logger
- from aphrodite.engine.metrics_types import (StatLoggerBase, Stats,
- SupportsMetricsInfo)
- from aphrodite.executor.ray_utils import ray
- if ray is not None:
- from ray.util import metrics as ray_metrics
- else:
- ray_metrics = None
- if TYPE_CHECKING:
- from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
- prometheus_client.disable_created_metrics()
- # The begin-* and end* here are used by the documentation generator
- # to extract the metrics definitions.
- # begin-metrics-definitions
- class Metrics:
- """
- Aphrodite uses a multiprocessing-based frontend for the OpenAI server.
- This means that we need to run prometheus_client in multiprocessing mode
- See https://prometheus.github.io/client_python/multiprocess/ for more
- details on limitations.
- """
- labelname_finish_reason = "finished_reason"
- _gauge_cls = prometheus_client.Gauge
- _counter_cls = prometheus_client.Counter
- _histogram_cls = prometheus_client.Histogram
- def __init__(self, labelnames: List[str], max_model_len: int):
- # Unregister any existing Aphrodite collectors (for CI/CD)
- self._unregister_aphrodite_metrics()
- # System stats
- # Scheduler State
- self.gauge_scheduler_running = self._gauge_cls(
- name="aphrodite:num_requests_running",
- documentation="Number of requests currently running on GPU.",
- labelnames=labelnames,
- multiprocess_mode="sum")
- self.gauge_scheduler_waiting = self._gauge_cls(
- name="aphrodite:num_requests_waiting",
- documentation="Number of requests waiting to be processed.",
- labelnames=labelnames,
- multiprocess_mode="sum")
- self.gauge_scheduler_swapped = self._gauge_cls(
- name="aphrodite:num_requests_swapped",
- documentation="Number of requests swapped to CPU.",
- labelnames=labelnames,
- multiprocess_mode="sum")
- # KV Cache Usage in %
- self.gauge_gpu_cache_usage = self._gauge_cls(
- name="aphrodite:gpu_cache_usage_perc",
- documentation="GPU KV-cache usage. 1 means 100 percent usage.",
- labelnames=labelnames,
- multiprocess_mode="sum")
- self.gauge_cpu_cache_usage = self._gauge_cls(
- name="aphrodite:cpu_cache_usage_perc",
- documentation="CPU KV-cache usage. 1 means 100 percent usage.",
- labelnames=labelnames,
- multiprocess_mode="sum")
-
- # Prefix caching block hit rate
- self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
- name="aphrodite:cpu_prefix_cache_hit_rate",
- documentation="CPU prefix cache block hit rate.",
- labelnames=labelnames,
- multiprocess_mode="sum")
- self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
- name="aphrodite:gpu_prefix_cache_hit_rate",
- documentation="GPU prefix cache block hit rate.",
- labelnames=labelnames,
- multiprocess_mode="sum")
- # Iteration stats
- self.counter_num_preemption = self._counter_cls(
- name="aphrodite:num_preemptions_total",
- documentation="Cumulative number of preemption from the engine.",
- labelnames=labelnames)
- self.counter_prompt_tokens = self._counter_cls(
- name="aphrodite:prompt_tokens_total",
- documentation="Number of prefill tokens processed.",
- labelnames=labelnames)
- self.counter_generation_tokens = self._counter_cls(
- name="aphrodite:generation_tokens_total",
- documentation="Number of generation tokens processed.",
- labelnames=labelnames)
- self.histogram_time_to_first_token = self._histogram_cls(
- name="aphrodite:time_to_first_token_seconds",
- documentation="Histogram of time to first token in seconds.",
- labelnames=labelnames,
- buckets=[
- 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
- 0.75, 1.0, 2.5, 5.0, 7.5, 10.0
- ])
- self.histogram_time_per_output_token = self._histogram_cls(
- name="aphrodite:time_per_output_token_seconds",
- documentation="Histogram of time per output token in seconds.",
- labelnames=labelnames,
- buckets=[
- 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
- 1.0, 2.5
- ])
- # Request stats
- # Latency
- self.histogram_e2e_time_request = self._histogram_cls(
- name="aphrodite:e2e_request_latency_seconds",
- documentation="Histogram of end to end request latency in seconds.",
- labelnames=labelnames,
- buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
- # Metadata
- self.histogram_num_prompt_tokens_request = self._histogram_cls(
- name="aphrodite:request_prompt_tokens",
- documentation="Number of prefill tokens processed.",
- labelnames=labelnames,
- buckets=build_1_2_5_buckets(max_model_len),
- )
- self.histogram_num_generation_tokens_request = \
- self._histogram_cls(
- name="aphrodite:request_generation_tokens",
- documentation="Number of generation tokens processed.",
- labelnames=labelnames,
- buckets=build_1_2_5_buckets(max_model_len),
- )
- self.histogram_best_of_request = self._histogram_cls(
- name="aphrodite:request_params_best_of",
- documentation="Histogram of the best_of request parameter.",
- labelnames=labelnames,
- buckets=[1, 2, 5, 10, 20],
- )
- self.histogram_n_request = self._histogram_cls(
- name="aphrodite:request_params_n",
- documentation="Histogram of the n request parameter.",
- labelnames=labelnames,
- buckets=[1, 2, 5, 10, 20],
- )
- self.counter_request_success = self._counter_cls(
- name="aphrodite:request_success_total",
- documentation="Count of successfully processed requests.",
- labelnames=labelnames + [Metrics.labelname_finish_reason])
- # Speculatie decoding stats
- self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
- name="aphrodite:spec_decode_draft_acceptance_rate",
- documentation="Speulative token acceptance rate.",
- labelnames=labelnames,
- multiprocess_mode="sum")
- self.gauge_spec_decode_efficiency = self._gauge_cls(
- name="aphrodite:spec_decode_efficiency",
- documentation="Speculative decoding system efficiency.",
- labelnames=labelnames,
- multiprocess_mode="sum")
- self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
- name="aphrodite:spec_decode_num_accepted_tokens_total",
- documentation="Number of accepted tokens.",
- labelnames=labelnames))
- self.counter_spec_decode_num_draft_tokens = self._counter_cls(
- name="aphrodite:spec_decode_num_draft_tokens_total",
- documentation="Number of draft tokens.",
- labelnames=labelnames)
- self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
- name="aphrodite:spec_decode_num_emitted_tokens_total",
- documentation="Number of emitted tokens.",
- labelnames=labelnames))
- # Deprecated in favor of aphrodite:prompt_tokens_total
- self.gauge_avg_prompt_throughput = self._gauge_cls(
- name="aphrodite:avg_prompt_throughput_toks_per_s",
- documentation="Average prefill throughput in tokens/s.",
- labelnames=labelnames,
- multiprocess_mode="sum",
- )
- # Deprecated in favor of aphrodite:generation_tokens_total
- self.gauge_avg_generation_throughput = self._gauge_cls(
- name="aphrodite:avg_generation_throughput_toks_per_s",
- documentation="Average generation throughput in tokens/s.",
- labelnames=labelnames,
- multiprocess_mode="sum",
- )
- # end-metrics-definitions
- def _unregister_aphrodite_metrics(self) -> None:
- for collector in list(prometheus_client.REGISTRY._collector_to_names):
- if hasattr(collector, "_name") and "aphrodite" in collector._name:
- prometheus_client.REGISTRY.unregister(collector)
- class _RayGaugeWrapper:
- """Wraps around ray.util.metrics.Gauge to provide same API as
- prometheus_client.Gauge"""
- def __init__(self,
- name: str,
- documentation: str = "",
- labelnames: Optional[List[str]] = None,
- multiprocess_mode: str = ""):
- del multiprocess_mode
- labelnames_tuple = tuple(labelnames) if labelnames else None
- self._gauge = ray_metrics.Gauge(name=name,
- description=documentation,
- tag_keys=labelnames_tuple)
- def labels(self, **labels):
- self._gauge.set_default_tags(labels)
- return self
- def set(self, value: Union[int, float]):
- return self._gauge.set(value)
- class _RayCounterWrapper:
- """Wraps around ray.util.metrics.Counter to provide same API as
- prometheus_client.Counter"""
- def __init__(self,
- name: str,
- documentation: str = "",
- labelnames: Optional[List[str]] = None):
- labelnames_tuple = tuple(labelnames) if labelnames else None
- self._counter = ray_metrics.Counter(name=name,
- description=documentation,
- tag_keys=labelnames_tuple)
- def labels(self, **labels):
- self._counter.set_default_tags(labels)
- return self
- def inc(self, value: Union[int, float] = 1.0):
- if value == 0:
- return
- return self._counter.inc(value)
- class _RayHistogramWrapper:
- """Wraps around ray.util.metrics.Histogram to provide same API as
- prometheus_client.Histogram"""
- def __init__(self,
- name: str,
- documentation: str = "",
- labelnames: Optional[List[str]] = None,
- buckets: Optional[List[float]] = None):
- labelnames_tuple = tuple(labelnames) if labelnames else None
- self._histogram = ray_metrics.Histogram(name=name,
- description=documentation,
- tag_keys=labelnames_tuple,
- boundaries=buckets)
- def labels(self, **labels):
- self._histogram.set_default_tags(labels)
- return self
- def observe(self, value: Union[int, float]):
- return self._histogram.observe(value)
- class RayMetrics(Metrics):
- """
- RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
- Provides the same metrics as Metrics but uses Ray's util.metrics library.
- """
- _gauge_cls = _RayGaugeWrapper
- _counter_cls = _RayCounterWrapper
- _histogram_cls = _RayHistogramWrapper
- def __init__(self, labelnames: List[str], max_model_len: int):
- if ray_metrics is None:
- raise ImportError("RayMetrics requires Ray to be installed.")
- super().__init__(labelnames, max_model_len)
- def _unregister_aphrodite_metrics(self) -> None:
- # No-op on purpose
- pass
- def build_1_2_5_buckets(max_value: int) -> List[int]:
- """
- Builds a list of buckets with increasing powers of 10 multiplied by
- mantissa values (1, 2, 5) until the value exceeds the specified maximum.
- Example:
- >>> build_1_2_5_buckets(100)
- [1, 2, 5, 10, 20, 50, 100]
- """
- mantissa_lst = [1, 2, 5]
- exponent = 0
- buckets: List[int] = []
- while True:
- for m in mantissa_lst:
- value = m * 10**exponent
- if value <= max_value:
- buckets.append(value)
- else:
- return buckets
- exponent += 1
- def local_interval_elapsed(now: float, last_log: float,
- local_interval: float) -> bool:
- elapsed_time = now - last_log
- return elapsed_time > local_interval
- def get_throughput(tracked_stats: List[int], now: float,
- last_log: float) -> float:
- return float(np.sum(tracked_stats) / (now - last_log))
- class LoggingStatLogger(StatLoggerBase):
- """LoggingStatLogger is used in LLMEngine to log to Stdout."""
- def log(self, stats: Stats) -> None:
- """Called by LLMEngine.
- Logs to Stdout every self.local_interval seconds."""
- # Save tracked stats for token counters.
- self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
- self.num_generation_tokens.append(stats.num_generation_tokens_iter)
- # Update spec decode metrics
- self.maybe_update_spec_decode_metrics(stats)
- # Log locally every local_interval seconds.
- if local_interval_elapsed(stats.now, self.last_local_log,
- self.local_interval):
- # Compute summary metrics for tracked stats (and log them
- # to promethus if applicable).
- prompt_throughput = get_throughput(self.num_prompt_tokens,
- now=stats.now,
- last_log=self.last_local_log)
- generation_throughput = get_throughput(
- self.num_generation_tokens,
- now=stats.now,
- last_log=self.last_local_log)
- # Log to stdout.
- logger.info(
- f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
- f"Avg generation throughput: {generation_throughput:.1f} "
- "tokens/s, "
- f"Running: {stats.num_running_sys} reqs, "
- f"Swapped: {stats.num_swapped_sys} reqs, "
- f"Pending: {stats.num_waiting_sys} reqs, "
- f"GPU KV cache usage: {stats.gpu_cache_usage_sys * 100:.1f}%, "
- f"CPU KV cache usage: {stats.cpu_cache_usage_sys * 100:.1f}%."
- )
- if (stats.cpu_prefix_cache_hit_rate >= 0
- or stats.gpu_prefix_cache_hit_rate >= 0):
- logger.info(
- "Prefix cache hit rate: "
- f"GPU: {stats.gpu_prefix_cache_hit_rate * 100:.2f}%, "
- f"CPU: {stats.cpu_prefix_cache_hit_rate * 100:.2f}%")
- if self.spec_decode_metrics is not None:
- logger.info(
- self._format_spec_decode_metrics_str(
- self.spec_decode_metrics))
- # Reset tracked stats for next interval.
- self.num_prompt_tokens = []
- self.num_generation_tokens = []
- self.last_local_log = stats.now
- self.spec_decode_metrics = None
- def _format_spec_decode_metrics_str(
- self, metrics: "SpecDecodeWorkerMetrics") -> str:
- return ("Speculative metrics: "
- f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, "
- f"System efficiency: {metrics.system_efficiency:.3f}, "
- f"Number of speculative tokens: {metrics.num_spec_tokens}, "
- f"Number of accepted tokens: {metrics.accepted_tokens}, "
- f"Number of draft tokens: {metrics.draft_tokens}, "
- f"Number of emitted tokens: {metrics.emitted_tokens}.")
- def info(self, type: str, obj: SupportsMetricsInfo) -> None:
- raise NotImplementedError
- class PrometheusStatLogger(StatLoggerBase):
- """PrometheusStatLogger is used LLMEngine to log to Promethus."""
- _metrics_cls = Metrics
- _gauge_cls = prometheus_client.Gauge
- def __init__(self, local_interval: float, labels: Dict[str, str],
- max_model_len: int) -> None:
- super().__init__(local_interval)
- # Prometheus metrics
- self.labels = labels
- self.metrics = self._metrics_cls(labelnames=list(labels.keys()),
- max_model_len=max_model_len)
- def _log_gauge(self, gauge, data: Union[int, float]) -> None:
- # Convenience function for logging to gauge.
- gauge.labels(**self.labels).set(data)
- def _log_counter(self, counter, data: Union[int, float]) -> None:
- # Convenience function for logging to counter.
- counter.labels(**self.labels).inc(data)
- def _log_counter_labels(self, counter, data: CollectionsCounter,
- label_key: str) -> None:
- # Convenience function for collection counter of labels.
- for label, count in data.items():
- counter.labels(**{**self.labels, label_key: label}).inc(count)
- def _log_histogram(self, histogram, data: Union[List[int],
- List[float]]) -> None:
- # Convenience function for logging list to histogram.
- for datum in data:
- histogram.labels(**self.labels).observe(datum)
- def _log_prometheus(self, stats: Stats) -> None:
- # System state data
- self._log_gauge(self.metrics.gauge_scheduler_running,
- stats.num_running_sys)
- self._log_gauge(self.metrics.gauge_scheduler_swapped,
- stats.num_swapped_sys)
- self._log_gauge(self.metrics.gauge_scheduler_waiting,
- stats.num_waiting_sys)
- self._log_gauge(self.metrics.gauge_gpu_cache_usage,
- stats.gpu_cache_usage_sys)
- self._log_gauge(self.metrics.gauge_cpu_cache_usage,
- stats.cpu_cache_usage_sys)
- self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
- stats.cpu_prefix_cache_hit_rate)
- self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
- stats.gpu_prefix_cache_hit_rate)
- # Iteration level data
- self._log_counter(self.metrics.counter_num_preemption,
- stats.num_preemption_iter)
- self._log_counter(self.metrics.counter_prompt_tokens,
- stats.num_prompt_tokens_iter)
- self._log_counter(self.metrics.counter_generation_tokens,
- stats.num_generation_tokens_iter)
- self._log_histogram(self.metrics.histogram_time_to_first_token,
- stats.time_to_first_tokens_iter)
- self._log_histogram(self.metrics.histogram_time_per_output_token,
- stats.time_per_output_tokens_iter)
- # Request level data
- # Latency
- self._log_histogram(self.metrics.histogram_e2e_time_request,
- stats.time_e2e_requests)
- # Metadata
- finished_reason_counter = CollectionsCounter(
- stats.finished_reason_requests)
- self._log_counter_labels(self.metrics.counter_request_success,
- finished_reason_counter,
- Metrics.labelname_finish_reason)
- self._log_histogram(self.metrics.histogram_num_prompt_tokens_request,
- stats.num_prompt_tokens_requests)
- self._log_histogram(
- self.metrics.histogram_num_generation_tokens_request,
- stats.num_generation_tokens_requests)
- self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
- self._log_histogram(self.metrics.histogram_best_of_request,
- stats.best_of_requests)
- def _log_prometheus_interval(self, prompt_throughput: float,
- generation_throughput: float) -> None:
- # Logs metrics to prometheus that are computed every logging_interval.
- # Support legacy gauge metrics that make throughput calculations on
- # the Aphrodite side. Moving forward, we should use counters like
- # counter_prompt_tokens, counter_generation_tokens
- # Which log raw data and calculate summaries using rate() on the
- # grafana/prometheus side.
- self.metrics.gauge_avg_prompt_throughput.labels(
- **self.labels).set(prompt_throughput)
- self.metrics.gauge_avg_generation_throughput.labels(
- **self.labels).set(generation_throughput)
- def log(self, stats: Stats):
- """Logs to prometheus and tracked stats every iteration."""
- # Log to prometheus.
- self._log_prometheus(stats)
- # Save tracked stats for token counters.
- self.num_prompt_tokens.append(stats.num_prompt_tokens_iter)
- self.num_generation_tokens.append(stats.num_generation_tokens_iter)
- # Update spec decode metrics
- self.maybe_update_spec_decode_metrics(stats)
- # Log locally every local_interval seconds.
- if local_interval_elapsed(stats.now, self.last_local_log,
- self.local_interval):
- # Compute summary metrics for tracked stats (and log them
- # to promethus if applicable).
- prompt_throughput = get_throughput(self.num_prompt_tokens,
- now=stats.now,
- last_log=self.last_local_log)
- generation_throughput = get_throughput(
- self.num_generation_tokens,
- now=stats.now,
- last_log=self.last_local_log)
- self._log_prometheus_interval(
- prompt_throughput=prompt_throughput,
- generation_throughput=generation_throughput)
- if self.spec_decode_metrics is not None:
- self._log_gauge(
- self.metrics.gauge_spec_decode_draft_acceptance_rate,
- self.spec_decode_metrics.draft_acceptance_rate)
- self._log_gauge(self.metrics.gauge_spec_decode_efficiency,
- self.spec_decode_metrics.system_efficiency)
- self._log_counter(
- self.metrics.counter_spec_decode_num_accepted_tokens,
- self.spec_decode_metrics.accepted_tokens)
- self._log_counter(
- self.metrics.counter_spec_decode_num_draft_tokens,
- self.spec_decode_metrics.draft_tokens)
- self._log_counter(
- self.metrics.counter_spec_decode_num_emitted_tokens,
- self.spec_decode_metrics.emitted_tokens)
- # Reset tracked stats for next interval.
- self.num_prompt_tokens = []
- self.num_generation_tokens = []
- self.last_local_log = stats.now
- self.spec_decode_metrics = None
- def info(self, type: str, obj: SupportsMetricsInfo) -> None:
- # Info type metrics are syntactic sugar for a gauge permanently set to 1
- # Since prometheus multiprocessing mode does not support Info, emulate
- # info here with a gauge.
- if type == "cache_config":
- metrics_info = obj.metrics_info()
- info_gauge = self._gauge_cls(
- name="aphrodite:cache_config_info",
- documentation="Information of the LLMEngine CacheConfig",
- labelnames=metrics_info.keys(),
- multiprocess_mode="mostrecent")
- info_gauge.labels(**metrics_info).set(1)
- class RayPrometheusStatLogger(PrometheusStatLogger):
- """RayPrometheusStatLogger uses Ray metrics instead."""
- _metrics_cls = RayMetrics
- def info(self, type: str, obj: SupportsMetricsInfo) -> None:
- return None
|