logger.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. """
  2. Internal logging utility.
  3. """
  4. import logging
  5. import os
  6. from loguru import logger
  7. from rich.console import Console
  8. from rich.markup import escape
  9. from rich.progress import (
  10. Progress,
  11. TextColumn,
  12. BarColumn,
  13. TimeRemainingColumn,
  14. TaskProgressColumn,
  15. MofNCompleteColumn,
  16. )
  17. RICH_CONSOLE = Console()
  18. LOG_LEVEL = os.getenv("APHRODITE_LOG_LEVEL", "INFO").upper()
  19. def unwrap(wrapped, default=None):
  20. """Unwrap function for Optionals."""
  21. if wrapped is None:
  22. return default
  23. return wrapped
  24. def get_loading_progress_bar():
  25. """Gets a pre-made progress bar for loading tasks."""
  26. return Progress(
  27. TextColumn("[progress.description]{task.description}"),
  28. BarColumn(),
  29. TaskProgressColumn(),
  30. MofNCompleteColumn(),
  31. TimeRemainingColumn(),
  32. console=RICH_CONSOLE,
  33. )
  34. def _log_formatter(record: dict):
  35. """Log message formatter."""
  36. color_map = {
  37. "TRACE": "dim blue",
  38. "DEBUG": "cyan",
  39. "INFO": "green",
  40. "SUCCESS": "bold green",
  41. "WARNING": "yellow",
  42. "ERROR": "red",
  43. "CRITICAL": "bold white on red",
  44. }
  45. level = record.get("level")
  46. level_color = color_map.get(level.name, "cyan")
  47. colored_level = f"[{level_color}]{level.name}[/{level_color}]:"
  48. separator = " " * (9 - len(level.name))
  49. message = unwrap(record.get("message"), "")
  50. # Replace once loguru allows for turning off str.format
  51. message = message.replace("{", "{{").replace("}", "}}").replace("<", "\<")
  52. # Escape markup tags from Rich
  53. message = escape(message)
  54. lines = message.splitlines()
  55. fmt = ""
  56. if len(lines) > 1:
  57. fmt = "\n".join(
  58. [f"{colored_level}{separator}{line}" for line in lines])
  59. else:
  60. fmt = f"{colored_level}{separator}{message}"
  61. return fmt
  62. _logged_messages = set()
  63. def log_once(level, message, *args, **kwargs):
  64. if message not in _logged_messages:
  65. _logged_messages.add(message)
  66. logger.log(level, message, *args, **kwargs)
  67. # Uvicorn log handler
  68. # Uvicorn log portions inspired from https://github.com/encode/uvicorn/discussions/2027#discussioncomment-6432362
  69. class UvicornLoggingHandler(logging.Handler):
  70. def emit(self, record: logging.LogRecord) -> None:
  71. logger.opt(exception=record.exc_info).log(record.levelname,
  72. self.format(record).rstrip())
  73. # Uvicorn config for logging. Passed into run when creating all loggers in
  74. # server
  75. UVICORN_LOG_CONFIG = {
  76. "version": 1,
  77. "disable_existing_loggers": False,
  78. "handlers": {
  79. "uvicorn": {
  80. "class":
  81. f"{UvicornLoggingHandler.__module__}.{UvicornLoggingHandler.__qualname__}", # noqa
  82. },
  83. },
  84. "root": {
  85. "handlers": ["uvicorn"],
  86. "propagate": False,
  87. "level": LOG_LEVEL
  88. },
  89. }
  90. def setup_logger():
  91. """Bootstrap the logger."""
  92. logger.remove()
  93. logger.add(
  94. RICH_CONSOLE.print,
  95. level=LOG_LEVEL,
  96. format=_log_formatter,
  97. colorize=True,
  98. )
  99. logger.log_once = log_once
  100. setup_logger()