|
@@ -4,8 +4,7 @@ import time
|
|
|
from typing import Any, Dict, List, Literal, Optional, Union
|
|
|
|
|
|
import torch
|
|
|
-from pydantic import (BaseModel, ConfigDict, Field, model_validator,
|
|
|
- root_validator)
|
|
|
+from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
|
from transformers import PreTrainedTokenizer
|
|
|
from typing_extensions import Annotated
|
|
|
|
|
@@ -148,6 +147,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
|
prompt_logprobs: Optional[int] = None
|
|
|
xtc_threshold: Optional[float] = 0.1
|
|
|
xtc_probability: Optional[float] = 0.0
|
|
|
+ dynatemp_min: Optional[float] = 0.0
|
|
|
+ dynatemp_max: Optional[float] = 0.0
|
|
|
+ dynatemp_exponent: Optional[float] = 1.0
|
|
|
+ custom_token_bans: Optional[List[int]] = None
|
|
|
# doc: end-chat-completion-sampling-params
|
|
|
|
|
|
# doc: begin-chat-completion-extra-params
|
|
@@ -287,6 +290,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|
|
temperature_last=self.temperature_last,
|
|
|
xtc_threshold=self.xtc_threshold,
|
|
|
xtc_probability=self.xtc_probability,
|
|
|
+ dynatemp_min=self.dynatemp_min,
|
|
|
+ dynatemp_max=self.dynatemp_max,
|
|
|
+ dynatemp_exponent=self.dynatemp_exponent,
|
|
|
+ custom_token_bans=self.custom_token_bans,
|
|
|
)
|
|
|
|
|
|
@model_validator(mode='before')
|
|
@@ -394,6 +401,10 @@ class CompletionRequest(OpenAIBaseModel):
|
|
|
prompt_logprobs: Optional[int] = None
|
|
|
xtc_threshold: Optional[float] = 0.1
|
|
|
xtc_probability: Optional[float] = 0.0
|
|
|
+ dynatemp_min: Optional[float] = 0.0
|
|
|
+ dynatemp_max: Optional[float] = 0.0
|
|
|
+ dynatemp_exponent: Optional[float] = 1.0
|
|
|
+ custom_token_bans: Optional[List[int]] = None
|
|
|
# doc: end-completion-sampling-params
|
|
|
|
|
|
# doc: begin-completion-extra-params
|
|
@@ -492,6 +503,10 @@ class CompletionRequest(OpenAIBaseModel):
|
|
|
temperature_last=self.temperature_last,
|
|
|
xtc_threshold=self.xtc_threshold,
|
|
|
xtc_probability=self.xtc_probability,
|
|
|
+ dynatemp_min=self.dynatemp_min,
|
|
|
+ dynatemp_max=self.dynatemp_max,
|
|
|
+ dynatemp_exponent=self.dynatemp_exponent,
|
|
|
+ custom_token_bans=self.custom_token_bans,
|
|
|
)
|
|
|
|
|
|
@model_validator(mode="before")
|
|
@@ -771,45 +786,6 @@ class DetokenizeResponse(OpenAIBaseModel):
|
|
|
# ========== KoboldAI ========== #
|
|
|
|
|
|
|
|
|
-class KoboldSamplingParams(BaseModel):
|
|
|
- n: int = Field(1, alias="n")
|
|
|
- best_of: Optional[int] = Field(None, alias="best_of")
|
|
|
- presence_penalty: float = Field(0.0, alias="presence_penalty")
|
|
|
- frequency_penalty: float = Field(0.0, alias="rep_pen")
|
|
|
- temperature: float = Field(1.0, alias="temperature")
|
|
|
- dynatemp_range: Optional[float] = 0.0
|
|
|
- dynatemp_exponent: Optional[float] = 1.0
|
|
|
- smoothing_factor: Optional[float] = 0.0
|
|
|
- smoothing_curve: Optional[float] = 1.0
|
|
|
- top_p: float = Field(1.0, alias="top_p")
|
|
|
- top_k: float = Field(-1, alias="top_k")
|
|
|
- min_p: float = Field(0.0, alias="min_p")
|
|
|
- top_a: float = Field(0.0, alias="top_a")
|
|
|
- tfs: float = Field(1.0, alias="tfs")
|
|
|
- eta_cutoff: float = Field(0.0, alias="eta_cutoff")
|
|
|
- epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
|
|
|
- typical_p: float = Field(1.0, alias="typical_p")
|
|
|
- use_beam_search: bool = Field(False, alias="use_beam_search")
|
|
|
- length_penalty: float = Field(1.0, alias="length_penalty")
|
|
|
- early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
|
|
|
- stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
|
|
|
- include_stop_str_in_output: Optional[bool] = False
|
|
|
- ignore_eos: bool = Field(False, alias="ignore_eos")
|
|
|
- max_tokens: int = Field(16, alias="max_length")
|
|
|
- logprobs: Optional[int] = Field(None, alias="logprobs")
|
|
|
- custom_token_bans: Optional[List[int]] = Field(None,
|
|
|
- alias="custom_token_bans")
|
|
|
-
|
|
|
- @root_validator(pre=False, skip_on_failure=True)
|
|
|
- def validate_best_of(cls, values): # pylint: disable=no-self-argument
|
|
|
- best_of = values.get("best_of")
|
|
|
- n = values.get("n")
|
|
|
- if best_of is not None and (best_of <= 0 or best_of > n):
|
|
|
- raise ValueError(
|
|
|
- "best_of must be a positive integer less than or equal to n")
|
|
|
- return values
|
|
|
-
|
|
|
-
|
|
|
class KAIGenerationInputSchema(BaseModel):
|
|
|
genkey: Optional[str] = None
|
|
|
prompt: str
|
|
@@ -817,8 +793,6 @@ class KAIGenerationInputSchema(BaseModel):
|
|
|
max_context_length: int
|
|
|
max_length: int
|
|
|
rep_pen: Optional[float] = 1.0
|
|
|
- rep_pen_range: Optional[int] = None
|
|
|
- rep_pen_slope: Optional[float] = None
|
|
|
top_k: Optional[int] = 0
|
|
|
top_a: Optional[float] = 0.0
|
|
|
top_p: Optional[float] = 1.0
|
|
@@ -832,31 +806,16 @@ class KAIGenerationInputSchema(BaseModel):
|
|
|
dynatemp_exponent: Optional[float] = 1.0
|
|
|
smoothing_factor: Optional[float] = 0.0
|
|
|
smoothing_curve: Optional[float] = 1.0
|
|
|
- use_memory: Optional[bool] = None
|
|
|
- use_story: Optional[bool] = None
|
|
|
- use_authors_note: Optional[bool] = None
|
|
|
- use_world_info: Optional[bool] = None
|
|
|
- use_userscripts: Optional[bool] = None
|
|
|
- soft_prompt: Optional[str] = None
|
|
|
- disable_output_formatting: Optional[bool] = None
|
|
|
- frmtrmblln: Optional[bool] = None
|
|
|
- frmtrmspch: Optional[bool] = None
|
|
|
- singleline: Optional[bool] = None
|
|
|
+ xtc_threshold: Optional[float] = 0.1
|
|
|
+ xtc_probability: Optional[float] = 0.0
|
|
|
use_default_badwordsids: Optional[bool] = None
|
|
|
- mirostat: Optional[int] = 0
|
|
|
- mirostat_tau: Optional[float] = 0.0
|
|
|
- mirostat_eta: Optional[float] = 0.0
|
|
|
- disable_input_formatting: Optional[bool] = None
|
|
|
- frmtadsnsp: Optional[bool] = None
|
|
|
quiet: Optional[bool] = None
|
|
|
# pylint: disable=unexpected-keyword-arg
|
|
|
- sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
|
|
|
sampler_seed: Optional[int] = None
|
|
|
- sampler_full_determinism: Optional[bool] = None
|
|
|
stop_sequence: Optional[List[str]] = None
|
|
|
include_stop_str_in_output: Optional[bool] = False
|
|
|
|
|
|
- @root_validator(pre=False, skip_on_failure=True)
|
|
|
+ @model_validator(mode='after')
|
|
|
def check_context(cls, values): # pylint: disable=no-self-argument
|
|
|
assert values.get("max_length") <= values.get(
|
|
|
"max_context_length"
|