123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926 |
- # Adapted from
- # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
- import json
- import time
- from typing import Any, Dict, List, Literal, Optional, Union
- import torch
- from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
- model_validator)
- from transformers import PreTrainedTokenizer
- from typing_extensions import Annotated
- from aphrodite.common.pooling_params import PoolingParams
- from aphrodite.common.sampling_params import (LogitsProcessorFunc,
- SamplingParams)
- from aphrodite.common.sequence import Logprob
- from aphrodite.common.utils import random_uuid
- from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
- from aphrodite.endpoints.openai.logits_processors import get_logits_processors
- class OpenAIBaseModel(BaseModel):
- model_config = ConfigDict(extra="ignore")
- class ErrorResponse(OpenAIBaseModel):
- object: str = "error"
- message: str
- type: str
- param: Optional[str] = None
- code: int
- class ModelPermission(OpenAIBaseModel):
- id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
- object: str = "model_permission"
- created: int = Field(default_factory=lambda: int(time.time()))
- allow_create_engine: bool = False
- allow_sampling: bool = True
- allow_logprobs: bool = True
- allow_search_indices: bool = False
- allow_view: bool = True
- allow_fine_tuning: bool = False
- organization: str = "*"
- group: Optional[str] = None
- is_blocking: bool = False
- class ModelCard(OpenAIBaseModel):
- id: str
- object: str = "model"
- created: int = Field(default_factory=lambda: int(time.time()))
- owned_by: str = "pygmalionai"
- root: Optional[str] = None
- parent: Optional[str] = None
- max_model_len: Optional[int] = None
- permission: List[ModelPermission] = Field(default_factory=list)
- class ModelList(OpenAIBaseModel):
- object: str = "list"
- data: List[ModelCard] = Field(default_factory=list)
- class UsageInfo(OpenAIBaseModel):
- prompt_tokens: int = 0
- total_tokens: int = 0
- completion_tokens: Optional[int] = 0
- class JsonSchemaResponseFormat(OpenAIBaseModel):
- name: str
- description: Optional[str] = None
- # schema is the field in openai but that causes conflicts with pydantic so
- # instead use json_schema with an alias
- json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
- strict: Optional[bool] = None
- class ResponseFormat(OpenAIBaseModel):
- # type must be "json_schema", "json_object" or "text"
- type: Literal["text", "json_object", "json_schema"]
- json_schema: Optional[JsonSchemaResponseFormat] = None
- class StreamOptions(OpenAIBaseModel):
- include_usage: Optional[bool] = True
- continuous_usage_stats: Optional[bool] = True
- class FunctionDefinition(OpenAIBaseModel):
- name: str
- description: Optional[str] = None
- parameters: Optional[Dict[str, Any]] = None
- class ChatCompletionToolsParam(OpenAIBaseModel):
- type: Literal["function"] = "function"
- function: FunctionDefinition
- class ChatCompletionNamedFunction(OpenAIBaseModel):
- name: str
- class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
- function: ChatCompletionNamedFunction
- type: Literal["function"] = "function"
- class ChatCompletionRequest(OpenAIBaseModel):
- # Ordered by official OpenAI API documentation
- # https://platform.openai.com/docs/api-reference/chat/create
- messages: List[ChatCompletionMessageParam]
- model: str
- frequency_penalty: Optional[float] = 0.0
- logit_bias: Optional[Dict[str, float]] = None
- logprobs: Optional[bool] = False
- top_logprobs: Optional[int] = 0
- max_tokens: Optional[int] = None
- n: Optional[int] = 1
- presence_penalty: Optional[float] = 0.0
- response_format: Optional[ResponseFormat] = None
- seed: Optional[int] = Field(None,
- ge=torch.iinfo(torch.long).min,
- le=torch.iinfo(torch.long).max)
- stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
- stream: Optional[bool] = False
- stream_options: Optional[StreamOptions] = None
- temperature: Optional[float] = 0.7
- top_p: Optional[float] = 1.0
- tools: Optional[List[ChatCompletionToolsParam]] = None
- tool_choice: Optional[Union[Literal["none"],
- ChatCompletionNamedToolChoiceParam]] = "none"
- user: Optional[str] = None
- # doc: begin-chat-completion-sampling-params
- best_of: Optional[int] = None
- use_beam_search: Optional[bool] = False
- top_k: Optional[int] = -1
- min_p: Optional[float] = 0.0
- top_a: Optional[float] = 0.0
- tfs: Optional[float] = 1.0
- eta_cutoff: Optional[float] = 0.0
- epsilon_cutoff: Optional[float] = 0.0
- typical_p: Optional[float] = 1.0
- smoothing_factor: Optional[float] = 0.0
- smoothing_curve: Optional[float] = 1.0
- repetition_penalty: Optional[float] = 1.0
- no_repeat_ngram_size: Optional[int] = 0
- length_penalty: Optional[float] = 1.0
- early_stopping: Optional[bool] = False
- ignore_eos: Optional[bool] = False
- min_tokens: Optional[int] = 0
- stop_token_ids: Optional[List[int]] = Field(default_factory=list)
- skip_special_tokens: Optional[bool] = True
- spaces_between_special_tokens: Optional[bool] = True
- truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
- temperature_last: Optional[bool] = False
- prompt_logprobs: Optional[int] = None
- xtc_threshold: Optional[float] = 0.1
- xtc_probability: Optional[float] = 0.0
- dry_multiplier: Optional[float] = 0
- dry_base: Optional[float] = 1.75
- dry_allowed_length: Optional[int] = 2
- dry_sequence_breakers: Optional[List[str]] = Field(
- default=["\n", ":", "\"", "*"])
- dry_range: Optional[int] = Field(
- default=0,
- validation_alias=AliasChoices("dry_range",
- "dry_penalty_last_n"))
- dynatemp_min: Optional[float] = 0.0
- dynatemp_max: Optional[float] = 0.0
- dynatemp_exponent: Optional[float] = 1.0
- nsigma: Optional[float] = 0.0
- skew: Optional[float] = 0.0
- custom_token_bans: Optional[List[int]] = None
- sampler_priority: Optional[Union[List[int], List[str]]] = Field(
- default=[],
- validation_alias=AliasChoices("sampler_priority",
- "sampler_order"))
- # doc: end-chat-completion-sampling-params
- # doc: begin-chat-completion-extra-params
- echo: Optional[bool] = Field(
- default=False,
- description=(
- "If true, the new message will be prepended with the last message "
- "if they belong to the same role."),
- )
- add_generation_prompt: Optional[bool] = Field(
- default=True,
- description=
- ("If true, the generation prompt will be added to the chat template. "
- "This is a parameter used by chat template in tokenizer config of the "
- "model."),
- )
- add_special_tokens: Optional[bool] = Field(
- default=False,
- description=(
- "If true, special tokens (e.g. BOS) will be added to the prompt "
- "on top of what is added by the chat template. "
- "For most models, the chat template takes care of adding the "
- "special tokens so this should be set to False (as is the "
- "default)."),
- )
- documents: Optional[List[Dict[str, str]]] = Field(
- default=None,
- description=
- ("A list of dicts representing documents that will be accessible to "
- "the model if it is performing RAG (retrieval-augmented generation)."
- " If the template does not support RAG, this argument will have no "
- "effect. We recommend that each document should be a dict containing "
- "\"title\" and \"text\" keys."),
- )
- chat_template: Optional[str] = Field(
- default=None,
- description=(
- "A Jinja template to use for this conversion. "
- "As of transformers v4.44, default chat template is no longer "
- "allowed, so you must provide a chat template if the tokenizer "
- "does not define one."),
- )
- chat_template_kwargs: Optional[Dict[str, Any]] = Field(
- default=None,
- description=("Additional kwargs to pass to the template renderer. "
- "Will be accessible by the chat template."),
- )
- include_stop_str_in_output: Optional[bool] = Field(
- default=False,
- description=(
- "Whether to include the stop string in the output. "
- "This is only applied when the stop or stop_token_ids is set."),
- )
- guided_json: Optional[Union[str, dict, BaseModel]] = Field(
- default=None,
- description=("If specified, the output will follow the JSON schema."),
- )
- guided_regex: Optional[str] = Field(
- default=None,
- description=(
- "If specified, the output will follow the regex pattern."),
- )
- guided_choice: Optional[List[str]] = Field(
- default=None,
- description=(
- "If specified, the output will be exactly one of the choices."),
- )
- guided_grammar: Optional[str] = Field(
- default=None,
- description=(
- "If specified, the output will follow the context free grammar."),
- )
- guided_decoding_backend: Optional[str] = Field(
- default=None,
- description=(
- "If specified, will override the default guided decoding backend "
- "of the server for this specific request. If set, must be either "
- "'outlines' / 'lm-format-enforcer'"))
- guided_whitespace_pattern: Optional[str] = Field(
- default=None,
- description=(
- "If specified, will override the default whitespace pattern "
- "for guided json decoding."))
- # doc: end-chat-completion-extra-params
- def to_sampling_params(
- self, tokenizer: PreTrainedTokenizer,
- guided_decode_logits_processor: Optional[LogitsProcessorFunc],
- default_max_tokens: int) -> SamplingParams:
- max_tokens = self.max_tokens
- if max_tokens is None:
- max_tokens = default_max_tokens
- # We now allow logprobs being true without top_logrobs.
- logits_processors = get_logits_processors(
- logit_bias=self.logit_bias,
- allowed_token_ids=None,
- tokenizer=tokenizer,
- )
- if guided_decode_logits_processor:
- logits_processors.append(guided_decode_logits_processor)
- dry_sequence_breaker_ids = []
- if self.dry_sequence_breakers:
- for s in self.dry_sequence_breakers:
- token_id = tokenizer.encode(f'a{s}')[-1]
- dry_sequence_breaker_ids.append(token_id)
- return SamplingParams(
- n=self.n,
- presence_penalty=self.presence_penalty,
- frequency_penalty=self.frequency_penalty,
- repetition_penalty=self.repetition_penalty,
- no_repeat_ngram_size=self.no_repeat_ngram_size,
- temperature=self.temperature,
- top_p=self.top_p,
- min_p=self.min_p,
- seed=self.seed,
- stop=self.stop,
- stop_token_ids=self.stop_token_ids,
- max_tokens=max_tokens,
- min_tokens=self.min_tokens,
- logprobs=self.top_logprobs if self.logprobs else None,
- prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
- (self.top_logprobs if self.echo else None),
- best_of=self.best_of,
- top_k=self.top_k,
- top_a=self.top_a,
- tfs=self.tfs,
- eta_cutoff=self.eta_cutoff,
- epsilon_cutoff=self.epsilon_cutoff,
- typical_p=self.typical_p,
- smoothing_factor=self.smoothing_factor,
- smoothing_curve=self.smoothing_curve,
- ignore_eos=self.ignore_eos,
- use_beam_search=self.use_beam_search,
- early_stopping=self.early_stopping,
- skip_special_tokens=self.skip_special_tokens,
- spaces_between_special_tokens=self.spaces_between_special_tokens,
- include_stop_str_in_output=self.include_stop_str_in_output,
- length_penalty=self.length_penalty,
- logits_processors=logits_processors,
- temperature_last=self.temperature_last,
- xtc_threshold=self.xtc_threshold,
- xtc_probability=self.xtc_probability,
- dry_multiplier=self.dry_multiplier,
- dry_base=self.dry_base,
- dry_allowed_length=self.dry_allowed_length,
- dry_sequence_breaker_ids=dry_sequence_breaker_ids,
- dry_range=self.dry_range,
- dynatemp_min=self.dynatemp_min,
- dynatemp_max=self.dynatemp_max,
- dynatemp_exponent=self.dynatemp_exponent,
- nsigma=self.nsigma,
- skew=self.skew,
- custom_token_bans=self.custom_token_bans,
- sampler_priority=self.sampler_priority,
- )
- @model_validator(mode='before')
- @classmethod
- def validate_stream_options(cls, values):
- if (values.get('stream_options') is not None
- and not values.get('stream')):
- raise ValueError(
- "stream_options can only be set if stream is true")
- return values
- @model_validator(mode="before")
- @classmethod
- def check_guided_decoding_count(cls, data):
- guide_count = sum([
- "guided_json" in data and data["guided_json"] is not None,
- "guided_regex" in data and data["guided_regex"] is not None,
- "guided_choice" in data and data["guided_choice"] is not None
- ])
- # you can only use one kind of guided decoding
- if guide_count > 1:
- raise ValueError(
- "You can only use one kind of guided decoding "
- "('guided_json', 'guided_regex' or 'guided_choice').")
- # you can only either use guided decoding or tools, not both
- if guide_count > 1 and "tool_choice" in data and data[
- "tool_choice"] != "none":
- raise ValueError(
- "You can only either use guided decoding or tools, not both.")
- return data
- @model_validator(mode="before")
- @classmethod
- def check_tool_choice(cls, data):
- if "tool_choice" in data and data["tool_choice"] != "none":
- if not isinstance(data["tool_choice"], dict):
- raise ValueError("Currently only named tools are supported.")
- if "tools" not in data or data["tools"] is None:
- raise ValueError(
- "When using `tool_choice`, `tools` must be set.")
- return data
- @model_validator(mode="before")
- @classmethod
- def check_logprobs(cls, data):
- if "top_logprobs" in data and data["top_logprobs"] is not None:
- if "logprobs" not in data or data["logprobs"] is False:
- raise ValueError(
- "when using `top_logprobs`, `logprobs` must be set to true."
- )
- elif data["top_logprobs"] < 0:
- raise ValueError(
- "`top_logprobs` must be a value a positive value.")
- return data
- class CompletionRequest(OpenAIBaseModel):
- # Ordered by official OpenAI API documentation
- # https://platform.openai.com/docs/api-reference/completions/create
- model: str
- prompt: Union[List[int], List[List[int]], str, List[str]]
- best_of: Optional[int] = None
- echo: Optional[bool] = False
- frequency_penalty: Optional[float] = 0.0
- logit_bias: Optional[Dict[str, float]] = None
- logprobs: Optional[int] = None
- max_tokens: Optional[int] = 16
- n: int = 1
- presence_penalty: Optional[float] = 0.0
- seed: Optional[int] = Field(None,
- ge=torch.iinfo(torch.long).min,
- le=torch.iinfo(torch.long).max)
- stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
- stream: Optional[bool] = False
- stream_options: Optional[StreamOptions] = None
- suffix: Optional[str] = None
- temperature: Optional[float] = 1.0
- top_p: Optional[float] = 1.0
- user: Optional[str] = None
- # doc: begin-completion-sampling-params
- use_beam_search: Optional[bool] = False
- top_k: Optional[int] = -1
- min_p: Optional[float] = 0.0
- top_a: Optional[float] = 0.0
- tfs: Optional[float] = 1.0
- eta_cutoff: Optional[float] = 0.0
- epsilon_cutoff: Optional[float] = 0.0
- typical_p: Optional[float] = 1.0
- smoothing_factor: Optional[float] = 0.0
- smoothing_curve: Optional[float] = 1.0
- repetition_penalty: Optional[float] = 1.0
- no_repeat_ngram_size: Optional[int] = 0
- length_penalty: Optional[float] = 1.0
- early_stopping: Optional[bool] = False
- stop_token_ids: Optional[List[int]] = Field(default_factory=list)
- ignore_eos: Optional[bool] = False
- min_tokens: Optional[int] = 0
- skip_special_tokens: Optional[bool] = True
- spaces_between_special_tokens: Optional[bool] = True
- truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
- allowed_token_ids: Optional[List[int]] = None
- include_stop_str_in_output: Optional[bool] = False
- add_special_tokens: Optional[bool] = False
- temperature_last: Optional[bool] = False
- prompt_logprobs: Optional[int] = None
- xtc_threshold: Optional[float] = 0.1
- xtc_probability: Optional[float] = 0.0
- dry_multiplier: Optional[float] = 0
- dry_base: Optional[float] = 1.75
- dry_allowed_length: Optional[int] = 2
- dry_sequence_breakers: Optional[List[str]] = Field(
- default=["\n", ":", "\"", "*"])
- dry_range: Optional[int] = Field(
- default=0,
- validation_alias=AliasChoices("dry_range",
- "dry_penalty_last_n"))
- dynatemp_min: Optional[float] = 0.0
- dynatemp_max: Optional[float] = 0.0
- dynatemp_exponent: Optional[float] = 1.0
- nsigma: Optional[float] = 0.0
- skew: Optional[float] = 0.0
- custom_token_bans: Optional[List[int]] = None
- sampler_priority: Optional[Union[List[int], List[str]]] = Field(
- default=[],
- validation_alias=AliasChoices("sampler_priority",
- "sampler_order"))
- # doc: end-completion-sampling-params
- # doc: begin-completion-extra-params
- response_format: Optional[ResponseFormat] = Field(
- default=None,
- description=
- ("Similar to chat completion, this parameter specifies the format of "
- "output. Only {'type': 'json_object'} or {'type': 'text' } is "
- "supported."),
- )
- guided_json: Optional[Union[str, dict, BaseModel]] = Field(
- default=None,
- description=("If specified, the output will follow the JSON schema."),
- )
- guided_regex: Optional[str] = Field(
- default=None,
- description=(
- "If specified, the output will follow the regex pattern."),
- )
- guided_choice: Optional[List[str]] = Field(
- default=None,
- description=(
- "If specified, the output will be exactly one of the choices."),
- )
- guided_grammar: Optional[str] = Field(
- default=None,
- description=(
- "If specified, the output will follow the context free grammar."),
- )
- guided_decoding_backend: Optional[str] = Field(
- default=None,
- description=(
- "If specified, will override the default guided decoding backend "
- "of the server for this specific request. If set, must be one of "
- "'outlines' / 'lm-format-enforcer'"))
- guided_whitespace_pattern: Optional[str] = Field(
- default=None,
- description=(
- "If specified, will override the default whitespace pattern "
- "for guided json decoding."))
- # doc: end-completion-extra-params
- def to_sampling_params(
- self, tokenizer: PreTrainedTokenizer,
- guided_decode_logits_processor: Optional[LogitsProcessorFunc],
- default_max_tokens: int) -> SamplingParams:
- max_tokens = self.max_tokens
- if max_tokens is None:
- max_tokens = default_max_tokens
- echo_without_generation = self.echo and self.max_tokens == 0
- logits_processors = get_logits_processors(
- logit_bias=self.logit_bias,
- allowed_token_ids=self.allowed_token_ids,
- tokenizer=tokenizer,
- )
- if guided_decode_logits_processor:
- logits_processors.append(guided_decode_logits_processor)
- dry_sequence_breaker_ids = []
- if self.dry_sequence_breakers:
- for s in self.dry_sequence_breakers:
- s = bytes(s, "utf-8").decode("unicode_escape")
- token_id = tokenizer.encode(f'a{s}')[-1]
- dry_sequence_breaker_ids.append(token_id)
- return SamplingParams(
- n=self.n,
- best_of=self.best_of,
- presence_penalty=self.presence_penalty,
- frequency_penalty=self.frequency_penalty,
- repetition_penalty=self.repetition_penalty,
- no_repeat_ngram_size=self.no_repeat_ngram_size,
- temperature=self.temperature,
- top_p=self.top_p,
- top_k=self.top_k,
- min_p=self.min_p,
- top_a=self.top_a,
- tfs=self.tfs,
- eta_cutoff=self.eta_cutoff,
- epsilon_cutoff=self.epsilon_cutoff,
- typical_p=self.typical_p,
- smoothing_factor=self.smoothing_factor,
- smoothing_curve=self.smoothing_curve,
- seed=self.seed,
- stop=self.stop,
- stop_token_ids=self.stop_token_ids,
- ignore_eos=self.ignore_eos,
- max_tokens=max_tokens if not echo_without_generation else 1,
- min_tokens=self.min_tokens,
- logprobs=self.logprobs,
- prompt_logprobs=self.prompt_logprobs
- if self.prompt_logprobs else self.logprobs if self.echo else None,
- use_beam_search=self.use_beam_search,
- early_stopping=self.early_stopping,
- skip_special_tokens=self.skip_special_tokens,
- spaces_between_special_tokens=(self.spaces_between_special_tokens),
- include_stop_str_in_output=self.include_stop_str_in_output,
- length_penalty=self.length_penalty,
- logits_processors=logits_processors,
- truncate_prompt_tokens=self.truncate_prompt_tokens,
- temperature_last=self.temperature_last,
- xtc_threshold=self.xtc_threshold,
- xtc_probability=self.xtc_probability,
- dry_multiplier=self.dry_multiplier,
- dry_base=self.dry_base,
- dry_allowed_length=self.dry_allowed_length,
- dry_sequence_breaker_ids=dry_sequence_breaker_ids,
- dry_range=self.dry_range,
- dynatemp_min=self.dynatemp_min,
- dynatemp_max=self.dynatemp_max,
- dynatemp_exponent=self.dynatemp_exponent,
- nsigma=self.nsigma,
- skew=self.skew,
- custom_token_bans=self.custom_token_bans,
- sampler_priority=self.sampler_priority,
- )
- @model_validator(mode="before")
- @classmethod
- def check_guided_decoding_count(cls, data):
- guide_count = sum([
- "guided_json" in data and data["guided_json"] is not None,
- "guided_regex" in data and data["guided_regex"] is not None,
- "guided_choice" in data and data["guided_choice"] is not None
- ])
- if guide_count > 1:
- raise ValueError(
- "You can only use one kind of guided decoding "
- "('guided_json', 'guided_regex' or 'guided_choice').")
- return data
- @model_validator(mode="before")
- @classmethod
- def check_logprobs(cls, data):
- if "logprobs" in data and data[
- "logprobs"] is not None and not data["logprobs"] >= 0:
- raise ValueError("if passed, `logprobs` must be a positive value.")
- return data
- @model_validator(mode="before")
- @classmethod
- def validate_stream_options(cls, data):
- if data.get("stream_options") and not data.get("stream"):
- raise ValueError(
- "Stream options can only be defined when stream is True.")
- return data
-
- @model_validator(mode='before')
- @classmethod
- def parse_dry_sequence_breakers(cls, data):
- if 'dry_sequence_breakers' in data:
- breakers = data['dry_sequence_breakers']
- if isinstance(breakers, str):
- try:
- # Try to parse as JSON string
- data['dry_sequence_breakers'] = json.loads(breakers)
- except json.JSONDecodeError as e:
- raise ValueError(f"Invalid JSON for dry_sequence_breakers:"
- f" {e}") from e
-
- # Validate that we now have a list of strings
- is_list = isinstance(data['dry_sequence_breakers'], list)
- all_strings = all(
- isinstance(x, str)
- for x in data['dry_sequence_breakers']
- )
- if not is_list or not all_strings:
- raise ValueError(
- "dry_sequence_breakers must be a list of strings or a "
- "JSON string representing a list of strings"
- )
-
- return data
- class EmbeddingRequest(OpenAIBaseModel):
- # Ordered by official OpenAI API documentation
- # https://platform.openai.com/docs/api-reference/embeddings
- model: str
- input: Union[List[int], List[List[int]], str, List[str]]
- encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
- dimensions: Optional[int] = None
- user: Optional[str] = None
- # doc: begin-embedding-pooling-params
- additional_data: Optional[Any] = None
- # doc: end-embedding-pooling-params
- def to_pooling_params(self):
- return PoolingParams(additional_data=self.additional_data)
- class CompletionLogProbs(OpenAIBaseModel):
- text_offset: List[int] = Field(default_factory=list)
- token_logprobs: List[Optional[float]] = Field(default_factory=list)
- tokens: List[str] = Field(default_factory=list)
- top_logprobs: List[Optional[Dict[str,
- float]]] = Field(default_factory=list)
- class CompletionResponseChoice(OpenAIBaseModel):
- index: int
- text: str
- logprobs: Optional[CompletionLogProbs] = None
- finish_reason: Optional[str] = None
- stop_reason: Optional[Union[int, str]] = Field(
- default=None,
- description=(
- "The stop string or token id that caused the completion "
- "to stop, None if the completion finished for some other reason "
- "including encountering the EOS token"),
- )
- prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
- class CompletionResponse(OpenAIBaseModel):
- id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
- object: str = "text_completion"
- created: int = Field(default_factory=lambda: int(time.time()))
- model: str
- choices: List[CompletionResponseChoice]
- usage: UsageInfo
- class CompletionResponseStreamChoice(OpenAIBaseModel):
- index: int
- text: str
- logprobs: Optional[CompletionLogProbs] = None
- finish_reason: Optional[str] = None
- stop_reason: Optional[Union[int, str]] = Field(
- default=None,
- description=(
- "The stop string or token id that caused the completion "
- "to stop, None if the completion finished for some other reason "
- "including encountering the EOS token"),
- )
- class CompletionStreamResponse(OpenAIBaseModel):
- id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
- object: str = "text_completion"
- created: int = Field(default_factory=lambda: int(time.time()))
- model: str
- choices: List[CompletionResponseStreamChoice]
- usage: Optional[UsageInfo] = Field(default=None)
- class EmbeddingResponseData(OpenAIBaseModel):
- index: int
- object: str = "embedding"
- embedding: Union[List[float], str]
- class EmbeddingResponse(OpenAIBaseModel):
- id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
- object: str = "list"
- created: int = Field(default_factory=lambda: int(time.time()))
- model: str
- data: List[EmbeddingResponseData]
- usage: UsageInfo
- class FunctionCall(OpenAIBaseModel):
- name: str
- arguments: str
- class ToolCall(OpenAIBaseModel):
- id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
- type: Literal["function"] = "function"
- function: FunctionCall
- class ChatMessage(OpenAIBaseModel):
- role: str
- content: str
- tool_calls: List[ToolCall] = Field(default_factory=list)
- class ChatCompletionLogProb(OpenAIBaseModel):
- token: str
- logprob: float = -9999.0
- bytes: Optional[List[int]] = None
- class ChatCompletionLogProbsContent(ChatCompletionLogProb):
- top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
- class ChatCompletionLogProbs(OpenAIBaseModel):
- content: Optional[List[ChatCompletionLogProbsContent]] = None
- class ChatCompletionResponseChoice(OpenAIBaseModel):
- index: int
- message: ChatMessage
- logprobs: Optional[ChatCompletionLogProbs] = None
- finish_reason: Optional[str] = None
- stop_reason: Optional[Union[int, str]] = None
- class ChatCompletionResponse(OpenAIBaseModel):
- id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
- object: Literal["chat.completion"] = "chat.completion"
- created: int = Field(default_factory=lambda: int(time.time()))
- model: str
- choices: List[ChatCompletionResponseChoice]
- usage: UsageInfo
- prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
- class DeltaMessage(OpenAIBaseModel):
- role: Optional[str] = None
- content: Optional[str] = None
- tool_calls: List[ToolCall] = Field(default_factory=list)
- class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
- index: int
- delta: DeltaMessage
- logprobs: Optional[ChatCompletionLogProbs] = None
- finish_reason: Optional[str] = None
- stop_reason: Optional[Union[int, str]] = None
- class ChatCompletionStreamResponse(OpenAIBaseModel):
- id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
- object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
- created: int = Field(default_factory=lambda: int(time.time()))
- model: str
- choices: List[ChatCompletionResponseStreamChoice]
- usage: Optional[UsageInfo] = Field(default=None)
- class BatchRequestInput(OpenAIBaseModel):
- """
- The per-line object of the batch input file.
- NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
- """
- # A developer-provided per-request id that will be used to match outputs to
- # inputs. Must be unique for each request in a batch.
- custom_id: str
- # The HTTP method to be used for the request. Currently only POST is
- # supported.
- method: str
- # The OpenAI API relative URL to be used for the request. Currently
- # /v1/chat/completions is supported.
- url: str
- # The parameters of the request.
- body: Union[ChatCompletionRequest, EmbeddingRequest]
- class BatchResponseData(OpenAIBaseModel):
- # HTTP status code of the response.
- status_code: int = 200
- # An unique identifier for the API request.
- request_id: str
- # The body of the response.
- body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
- class BatchRequestOutput(OpenAIBaseModel):
- """
- The per-line object of the batch output and error files
- """
- id: str
- # A developer-provided per-request id that will be used to match outputs to
- # inputs.
- custom_id: str
- response: Optional[BatchResponseData]
- # For requests that failed with a non-HTTP error, this will contain more
- # information on the cause of the failure.
- error: Optional[Any]
- class TokenizeCompletionRequest(OpenAIBaseModel):
- model: str
- prompt: str
- add_special_tokens: bool = Field(default=True)
- class TokenizeChatRequest(OpenAIBaseModel):
- model: str
- messages: List[ChatCompletionMessageParam]
- add_generation_prompt: bool = Field(default=True)
- add_special_tokens: bool = Field(default=False)
- TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
- class TokenizeResponse(OpenAIBaseModel):
- tokens: List[int]
- count: int
- max_model_len: int
- class DetokenizeRequest(OpenAIBaseModel):
- model: Optional[str]
- tokens: List[int]
- class DetokenizeResponse(OpenAIBaseModel):
- prompt: str
- # ========== KoboldAI ========== #
- class KAIGenerationInputSchema(BaseModel):
- genkey: Optional[str] = None
- prompt: str
- n: Optional[int] = 1
- max_context_length: int
- max_length: int
- rep_pen: Optional[float] = 1.0
- top_k: Optional[int] = 0
- top_a: Optional[float] = 0.0
- top_p: Optional[float] = 1.0
- min_p: Optional[float] = 0.0
- tfs: Optional[float] = 1.0
- eps_cutoff: Optional[float] = 0.0
- eta_cutoff: Optional[float] = 0.0
- typical: Optional[float] = 1.0
- temperature: Optional[float] = 1.0
- dynatemp_range: Optional[float] = 0.0
- dynatemp_exponent: Optional[float] = 1.0
- smoothing_factor: Optional[float] = 0.0
- smoothing_curve: Optional[float] = 1.0
- xtc_threshold: Optional[float] = 0.1
- xtc_probability: Optional[float] = 0.0
- use_default_badwordsids: Optional[bool] = None
- quiet: Optional[bool] = None
- # pylint: disable=unexpected-keyword-arg
- sampler_seed: Optional[int] = None
- stop_sequence: Optional[List[str]] = None
- include_stop_str_in_output: Optional[bool] = False
- @model_validator(mode='before')
- def check_context(cls, values): # pylint: disable=no-self-argument
- assert values.get("max_length") <= values.get(
- "max_context_length"
- ), "max_length must not be larger than max_context_length"
- return values
|