protocol.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import json
  4. import time
  5. from typing import Any, Dict, List, Literal, Optional, Union
  6. import torch
  7. from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
  8. model_validator)
  9. from transformers import PreTrainedTokenizer
  10. from typing_extensions import Annotated
  11. from aphrodite.common.pooling_params import PoolingParams
  12. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  13. SamplingParams)
  14. from aphrodite.common.sequence import Logprob
  15. from aphrodite.common.utils import random_uuid
  16. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  17. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  18. class OpenAIBaseModel(BaseModel):
  19. model_config = ConfigDict(extra="ignore")
  20. class ErrorResponse(OpenAIBaseModel):
  21. object: str = "error"
  22. message: str
  23. type: str
  24. param: Optional[str] = None
  25. code: int
  26. class ModelPermission(OpenAIBaseModel):
  27. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  28. object: str = "model_permission"
  29. created: int = Field(default_factory=lambda: int(time.time()))
  30. allow_create_engine: bool = False
  31. allow_sampling: bool = True
  32. allow_logprobs: bool = True
  33. allow_search_indices: bool = False
  34. allow_view: bool = True
  35. allow_fine_tuning: bool = False
  36. organization: str = "*"
  37. group: Optional[str] = None
  38. is_blocking: bool = False
  39. class ModelCard(OpenAIBaseModel):
  40. id: str
  41. object: str = "model"
  42. created: int = Field(default_factory=lambda: int(time.time()))
  43. owned_by: str = "pygmalionai"
  44. root: Optional[str] = None
  45. parent: Optional[str] = None
  46. max_model_len: Optional[int] = None
  47. permission: List[ModelPermission] = Field(default_factory=list)
  48. class ModelList(OpenAIBaseModel):
  49. object: str = "list"
  50. data: List[ModelCard] = Field(default_factory=list)
  51. class UsageInfo(OpenAIBaseModel):
  52. prompt_tokens: int = 0
  53. total_tokens: int = 0
  54. completion_tokens: Optional[int] = 0
  55. class ResponseFormat(OpenAIBaseModel):
  56. # type must be "json_object" or "text"
  57. type: Literal["text", "json_object"]
  58. class StreamOptions(OpenAIBaseModel):
  59. include_usage: Optional[bool] = True
  60. continuous_usage_stats: Optional[bool] = True
  61. class FunctionDefinition(OpenAIBaseModel):
  62. name: str
  63. description: Optional[str] = None
  64. parameters: Optional[Dict[str, Any]] = None
  65. class ChatCompletionToolsParam(OpenAIBaseModel):
  66. type: Literal["function"] = "function"
  67. function: FunctionDefinition
  68. class ChatCompletionNamedFunction(OpenAIBaseModel):
  69. name: str
  70. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  71. function: ChatCompletionNamedFunction
  72. type: Literal["function"] = "function"
  73. class ChatCompletionRequest(OpenAIBaseModel):
  74. # Ordered by official OpenAI API documentation
  75. # https://platform.openai.com/docs/api-reference/chat/create
  76. messages: List[ChatCompletionMessageParam]
  77. model: str
  78. frequency_penalty: Optional[float] = 0.0
  79. logit_bias: Optional[Dict[str, float]] = None
  80. logprobs: Optional[bool] = False
  81. top_logprobs: Optional[int] = 0
  82. max_tokens: Optional[int] = None
  83. n: Optional[int] = 1
  84. presence_penalty: Optional[float] = 0.0
  85. response_format: Optional[ResponseFormat] = None
  86. seed: Optional[int] = Field(None,
  87. ge=torch.iinfo(torch.long).min,
  88. le=torch.iinfo(torch.long).max)
  89. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  90. stream: Optional[bool] = False
  91. stream_options: Optional[StreamOptions] = None
  92. temperature: Optional[float] = 0.7
  93. top_p: Optional[float] = 1.0
  94. tools: Optional[List[ChatCompletionToolsParam]] = None
  95. tool_choice: Optional[Union[Literal["none"],
  96. ChatCompletionNamedToolChoiceParam]] = "none"
  97. user: Optional[str] = None
  98. # doc: begin-chat-completion-sampling-params
  99. best_of: Optional[int] = None
  100. use_beam_search: Optional[bool] = False
  101. top_k: Optional[int] = -1
  102. min_p: Optional[float] = 0.0
  103. top_a: Optional[float] = 0.0
  104. tfs: Optional[float] = 1.0
  105. eta_cutoff: Optional[float] = 0.0
  106. epsilon_cutoff: Optional[float] = 0.0
  107. typical_p: Optional[float] = 1.0
  108. smoothing_factor: Optional[float] = 0.0
  109. smoothing_curve: Optional[float] = 1.0
  110. repetition_penalty: Optional[float] = 1.0
  111. no_repeat_ngram_size: Optional[int] = 0
  112. length_penalty: Optional[float] = 1.0
  113. early_stopping: Optional[bool] = False
  114. ignore_eos: Optional[bool] = False
  115. min_tokens: Optional[int] = 0
  116. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  117. skip_special_tokens: Optional[bool] = True
  118. spaces_between_special_tokens: Optional[bool] = True
  119. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  120. temperature_last: Optional[bool] = False
  121. prompt_logprobs: Optional[int] = None
  122. xtc_threshold: Optional[float] = 0.1
  123. xtc_probability: Optional[float] = 0.0
  124. dry_multiplier: Optional[float] = 0
  125. dry_base: Optional[float] = 1.75
  126. dry_allowed_length: Optional[int] = 2
  127. dry_sequence_breakers: Optional[List[str]] = Field(
  128. default=["\n", ":", "\"", "*"])
  129. dynatemp_min: Optional[float] = 0.0
  130. dynatemp_max: Optional[float] = 0.0
  131. dynatemp_exponent: Optional[float] = 1.0
  132. nsigma: Optional[float] = 0.0
  133. skew: Optional[float] = 0.0
  134. custom_token_bans: Optional[List[int]] = None
  135. sampler_priority: Optional[List[int]] = Field(
  136. default=[],
  137. validation_alias=AliasChoices("sampler_priority",
  138. "sampler_order"))
  139. # doc: end-chat-completion-sampling-params
  140. # doc: begin-chat-completion-extra-params
  141. echo: Optional[bool] = Field(
  142. default=False,
  143. description=(
  144. "If true, the new message will be prepended with the last message "
  145. "if they belong to the same role."),
  146. )
  147. add_generation_prompt: Optional[bool] = Field(
  148. default=True,
  149. description=
  150. ("If true, the generation prompt will be added to the chat template. "
  151. "This is a parameter used by chat template in tokenizer config of the "
  152. "model."),
  153. )
  154. add_special_tokens: Optional[bool] = Field(
  155. default=False,
  156. description=(
  157. "If true, special tokens (e.g. BOS) will be added to the prompt "
  158. "on top of what is added by the chat template. "
  159. "For most models, the chat template takes care of adding the "
  160. "special tokens so this should be set to False (as is the "
  161. "default)."),
  162. )
  163. documents: Optional[List[Dict[str, str]]] = Field(
  164. default=None,
  165. description=
  166. ("A list of dicts representing documents that will be accessible to "
  167. "the model if it is performing RAG (retrieval-augmented generation)."
  168. " If the template does not support RAG, this argument will have no "
  169. "effect. We recommend that each document should be a dict containing "
  170. "\"title\" and \"text\" keys."),
  171. )
  172. chat_template: Optional[str] = Field(
  173. default=None,
  174. description=(
  175. "A Jinja template to use for this conversion. "
  176. "As of transformers v4.44, default chat template is no longer "
  177. "allowed, so you must provide a chat template if the tokenizer "
  178. "does not define one."),
  179. )
  180. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  181. default=None,
  182. description=("Additional kwargs to pass to the template renderer. "
  183. "Will be accessible by the chat template."),
  184. )
  185. include_stop_str_in_output: Optional[bool] = Field(
  186. default=False,
  187. description=(
  188. "Whether to include the stop string in the output. "
  189. "This is only applied when the stop or stop_token_ids is set."),
  190. )
  191. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  192. default=None,
  193. description=("If specified, the output will follow the JSON schema."),
  194. )
  195. guided_regex: Optional[str] = Field(
  196. default=None,
  197. description=(
  198. "If specified, the output will follow the regex pattern."),
  199. )
  200. guided_choice: Optional[List[str]] = Field(
  201. default=None,
  202. description=(
  203. "If specified, the output will be exactly one of the choices."),
  204. )
  205. guided_grammar: Optional[str] = Field(
  206. default=None,
  207. description=(
  208. "If specified, the output will follow the context free grammar."),
  209. )
  210. guided_decoding_backend: Optional[str] = Field(
  211. default=None,
  212. description=(
  213. "If specified, will override the default guided decoding backend "
  214. "of the server for this specific request. If set, must be either "
  215. "'outlines' / 'lm-format-enforcer'"))
  216. guided_whitespace_pattern: Optional[str] = Field(
  217. default=None,
  218. description=(
  219. "If specified, will override the default whitespace pattern "
  220. "for guided json decoding."))
  221. # doc: end-chat-completion-extra-params
  222. def to_sampling_params(
  223. self, tokenizer: PreTrainedTokenizer,
  224. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  225. default_max_tokens: int) -> SamplingParams:
  226. max_tokens = self.max_tokens
  227. if max_tokens is None:
  228. max_tokens = default_max_tokens
  229. # We now allow logprobs being true without top_logrobs.
  230. logits_processors = get_logits_processors(
  231. logit_bias=self.logit_bias,
  232. allowed_token_ids=None,
  233. tokenizer=tokenizer,
  234. )
  235. if guided_decode_logits_processor:
  236. logits_processors.append(guided_decode_logits_processor)
  237. dry_sequence_breaker_ids = []
  238. if self.dry_sequence_breakers:
  239. for s in self.dry_sequence_breakers:
  240. token_id = tokenizer.encode(f'a{s}')[-1]
  241. dry_sequence_breaker_ids.append(token_id)
  242. return SamplingParams(
  243. n=self.n,
  244. presence_penalty=self.presence_penalty,
  245. frequency_penalty=self.frequency_penalty,
  246. repetition_penalty=self.repetition_penalty,
  247. no_repeat_ngram_size=self.no_repeat_ngram_size,
  248. temperature=self.temperature,
  249. top_p=self.top_p,
  250. min_p=self.min_p,
  251. seed=self.seed,
  252. stop=self.stop,
  253. stop_token_ids=self.stop_token_ids,
  254. max_tokens=max_tokens,
  255. min_tokens=self.min_tokens,
  256. logprobs=self.top_logprobs if self.logprobs else None,
  257. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  258. (self.top_logprobs if self.echo else None),
  259. best_of=self.best_of,
  260. top_k=self.top_k,
  261. top_a=self.top_a,
  262. tfs=self.tfs,
  263. eta_cutoff=self.eta_cutoff,
  264. epsilon_cutoff=self.epsilon_cutoff,
  265. typical_p=self.typical_p,
  266. smoothing_factor=self.smoothing_factor,
  267. smoothing_curve=self.smoothing_curve,
  268. ignore_eos=self.ignore_eos,
  269. use_beam_search=self.use_beam_search,
  270. early_stopping=self.early_stopping,
  271. skip_special_tokens=self.skip_special_tokens,
  272. spaces_between_special_tokens=self.spaces_between_special_tokens,
  273. include_stop_str_in_output=self.include_stop_str_in_output,
  274. length_penalty=self.length_penalty,
  275. logits_processors=logits_processors,
  276. temperature_last=self.temperature_last,
  277. xtc_threshold=self.xtc_threshold,
  278. xtc_probability=self.xtc_probability,
  279. dry_multiplier=self.dry_multiplier,
  280. dry_base=self.dry_base,
  281. dry_allowed_length=self.dry_allowed_length,
  282. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  283. dynatemp_min=self.dynatemp_min,
  284. dynatemp_max=self.dynatemp_max,
  285. dynatemp_exponent=self.dynatemp_exponent,
  286. nsigma=self.nsigma,
  287. skew=self.skew,
  288. custom_token_bans=self.custom_token_bans,
  289. sampler_priority=self.sampler_priority,
  290. )
  291. @model_validator(mode='before')
  292. @classmethod
  293. def validate_stream_options(cls, values):
  294. if (values.get('stream_options') is not None
  295. and not values.get('stream')):
  296. raise ValueError(
  297. "stream_options can only be set if stream is true")
  298. return values
  299. @model_validator(mode="before")
  300. @classmethod
  301. def check_guided_decoding_count(cls, data):
  302. guide_count = sum([
  303. "guided_json" in data and data["guided_json"] is not None,
  304. "guided_regex" in data and data["guided_regex"] is not None,
  305. "guided_choice" in data and data["guided_choice"] is not None
  306. ])
  307. # you can only use one kind of guided decoding
  308. if guide_count > 1:
  309. raise ValueError(
  310. "You can only use one kind of guided decoding "
  311. "('guided_json', 'guided_regex' or 'guided_choice').")
  312. # you can only either use guided decoding or tools, not both
  313. if guide_count > 1 and "tool_choice" in data and data[
  314. "tool_choice"] != "none":
  315. raise ValueError(
  316. "You can only either use guided decoding or tools, not both.")
  317. return data
  318. @model_validator(mode="before")
  319. @classmethod
  320. def check_tool_choice(cls, data):
  321. if "tool_choice" in data and data["tool_choice"] != "none":
  322. if not isinstance(data["tool_choice"], dict):
  323. raise ValueError("Currently only named tools are supported.")
  324. if "tools" not in data or data["tools"] is None:
  325. raise ValueError(
  326. "When using `tool_choice`, `tools` must be set.")
  327. return data
  328. @model_validator(mode="before")
  329. @classmethod
  330. def check_logprobs(cls, data):
  331. if "top_logprobs" in data and data["top_logprobs"] is not None:
  332. if "logprobs" not in data or data["logprobs"] is False:
  333. raise ValueError(
  334. "when using `top_logprobs`, `logprobs` must be set to true."
  335. )
  336. elif data["top_logprobs"] < 0:
  337. raise ValueError(
  338. "`top_logprobs` must be a value a positive value.")
  339. return data
  340. class CompletionRequest(OpenAIBaseModel):
  341. # Ordered by official OpenAI API documentation
  342. # https://platform.openai.com/docs/api-reference/completions/create
  343. model: str
  344. prompt: Union[List[int], List[List[int]], str, List[str]]
  345. best_of: Optional[int] = None
  346. echo: Optional[bool] = False
  347. frequency_penalty: Optional[float] = 0.0
  348. logit_bias: Optional[Dict[str, float]] = None
  349. logprobs: Optional[int] = None
  350. max_tokens: Optional[int] = 16
  351. n: int = 1
  352. presence_penalty: Optional[float] = 0.0
  353. seed: Optional[int] = Field(None,
  354. ge=torch.iinfo(torch.long).min,
  355. le=torch.iinfo(torch.long).max)
  356. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  357. stream: Optional[bool] = False
  358. stream_options: Optional[StreamOptions] = None
  359. suffix: Optional[str] = None
  360. temperature: Optional[float] = 1.0
  361. top_p: Optional[float] = 1.0
  362. user: Optional[str] = None
  363. # doc: begin-completion-sampling-params
  364. use_beam_search: Optional[bool] = False
  365. top_k: Optional[int] = -1
  366. min_p: Optional[float] = 0.0
  367. top_a: Optional[float] = 0.0
  368. tfs: Optional[float] = 1.0
  369. eta_cutoff: Optional[float] = 0.0
  370. epsilon_cutoff: Optional[float] = 0.0
  371. typical_p: Optional[float] = 1.0
  372. smoothing_factor: Optional[float] = 0.0
  373. smoothing_curve: Optional[float] = 1.0
  374. repetition_penalty: Optional[float] = 1.0
  375. no_repeat_ngram_size: Optional[int] = 0
  376. length_penalty: Optional[float] = 1.0
  377. early_stopping: Optional[bool] = False
  378. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  379. ignore_eos: Optional[bool] = False
  380. min_tokens: Optional[int] = 0
  381. skip_special_tokens: Optional[bool] = True
  382. spaces_between_special_tokens: Optional[bool] = True
  383. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  384. allowed_token_ids: Optional[List[int]] = None
  385. include_stop_str_in_output: Optional[bool] = False
  386. add_special_tokens: Optional[bool] = False
  387. temperature_last: Optional[bool] = False
  388. prompt_logprobs: Optional[int] = None
  389. xtc_threshold: Optional[float] = 0.1
  390. xtc_probability: Optional[float] = 0.0
  391. dry_multiplier: Optional[float] = 0
  392. dry_base: Optional[float] = 1.75
  393. dry_allowed_length: Optional[int] = 2
  394. dry_sequence_breakers: Optional[List[str]] = Field(
  395. default=["\n", ":", "\"", "*"])
  396. dynatemp_min: Optional[float] = 0.0
  397. dynatemp_max: Optional[float] = 0.0
  398. dynatemp_exponent: Optional[float] = 1.0
  399. nsigma: Optional[float] = 0.0
  400. skew: Optional[float] = 0.0
  401. custom_token_bans: Optional[List[int]] = None
  402. sampler_priority: Optional[List[int]] = Field(
  403. default=[],
  404. validation_alias=AliasChoices("sampler_priority",
  405. "sampler_order"))
  406. # doc: end-completion-sampling-params
  407. # doc: begin-completion-extra-params
  408. response_format: Optional[ResponseFormat] = Field(
  409. default=None,
  410. description=
  411. ("Similar to chat completion, this parameter specifies the format of "
  412. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  413. "supported."),
  414. )
  415. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  416. default=None,
  417. description=("If specified, the output will follow the JSON schema."),
  418. )
  419. guided_regex: Optional[str] = Field(
  420. default=None,
  421. description=(
  422. "If specified, the output will follow the regex pattern."),
  423. )
  424. guided_choice: Optional[List[str]] = Field(
  425. default=None,
  426. description=(
  427. "If specified, the output will be exactly one of the choices."),
  428. )
  429. guided_grammar: Optional[str] = Field(
  430. default=None,
  431. description=(
  432. "If specified, the output will follow the context free grammar."),
  433. )
  434. guided_decoding_backend: Optional[str] = Field(
  435. default=None,
  436. description=(
  437. "If specified, will override the default guided decoding backend "
  438. "of the server for this specific request. If set, must be one of "
  439. "'outlines' / 'lm-format-enforcer'"))
  440. guided_whitespace_pattern: Optional[str] = Field(
  441. default=None,
  442. description=(
  443. "If specified, will override the default whitespace pattern "
  444. "for guided json decoding."))
  445. # doc: end-completion-extra-params
  446. def to_sampling_params(
  447. self, tokenizer: PreTrainedTokenizer,
  448. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  449. default_max_tokens: int) -> SamplingParams:
  450. max_tokens = self.max_tokens
  451. if max_tokens is None:
  452. max_tokens = default_max_tokens
  453. echo_without_generation = self.echo and self.max_tokens == 0
  454. logits_processors = get_logits_processors(
  455. logit_bias=self.logit_bias,
  456. allowed_token_ids=self.allowed_token_ids,
  457. tokenizer=tokenizer,
  458. )
  459. if guided_decode_logits_processor:
  460. logits_processors.append(guided_decode_logits_processor)
  461. dry_sequence_breaker_ids = []
  462. if self.dry_sequence_breakers:
  463. for s in self.dry_sequence_breakers:
  464. s = bytes(s, "utf-8").decode("unicode_escape")
  465. token_id = tokenizer.encode(f'a{s}')[-1]
  466. dry_sequence_breaker_ids.append(token_id)
  467. return SamplingParams(
  468. n=self.n,
  469. best_of=self.best_of,
  470. presence_penalty=self.presence_penalty,
  471. frequency_penalty=self.frequency_penalty,
  472. repetition_penalty=self.repetition_penalty,
  473. no_repeat_ngram_size=self.no_repeat_ngram_size,
  474. temperature=self.temperature,
  475. top_p=self.top_p,
  476. top_k=self.top_k,
  477. min_p=self.min_p,
  478. top_a=self.top_a,
  479. tfs=self.tfs,
  480. eta_cutoff=self.eta_cutoff,
  481. epsilon_cutoff=self.epsilon_cutoff,
  482. typical_p=self.typical_p,
  483. smoothing_factor=self.smoothing_factor,
  484. smoothing_curve=self.smoothing_curve,
  485. seed=self.seed,
  486. stop=self.stop,
  487. stop_token_ids=self.stop_token_ids,
  488. ignore_eos=self.ignore_eos,
  489. max_tokens=max_tokens if not echo_without_generation else 1,
  490. min_tokens=self.min_tokens,
  491. logprobs=self.logprobs,
  492. prompt_logprobs=self.prompt_logprobs
  493. if self.prompt_logprobs else self.logprobs if self.echo else None,
  494. use_beam_search=self.use_beam_search,
  495. early_stopping=self.early_stopping,
  496. skip_special_tokens=self.skip_special_tokens,
  497. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  498. include_stop_str_in_output=self.include_stop_str_in_output,
  499. length_penalty=self.length_penalty,
  500. logits_processors=logits_processors,
  501. truncate_prompt_tokens=self.truncate_prompt_tokens,
  502. temperature_last=self.temperature_last,
  503. xtc_threshold=self.xtc_threshold,
  504. xtc_probability=self.xtc_probability,
  505. dry_multiplier=self.dry_multiplier,
  506. dry_base=self.dry_base,
  507. dry_allowed_length=self.dry_allowed_length,
  508. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  509. dynatemp_min=self.dynatemp_min,
  510. dynatemp_max=self.dynatemp_max,
  511. dynatemp_exponent=self.dynatemp_exponent,
  512. nsigma=self.nsigma,
  513. skew=self.skew,
  514. custom_token_bans=self.custom_token_bans,
  515. sampler_priority=self.sampler_priority,
  516. )
  517. @model_validator(mode="before")
  518. @classmethod
  519. def check_guided_decoding_count(cls, data):
  520. guide_count = sum([
  521. "guided_json" in data and data["guided_json"] is not None,
  522. "guided_regex" in data and data["guided_regex"] is not None,
  523. "guided_choice" in data and data["guided_choice"] is not None
  524. ])
  525. if guide_count > 1:
  526. raise ValueError(
  527. "You can only use one kind of guided decoding "
  528. "('guided_json', 'guided_regex' or 'guided_choice').")
  529. return data
  530. @model_validator(mode="before")
  531. @classmethod
  532. def check_logprobs(cls, data):
  533. if "logprobs" in data and data[
  534. "logprobs"] is not None and not data["logprobs"] >= 0:
  535. raise ValueError("if passed, `logprobs` must be a positive value.")
  536. return data
  537. @model_validator(mode="before")
  538. @classmethod
  539. def validate_stream_options(cls, data):
  540. if data.get("stream_options") and not data.get("stream"):
  541. raise ValueError(
  542. "Stream options can only be defined when stream is True.")
  543. return data
  544. @model_validator(mode='before')
  545. @classmethod
  546. def parse_dry_sequence_breakers(cls, data):
  547. if 'dry_sequence_breakers' in data:
  548. breakers = data['dry_sequence_breakers']
  549. if isinstance(breakers, str):
  550. try:
  551. # Try to parse as JSON string
  552. data['dry_sequence_breakers'] = json.loads(breakers)
  553. except json.JSONDecodeError as e:
  554. raise ValueError(f"Invalid JSON for dry_sequence_breakers:"
  555. f" {e}") from e
  556. # Validate that we now have a list of strings
  557. is_list = isinstance(data['dry_sequence_breakers'], list)
  558. all_strings = all(
  559. isinstance(x, str)
  560. for x in data['dry_sequence_breakers']
  561. )
  562. if not is_list or not all_strings:
  563. raise ValueError(
  564. "dry_sequence_breakers must be a list of strings or a "
  565. "JSON string representing a list of strings"
  566. )
  567. return data
  568. class EmbeddingRequest(OpenAIBaseModel):
  569. # Ordered by official OpenAI API documentation
  570. # https://platform.openai.com/docs/api-reference/embeddings
  571. model: str
  572. input: Union[List[int], List[List[int]], str, List[str]]
  573. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  574. dimensions: Optional[int] = None
  575. user: Optional[str] = None
  576. # doc: begin-embedding-pooling-params
  577. additional_data: Optional[Any] = None
  578. # doc: end-embedding-pooling-params
  579. def to_pooling_params(self):
  580. return PoolingParams(additional_data=self.additional_data)
  581. class CompletionLogProbs(OpenAIBaseModel):
  582. text_offset: List[int] = Field(default_factory=list)
  583. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  584. tokens: List[str] = Field(default_factory=list)
  585. top_logprobs: List[Optional[Dict[str,
  586. float]]] = Field(default_factory=list)
  587. class CompletionResponseChoice(OpenAIBaseModel):
  588. index: int
  589. text: str
  590. logprobs: Optional[CompletionLogProbs] = None
  591. finish_reason: Optional[str] = None
  592. stop_reason: Optional[Union[int, str]] = Field(
  593. default=None,
  594. description=(
  595. "The stop string or token id that caused the completion "
  596. "to stop, None if the completion finished for some other reason "
  597. "including encountering the EOS token"),
  598. )
  599. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  600. class CompletionResponse(OpenAIBaseModel):
  601. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  602. object: str = "text_completion"
  603. created: int = Field(default_factory=lambda: int(time.time()))
  604. model: str
  605. choices: List[CompletionResponseChoice]
  606. usage: UsageInfo
  607. class CompletionResponseStreamChoice(OpenAIBaseModel):
  608. index: int
  609. text: str
  610. logprobs: Optional[CompletionLogProbs] = None
  611. finish_reason: Optional[str] = None
  612. stop_reason: Optional[Union[int, str]] = Field(
  613. default=None,
  614. description=(
  615. "The stop string or token id that caused the completion "
  616. "to stop, None if the completion finished for some other reason "
  617. "including encountering the EOS token"),
  618. )
  619. class CompletionStreamResponse(OpenAIBaseModel):
  620. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  621. object: str = "text_completion"
  622. created: int = Field(default_factory=lambda: int(time.time()))
  623. model: str
  624. choices: List[CompletionResponseStreamChoice]
  625. usage: Optional[UsageInfo] = Field(default=None)
  626. class EmbeddingResponseData(OpenAIBaseModel):
  627. index: int
  628. object: str = "embedding"
  629. embedding: Union[List[float], str]
  630. class EmbeddingResponse(OpenAIBaseModel):
  631. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  632. object: str = "list"
  633. created: int = Field(default_factory=lambda: int(time.time()))
  634. model: str
  635. data: List[EmbeddingResponseData]
  636. usage: UsageInfo
  637. class FunctionCall(OpenAIBaseModel):
  638. name: str
  639. arguments: str
  640. class ToolCall(OpenAIBaseModel):
  641. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  642. type: Literal["function"] = "function"
  643. function: FunctionCall
  644. class ChatMessage(OpenAIBaseModel):
  645. role: str
  646. content: str
  647. tool_calls: List[ToolCall] = Field(default_factory=list)
  648. class ChatCompletionLogProb(OpenAIBaseModel):
  649. token: str
  650. logprob: float = -9999.0
  651. bytes: Optional[List[int]] = None
  652. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  653. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  654. class ChatCompletionLogProbs(OpenAIBaseModel):
  655. content: Optional[List[ChatCompletionLogProbsContent]] = None
  656. class ChatCompletionResponseChoice(OpenAIBaseModel):
  657. index: int
  658. message: ChatMessage
  659. logprobs: Optional[ChatCompletionLogProbs] = None
  660. finish_reason: Optional[str] = None
  661. stop_reason: Optional[Union[int, str]] = None
  662. class ChatCompletionResponse(OpenAIBaseModel):
  663. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  664. object: Literal["chat.completion"] = "chat.completion"
  665. created: int = Field(default_factory=lambda: int(time.time()))
  666. model: str
  667. choices: List[ChatCompletionResponseChoice]
  668. usage: UsageInfo
  669. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  670. class DeltaMessage(OpenAIBaseModel):
  671. role: Optional[str] = None
  672. content: Optional[str] = None
  673. tool_calls: List[ToolCall] = Field(default_factory=list)
  674. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  675. index: int
  676. delta: DeltaMessage
  677. logprobs: Optional[ChatCompletionLogProbs] = None
  678. finish_reason: Optional[str] = None
  679. stop_reason: Optional[Union[int, str]] = None
  680. class ChatCompletionStreamResponse(OpenAIBaseModel):
  681. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  682. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  683. created: int = Field(default_factory=lambda: int(time.time()))
  684. model: str
  685. choices: List[ChatCompletionResponseStreamChoice]
  686. usage: Optional[UsageInfo] = Field(default=None)
  687. class BatchRequestInput(OpenAIBaseModel):
  688. """
  689. The per-line object of the batch input file.
  690. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  691. """
  692. # A developer-provided per-request id that will be used to match outputs to
  693. # inputs. Must be unique for each request in a batch.
  694. custom_id: str
  695. # The HTTP method to be used for the request. Currently only POST is
  696. # supported.
  697. method: str
  698. # The OpenAI API relative URL to be used for the request. Currently
  699. # /v1/chat/completions is supported.
  700. url: str
  701. # The parameters of the request.
  702. body: Union[ChatCompletionRequest, EmbeddingRequest]
  703. class BatchResponseData(OpenAIBaseModel):
  704. # HTTP status code of the response.
  705. status_code: int = 200
  706. # An unique identifier for the API request.
  707. request_id: str
  708. # The body of the response.
  709. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  710. class BatchRequestOutput(OpenAIBaseModel):
  711. """
  712. The per-line object of the batch output and error files
  713. """
  714. id: str
  715. # A developer-provided per-request id that will be used to match outputs to
  716. # inputs.
  717. custom_id: str
  718. response: Optional[BatchResponseData]
  719. # For requests that failed with a non-HTTP error, this will contain more
  720. # information on the cause of the failure.
  721. error: Optional[Any]
  722. class TokenizeCompletionRequest(OpenAIBaseModel):
  723. model: str
  724. prompt: str
  725. add_special_tokens: bool = Field(default=True)
  726. class TokenizeChatRequest(OpenAIBaseModel):
  727. model: str
  728. messages: List[ChatCompletionMessageParam]
  729. add_generation_prompt: bool = Field(default=True)
  730. add_special_tokens: bool = Field(default=False)
  731. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  732. class TokenizeResponse(OpenAIBaseModel):
  733. tokens: List[int]
  734. count: int
  735. max_model_len: int
  736. class DetokenizeRequest(OpenAIBaseModel):
  737. model: Optional[str]
  738. tokens: List[int]
  739. class DetokenizeResponse(OpenAIBaseModel):
  740. prompt: str
  741. # ========== KoboldAI ========== #
  742. class KAIGenerationInputSchema(BaseModel):
  743. genkey: Optional[str] = None
  744. prompt: str
  745. n: Optional[int] = 1
  746. max_context_length: int
  747. max_length: int
  748. rep_pen: Optional[float] = 1.0
  749. top_k: Optional[int] = 0
  750. top_a: Optional[float] = 0.0
  751. top_p: Optional[float] = 1.0
  752. min_p: Optional[float] = 0.0
  753. tfs: Optional[float] = 1.0
  754. eps_cutoff: Optional[float] = 0.0
  755. eta_cutoff: Optional[float] = 0.0
  756. typical: Optional[float] = 1.0
  757. temperature: Optional[float] = 1.0
  758. dynatemp_range: Optional[float] = 0.0
  759. dynatemp_exponent: Optional[float] = 1.0
  760. smoothing_factor: Optional[float] = 0.0
  761. smoothing_curve: Optional[float] = 1.0
  762. xtc_threshold: Optional[float] = 0.1
  763. xtc_probability: Optional[float] = 0.0
  764. use_default_badwordsids: Optional[bool] = None
  765. quiet: Optional[bool] = None
  766. # pylint: disable=unexpected-keyword-arg
  767. sampler_seed: Optional[int] = None
  768. stop_sequence: Optional[List[str]] = None
  769. include_stop_str_in_output: Optional[bool] = False
  770. @model_validator(mode='before')
  771. def check_context(cls, values): # pylint: disable=no-self-argument
  772. assert values.get("max_length") <= values.get(
  773. "max_context_length"
  774. ), "max_length must not be larger than max_context_length"
  775. return values