protocol.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import json
  4. import time
  5. from typing import Any, Dict, List, Literal, Optional, Union
  6. import torch
  7. from pydantic import BaseModel, ConfigDict, Field, model_validator
  8. from transformers import PreTrainedTokenizer
  9. from typing_extensions import Annotated
  10. from aphrodite.common.pooling_params import PoolingParams
  11. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  12. SamplingParams)
  13. from aphrodite.common.sequence import Logprob
  14. from aphrodite.common.utils import random_uuid
  15. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  16. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  17. class OpenAIBaseModel(BaseModel):
  18. model_config = ConfigDict(extra="ignore")
  19. class ErrorResponse(OpenAIBaseModel):
  20. object: str = "error"
  21. message: str
  22. type: str
  23. param: Optional[str] = None
  24. code: int
  25. class ModelPermission(OpenAIBaseModel):
  26. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  27. object: str = "model_permission"
  28. created: int = Field(default_factory=lambda: int(time.time()))
  29. allow_create_engine: bool = False
  30. allow_sampling: bool = True
  31. allow_logprobs: bool = True
  32. allow_search_indices: bool = False
  33. allow_view: bool = True
  34. allow_fine_tuning: bool = False
  35. organization: str = "*"
  36. group: Optional[str] = None
  37. is_blocking: bool = False
  38. class ModelCard(OpenAIBaseModel):
  39. id: str
  40. object: str = "model"
  41. created: int = Field(default_factory=lambda: int(time.time()))
  42. owned_by: str = "pygmalionai"
  43. root: Optional[str] = None
  44. parent: Optional[str] = None
  45. max_model_len: Optional[int] = None
  46. permission: List[ModelPermission] = Field(default_factory=list)
  47. class ModelList(OpenAIBaseModel):
  48. object: str = "list"
  49. data: List[ModelCard] = Field(default_factory=list)
  50. class UsageInfo(OpenAIBaseModel):
  51. prompt_tokens: int = 0
  52. total_tokens: int = 0
  53. completion_tokens: Optional[int] = 0
  54. class ResponseFormat(OpenAIBaseModel):
  55. # type must be "json_object" or "text"
  56. type: Literal["text", "json_object"]
  57. class StreamOptions(OpenAIBaseModel):
  58. include_usage: Optional[bool] = True
  59. continuous_usage_stats: Optional[bool] = True
  60. class FunctionDefinition(OpenAIBaseModel):
  61. name: str
  62. description: Optional[str] = None
  63. parameters: Optional[Dict[str, Any]] = None
  64. class ChatCompletionToolsParam(OpenAIBaseModel):
  65. type: Literal["function"] = "function"
  66. function: FunctionDefinition
  67. class ChatCompletionNamedFunction(OpenAIBaseModel):
  68. name: str
  69. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  70. function: ChatCompletionNamedFunction
  71. type: Literal["function"] = "function"
  72. class ChatCompletionRequest(OpenAIBaseModel):
  73. # Ordered by official OpenAI API documentation
  74. # https://platform.openai.com/docs/api-reference/chat/create
  75. messages: List[ChatCompletionMessageParam]
  76. model: str
  77. frequency_penalty: Optional[float] = 0.0
  78. logit_bias: Optional[Dict[str, float]] = None
  79. logprobs: Optional[bool] = False
  80. top_logprobs: Optional[int] = 0
  81. max_tokens: Optional[int] = None
  82. n: Optional[int] = 1
  83. presence_penalty: Optional[float] = 0.0
  84. response_format: Optional[ResponseFormat] = None
  85. seed: Optional[int] = Field(None,
  86. ge=torch.iinfo(torch.long).min,
  87. le=torch.iinfo(torch.long).max)
  88. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  89. stream: Optional[bool] = False
  90. stream_options: Optional[StreamOptions] = None
  91. temperature: Optional[float] = 0.7
  92. top_p: Optional[float] = 1.0
  93. tools: Optional[List[ChatCompletionToolsParam]] = None
  94. tool_choice: Optional[Union[Literal["none"],
  95. ChatCompletionNamedToolChoiceParam]] = "none"
  96. user: Optional[str] = None
  97. # doc: begin-chat-completion-sampling-params
  98. best_of: Optional[int] = None
  99. use_beam_search: Optional[bool] = False
  100. top_k: Optional[int] = -1
  101. min_p: Optional[float] = 0.0
  102. top_a: Optional[float] = 0.0
  103. tfs: Optional[float] = 1.0
  104. eta_cutoff: Optional[float] = 0.0
  105. epsilon_cutoff: Optional[float] = 0.0
  106. typical_p: Optional[float] = 1.0
  107. smoothing_factor: Optional[float] = 0.0
  108. smoothing_curve: Optional[float] = 1.0
  109. repetition_penalty: Optional[float] = 1.0
  110. no_repeat_ngram_size: Optional[int] = 0
  111. length_penalty: Optional[float] = 1.0
  112. early_stopping: Optional[bool] = False
  113. ignore_eos: Optional[bool] = False
  114. min_tokens: Optional[int] = 0
  115. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  116. skip_special_tokens: Optional[bool] = True
  117. spaces_between_special_tokens: Optional[bool] = True
  118. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  119. temperature_last: Optional[bool] = False
  120. prompt_logprobs: Optional[int] = None
  121. xtc_threshold: Optional[float] = 0.1
  122. xtc_probability: Optional[float] = 0.0
  123. dry_multiplier: Optional[float] = 0
  124. dry_base: Optional[float] = 1.75
  125. dry_allowed_length: Optional[int] = 2
  126. dry_sequence_breakers: Optional[List[str]] = Field(
  127. default=["\n", ":", "\"", "*"])
  128. dynatemp_min: Optional[float] = 0.0
  129. dynatemp_max: Optional[float] = 0.0
  130. dynatemp_exponent: Optional[float] = 1.0
  131. nsigma: Optional[float] = 0.0
  132. skew: Optional[float] = 0.0
  133. custom_token_bans: Optional[List[int]] = None
  134. # doc: end-chat-completion-sampling-params
  135. # doc: begin-chat-completion-extra-params
  136. echo: Optional[bool] = Field(
  137. default=False,
  138. description=(
  139. "If true, the new message will be prepended with the last message "
  140. "if they belong to the same role."),
  141. )
  142. add_generation_prompt: Optional[bool] = Field(
  143. default=True,
  144. description=
  145. ("If true, the generation prompt will be added to the chat template. "
  146. "This is a parameter used by chat template in tokenizer config of the "
  147. "model."),
  148. )
  149. add_special_tokens: Optional[bool] = Field(
  150. default=False,
  151. description=(
  152. "If true, special tokens (e.g. BOS) will be added to the prompt "
  153. "on top of what is added by the chat template. "
  154. "For most models, the chat template takes care of adding the "
  155. "special tokens so this should be set to False (as is the "
  156. "default)."),
  157. )
  158. documents: Optional[List[Dict[str, str]]] = Field(
  159. default=None,
  160. description=
  161. ("A list of dicts representing documents that will be accessible to "
  162. "the model if it is performing RAG (retrieval-augmented generation)."
  163. " If the template does not support RAG, this argument will have no "
  164. "effect. We recommend that each document should be a dict containing "
  165. "\"title\" and \"text\" keys."),
  166. )
  167. chat_template: Optional[str] = Field(
  168. default=None,
  169. description=(
  170. "A Jinja template to use for this conversion. "
  171. "As of transformers v4.44, default chat template is no longer "
  172. "allowed, so you must provide a chat template if the tokenizer "
  173. "does not define one."),
  174. )
  175. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  176. default=None,
  177. description=("Additional kwargs to pass to the template renderer. "
  178. "Will be accessible by the chat template."),
  179. )
  180. include_stop_str_in_output: Optional[bool] = Field(
  181. default=False,
  182. description=(
  183. "Whether to include the stop string in the output. "
  184. "This is only applied when the stop or stop_token_ids is set."),
  185. )
  186. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  187. default=None,
  188. description=("If specified, the output will follow the JSON schema."),
  189. )
  190. guided_regex: Optional[str] = Field(
  191. default=None,
  192. description=(
  193. "If specified, the output will follow the regex pattern."),
  194. )
  195. guided_choice: Optional[List[str]] = Field(
  196. default=None,
  197. description=(
  198. "If specified, the output will be exactly one of the choices."),
  199. )
  200. guided_grammar: Optional[str] = Field(
  201. default=None,
  202. description=(
  203. "If specified, the output will follow the context free grammar."),
  204. )
  205. guided_decoding_backend: Optional[str] = Field(
  206. default=None,
  207. description=(
  208. "If specified, will override the default guided decoding backend "
  209. "of the server for this specific request. If set, must be either "
  210. "'outlines' / 'lm-format-enforcer'"))
  211. guided_whitespace_pattern: Optional[str] = Field(
  212. default=None,
  213. description=(
  214. "If specified, will override the default whitespace pattern "
  215. "for guided json decoding."))
  216. # doc: end-chat-completion-extra-params
  217. def to_sampling_params(
  218. self, tokenizer: PreTrainedTokenizer,
  219. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  220. default_max_tokens: int) -> SamplingParams:
  221. max_tokens = self.max_tokens
  222. if max_tokens is None:
  223. max_tokens = default_max_tokens
  224. # We now allow logprobs being true without top_logrobs.
  225. logits_processors = get_logits_processors(
  226. logit_bias=self.logit_bias,
  227. allowed_token_ids=None,
  228. tokenizer=tokenizer,
  229. )
  230. if guided_decode_logits_processor:
  231. logits_processors.append(guided_decode_logits_processor)
  232. dry_sequence_breaker_ids = []
  233. if self.dry_sequence_breakers:
  234. for s in self.dry_sequence_breakers:
  235. token_id = tokenizer.encode(f'a{s}')[-1]
  236. dry_sequence_breaker_ids.append(token_id)
  237. return SamplingParams(
  238. n=self.n,
  239. presence_penalty=self.presence_penalty,
  240. frequency_penalty=self.frequency_penalty,
  241. repetition_penalty=self.repetition_penalty,
  242. no_repeat_ngram_size=self.no_repeat_ngram_size,
  243. temperature=self.temperature,
  244. top_p=self.top_p,
  245. min_p=self.min_p,
  246. seed=self.seed,
  247. stop=self.stop,
  248. stop_token_ids=self.stop_token_ids,
  249. max_tokens=max_tokens,
  250. min_tokens=self.min_tokens,
  251. logprobs=self.top_logprobs if self.logprobs else None,
  252. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  253. (self.top_logprobs if self.echo else None),
  254. best_of=self.best_of,
  255. top_k=self.top_k,
  256. top_a=self.top_a,
  257. tfs=self.tfs,
  258. eta_cutoff=self.eta_cutoff,
  259. epsilon_cutoff=self.epsilon_cutoff,
  260. typical_p=self.typical_p,
  261. smoothing_factor=self.smoothing_factor,
  262. smoothing_curve=self.smoothing_curve,
  263. ignore_eos=self.ignore_eos,
  264. use_beam_search=self.use_beam_search,
  265. early_stopping=self.early_stopping,
  266. skip_special_tokens=self.skip_special_tokens,
  267. spaces_between_special_tokens=self.spaces_between_special_tokens,
  268. include_stop_str_in_output=self.include_stop_str_in_output,
  269. length_penalty=self.length_penalty,
  270. logits_processors=logits_processors,
  271. temperature_last=self.temperature_last,
  272. xtc_threshold=self.xtc_threshold,
  273. xtc_probability=self.xtc_probability,
  274. dry_multiplier=self.dry_multiplier,
  275. dry_base=self.dry_base,
  276. dry_allowed_length=self.dry_allowed_length,
  277. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  278. dynatemp_min=self.dynatemp_min,
  279. dynatemp_max=self.dynatemp_max,
  280. dynatemp_exponent=self.dynatemp_exponent,
  281. nsigma=self.nsigma,
  282. skew=self.skew,
  283. custom_token_bans=self.custom_token_bans,
  284. )
  285. @model_validator(mode='before')
  286. @classmethod
  287. def validate_stream_options(cls, values):
  288. if (values.get('stream_options') is not None
  289. and not values.get('stream')):
  290. raise ValueError(
  291. "stream_options can only be set if stream is true")
  292. return values
  293. @model_validator(mode="before")
  294. @classmethod
  295. def check_guided_decoding_count(cls, data):
  296. guide_count = sum([
  297. "guided_json" in data and data["guided_json"] is not None,
  298. "guided_regex" in data and data["guided_regex"] is not None,
  299. "guided_choice" in data and data["guided_choice"] is not None
  300. ])
  301. # you can only use one kind of guided decoding
  302. if guide_count > 1:
  303. raise ValueError(
  304. "You can only use one kind of guided decoding "
  305. "('guided_json', 'guided_regex' or 'guided_choice').")
  306. # you can only either use guided decoding or tools, not both
  307. if guide_count > 1 and "tool_choice" in data and data[
  308. "tool_choice"] != "none":
  309. raise ValueError(
  310. "You can only either use guided decoding or tools, not both.")
  311. return data
  312. @model_validator(mode="before")
  313. @classmethod
  314. def check_tool_choice(cls, data):
  315. if "tool_choice" in data and data["tool_choice"] != "none":
  316. if not isinstance(data["tool_choice"], dict):
  317. raise ValueError("Currently only named tools are supported.")
  318. if "tools" not in data or data["tools"] is None:
  319. raise ValueError(
  320. "When using `tool_choice`, `tools` must be set.")
  321. return data
  322. @model_validator(mode="before")
  323. @classmethod
  324. def check_logprobs(cls, data):
  325. if "top_logprobs" in data and data["top_logprobs"] is not None:
  326. if "logprobs" not in data or data["logprobs"] is False:
  327. raise ValueError(
  328. "when using `top_logprobs`, `logprobs` must be set to true."
  329. )
  330. elif data["top_logprobs"] < 0:
  331. raise ValueError(
  332. "`top_logprobs` must be a value a positive value.")
  333. return data
  334. class CompletionRequest(OpenAIBaseModel):
  335. # Ordered by official OpenAI API documentation
  336. # https://platform.openai.com/docs/api-reference/completions/create
  337. model: str
  338. prompt: Union[List[int], List[List[int]], str, List[str]]
  339. best_of: Optional[int] = None
  340. echo: Optional[bool] = False
  341. frequency_penalty: Optional[float] = 0.0
  342. logit_bias: Optional[Dict[str, float]] = None
  343. logprobs: Optional[int] = None
  344. max_tokens: Optional[int] = 16
  345. n: int = 1
  346. presence_penalty: Optional[float] = 0.0
  347. seed: Optional[int] = Field(None,
  348. ge=torch.iinfo(torch.long).min,
  349. le=torch.iinfo(torch.long).max)
  350. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  351. stream: Optional[bool] = False
  352. stream_options: Optional[StreamOptions] = None
  353. suffix: Optional[str] = None
  354. temperature: Optional[float] = 1.0
  355. top_p: Optional[float] = 1.0
  356. user: Optional[str] = None
  357. # doc: begin-completion-sampling-params
  358. use_beam_search: Optional[bool] = False
  359. top_k: Optional[int] = -1
  360. min_p: Optional[float] = 0.0
  361. top_a: Optional[float] = 0.0
  362. tfs: Optional[float] = 1.0
  363. eta_cutoff: Optional[float] = 0.0
  364. epsilon_cutoff: Optional[float] = 0.0
  365. typical_p: Optional[float] = 1.0
  366. smoothing_factor: Optional[float] = 0.0
  367. smoothing_curve: Optional[float] = 1.0
  368. repetition_penalty: Optional[float] = 1.0
  369. no_repeat_ngram_size: Optional[int] = 0
  370. length_penalty: Optional[float] = 1.0
  371. early_stopping: Optional[bool] = False
  372. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  373. ignore_eos: Optional[bool] = False
  374. min_tokens: Optional[int] = 0
  375. skip_special_tokens: Optional[bool] = True
  376. spaces_between_special_tokens: Optional[bool] = True
  377. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  378. allowed_token_ids: Optional[List[int]] = None
  379. include_stop_str_in_output: Optional[bool] = False
  380. add_special_tokens: Optional[bool] = False
  381. temperature_last: Optional[bool] = False
  382. prompt_logprobs: Optional[int] = None
  383. xtc_threshold: Optional[float] = 0.1
  384. xtc_probability: Optional[float] = 0.0
  385. dry_multiplier: Optional[float] = 0
  386. dry_base: Optional[float] = 1.75
  387. dry_allowed_length: Optional[int] = 2
  388. dry_sequence_breakers: Optional[List[str]] = Field(
  389. default=["\n", ":", "\"", "*"])
  390. dynatemp_min: Optional[float] = 0.0
  391. dynatemp_max: Optional[float] = 0.0
  392. dynatemp_exponent: Optional[float] = 1.0
  393. nsigma: Optional[float] = 0.0
  394. skew: Optional[float] = 0.0
  395. custom_token_bans: Optional[List[int]] = None
  396. # doc: end-completion-sampling-params
  397. # doc: begin-completion-extra-params
  398. response_format: Optional[ResponseFormat] = Field(
  399. default=None,
  400. description=
  401. ("Similar to chat completion, this parameter specifies the format of "
  402. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  403. "supported."),
  404. )
  405. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  406. default=None,
  407. description=("If specified, the output will follow the JSON schema."),
  408. )
  409. guided_regex: Optional[str] = Field(
  410. default=None,
  411. description=(
  412. "If specified, the output will follow the regex pattern."),
  413. )
  414. guided_choice: Optional[List[str]] = Field(
  415. default=None,
  416. description=(
  417. "If specified, the output will be exactly one of the choices."),
  418. )
  419. guided_grammar: Optional[str] = Field(
  420. default=None,
  421. description=(
  422. "If specified, the output will follow the context free grammar."),
  423. )
  424. guided_decoding_backend: Optional[str] = Field(
  425. default=None,
  426. description=(
  427. "If specified, will override the default guided decoding backend "
  428. "of the server for this specific request. If set, must be one of "
  429. "'outlines' / 'lm-format-enforcer'"))
  430. guided_whitespace_pattern: Optional[str] = Field(
  431. default=None,
  432. description=(
  433. "If specified, will override the default whitespace pattern "
  434. "for guided json decoding."))
  435. # doc: end-completion-extra-params
  436. def to_sampling_params(
  437. self, tokenizer: PreTrainedTokenizer,
  438. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  439. default_max_tokens: int) -> SamplingParams:
  440. max_tokens = self.max_tokens
  441. if max_tokens is None:
  442. max_tokens = default_max_tokens
  443. echo_without_generation = self.echo and self.max_tokens == 0
  444. logits_processors = get_logits_processors(
  445. logit_bias=self.logit_bias,
  446. allowed_token_ids=self.allowed_token_ids,
  447. tokenizer=tokenizer,
  448. )
  449. if guided_decode_logits_processor:
  450. logits_processors.append(guided_decode_logits_processor)
  451. dry_sequence_breaker_ids = []
  452. if self.dry_sequence_breakers:
  453. for s in self.dry_sequence_breakers:
  454. s = bytes(s, "utf-8").decode("unicode_escape")
  455. token_id = tokenizer.encode(f'a{s}')[-1]
  456. dry_sequence_breaker_ids.append(token_id)
  457. return SamplingParams(
  458. n=self.n,
  459. best_of=self.best_of,
  460. presence_penalty=self.presence_penalty,
  461. frequency_penalty=self.frequency_penalty,
  462. repetition_penalty=self.repetition_penalty,
  463. no_repeat_ngram_size=self.no_repeat_ngram_size,
  464. temperature=self.temperature,
  465. top_p=self.top_p,
  466. top_k=self.top_k,
  467. min_p=self.min_p,
  468. top_a=self.top_a,
  469. tfs=self.tfs,
  470. eta_cutoff=self.eta_cutoff,
  471. epsilon_cutoff=self.epsilon_cutoff,
  472. typical_p=self.typical_p,
  473. smoothing_factor=self.smoothing_factor,
  474. smoothing_curve=self.smoothing_curve,
  475. seed=self.seed,
  476. stop=self.stop,
  477. stop_token_ids=self.stop_token_ids,
  478. ignore_eos=self.ignore_eos,
  479. max_tokens=max_tokens if not echo_without_generation else 1,
  480. min_tokens=self.min_tokens,
  481. logprobs=self.logprobs,
  482. prompt_logprobs=self.prompt_logprobs
  483. if self.prompt_logprobs else self.logprobs if self.echo else None,
  484. use_beam_search=self.use_beam_search,
  485. early_stopping=self.early_stopping,
  486. skip_special_tokens=self.skip_special_tokens,
  487. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  488. include_stop_str_in_output=self.include_stop_str_in_output,
  489. length_penalty=self.length_penalty,
  490. logits_processors=logits_processors,
  491. truncate_prompt_tokens=self.truncate_prompt_tokens,
  492. temperature_last=self.temperature_last,
  493. xtc_threshold=self.xtc_threshold,
  494. xtc_probability=self.xtc_probability,
  495. dry_multiplier=self.dry_multiplier,
  496. dry_base=self.dry_base,
  497. dry_allowed_length=self.dry_allowed_length,
  498. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  499. dynatemp_min=self.dynatemp_min,
  500. dynatemp_max=self.dynatemp_max,
  501. dynatemp_exponent=self.dynatemp_exponent,
  502. nsigma=self.nsigma,
  503. skew=self.skew,
  504. custom_token_bans=self.custom_token_bans,
  505. )
  506. @model_validator(mode="before")
  507. @classmethod
  508. def check_guided_decoding_count(cls, data):
  509. guide_count = sum([
  510. "guided_json" in data and data["guided_json"] is not None,
  511. "guided_regex" in data and data["guided_regex"] is not None,
  512. "guided_choice" in data and data["guided_choice"] is not None
  513. ])
  514. if guide_count > 1:
  515. raise ValueError(
  516. "You can only use one kind of guided decoding "
  517. "('guided_json', 'guided_regex' or 'guided_choice').")
  518. return data
  519. @model_validator(mode="before")
  520. @classmethod
  521. def check_logprobs(cls, data):
  522. if "logprobs" in data and data[
  523. "logprobs"] is not None and not data["logprobs"] >= 0:
  524. raise ValueError("if passed, `logprobs` must be a positive value.")
  525. return data
  526. @model_validator(mode="before")
  527. @classmethod
  528. def validate_stream_options(cls, data):
  529. if data.get("stream_options") and not data.get("stream"):
  530. raise ValueError(
  531. "Stream options can only be defined when stream is True.")
  532. return data
  533. @model_validator(mode='before')
  534. @classmethod
  535. def parse_dry_sequence_breakers(cls, data):
  536. if 'dry_sequence_breakers' in data:
  537. breakers = data['dry_sequence_breakers']
  538. if isinstance(breakers, str):
  539. try:
  540. # Try to parse as JSON string
  541. data['dry_sequence_breakers'] = json.loads(breakers)
  542. except json.JSONDecodeError as e:
  543. raise ValueError(f"Invalid JSON for dry_sequence_breakers:"
  544. f" {e}") from e
  545. # Validate that we now have a list of strings
  546. is_list = isinstance(data['dry_sequence_breakers'], list)
  547. all_strings = all(
  548. isinstance(x, str)
  549. for x in data['dry_sequence_breakers']
  550. )
  551. if not is_list or not all_strings:
  552. raise ValueError(
  553. "dry_sequence_breakers must be a list of strings or a "
  554. "JSON string representing a list of strings"
  555. )
  556. return data
  557. class EmbeddingRequest(OpenAIBaseModel):
  558. # Ordered by official OpenAI API documentation
  559. # https://platform.openai.com/docs/api-reference/embeddings
  560. model: str
  561. input: Union[List[int], List[List[int]], str, List[str]]
  562. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  563. dimensions: Optional[int] = None
  564. user: Optional[str] = None
  565. # doc: begin-embedding-pooling-params
  566. additional_data: Optional[Any] = None
  567. # doc: end-embedding-pooling-params
  568. def to_pooling_params(self):
  569. return PoolingParams(additional_data=self.additional_data)
  570. class CompletionLogProbs(OpenAIBaseModel):
  571. text_offset: List[int] = Field(default_factory=list)
  572. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  573. tokens: List[str] = Field(default_factory=list)
  574. top_logprobs: List[Optional[Dict[str,
  575. float]]] = Field(default_factory=list)
  576. class CompletionResponseChoice(OpenAIBaseModel):
  577. index: int
  578. text: str
  579. logprobs: Optional[CompletionLogProbs] = None
  580. finish_reason: Optional[str] = None
  581. stop_reason: Optional[Union[int, str]] = Field(
  582. default=None,
  583. description=(
  584. "The stop string or token id that caused the completion "
  585. "to stop, None if the completion finished for some other reason "
  586. "including encountering the EOS token"),
  587. )
  588. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  589. class CompletionResponse(OpenAIBaseModel):
  590. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  591. object: str = "text_completion"
  592. created: int = Field(default_factory=lambda: int(time.time()))
  593. model: str
  594. choices: List[CompletionResponseChoice]
  595. usage: UsageInfo
  596. class CompletionResponseStreamChoice(OpenAIBaseModel):
  597. index: int
  598. text: str
  599. logprobs: Optional[CompletionLogProbs] = None
  600. finish_reason: Optional[str] = None
  601. stop_reason: Optional[Union[int, str]] = Field(
  602. default=None,
  603. description=(
  604. "The stop string or token id that caused the completion "
  605. "to stop, None if the completion finished for some other reason "
  606. "including encountering the EOS token"),
  607. )
  608. class CompletionStreamResponse(OpenAIBaseModel):
  609. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  610. object: str = "text_completion"
  611. created: int = Field(default_factory=lambda: int(time.time()))
  612. model: str
  613. choices: List[CompletionResponseStreamChoice]
  614. usage: Optional[UsageInfo] = Field(default=None)
  615. class EmbeddingResponseData(OpenAIBaseModel):
  616. index: int
  617. object: str = "embedding"
  618. embedding: Union[List[float], str]
  619. class EmbeddingResponse(OpenAIBaseModel):
  620. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  621. object: str = "list"
  622. created: int = Field(default_factory=lambda: int(time.time()))
  623. model: str
  624. data: List[EmbeddingResponseData]
  625. usage: UsageInfo
  626. class FunctionCall(OpenAIBaseModel):
  627. name: str
  628. arguments: str
  629. class ToolCall(OpenAIBaseModel):
  630. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  631. type: Literal["function"] = "function"
  632. function: FunctionCall
  633. class ChatMessage(OpenAIBaseModel):
  634. role: str
  635. content: str
  636. tool_calls: List[ToolCall] = Field(default_factory=list)
  637. class ChatCompletionLogProb(OpenAIBaseModel):
  638. token: str
  639. logprob: float = -9999.0
  640. bytes: Optional[List[int]] = None
  641. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  642. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  643. class ChatCompletionLogProbs(OpenAIBaseModel):
  644. content: Optional[List[ChatCompletionLogProbsContent]] = None
  645. class ChatCompletionResponseChoice(OpenAIBaseModel):
  646. index: int
  647. message: ChatMessage
  648. logprobs: Optional[ChatCompletionLogProbs] = None
  649. finish_reason: Optional[str] = None
  650. stop_reason: Optional[Union[int, str]] = None
  651. class ChatCompletionResponse(OpenAIBaseModel):
  652. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  653. object: Literal["chat.completion"] = "chat.completion"
  654. created: int = Field(default_factory=lambda: int(time.time()))
  655. model: str
  656. choices: List[ChatCompletionResponseChoice]
  657. usage: UsageInfo
  658. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  659. class DeltaMessage(OpenAIBaseModel):
  660. role: Optional[str] = None
  661. content: Optional[str] = None
  662. tool_calls: List[ToolCall] = Field(default_factory=list)
  663. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  664. index: int
  665. delta: DeltaMessage
  666. logprobs: Optional[ChatCompletionLogProbs] = None
  667. finish_reason: Optional[str] = None
  668. stop_reason: Optional[Union[int, str]] = None
  669. class ChatCompletionStreamResponse(OpenAIBaseModel):
  670. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  671. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  672. created: int = Field(default_factory=lambda: int(time.time()))
  673. model: str
  674. choices: List[ChatCompletionResponseStreamChoice]
  675. usage: Optional[UsageInfo] = Field(default=None)
  676. class BatchRequestInput(OpenAIBaseModel):
  677. """
  678. The per-line object of the batch input file.
  679. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  680. """
  681. # A developer-provided per-request id that will be used to match outputs to
  682. # inputs. Must be unique for each request in a batch.
  683. custom_id: str
  684. # The HTTP method to be used for the request. Currently only POST is
  685. # supported.
  686. method: str
  687. # The OpenAI API relative URL to be used for the request. Currently
  688. # /v1/chat/completions is supported.
  689. url: str
  690. # The parameters of the request.
  691. body: Union[ChatCompletionRequest, EmbeddingRequest]
  692. class BatchResponseData(OpenAIBaseModel):
  693. # HTTP status code of the response.
  694. status_code: int = 200
  695. # An unique identifier for the API request.
  696. request_id: str
  697. # The body of the response.
  698. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  699. class BatchRequestOutput(OpenAIBaseModel):
  700. """
  701. The per-line object of the batch output and error files
  702. """
  703. id: str
  704. # A developer-provided per-request id that will be used to match outputs to
  705. # inputs.
  706. custom_id: str
  707. response: Optional[BatchResponseData]
  708. # For requests that failed with a non-HTTP error, this will contain more
  709. # information on the cause of the failure.
  710. error: Optional[Any]
  711. class TokenizeCompletionRequest(OpenAIBaseModel):
  712. model: str
  713. prompt: str
  714. add_special_tokens: bool = Field(default=True)
  715. class TokenizeChatRequest(OpenAIBaseModel):
  716. model: str
  717. messages: List[ChatCompletionMessageParam]
  718. add_generation_prompt: bool = Field(default=True)
  719. add_special_tokens: bool = Field(default=False)
  720. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  721. class TokenizeResponse(OpenAIBaseModel):
  722. tokens: List[int]
  723. count: int
  724. max_model_len: int
  725. class DetokenizeRequest(OpenAIBaseModel):
  726. model: Optional[str]
  727. tokens: List[int]
  728. class DetokenizeResponse(OpenAIBaseModel):
  729. prompt: str
  730. # ========== KoboldAI ========== #
  731. class KAIGenerationInputSchema(BaseModel):
  732. genkey: Optional[str] = None
  733. prompt: str
  734. n: Optional[int] = 1
  735. max_context_length: int
  736. max_length: int
  737. rep_pen: Optional[float] = 1.0
  738. top_k: Optional[int] = 0
  739. top_a: Optional[float] = 0.0
  740. top_p: Optional[float] = 1.0
  741. min_p: Optional[float] = 0.0
  742. tfs: Optional[float] = 1.0
  743. eps_cutoff: Optional[float] = 0.0
  744. eta_cutoff: Optional[float] = 0.0
  745. typical: Optional[float] = 1.0
  746. temperature: Optional[float] = 1.0
  747. dynatemp_range: Optional[float] = 0.0
  748. dynatemp_exponent: Optional[float] = 1.0
  749. smoothing_factor: Optional[float] = 0.0
  750. smoothing_curve: Optional[float] = 1.0
  751. xtc_threshold: Optional[float] = 0.1
  752. xtc_probability: Optional[float] = 0.0
  753. use_default_badwordsids: Optional[bool] = None
  754. quiet: Optional[bool] = None
  755. # pylint: disable=unexpected-keyword-arg
  756. sampler_seed: Optional[int] = None
  757. stop_sequence: Optional[List[str]] = None
  758. include_stop_str_in_output: Optional[bool] = False
  759. @model_validator(mode='before')
  760. def check_context(cls, values): # pylint: disable=no-self-argument
  761. assert values.get("max_length") <= values.get(
  762. "max_context_length"
  763. ), "max_length must not be larger than max_context_length"
  764. return values