guided_fields.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Optional, TypedDict, Union
  3. from pydantic import BaseModel
  4. class LLMGuidedOptions(TypedDict, total=False):
  5. guided_json: Union[Dict, BaseModel, str]
  6. guided_regex: str
  7. guided_choice: List[str]
  8. guided_grammar: str
  9. guided_decoding_backend: str
  10. guided_whitespace_pattern: str
  11. guided_json_object: bool
  12. @dataclass
  13. class GuidedDecodingRequest:
  14. """One of the fields will be used to retrieve the logit processor."""
  15. guided_json: Optional[Union[Dict, BaseModel, str]] = None
  16. guided_regex: Optional[str] = None
  17. guided_choice: Optional[List[str]] = None
  18. guided_grammar: Optional[str] = None
  19. guided_decoding_backend: Optional[str] = None
  20. guided_whitespace_pattern: Optional[str] = None
  21. guided_json_object: Optional[bool] = None
  22. def __post_init__(self):
  23. """Validate that some fields are mutually exclusive."""
  24. guide_count = sum([
  25. self.guided_json is not None, self.guided_regex is not None,
  26. self.guided_choice is not None, self.guided_grammar is not None,
  27. self.guided_json_object is not None
  28. ])
  29. if guide_count > 1:
  30. raise ValueError(
  31. "You can only use one kind of guided decoding but multiple are "
  32. f"specified: {self.__dict__}")