123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- from typing import List, Optional, Union
- from pydantic import BaseModel, Field, root_validator, conint, confloat, conlist, NonNegativeFloat, NonNegativeInt, PositiveInt
- class SamplingParams(BaseModel):
- n: int = Field(1, alias="n")
- best_of: Optional[int] = Field(None, alias="best_of")
- presence_penalty: float = Field(0.0, alias="presence_penalty")
- frequency_penalty: float = Field(0.0, alias="rep_pen")
- temperature: float = Field(1.0, alias="temperature")
- dynatemp_range: Optional[float] = 0.0
- dynatemp_exponent: Optional[float] = 1.0
- smoothing_factor: Optional[float] = 0.0
- top_p: float = Field(1.0, alias="top_p")
- top_k: float = Field(-1, alias="top_k")
- min_p: float = Field(0.0, alias="min_p")
- top_a: float = Field(0.0, alias="top_a")
- tfs: float = Field(1.0, alias="tfs")
- eta_cutoff: float = Field(0.0, alias="eta_cutoff")
- epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
- typical_p: float = Field(1.0, alias="typical_p")
- use_beam_search: bool = Field(False, alias="use_beam_search")
- length_penalty: float = Field(1.0, alias="length_penalty")
- early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
- stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
- include_stop_str_in_output: Optional[bool] = False
- ignore_eos: bool = Field(False, alias="ignore_eos")
- max_tokens: int = Field(16, alias="max_length")
- logprobs: Optional[int] = Field(None, alias="logprobs")
- custom_token_bans: Optional[List[int]] = Field(None,
- alias="custom_token_bans")
- @root_validator
- def validate_best_of(cls, values): # pylint: disable=no-self-argument
- best_of = values.get("best_of")
- n = values.get("n")
- if best_of is not None and (best_of <= 0 or best_of > n):
- raise ValueError(
- "best_of must be a positive integer less than or equal to n")
- return values
- class KAIGenerationInputSchema(BaseModel):
- genkey: Optional[str]
- prompt: str
- n: Optional[conint(ge=1, le=5)] = 1
- max_context_length: PositiveInt
- max_length: PositiveInt
- rep_pen: Optional[confloat(ge=1)] = 1.0
- rep_pen_range: Optional[NonNegativeInt]
- rep_pen_slope: Optional[NonNegativeFloat]
- top_k: Optional[NonNegativeInt] = 0.0
- top_a: Optional[NonNegativeFloat] = 0.0
- top_p: Optional[confloat(ge=0, le=1)] = 1.0
- min_p: Optional[confloat(ge=0, le=1)] = 0.0
- tfs: Optional[confloat(ge=0, le=1)] = 1.0
- eps_cutoff: Optional[confloat(ge=0, le=1000)] = 0.0
- eta_cutoff: Optional[NonNegativeFloat] = 0.0
- typical: Optional[confloat(ge=0, le=1)] = 1.0
- temperature: Optional[NonNegativeFloat] = 1.0
- dynatemp_range: Optional[NonNegativeFloat] = 0.0
- dynatemp_exponent: Optional[NonNegativeFloat] = 1.0
- smoothing_factor: Optional[NonNegativeFloat] = 0.0
- use_memory: Optional[bool]
- use_story: Optional[bool]
- use_authors_note: Optional[bool]
- use_world_info: Optional[bool]
- use_userscripts: Optional[bool]
- soft_prompt: Optional[str]
- disable_output_formatting: Optional[bool]
- frmtrmblln: Optional[bool]
- frmtrmspch: Optional[bool]
- singleline: Optional[bool]
- use_default_badwordsids: Optional[bool]
- mirostat: Optional[int] = 0
- mirostat_tau: Optional[float] = 0.0
- mirostat_eta: Optional[float] = 0.0
- disable_input_formatting: Optional[bool]
- frmtadsnsp: Optional[bool]
- quiet: Optional[bool]
- # pylint: disable=unexpected-keyword-arg
- sampler_order: Optional[conlist(int, min_items=6)]
- sampler_seed: Optional[conint(ge=0, le=2**64 - 1)]
- sampler_full_determinism: Optional[bool]
- stop_sequence: Optional[List[str]]
- include_stop_str_in_output: Optional[bool] = False
- @root_validator
- def check_context(cls, values): # pylint: disable=no-self-argument
- assert values.get("max_length") <= values.get(
- "max_context_length"
- ), "max_length must not be larger than max_context_length"
- return values
|