protocol.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Any, Dict, List, Literal, Optional, Union
  5. import torch
  6. from pydantic import (BaseModel, ConfigDict, Field, model_validator,
  7. root_validator)
  8. from transformers import PreTrainedTokenizer
  9. from typing_extensions import Annotated
  10. from aphrodite.common.pooling_params import PoolingParams
  11. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  12. SamplingParams)
  13. from aphrodite.common.utils import random_uuid
  14. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  15. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  16. class OpenAIBaseModel(BaseModel):
  17. model_config = ConfigDict(extra="ignore")
  18. class ErrorResponse(OpenAIBaseModel):
  19. object: str = "error"
  20. message: str
  21. type: str
  22. param: Optional[str] = None
  23. code: int
  24. class ModelPermission(OpenAIBaseModel):
  25. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  26. object: str = "model_permission"
  27. created: int = Field(default_factory=lambda: int(time.time()))
  28. allow_create_engine: bool = False
  29. allow_sampling: bool = True
  30. allow_logprobs: bool = True
  31. allow_search_indices: bool = False
  32. allow_view: bool = True
  33. allow_fine_tuning: bool = False
  34. organization: str = "*"
  35. group: Optional[str] = None
  36. is_blocking: bool = False
  37. class ModelCard(OpenAIBaseModel):
  38. id: str
  39. object: str = "model"
  40. created: int = Field(default_factory=lambda: int(time.time()))
  41. owned_by: str = "pygmalionai"
  42. root: Optional[str] = None
  43. parent: Optional[str] = None
  44. max_model_len: Optional[int] = None
  45. permission: List[ModelPermission] = Field(default_factory=list)
  46. class ModelList(OpenAIBaseModel):
  47. object: str = "list"
  48. data: List[ModelCard] = Field(default_factory=list)
  49. class UsageInfo(OpenAIBaseModel):
  50. prompt_tokens: int = 0
  51. total_tokens: int = 0
  52. completion_tokens: Optional[int] = 0
  53. class ResponseFormat(OpenAIBaseModel):
  54. # type must be "json_object" or "text"
  55. type: Literal["text", "json_object"]
  56. class StreamOptions(OpenAIBaseModel):
  57. include_usage: Optional[bool] = True
  58. continuous_usage_stats: Optional[bool] = True
  59. class FunctionDefinition(OpenAIBaseModel):
  60. name: str
  61. description: Optional[str] = None
  62. parameters: Optional[Dict[str, Any]] = None
  63. class ChatCompletionToolsParam(OpenAIBaseModel):
  64. type: Literal["function"] = "function"
  65. function: FunctionDefinition
  66. class ChatCompletionNamedFunction(OpenAIBaseModel):
  67. name: str
  68. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  69. function: ChatCompletionNamedFunction
  70. type: Literal["function"] = "function"
  71. class ChatCompletionRequest(OpenAIBaseModel):
  72. # Ordered by official OpenAI API documentation
  73. # https://platform.openai.com/docs/api-reference/chat/create
  74. messages: List[ChatCompletionMessageParam]
  75. model: str
  76. frequency_penalty: Optional[float] = 0.0
  77. logit_bias: Optional[Dict[str, float]] = None
  78. logprobs: Optional[bool] = False
  79. top_logprobs: Optional[int] = 0
  80. max_tokens: Optional[int] = None
  81. n: Optional[int] = 1
  82. presence_penalty: Optional[float] = 0.0
  83. response_format: Optional[ResponseFormat] = None
  84. seed: Optional[int] = Field(None,
  85. ge=torch.iinfo(torch.long).min,
  86. le=torch.iinfo(torch.long).max)
  87. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  88. stream: Optional[bool] = False
  89. stream_options: Optional[StreamOptions] = None
  90. temperature: Optional[float] = 0.7
  91. top_p: Optional[float] = 1.0
  92. tools: Optional[List[ChatCompletionToolsParam]] = None
  93. tool_choice: Optional[Union[Literal["none"],
  94. ChatCompletionNamedToolChoiceParam]] = "none"
  95. user: Optional[str] = None
  96. # doc: begin-chat-completion-sampling-params
  97. best_of: Optional[int] = None
  98. use_beam_search: Optional[bool] = False
  99. top_k: Optional[int] = -1
  100. min_p: Optional[float] = 0.0
  101. top_a: Optional[float] = 0.0
  102. tfs: Optional[float] = 1.0
  103. eta_cutoff: Optional[float] = 0.0
  104. epsilon_cutoff: Optional[float] = 0.0
  105. typical_p: Optional[float] = 1.0
  106. smoothing_factor: Optional[float] = 0.0
  107. smoothing_curve: Optional[float] = 1.0
  108. dynatemp_min: Optional[float] = 0.0
  109. dynatemp_max: Optional[float] = 0.0
  110. dynatemp_exponent: Optional[float] = 1.0
  111. repetition_penalty: Optional[float] = 1.0
  112. length_penalty: Optional[float] = 1.0
  113. early_stopping: Optional[bool] = False
  114. ignore_eos: Optional[bool] = False
  115. min_tokens: Optional[int] = 0
  116. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  117. skip_special_tokens: Optional[bool] = True
  118. spaces_between_special_tokens: Optional[bool] = True
  119. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  120. # doc: end-chat-completion-sampling-params
  121. # doc: begin-chat-completion-extra-params
  122. echo: Optional[bool] = Field(
  123. default=False,
  124. description=(
  125. "If true, the new message will be prepended with the last message "
  126. "if they belong to the same role."),
  127. )
  128. add_generation_prompt: Optional[bool] = Field(
  129. default=True,
  130. description=
  131. ("If true, the generation prompt will be added to the chat template. "
  132. "This is a parameter used by chat template in tokenizer config of the "
  133. "model."),
  134. )
  135. add_special_tokens: Optional[bool] = Field(
  136. default=False,
  137. description=(
  138. "If true, special tokens (e.g. BOS) will be added to the prompt "
  139. "on top of what is added by the chat template. "
  140. "For most models, the chat template takes care of adding the "
  141. "special tokens so this should be set to False (as is the "
  142. "default)."),
  143. )
  144. documents: Optional[List[Dict[str, str]]] = Field(
  145. default=None,
  146. description=
  147. ("A list of dicts representing documents that will be accessible to "
  148. "the model if it is performing RAG (retrieval-augmented generation)."
  149. " If the template does not support RAG, this argument will have no "
  150. "effect. We recommend that each document should be a dict containing "
  151. "\"title\" and \"text\" keys."),
  152. )
  153. chat_template: Optional[str] = Field(
  154. default=None,
  155. description=(
  156. "A Jinja template to use for this conversion. "
  157. "As of transformers v4.44, default chat template is no longer "
  158. "allowed, so you must provide a chat template if the tokenizer "
  159. "does not define one."),
  160. )
  161. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  162. default=None,
  163. description=("Additional kwargs to pass to the template renderer. "
  164. "Will be accessible by the chat template."),
  165. )
  166. include_stop_str_in_output: Optional[bool] = Field(
  167. default=False,
  168. description=(
  169. "Whether to include the stop string in the output. "
  170. "This is only applied when the stop or stop_token_ids is set."),
  171. )
  172. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  173. default=None,
  174. description=("If specified, the output will follow the JSON schema."),
  175. )
  176. guided_regex: Optional[str] = Field(
  177. default=None,
  178. description=(
  179. "If specified, the output will follow the regex pattern."),
  180. )
  181. guided_choice: Optional[List[str]] = Field(
  182. default=None,
  183. description=(
  184. "If specified, the output will be exactly one of the choices."),
  185. )
  186. guided_grammar: Optional[str] = Field(
  187. default=None,
  188. description=(
  189. "If specified, the output will follow the context free grammar."),
  190. )
  191. guided_decoding_backend: Optional[str] = Field(
  192. default=None,
  193. description=(
  194. "If specified, will override the default guided decoding backend "
  195. "of the server for this specific request. If set, must be either "
  196. "'outlines' / 'lm-format-enforcer'"))
  197. guided_whitespace_pattern: Optional[str] = Field(
  198. default=None,
  199. description=(
  200. "If specified, will override the default whitespace pattern "
  201. "for guided json decoding."))
  202. # doc: end-chat-completion-extra-params
  203. def to_sampling_params(
  204. self, tokenizer: PreTrainedTokenizer,
  205. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  206. default_max_tokens: int) -> SamplingParams:
  207. max_tokens = self.max_tokens
  208. if max_tokens is None:
  209. max_tokens = default_max_tokens
  210. # We now allow logprobs being true without top_logrobs.
  211. logits_processors = get_logits_processors(
  212. logit_bias=self.logit_bias,
  213. allowed_token_ids=None,
  214. tokenizer=tokenizer,
  215. )
  216. if guided_decode_logits_processor:
  217. logits_processors.append(guided_decode_logits_processor)
  218. return SamplingParams(
  219. n=self.n,
  220. presence_penalty=self.presence_penalty,
  221. frequency_penalty=self.frequency_penalty,
  222. repetition_penalty=self.repetition_penalty,
  223. temperature=self.temperature,
  224. top_p=self.top_p,
  225. min_p=self.min_p,
  226. seed=self.seed,
  227. stop=self.stop,
  228. stop_token_ids=self.stop_token_ids,
  229. max_tokens=max_tokens,
  230. min_tokens=self.min_tokens,
  231. logprobs=self.top_logprobs if self.logprobs else None,
  232. prompt_logprobs=self.top_logprobs if self.echo else None,
  233. best_of=self.best_of,
  234. top_k=self.top_k,
  235. top_a=self.top_a,
  236. tfs=self.tfs,
  237. eta_cutoff=self.eta_cutoff,
  238. epsilon_cutoff=self.epsilon_cutoff,
  239. typical_p=self.typical_p,
  240. smoothing_factor=self.smoothing_factor,
  241. smoothing_curve=self.smoothing_curve,
  242. dynatemp_min=self.dynatemp_min,
  243. dynatemp_max=self.dynatemp_max,
  244. dynatemp_exponent=self.dynatemp_exponent,
  245. ignore_eos=self.ignore_eos,
  246. use_beam_search=self.use_beam_search,
  247. early_stopping=self.early_stopping,
  248. skip_special_tokens=self.skip_special_tokens,
  249. spaces_between_special_tokens=self.spaces_between_special_tokens,
  250. include_stop_str_in_output=self.include_stop_str_in_output,
  251. length_penalty=self.length_penalty,
  252. logits_processors=logits_processors,
  253. )
  254. @model_validator(mode='before')
  255. @classmethod
  256. def validate_stream_options(cls, values):
  257. if (values.get('stream_options') is not None
  258. and not values.get('stream')):
  259. raise ValueError(
  260. "stream_options can only be set if stream is true")
  261. return values
  262. @model_validator(mode="before")
  263. @classmethod
  264. def check_guided_decoding_count(cls, data):
  265. guide_count = sum([
  266. "guided_json" in data and data["guided_json"] is not None,
  267. "guided_regex" in data and data["guided_regex"] is not None,
  268. "guided_choice" in data and data["guided_choice"] is not None
  269. ])
  270. # you can only use one kind of guided decoding
  271. if guide_count > 1:
  272. raise ValueError(
  273. "You can only use one kind of guided decoding "
  274. "('guided_json', 'guided_regex' or 'guided_choice').")
  275. # you can only either use guided decoding or tools, not both
  276. if guide_count > 1 and "tool_choice" in data and data[
  277. "tool_choice"] != "none":
  278. raise ValueError(
  279. "You can only either use guided decoding or tools, not both.")
  280. return data
  281. @model_validator(mode="before")
  282. @classmethod
  283. def check_tool_choice(cls, data):
  284. if "tool_choice" in data and data["tool_choice"] != "none":
  285. if not isinstance(data["tool_choice"], dict):
  286. raise ValueError("Currently only named tools are supported.")
  287. if "tools" not in data or data["tools"] is None:
  288. raise ValueError(
  289. "When using `tool_choice`, `tools` must be set.")
  290. return data
  291. @model_validator(mode="before")
  292. @classmethod
  293. def check_logprobs(cls, data):
  294. if "top_logprobs" in data and data["top_logprobs"] is not None:
  295. if "logprobs" not in data or data["logprobs"] is False:
  296. raise ValueError(
  297. "when using `top_logprobs`, `logprobs` must be set to true."
  298. )
  299. elif data["top_logprobs"] < 0:
  300. raise ValueError(
  301. "`top_logprobs` must be a value a positive value.")
  302. return data
  303. class CompletionRequest(OpenAIBaseModel):
  304. # Ordered by official OpenAI API documentation
  305. # https://platform.openai.com/docs/api-reference/completions/create
  306. model: str
  307. prompt: Union[List[int], List[List[int]], str, List[str]]
  308. best_of: Optional[int] = None
  309. echo: Optional[bool] = False
  310. frequency_penalty: Optional[float] = 0.0
  311. logit_bias: Optional[Dict[str, float]] = None
  312. logprobs: Optional[int] = None
  313. max_tokens: Optional[int] = 16
  314. n: int = 1
  315. presence_penalty: Optional[float] = 0.0
  316. seed: Optional[int] = Field(None,
  317. ge=torch.iinfo(torch.long).min,
  318. le=torch.iinfo(torch.long).max)
  319. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  320. stream: Optional[bool] = False
  321. stream_options: Optional[StreamOptions] = None
  322. suffix: Optional[str] = None
  323. temperature: Optional[float] = 1.0
  324. top_p: Optional[float] = 1.0
  325. user: Optional[str] = None
  326. # doc: begin-completion-sampling-params
  327. use_beam_search: Optional[bool] = False
  328. top_k: Optional[int] = -1
  329. min_p: Optional[float] = 0.0
  330. top_a: Optional[float] = 0.0
  331. tfs: Optional[float] = 1.0
  332. eta_cutoff: Optional[float] = 0.0
  333. epsilon_cutoff: Optional[float] = 0.0
  334. typical_p: Optional[float] = 1.0
  335. smoothing_factor: Optional[float] = 0.0
  336. smoothing_curve: Optional[float] = 1.0
  337. dynatemp_min: Optional[float] = 0.0
  338. dynatemp_max: Optional[float] = 0.0
  339. dynatemp_exponent: Optional[float] = 1.0
  340. repetition_penalty: Optional[float] = 1.0
  341. length_penalty: Optional[float] = 1.0
  342. early_stopping: Optional[bool] = False
  343. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  344. ignore_eos: Optional[bool] = False
  345. min_tokens: Optional[int] = 0
  346. skip_special_tokens: Optional[bool] = True
  347. spaces_between_special_tokens: Optional[bool] = True
  348. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  349. allowed_token_ids: Optional[List[int]] = None
  350. include_stop_str_in_output: Optional[bool] = False
  351. add_special_tokens: Optional[bool] = False
  352. # doc: end-completion-sampling-params
  353. # doc: begin-completion-extra-params
  354. response_format: Optional[ResponseFormat] = Field(
  355. default=None,
  356. description=
  357. ("Similar to chat completion, this parameter specifies the format of "
  358. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  359. "supported."),
  360. )
  361. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  362. default=None,
  363. description=("If specified, the output will follow the JSON schema."),
  364. )
  365. guided_regex: Optional[str] = Field(
  366. default=None,
  367. description=(
  368. "If specified, the output will follow the regex pattern."),
  369. )
  370. guided_choice: Optional[List[str]] = Field(
  371. default=None,
  372. description=(
  373. "If specified, the output will be exactly one of the choices."),
  374. )
  375. guided_grammar: Optional[str] = Field(
  376. default=None,
  377. description=(
  378. "If specified, the output will follow the context free grammar."),
  379. )
  380. guided_decoding_backend: Optional[str] = Field(
  381. default=None,
  382. description=(
  383. "If specified, will override the default guided decoding backend "
  384. "of the server for this specific request. If set, must be one of "
  385. "'outlines' / 'lm-format-enforcer'"))
  386. guided_whitespace_pattern: Optional[str] = Field(
  387. default=None,
  388. description=(
  389. "If specified, will override the default whitespace pattern "
  390. "for guided json decoding."))
  391. # doc: end-completion-extra-params
  392. def to_sampling_params(
  393. self, tokenizer: PreTrainedTokenizer,
  394. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  395. default_max_tokens: int) -> SamplingParams:
  396. max_tokens = self.max_tokens
  397. if max_tokens is None:
  398. max_tokens = default_max_tokens
  399. echo_without_generation = self.echo and self.max_tokens == 0
  400. logits_processors = get_logits_processors(
  401. logit_bias=self.logit_bias,
  402. allowed_token_ids=self.allowed_token_ids,
  403. tokenizer=tokenizer,
  404. )
  405. if guided_decode_logits_processor:
  406. logits_processors.append(guided_decode_logits_processor)
  407. return SamplingParams(
  408. n=self.n,
  409. best_of=self.best_of,
  410. presence_penalty=self.presence_penalty,
  411. frequency_penalty=self.frequency_penalty,
  412. repetition_penalty=self.repetition_penalty,
  413. temperature=self.temperature,
  414. top_p=self.top_p,
  415. top_k=self.top_k,
  416. min_p=self.min_p,
  417. top_a=self.top_a,
  418. tfs=self.tfs,
  419. eta_cutoff=self.eta_cutoff,
  420. epsilon_cutoff=self.epsilon_cutoff,
  421. typical_p=self.typical_p,
  422. smoothing_factor=self.smoothing_factor,
  423. smoothing_curve=self.smoothing_curve,
  424. dynatemp_min=self.dynatemp_min,
  425. dynatemp_max=self.dynatemp_max,
  426. dynatemp_exponent=self.dynatemp_exponent,
  427. seed=self.seed,
  428. stop=self.stop,
  429. stop_token_ids=self.stop_token_ids,
  430. ignore_eos=self.ignore_eos,
  431. max_tokens=max_tokens if not echo_without_generation else 1,
  432. min_tokens=self.min_tokens,
  433. logprobs=self.logprobs,
  434. use_beam_search=self.use_beam_search,
  435. early_stopping=self.early_stopping,
  436. prompt_logprobs=self.logprobs if self.echo else None,
  437. skip_special_tokens=self.skip_special_tokens,
  438. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  439. include_stop_str_in_output=self.include_stop_str_in_output,
  440. length_penalty=self.length_penalty,
  441. logits_processors=logits_processors,
  442. truncate_prompt_tokens=self.truncate_prompt_tokens,
  443. )
  444. @model_validator(mode="before")
  445. @classmethod
  446. def check_guided_decoding_count(cls, data):
  447. guide_count = sum([
  448. "guided_json" in data and data["guided_json"] is not None,
  449. "guided_regex" in data and data["guided_regex"] is not None,
  450. "guided_choice" in data and data["guided_choice"] is not None
  451. ])
  452. if guide_count > 1:
  453. raise ValueError(
  454. "You can only use one kind of guided decoding "
  455. "('guided_json', 'guided_regex' or 'guided_choice').")
  456. return data
  457. @model_validator(mode="before")
  458. @classmethod
  459. def check_logprobs(cls, data):
  460. if "logprobs" in data and data[
  461. "logprobs"] is not None and not data["logprobs"] >= 0:
  462. raise ValueError("if passed, `logprobs` must be a positive value.")
  463. return data
  464. @model_validator(mode="before")
  465. @classmethod
  466. def validate_stream_options(cls, data):
  467. if data.get("stream_options") and not data.get("stream"):
  468. raise ValueError(
  469. "Stream options can only be defined when stream is True.")
  470. return data
  471. class EmbeddingRequest(OpenAIBaseModel):
  472. # Ordered by official OpenAI API documentation
  473. # https://platform.openai.com/docs/api-reference/embeddings
  474. model: str
  475. input: Union[List[int], List[List[int]], str, List[str]]
  476. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  477. dimensions: Optional[int] = None
  478. user: Optional[str] = None
  479. # doc: begin-embedding-pooling-params
  480. additional_data: Optional[Any] = None
  481. # doc: end-embedding-pooling-params
  482. def to_pooling_params(self):
  483. return PoolingParams(additional_data=self.additional_data)
  484. class CompletionLogProbs(OpenAIBaseModel):
  485. text_offset: List[int] = Field(default_factory=list)
  486. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  487. tokens: List[str] = Field(default_factory=list)
  488. top_logprobs: List[Optional[Dict[str,
  489. float]]] = Field(default_factory=list)
  490. class CompletionResponseChoice(OpenAIBaseModel):
  491. index: int
  492. text: str
  493. logprobs: Optional[CompletionLogProbs] = None
  494. finish_reason: Optional[str] = None
  495. stop_reason: Optional[Union[int, str]] = Field(
  496. default=None,
  497. description=(
  498. "The stop string or token id that caused the completion "
  499. "to stop, None if the completion finished for some other reason "
  500. "including encountering the EOS token"),
  501. )
  502. class CompletionResponse(OpenAIBaseModel):
  503. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  504. object: str = "text_completion"
  505. created: int = Field(default_factory=lambda: int(time.time()))
  506. model: str
  507. choices: List[CompletionResponseChoice]
  508. usage: UsageInfo
  509. class CompletionResponseStreamChoice(OpenAIBaseModel):
  510. index: int
  511. text: str
  512. logprobs: Optional[CompletionLogProbs] = None
  513. finish_reason: Optional[str] = None
  514. stop_reason: Optional[Union[int, str]] = Field(
  515. default=None,
  516. description=(
  517. "The stop string or token id that caused the completion "
  518. "to stop, None if the completion finished for some other reason "
  519. "including encountering the EOS token"),
  520. )
  521. class CompletionStreamResponse(OpenAIBaseModel):
  522. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  523. object: str = "text_completion"
  524. created: int = Field(default_factory=lambda: int(time.time()))
  525. model: str
  526. choices: List[CompletionResponseStreamChoice]
  527. usage: Optional[UsageInfo] = Field(default=None)
  528. class EmbeddingResponseData(OpenAIBaseModel):
  529. index: int
  530. object: str = "embedding"
  531. embedding: Union[List[float], str]
  532. class EmbeddingResponse(OpenAIBaseModel):
  533. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  534. object: str = "list"
  535. created: int = Field(default_factory=lambda: int(time.time()))
  536. model: str
  537. data: List[EmbeddingResponseData]
  538. usage: UsageInfo
  539. class FunctionCall(OpenAIBaseModel):
  540. name: str
  541. arguments: str
  542. class ToolCall(OpenAIBaseModel):
  543. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  544. type: Literal["function"] = "function"
  545. function: FunctionCall
  546. class ChatMessage(OpenAIBaseModel):
  547. role: str
  548. content: str
  549. tool_calls: List[ToolCall] = Field(default_factory=list)
  550. class ChatCompletionLogProb(OpenAIBaseModel):
  551. token: str
  552. logprob: float = -9999.0
  553. bytes: Optional[List[int]] = None
  554. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  555. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  556. class ChatCompletionLogProbs(OpenAIBaseModel):
  557. content: Optional[List[ChatCompletionLogProbsContent]] = None
  558. class ChatCompletionResponseChoice(OpenAIBaseModel):
  559. index: int
  560. message: ChatMessage
  561. logprobs: Optional[ChatCompletionLogProbs] = None
  562. finish_reason: Optional[str] = None
  563. stop_reason: Optional[Union[int, str]] = None
  564. class ChatCompletionResponse(OpenAIBaseModel):
  565. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  566. object: Literal["chat.completion"] = "chat.completion"
  567. created: int = Field(default_factory=lambda: int(time.time()))
  568. model: str
  569. choices: List[ChatCompletionResponseChoice]
  570. usage: UsageInfo
  571. class DeltaMessage(OpenAIBaseModel):
  572. role: Optional[str] = None
  573. content: Optional[str] = None
  574. tool_calls: List[ToolCall] = Field(default_factory=list)
  575. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  576. index: int
  577. delta: DeltaMessage
  578. logprobs: Optional[ChatCompletionLogProbs] = None
  579. finish_reason: Optional[str] = None
  580. stop_reason: Optional[Union[int, str]] = None
  581. class ChatCompletionStreamResponse(OpenAIBaseModel):
  582. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  583. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  584. created: int = Field(default_factory=lambda: int(time.time()))
  585. model: str
  586. choices: List[ChatCompletionResponseStreamChoice]
  587. usage: Optional[UsageInfo] = Field(default=None)
  588. class BatchRequestInput(OpenAIBaseModel):
  589. """
  590. The per-line object of the batch input file.
  591. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  592. """
  593. # A developer-provided per-request id that will be used to match outputs to
  594. # inputs. Must be unique for each request in a batch.
  595. custom_id: str
  596. # The HTTP method to be used for the request. Currently only POST is
  597. # supported.
  598. method: str
  599. # The OpenAI API relative URL to be used for the request. Currently
  600. # /v1/chat/completions is supported.
  601. url: str
  602. # The parameters of the request.
  603. body: ChatCompletionRequest
  604. class BatchResponseData(OpenAIBaseModel):
  605. # HTTP status code of the response.
  606. status_code: int = 200
  607. # An unique identifier for the API request.
  608. request_id: str
  609. # The body of the response.
  610. body: Optional[ChatCompletionResponse] = None
  611. class BatchRequestOutput(OpenAIBaseModel):
  612. """
  613. The per-line object of the batch output and error files
  614. """
  615. id: str
  616. # A developer-provided per-request id that will be used to match outputs to
  617. # inputs.
  618. custom_id: str
  619. response: Optional[BatchResponseData]
  620. # For requests that failed with a non-HTTP error, this will contain more
  621. # information on the cause of the failure.
  622. error: Optional[Any]
  623. class TokenizeCompletionRequest(OpenAIBaseModel):
  624. model: str
  625. prompt: str
  626. add_special_tokens: bool = Field(default=True)
  627. class TokenizeChatRequest(OpenAIBaseModel):
  628. model: str
  629. messages: List[ChatCompletionMessageParam]
  630. add_generation_prompt: bool = Field(default=True)
  631. add_special_tokens: bool = Field(default=False)
  632. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  633. class TokenizeResponse(OpenAIBaseModel):
  634. tokens: List[int]
  635. count: int
  636. max_model_len: int
  637. class DetokenizeRequest(OpenAIBaseModel):
  638. model: Optional[str]
  639. tokens: List[int]
  640. class DetokenizeResponse(OpenAIBaseModel):
  641. prompt: str
  642. # ========== KoboldAI ========== #
  643. class KoboldSamplingParams(BaseModel):
  644. n: int = Field(1, alias="n")
  645. best_of: Optional[int] = Field(None, alias="best_of")
  646. presence_penalty: float = Field(0.0, alias="presence_penalty")
  647. frequency_penalty: float = Field(0.0, alias="rep_pen")
  648. temperature: float = Field(1.0, alias="temperature")
  649. dynatemp_range: Optional[float] = 0.0
  650. dynatemp_exponent: Optional[float] = 1.0
  651. smoothing_factor: Optional[float] = 0.0
  652. smoothing_curve: Optional[float] = 1.0
  653. top_p: float = Field(1.0, alias="top_p")
  654. top_k: float = Field(-1, alias="top_k")
  655. min_p: float = Field(0.0, alias="min_p")
  656. top_a: float = Field(0.0, alias="top_a")
  657. tfs: float = Field(1.0, alias="tfs")
  658. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  659. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  660. typical_p: float = Field(1.0, alias="typical_p")
  661. use_beam_search: bool = Field(False, alias="use_beam_search")
  662. length_penalty: float = Field(1.0, alias="length_penalty")
  663. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  664. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  665. include_stop_str_in_output: Optional[bool] = False
  666. ignore_eos: bool = Field(False, alias="ignore_eos")
  667. max_tokens: int = Field(16, alias="max_length")
  668. logprobs: Optional[int] = Field(None, alias="logprobs")
  669. custom_token_bans: Optional[List[int]] = Field(None,
  670. alias="custom_token_bans")
  671. @root_validator(pre=False, skip_on_failure=True)
  672. def validate_best_of(cls, values): # pylint: disable=no-self-argument
  673. best_of = values.get("best_of")
  674. n = values.get("n")
  675. if best_of is not None and (best_of <= 0 or best_of > n):
  676. raise ValueError(
  677. "best_of must be a positive integer less than or equal to n")
  678. return values
  679. class KAIGenerationInputSchema(BaseModel):
  680. genkey: Optional[str] = None
  681. prompt: str
  682. n: Optional[int] = 1
  683. max_context_length: int
  684. max_length: int
  685. rep_pen: Optional[float] = 1.0
  686. rep_pen_range: Optional[int] = None
  687. rep_pen_slope: Optional[float] = None
  688. top_k: Optional[int] = 0
  689. top_a: Optional[float] = 0.0
  690. top_p: Optional[float] = 1.0
  691. min_p: Optional[float] = 0.0
  692. tfs: Optional[float] = 1.0
  693. eps_cutoff: Optional[float] = 0.0
  694. eta_cutoff: Optional[float] = 0.0
  695. typical: Optional[float] = 1.0
  696. temperature: Optional[float] = 1.0
  697. dynatemp_range: Optional[float] = 0.0
  698. dynatemp_exponent: Optional[float] = 1.0
  699. smoothing_factor: Optional[float] = 0.0
  700. smoothing_curve: Optional[float] = 1.0
  701. use_memory: Optional[bool] = None
  702. use_story: Optional[bool] = None
  703. use_authors_note: Optional[bool] = None
  704. use_world_info: Optional[bool] = None
  705. use_userscripts: Optional[bool] = None
  706. soft_prompt: Optional[str] = None
  707. disable_output_formatting: Optional[bool] = None
  708. frmtrmblln: Optional[bool] = None
  709. frmtrmspch: Optional[bool] = None
  710. singleline: Optional[bool] = None
  711. use_default_badwordsids: Optional[bool] = None
  712. mirostat: Optional[int] = 0
  713. mirostat_tau: Optional[float] = 0.0
  714. mirostat_eta: Optional[float] = 0.0
  715. disable_input_formatting: Optional[bool] = None
  716. frmtadsnsp: Optional[bool] = None
  717. quiet: Optional[bool] = None
  718. # pylint: disable=unexpected-keyword-arg
  719. sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
  720. sampler_seed: Optional[int] = None
  721. sampler_full_determinism: Optional[bool] = None
  722. stop_sequence: Optional[List[str]] = None
  723. include_stop_str_in_output: Optional[bool] = False
  724. @root_validator(pre=False, skip_on_failure=True)
  725. def check_context(cls, values): # pylint: disable=no-self-argument
  726. assert values.get("max_length") <= values.get(
  727. "max_context_length"
  728. ), "max_length must not be larger than max_context_length"
  729. return values