protocol.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Any, Dict, List, Literal, Optional, Union
  5. from pydantic import (AliasChoices, BaseModel, Field, conint, model_validator,
  6. root_validator)
  7. import torch
  8. from aphrodite.common.pooling_params import PoolingParams
  9. from aphrodite.common.sampling_params import SamplingParams
  10. from aphrodite.common.utils import random_uuid
  11. from aphrodite.common.logits_processor import BiasLogitsProcessor
  12. class ErrorResponse(BaseModel):
  13. object: str = "error"
  14. message: str
  15. type: str
  16. param: Optional[str] = None
  17. code: int
  18. class ModelPermission(BaseModel):
  19. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  20. object: str = "model_permission"
  21. created: int = Field(default_factory=lambda: int(time.time()))
  22. allow_create_engine: bool = False
  23. allow_sampling: bool = True
  24. allow_logprobs: bool = True
  25. allow_search_indices: bool = False
  26. allow_view: bool = True
  27. allow_fine_tuning: bool = False
  28. organization: str = "*"
  29. group: Optional[str] = None
  30. is_blocking: bool = False
  31. class ModelCard(BaseModel):
  32. id: str
  33. object: str = "model"
  34. created: int = Field(default_factory=lambda: int(time.time()))
  35. owned_by: str = "pygmalionai"
  36. root: Optional[str] = None
  37. parent: Optional[str] = None
  38. max_model_len: Optional[int] = None
  39. permission: List[ModelPermission] = Field(default_factory=list)
  40. class ModelList(BaseModel):
  41. object: str = "list"
  42. data: List[ModelCard] = Field(default_factory=list)
  43. class UsageInfo(BaseModel):
  44. prompt_tokens: int = 0
  45. total_tokens: int = 0
  46. completion_tokens: Optional[int] = 0
  47. class ResponseFormat(BaseModel):
  48. # type must be "json_object" or "text"
  49. type: Literal["text", "json_object"]
  50. class StreamOptions(BaseModel):
  51. include_usage: Optional[bool]
  52. class FunctionDefinition(BaseModel):
  53. name: str
  54. description: Optional[str] = None
  55. parameters: Optional[Dict[str, Any]] = None
  56. class ChatCompletionToolsParam(BaseModel):
  57. type: Literal["function"] = "function"
  58. function: FunctionDefinition
  59. class ChatCompletionNamedFunction(BaseModel):
  60. name: str
  61. class ChatCompletionNamedToolChoiceParam(BaseModel):
  62. function: ChatCompletionNamedFunction
  63. type: Literal["function"] = "function"
  64. class ChatCompletionRequest(BaseModel):
  65. model: str
  66. # support list type in messages.content
  67. messages: List[Dict[str, Union[str, List[Dict[str, str]]]]]
  68. temperature: Optional[float] = 0.7
  69. top_p: Optional[float] = 1.0
  70. tfs: Optional[float] = 1.0
  71. eta_cutoff: Optional[float] = 0.0
  72. epsilon_cutoff: Optional[float] = 0.0
  73. typical_p: Optional[float] = 1.0
  74. n: Optional[int] = 1
  75. max_tokens: Optional[int] = None
  76. min_tokens: Optional[int] = 0
  77. seed: Optional[int] = Field(None,
  78. ge=torch.iinfo(torch.long).min,
  79. le=torch.iinfo(torch.long).max)
  80. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  81. include_stop_str_in_output: Optional[bool] = False
  82. stream: Optional[bool] = False
  83. stream_options: Optional[StreamOptions] = None
  84. logprobs: Optional[bool] = False
  85. top_logprobs: Optional[int] = None
  86. presence_penalty: Optional[float] = 0.0
  87. frequency_penalty: Optional[float] = 0.0
  88. repetition_penalty: Optional[float] = 1.0
  89. logit_bias: Optional[Dict[str, float]] = None
  90. tools: Optional[List[ChatCompletionToolsParam]] = None
  91. tool_choice: Optional[Union[Literal["none"],
  92. ChatCompletionNamedToolChoiceParam]] = "none"
  93. user: Optional[str] = None
  94. best_of: Optional[int] = None
  95. top_k: Optional[int] = -1
  96. top_a: Optional[float] = 0.0
  97. min_p: Optional[float] = 0.0
  98. mirostat_mode: Optional[int] = 0
  99. mirostat_tau: Optional[float] = 0.0
  100. mirostat_eta: Optional[float] = 0.0
  101. dynatemp_min: Optional[float] = 0.0
  102. dynatemp_max: Optional[float] = 0.0
  103. dynatemp_exponent: Optional[float] = 1.0
  104. smoothing_factor: Optional[float] = 0.0
  105. smoothing_curve: Optional[float] = 1.0
  106. ignore_eos: Optional[bool] = False
  107. use_beam_search: Optional[bool] = False
  108. prompt_logprobs: Optional[int] = None
  109. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  110. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  111. skip_special_tokens: Optional[bool] = True
  112. spaces_between_special_tokens: Optional[bool] = True
  113. add_generation_prompt: Optional[bool] = True
  114. echo: Optional[bool] = False
  115. length_penalty: Optional[float] = 1.0
  116. guided_json: Optional[Union[str, dict, BaseModel]] = None
  117. guided_regex: Optional[str] = None
  118. guided_choice: Optional[List[str]] = None
  119. guided_grammar: Optional[str] = None
  120. response_format: Optional[ResponseFormat] = None
  121. guided_decoding_backend: Optional[str] = Field(
  122. default="outlines",
  123. description=(
  124. "If specified, will override the default guided decoding backend "
  125. "of the server for this specific request. If set, must be either "
  126. "'outlines' / 'lm-format-enforcer'"))
  127. guided_whitespace_pattern: Optional[str] = Field(
  128. default=None,
  129. description=(
  130. "If specified, will override the default whitespace pattern "
  131. "for guided json decoding."))
  132. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  133. if self.logprobs and not self.top_logprobs:
  134. raise ValueError("Top logprobs must be set when logprobs is.")
  135. if self.top_k == 0:
  136. self.top_k = -1
  137. logits_processors = []
  138. if self.logit_bias:
  139. biases = {
  140. int(tok): max(-100, min(float(bias), 100))
  141. for tok, bias in self.logit_bias.items()
  142. if 0 < int(tok) < vocab_size
  143. }
  144. logits_processors.append(BiasLogitsProcessor(biases))
  145. return SamplingParams(
  146. n=self.n,
  147. max_tokens=self.max_tokens,
  148. min_tokens=self.min_tokens,
  149. logprobs=self.top_logprobs if self.logprobs else None,
  150. prompt_logprobs=self.top_logprobs if self.echo else None,
  151. temperature=self.temperature,
  152. top_p=self.top_p,
  153. tfs=self.tfs,
  154. eta_cutoff=self.eta_cutoff,
  155. epsilon_cutoff=self.epsilon_cutoff,
  156. typical_p=self.typical_p,
  157. presence_penalty=self.presence_penalty,
  158. frequency_penalty=self.frequency_penalty,
  159. repetition_penalty=self.repetition_penalty,
  160. top_k=self.top_k,
  161. top_a=self.top_a,
  162. min_p=self.min_p,
  163. mirostat_mode=self.mirostat_mode,
  164. mirostat_tau=self.mirostat_tau,
  165. mirostat_eta=self.mirostat_eta,
  166. dynatemp_min=self.dynatemp_min,
  167. dynatemp_max=self.dynatemp_max,
  168. dynatemp_exponent=self.dynatemp_exponent,
  169. smoothing_factor=self.smoothing_factor,
  170. smoothing_curve=self.smoothing_curve,
  171. ignore_eos=self.ignore_eos,
  172. use_beam_search=self.use_beam_search,
  173. stop_token_ids=self.stop_token_ids,
  174. custom_token_bans=self.custom_token_bans,
  175. skip_special_tokens=self.skip_special_tokens,
  176. spaces_between_special_tokens=self.spaces_between_special_tokens,
  177. stop=self.stop,
  178. best_of=self.best_of,
  179. include_stop_str_in_output=self.include_stop_str_in_output,
  180. seed=self.seed,
  181. logits_processors=logits_processors,
  182. )
  183. @model_validator(mode='before')
  184. @classmethod
  185. def validate_stream_options(cls, values):
  186. if (values.get('stream_options') is not None
  187. and not values.get('stream')):
  188. raise ValueError(
  189. "stream_options can only be set if stream is true")
  190. return values
  191. @model_validator(mode="before")
  192. @classmethod
  193. def check_guided_decoding_count(cls, data):
  194. guide_count = sum([
  195. "guided_json" in data and data["guided_json"] is not None,
  196. "guided_regex" in data and data["guided_regex"] is not None,
  197. "guided_choice" in data and data["guided_choice"] is not None
  198. ])
  199. if guide_count > 1:
  200. raise ValueError(
  201. "You can only use one kind of guided decoding "
  202. "('guided_json', 'guided_regex' or 'guided_choice').")
  203. # you can only either use guided decoding or tools, not both
  204. if guide_count > 1 and "tool_choice" in data and data[
  205. "tool_choice"] != "none":
  206. raise ValueError(
  207. "You can only either use guided decoding or tools, not both.")
  208. return data
  209. @model_validator(mode="before")
  210. @classmethod
  211. def check_tool_choice(cls, data):
  212. if "tool_choice" in data and data["tool_choice"] != "none":
  213. if not isinstance(data["tool_choice"], dict):
  214. raise ValueError("Currently only named tools are supported.")
  215. if "tools" not in data or data["tools"] is None:
  216. raise ValueError(
  217. "When using `tool_choice`, `tools` must be set.")
  218. return data
  219. class CompletionRequest(BaseModel):
  220. model: str
  221. # a string, array of strings, array of tokens, or array of token arrays
  222. prompt: Union[List[int], List[List[int]], str, List[str]]
  223. suffix: Optional[str] = None
  224. max_tokens: Optional[int] = 16
  225. min_tokens: Optional[int] = 0
  226. temperature: Optional[float] = 1.0
  227. top_p: Optional[float] = 1.0
  228. tfs: Optional[float] = 1.0
  229. eta_cutoff: Optional[float] = 0.0
  230. epsilon_cutoff: Optional[float] = 0.0
  231. typical_p: Optional[float] = 1.0
  232. n: Optional[int] = 1
  233. stream: Optional[bool] = False
  234. logprobs: Optional[int] = None
  235. echo: Optional[bool] = False
  236. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  237. seed: Optional[int] = Field(None,
  238. ge=torch.iinfo(torch.long).min,
  239. le=torch.iinfo(torch.long).max)
  240. include_stop_str_in_output: Optional[bool] = False
  241. presence_penalty: Optional[float] = 0.0
  242. frequency_penalty: Optional[float] = 0.0
  243. repetition_penalty: Optional[float] = 1.0
  244. best_of: Optional[int] = None
  245. logit_bias: Optional[Dict[str, float]] = None
  246. user: Optional[str] = None
  247. top_k: Optional[int] = -1
  248. top_a: Optional[float] = 0.0
  249. min_p: Optional[float] = 0.0
  250. mirostat_mode: Optional[int] = 0
  251. mirostat_tau: Optional[float] = 0.0
  252. mirostat_eta: Optional[float] = 0.0
  253. dynatemp_min: Optional[float] = Field(0.0,
  254. validation_alias=AliasChoices(
  255. "dynatemp_min", "dynatemp_low"),
  256. description="Aliases: dynatemp_low")
  257. dynatemp_max: Optional[float] = Field(0.0,
  258. validation_alias=AliasChoices(
  259. "dynatemp_max", "dynatemp_high"),
  260. description="Aliases: dynatemp_high")
  261. dynatemp_exponent: Optional[float] = 1.0
  262. smoothing_factor: Optional[float] = 0.0
  263. smoothing_curve: Optional[float] = 1.0
  264. ignore_eos: Optional[bool] = False
  265. use_beam_search: Optional[bool] = False
  266. logprobs: Optional[int] = None
  267. prompt_logprobs: Optional[int] = None
  268. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  269. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  270. skip_special_tokens: Optional[bool] = True
  271. spaces_between_special_tokens: Optional[bool] = True
  272. truncate_prompt_tokens: Optional[conint(ge=1)] = None
  273. grammar: Optional[str] = None
  274. length_penalty: Optional[float] = 1.0
  275. guided_json: Optional[Union[str, dict, BaseModel]] = None
  276. guided_regex: Optional[str] = None
  277. guided_choice: Optional[List[str]] = None
  278. guided_grammar: Optional[str] = None
  279. response_format: Optional[ResponseFormat] = None
  280. guided_decoding_backend: Optional[str] = Field(
  281. default="outlines",
  282. description=(
  283. "If specified, will override the default guided decoding backend "
  284. "of the server for this specific request. If set, must be one of "
  285. "'outlines' / 'lm-format-enforcer'"))
  286. guided_whitespace_pattern: Optional[str] = Field(
  287. default=None,
  288. description=(
  289. "If specified, will override the default whitespace pattern "
  290. "for guided json decoding."))
  291. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  292. echo_without_generation = self.echo and self.max_tokens == 0
  293. if self.top_k == 0:
  294. self.top_k = -1
  295. logits_processors = []
  296. if self.logit_bias:
  297. biases = {
  298. int(tok): max(-100, min(float(bias), 100))
  299. for tok, bias in self.logit_bias.items()
  300. if 0 < int(tok) < vocab_size
  301. }
  302. logits_processors.append(BiasLogitsProcessor(biases))
  303. return SamplingParams(
  304. n=self.n,
  305. max_tokens=self.max_tokens if not echo_without_generation else 1,
  306. min_tokens=self.min_tokens,
  307. temperature=self.temperature,
  308. top_p=self.top_p,
  309. tfs=self.tfs,
  310. eta_cutoff=self.eta_cutoff,
  311. epsilon_cutoff=self.epsilon_cutoff,
  312. typical_p=self.typical_p,
  313. presence_penalty=self.presence_penalty,
  314. frequency_penalty=self.frequency_penalty,
  315. repetition_penalty=self.repetition_penalty,
  316. top_k=self.top_k,
  317. top_a=self.top_a,
  318. min_p=self.min_p,
  319. mirostat_mode=self.mirostat_mode,
  320. mirostat_tau=self.mirostat_tau,
  321. mirostat_eta=self.mirostat_eta,
  322. dynatemp_min=self.dynatemp_min,
  323. dynatemp_max=self.dynatemp_max,
  324. dynatemp_exponent=self.dynatemp_exponent,
  325. smoothing_factor=self.smoothing_factor,
  326. smoothing_curve=self.smoothing_curve,
  327. ignore_eos=self.ignore_eos,
  328. use_beam_search=self.use_beam_search,
  329. logprobs=self.logprobs,
  330. prompt_logprobs=self.prompt_logprobs if self.echo else None,
  331. stop_token_ids=self.stop_token_ids,
  332. custom_token_bans=self.custom_token_bans,
  333. skip_special_tokens=self.skip_special_tokens,
  334. spaces_between_special_tokens=self.spaces_between_special_tokens,
  335. stop=self.stop,
  336. best_of=self.best_of,
  337. include_stop_str_in_output=self.include_stop_str_in_output,
  338. seed=self.seed,
  339. logits_processors=logits_processors,
  340. truncate_prompt_tokens=self.truncate_prompt_tokens,
  341. )
  342. @model_validator(mode="before")
  343. @classmethod
  344. def check_guided_decoding_count(cls, data):
  345. guide_count = sum([
  346. "guided_json" in data and data["guided_json"] is not None,
  347. "guided_regex" in data and data["guided_regex"] is not None,
  348. "guided_choice" in data and data["guided_choice"] is not None
  349. ])
  350. if guide_count > 1:
  351. raise ValueError(
  352. "You can only use one kind of guided decoding "
  353. "('guided_json', 'guided_regex' or 'guided_choice').")
  354. return data
  355. class LogProbs(BaseModel):
  356. text_offset: List[int] = Field(default_factory=list)
  357. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  358. tokens: List[str] = Field(default_factory=list)
  359. top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
  360. class CompletionResponseChoice(BaseModel):
  361. index: int
  362. text: str
  363. logprobs: Optional[LogProbs] = None
  364. finish_reason: Optional[Literal["stop", "length"]] = None
  365. stop_reason: Union[None, int, str] = Field(
  366. default=None,
  367. description=(
  368. "The stop string or token id that caused the completion "
  369. "to stop, None if the completion finished for some other reason "
  370. "including encountering the EOS token"),
  371. )
  372. class CompletionResponse(BaseModel):
  373. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  374. object: str = "text_completion"
  375. created: int = Field(default_factory=lambda: int(time.time()))
  376. model: str
  377. choices: List[CompletionResponseChoice]
  378. usage: UsageInfo
  379. class CompletionResponseStreamChoice(BaseModel):
  380. index: int
  381. text: str
  382. logprobs: Optional[LogProbs] = None
  383. finish_reason: Optional[Literal["stop", "length"]] = None
  384. stop_reason: Union[None, int, str] = Field(
  385. default=None,
  386. description=(
  387. "The stop string or token id that caused the completion "
  388. "to stop, None if the completion finished for some other reason "
  389. "including encountering the EOS token"),
  390. )
  391. class CompletionStreamResponse(BaseModel):
  392. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  393. object: str = "text_completion"
  394. created: int = Field(default_factory=lambda: int(time.time()))
  395. model: str
  396. choices: List[CompletionResponseStreamChoice]
  397. usage: Optional[UsageInfo] = Field(default=None)
  398. class EmbeddingResponseData(BaseModel):
  399. index: int
  400. object: str = "embedding"
  401. embedding: List[float]
  402. class EmbeddingResponse(BaseModel):
  403. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  404. object: str = "list"
  405. created: int = Field(default_factory=lambda: int(time.time()))
  406. model: str
  407. data: List[EmbeddingResponseData]
  408. usage: UsageInfo
  409. class FunctionCall(BaseModel):
  410. name: str
  411. arguments: str
  412. class ToolCall(BaseModel):
  413. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  414. type: Literal["function"] = "function"
  415. function: FunctionCall
  416. class ChatMessage(BaseModel):
  417. role: str
  418. content: str
  419. class ChatCompletionResponseChoice(BaseModel):
  420. index: int
  421. message: ChatMessage
  422. logprobs: Optional[LogProbs] = None
  423. finish_reason: Optional[Literal["stop", "length"]] = None
  424. stop_reason: Union[None, int, str] = None
  425. class ChatCompletionResponse(BaseModel):
  426. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  427. object: Literal["chat.completion"] = "chat.completion"
  428. created: int = Field(default_factory=lambda: int(time.time()))
  429. model: str
  430. choices: List[ChatCompletionResponseChoice]
  431. usage: UsageInfo
  432. class DeltaMessage(BaseModel):
  433. role: Optional[str] = None
  434. content: Optional[str] = None
  435. tool_calls: List[ToolCall] = Field(default_factory=list)
  436. class ChatCompletionResponseStreamChoice(BaseModel):
  437. index: int
  438. delta: DeltaMessage
  439. logprobs: Optional[LogProbs] = None
  440. finish_reason: Optional[Literal["stop", "length"]] = None
  441. stop_reason: Union[None, int, str] = None
  442. class ChatCompletionStreamResponse(BaseModel):
  443. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  444. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  445. created: int = Field(default_factory=lambda: int(time.time()))
  446. model: str
  447. choices: List[ChatCompletionResponseStreamChoice]
  448. usage: Optional[UsageInfo] = Field(default=None)
  449. logprobs: Optional[LogProbs] = None
  450. class EmbeddingRequest(BaseModel):
  451. # Ordered by official OpenAI API documentation
  452. # https://platform.openai.com/docs/api-reference/embeddings
  453. model: str
  454. input: Union[List[int], List[List[int]], str, List[str]]
  455. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  456. dimensions: Optional[int] = None
  457. user: Optional[str] = None
  458. # doc: begin-embedding-pooling-params
  459. additional_data: Optional[Any] = None
  460. # doc: end-embedding-pooling-params
  461. def to_pooling_params(self):
  462. return PoolingParams(additional_data=self.additional_data)
  463. class Prompt(BaseModel):
  464. prompt: str
  465. # ========== KoboldAI ========== #
  466. class KoboldSamplingParams(BaseModel):
  467. n: int = Field(1, alias="n")
  468. best_of: Optional[int] = Field(None, alias="best_of")
  469. presence_penalty: float = Field(0.0, alias="presence_penalty")
  470. frequency_penalty: float = Field(0.0, alias="rep_pen")
  471. temperature: float = Field(1.0, alias="temperature")
  472. dynatemp_range: Optional[float] = 0.0
  473. dynatemp_exponent: Optional[float] = 1.0
  474. smoothing_factor: Optional[float] = 0.0
  475. smoothing_curve: Optional[float] = 1.0
  476. top_p: float = Field(1.0, alias="top_p")
  477. top_k: float = Field(-1, alias="top_k")
  478. min_p: float = Field(0.0, alias="min_p")
  479. top_a: float = Field(0.0, alias="top_a")
  480. tfs: float = Field(1.0, alias="tfs")
  481. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  482. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  483. typical_p: float = Field(1.0, alias="typical_p")
  484. use_beam_search: bool = Field(False, alias="use_beam_search")
  485. length_penalty: float = Field(1.0, alias="length_penalty")
  486. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  487. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  488. include_stop_str_in_output: Optional[bool] = False
  489. ignore_eos: bool = Field(False, alias="ignore_eos")
  490. max_tokens: int = Field(16, alias="max_length")
  491. logprobs: Optional[int] = Field(None, alias="logprobs")
  492. custom_token_bans: Optional[List[int]] = Field(None,
  493. alias="custom_token_bans")
  494. @root_validator(pre=False, skip_on_failure=True)
  495. def validate_best_of(cls, values): # pylint: disable=no-self-argument
  496. best_of = values.get("best_of")
  497. n = values.get("n")
  498. if best_of is not None and (best_of <= 0 or best_of > n):
  499. raise ValueError(
  500. "best_of must be a positive integer less than or equal to n")
  501. return values
  502. class KAIGenerationInputSchema(BaseModel):
  503. genkey: Optional[str] = None
  504. prompt: str
  505. n: Optional[int] = 1
  506. max_context_length: int
  507. max_length: int
  508. rep_pen: Optional[float] = 1.0
  509. rep_pen_range: Optional[int] = None
  510. rep_pen_slope: Optional[float] = None
  511. top_k: Optional[int] = 0
  512. top_a: Optional[float] = 0.0
  513. top_p: Optional[float] = 1.0
  514. min_p: Optional[float] = 0.0
  515. tfs: Optional[float] = 1.0
  516. eps_cutoff: Optional[float] = 0.0
  517. eta_cutoff: Optional[float] = 0.0
  518. typical: Optional[float] = 1.0
  519. temperature: Optional[float] = 1.0
  520. dynatemp_range: Optional[float] = 0.0
  521. dynatemp_exponent: Optional[float] = 1.0
  522. smoothing_factor: Optional[float] = 0.0
  523. smoothing_curve: Optional[float] = 1.0
  524. use_memory: Optional[bool] = None
  525. use_story: Optional[bool] = None
  526. use_authors_note: Optional[bool] = None
  527. use_world_info: Optional[bool] = None
  528. use_userscripts: Optional[bool] = None
  529. soft_prompt: Optional[str] = None
  530. disable_output_formatting: Optional[bool] = None
  531. frmtrmblln: Optional[bool] = None
  532. frmtrmspch: Optional[bool] = None
  533. singleline: Optional[bool] = None
  534. use_default_badwordsids: Optional[bool] = None
  535. mirostat: Optional[int] = 0
  536. mirostat_tau: Optional[float] = 0.0
  537. mirostat_eta: Optional[float] = 0.0
  538. disable_input_formatting: Optional[bool] = None
  539. frmtadsnsp: Optional[bool] = None
  540. quiet: Optional[bool] = None
  541. # pylint: disable=unexpected-keyword-arg
  542. sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
  543. sampler_seed: Optional[int] = None
  544. sampler_full_determinism: Optional[bool] = None
  545. stop_sequence: Optional[List[str]] = None
  546. include_stop_str_in_output: Optional[bool] = False
  547. @root_validator(pre=False, skip_on_failure=True)
  548. def check_context(cls, values): # pylint: disable=no-self-argument
  549. assert values.get("max_length") <= values.get(
  550. "max_context_length"
  551. ), "max_length must not be larger than max_context_length"
  552. return values