log_generation.py 907 B

123456789101112131415161718192021222324252627282930313233343536
  1. def _get_typed_dict_name(typed_dict: dict) -> str:
  2. if typed_dict.get("_type", None):
  3. return typed_dict["_type"].capitalize()
  4. return "Params"
  5. def custom_repr(value):
  6. if isinstance(value, dict):
  7. return "dict"
  8. return repr(value)
  9. def StringifyParams(x):
  10. params = ",\n ".join(f"{k}={custom_repr(v)}" for k, v in x.items())
  11. return f"{_get_typed_dict_name(x)}(\n {params}\n)"
  12. def middleware_log_generation(params: dict):
  13. print("Generating: '''", params["text"], "'''")
  14. print(StringifyParams(params))
  15. if __name__ == "__main__":
  16. kwargs = {
  17. "text": "I am a robot.",
  18. "text_temp": 1.0,
  19. "waveform_temp": 1.0,
  20. "history_prompt": "",
  21. "output_full": False,
  22. "seed": 0,
  23. "max_length": 15,
  24. "burn_in_prompt": "",
  25. "history_prompt_semantic": None,
  26. }
  27. middleware_log_generation(kwargs)