1234567891011121314151617181920212223242526272829303132333435363738 |
- from dataclasses import dataclass
- from typing import Dict, List, Optional, TypedDict, Union
- from pydantic import BaseModel
- class LLMGuidedOptions(TypedDict, total=False):
- guided_json: Union[Dict, BaseModel, str]
- guided_regex: str
- guided_choice: List[str]
- guided_grammar: str
- guided_decoding_backend: str
- guided_whitespace_pattern: str
- guided_json_object: bool
- @dataclass
- class GuidedDecodingRequest:
- """One of the fields will be used to retrieve the logit processor."""
- guided_json: Optional[Union[Dict, BaseModel, str]] = None
- guided_regex: Optional[str] = None
- guided_choice: Optional[List[str]] = None
- guided_grammar: Optional[str] = None
- guided_decoding_backend: Optional[str] = None
- guided_whitespace_pattern: Optional[str] = None
- guided_json_object: Optional[bool] = None
- def __post_init__(self):
- """Validate that some fields are mutually exclusive."""
- guide_count = sum([
- self.guided_json is not None, self.guided_regex is not None,
- self.guided_choice is not None, self.guided_grammar is not None,
- self.guided_json_object is not None
- ])
- if guide_count > 1:
- raise ValueError(
- "You can only use one kind of guided decoding but multiple are "
- f"specified: {self.__dict__}")
|