1
0

protocol.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import json
  4. import time
  5. from typing import Any, Dict, List, Literal, Optional, Union
  6. import torch
  7. from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
  8. model_validator)
  9. from transformers import PreTrainedTokenizer
  10. from typing_extensions import Annotated
  11. from aphrodite.common.pooling_params import PoolingParams
  12. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  13. SamplingParams)
  14. from aphrodite.common.sequence import Logprob
  15. from aphrodite.common.utils import random_uuid
  16. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  17. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  18. class OpenAIBaseModel(BaseModel):
  19. model_config = ConfigDict(extra="ignore")
  20. class ErrorResponse(OpenAIBaseModel):
  21. object: str = "error"
  22. message: str
  23. type: str
  24. param: Optional[str] = None
  25. code: int
  26. class ModelPermission(OpenAIBaseModel):
  27. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  28. object: str = "model_permission"
  29. created: int = Field(default_factory=lambda: int(time.time()))
  30. allow_create_engine: bool = False
  31. allow_sampling: bool = True
  32. allow_logprobs: bool = True
  33. allow_search_indices: bool = False
  34. allow_view: bool = True
  35. allow_fine_tuning: bool = False
  36. organization: str = "*"
  37. group: Optional[str] = None
  38. is_blocking: bool = False
  39. class ModelCard(OpenAIBaseModel):
  40. id: str
  41. object: str = "model"
  42. created: int = Field(default_factory=lambda: int(time.time()))
  43. owned_by: str = "pygmalionai"
  44. root: Optional[str] = None
  45. parent: Optional[str] = None
  46. max_model_len: Optional[int] = None
  47. permission: List[ModelPermission] = Field(default_factory=list)
  48. class ModelList(OpenAIBaseModel):
  49. object: str = "list"
  50. data: List[ModelCard] = Field(default_factory=list)
  51. class UsageInfo(OpenAIBaseModel):
  52. prompt_tokens: int = 0
  53. total_tokens: int = 0
  54. completion_tokens: Optional[int] = 0
  55. class ResponseFormat(OpenAIBaseModel):
  56. # type must be "json_object" or "text"
  57. type: Literal["text", "json_object"]
  58. class StreamOptions(OpenAIBaseModel):
  59. include_usage: Optional[bool] = True
  60. continuous_usage_stats: Optional[bool] = True
  61. class FunctionDefinition(OpenAIBaseModel):
  62. name: str
  63. description: Optional[str] = None
  64. parameters: Optional[Dict[str, Any]] = None
  65. class ChatCompletionToolsParam(OpenAIBaseModel):
  66. type: Literal["function"] = "function"
  67. function: FunctionDefinition
  68. class ChatCompletionNamedFunction(OpenAIBaseModel):
  69. name: str
  70. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  71. function: ChatCompletionNamedFunction
  72. type: Literal["function"] = "function"
  73. class ChatCompletionRequest(OpenAIBaseModel):
  74. # Ordered by official OpenAI API documentation
  75. # https://platform.openai.com/docs/api-reference/chat/create
  76. messages: List[ChatCompletionMessageParam]
  77. model: str
  78. frequency_penalty: Optional[float] = 0.0
  79. logit_bias: Optional[Dict[str, float]] = None
  80. logprobs: Optional[bool] = False
  81. top_logprobs: Optional[int] = 0
  82. max_tokens: Optional[int] = None
  83. n: Optional[int] = 1
  84. presence_penalty: Optional[float] = 0.0
  85. response_format: Optional[ResponseFormat] = None
  86. seed: Optional[int] = Field(None,
  87. ge=torch.iinfo(torch.long).min,
  88. le=torch.iinfo(torch.long).max)
  89. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  90. stream: Optional[bool] = False
  91. stream_options: Optional[StreamOptions] = None
  92. temperature: Optional[float] = 0.7
  93. top_p: Optional[float] = 1.0
  94. tools: Optional[List[ChatCompletionToolsParam]] = None
  95. tool_choice: Optional[Union[Literal["none"],
  96. ChatCompletionNamedToolChoiceParam]] = "none"
  97. user: Optional[str] = None
  98. # doc: begin-chat-completion-sampling-params
  99. best_of: Optional[int] = None
  100. use_beam_search: Optional[bool] = False
  101. top_k: Optional[int] = -1
  102. min_p: Optional[float] = 0.0
  103. top_a: Optional[float] = 0.0
  104. tfs: Optional[float] = 1.0
  105. eta_cutoff: Optional[float] = 0.0
  106. epsilon_cutoff: Optional[float] = 0.0
  107. typical_p: Optional[float] = 1.0
  108. smoothing_factor: Optional[float] = 0.0
  109. smoothing_curve: Optional[float] = 1.0
  110. repetition_penalty: Optional[float] = 1.0
  111. no_repeat_ngram_size: Optional[int] = 0
  112. length_penalty: Optional[float] = 1.0
  113. early_stopping: Optional[bool] = False
  114. ignore_eos: Optional[bool] = False
  115. min_tokens: Optional[int] = 0
  116. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  117. skip_special_tokens: Optional[bool] = True
  118. spaces_between_special_tokens: Optional[bool] = True
  119. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  120. temperature_last: Optional[bool] = False
  121. prompt_logprobs: Optional[int] = None
  122. xtc_threshold: Optional[float] = 0.1
  123. xtc_probability: Optional[float] = 0.0
  124. dry_multiplier: Optional[float] = 0
  125. dry_base: Optional[float] = 1.75
  126. dry_allowed_length: Optional[int] = 2
  127. dry_sequence_breakers: Optional[List[str]] = Field(
  128. default=["\n", ":", "\"", "*"])
  129. dry_range: Optional[int] = Field(
  130. default=0,
  131. validation_alias=AliasChoices("dry_range",
  132. "dry_penalty_last_n"))
  133. dynatemp_min: Optional[float] = 0.0
  134. dynatemp_max: Optional[float] = 0.0
  135. dynatemp_exponent: Optional[float] = 1.0
  136. nsigma: Optional[float] = 0.0
  137. skew: Optional[float] = 0.0
  138. custom_token_bans: Optional[List[int]] = None
  139. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  140. default=[],
  141. validation_alias=AliasChoices("sampler_priority",
  142. "sampler_order"))
  143. # doc: end-chat-completion-sampling-params
  144. # doc: begin-chat-completion-extra-params
  145. echo: Optional[bool] = Field(
  146. default=False,
  147. description=(
  148. "If true, the new message will be prepended with the last message "
  149. "if they belong to the same role."),
  150. )
  151. add_generation_prompt: Optional[bool] = Field(
  152. default=True,
  153. description=
  154. ("If true, the generation prompt will be added to the chat template. "
  155. "This is a parameter used by chat template in tokenizer config of the "
  156. "model."),
  157. )
  158. add_special_tokens: Optional[bool] = Field(
  159. default=False,
  160. description=(
  161. "If true, special tokens (e.g. BOS) will be added to the prompt "
  162. "on top of what is added by the chat template. "
  163. "For most models, the chat template takes care of adding the "
  164. "special tokens so this should be set to False (as is the "
  165. "default)."),
  166. )
  167. documents: Optional[List[Dict[str, str]]] = Field(
  168. default=None,
  169. description=
  170. ("A list of dicts representing documents that will be accessible to "
  171. "the model if it is performing RAG (retrieval-augmented generation)."
  172. " If the template does not support RAG, this argument will have no "
  173. "effect. We recommend that each document should be a dict containing "
  174. "\"title\" and \"text\" keys."),
  175. )
  176. chat_template: Optional[str] = Field(
  177. default=None,
  178. description=(
  179. "A Jinja template to use for this conversion. "
  180. "As of transformers v4.44, default chat template is no longer "
  181. "allowed, so you must provide a chat template if the tokenizer "
  182. "does not define one."),
  183. )
  184. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  185. default=None,
  186. description=("Additional kwargs to pass to the template renderer. "
  187. "Will be accessible by the chat template."),
  188. )
  189. include_stop_str_in_output: Optional[bool] = Field(
  190. default=False,
  191. description=(
  192. "Whether to include the stop string in the output. "
  193. "This is only applied when the stop or stop_token_ids is set."),
  194. )
  195. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  196. default=None,
  197. description=("If specified, the output will follow the JSON schema."),
  198. )
  199. guided_regex: Optional[str] = Field(
  200. default=None,
  201. description=(
  202. "If specified, the output will follow the regex pattern."),
  203. )
  204. guided_choice: Optional[List[str]] = Field(
  205. default=None,
  206. description=(
  207. "If specified, the output will be exactly one of the choices."),
  208. )
  209. guided_grammar: Optional[str] = Field(
  210. default=None,
  211. description=(
  212. "If specified, the output will follow the context free grammar."),
  213. )
  214. guided_decoding_backend: Optional[str] = Field(
  215. default=None,
  216. description=(
  217. "If specified, will override the default guided decoding backend "
  218. "of the server for this specific request. If set, must be either "
  219. "'outlines' / 'lm-format-enforcer'"))
  220. guided_whitespace_pattern: Optional[str] = Field(
  221. default=None,
  222. description=(
  223. "If specified, will override the default whitespace pattern "
  224. "for guided json decoding."))
  225. # doc: end-chat-completion-extra-params
  226. def to_sampling_params(
  227. self, tokenizer: PreTrainedTokenizer,
  228. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  229. default_max_tokens: int) -> SamplingParams:
  230. max_tokens = self.max_tokens
  231. if max_tokens is None:
  232. max_tokens = default_max_tokens
  233. # We now allow logprobs being true without top_logrobs.
  234. logits_processors = get_logits_processors(
  235. logit_bias=self.logit_bias,
  236. allowed_token_ids=None,
  237. tokenizer=tokenizer,
  238. )
  239. if guided_decode_logits_processor:
  240. logits_processors.append(guided_decode_logits_processor)
  241. dry_sequence_breaker_ids = []
  242. if self.dry_sequence_breakers:
  243. for s in self.dry_sequence_breakers:
  244. token_id = tokenizer.encode(f'a{s}')[-1]
  245. dry_sequence_breaker_ids.append(token_id)
  246. return SamplingParams(
  247. n=self.n,
  248. presence_penalty=self.presence_penalty,
  249. frequency_penalty=self.frequency_penalty,
  250. repetition_penalty=self.repetition_penalty,
  251. no_repeat_ngram_size=self.no_repeat_ngram_size,
  252. temperature=self.temperature,
  253. top_p=self.top_p,
  254. min_p=self.min_p,
  255. seed=self.seed,
  256. stop=self.stop,
  257. stop_token_ids=self.stop_token_ids,
  258. max_tokens=max_tokens,
  259. min_tokens=self.min_tokens,
  260. logprobs=self.top_logprobs if self.logprobs else None,
  261. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  262. (self.top_logprobs if self.echo else None),
  263. best_of=self.best_of,
  264. top_k=self.top_k,
  265. top_a=self.top_a,
  266. tfs=self.tfs,
  267. eta_cutoff=self.eta_cutoff,
  268. epsilon_cutoff=self.epsilon_cutoff,
  269. typical_p=self.typical_p,
  270. smoothing_factor=self.smoothing_factor,
  271. smoothing_curve=self.smoothing_curve,
  272. ignore_eos=self.ignore_eos,
  273. use_beam_search=self.use_beam_search,
  274. early_stopping=self.early_stopping,
  275. skip_special_tokens=self.skip_special_tokens,
  276. spaces_between_special_tokens=self.spaces_between_special_tokens,
  277. include_stop_str_in_output=self.include_stop_str_in_output,
  278. length_penalty=self.length_penalty,
  279. logits_processors=logits_processors,
  280. temperature_last=self.temperature_last,
  281. xtc_threshold=self.xtc_threshold,
  282. xtc_probability=self.xtc_probability,
  283. dry_multiplier=self.dry_multiplier,
  284. dry_base=self.dry_base,
  285. dry_allowed_length=self.dry_allowed_length,
  286. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  287. dry_range=self.dry_range,
  288. dynatemp_min=self.dynatemp_min,
  289. dynatemp_max=self.dynatemp_max,
  290. dynatemp_exponent=self.dynatemp_exponent,
  291. nsigma=self.nsigma,
  292. skew=self.skew,
  293. custom_token_bans=self.custom_token_bans,
  294. sampler_priority=self.sampler_priority,
  295. )
  296. @model_validator(mode='before')
  297. @classmethod
  298. def validate_stream_options(cls, values):
  299. if (values.get('stream_options') is not None
  300. and not values.get('stream')):
  301. raise ValueError(
  302. "stream_options can only be set if stream is true")
  303. return values
  304. @model_validator(mode="before")
  305. @classmethod
  306. def check_guided_decoding_count(cls, data):
  307. guide_count = sum([
  308. "guided_json" in data and data["guided_json"] is not None,
  309. "guided_regex" in data and data["guided_regex"] is not None,
  310. "guided_choice" in data and data["guided_choice"] is not None
  311. ])
  312. # you can only use one kind of guided decoding
  313. if guide_count > 1:
  314. raise ValueError(
  315. "You can only use one kind of guided decoding "
  316. "('guided_json', 'guided_regex' or 'guided_choice').")
  317. # you can only either use guided decoding or tools, not both
  318. if guide_count > 1 and "tool_choice" in data and data[
  319. "tool_choice"] != "none":
  320. raise ValueError(
  321. "You can only either use guided decoding or tools, not both.")
  322. return data
  323. @model_validator(mode="before")
  324. @classmethod
  325. def check_tool_choice(cls, data):
  326. if "tool_choice" in data and data["tool_choice"] != "none":
  327. if not isinstance(data["tool_choice"], dict):
  328. raise ValueError("Currently only named tools are supported.")
  329. if "tools" not in data or data["tools"] is None:
  330. raise ValueError(
  331. "When using `tool_choice`, `tools` must be set.")
  332. return data
  333. @model_validator(mode="before")
  334. @classmethod
  335. def check_logprobs(cls, data):
  336. if "top_logprobs" in data and data["top_logprobs"] is not None:
  337. if "logprobs" not in data or data["logprobs"] is False:
  338. raise ValueError(
  339. "when using `top_logprobs`, `logprobs` must be set to true."
  340. )
  341. elif data["top_logprobs"] < 0:
  342. raise ValueError(
  343. "`top_logprobs` must be a value a positive value.")
  344. return data
  345. class CompletionRequest(OpenAIBaseModel):
  346. # Ordered by official OpenAI API documentation
  347. # https://platform.openai.com/docs/api-reference/completions/create
  348. model: str
  349. prompt: Union[List[int], List[List[int]], str, List[str]]
  350. best_of: Optional[int] = None
  351. echo: Optional[bool] = False
  352. frequency_penalty: Optional[float] = 0.0
  353. logit_bias: Optional[Dict[str, float]] = None
  354. logprobs: Optional[int] = None
  355. max_tokens: Optional[int] = 16
  356. n: int = 1
  357. presence_penalty: Optional[float] = 0.0
  358. seed: Optional[int] = Field(None,
  359. ge=torch.iinfo(torch.long).min,
  360. le=torch.iinfo(torch.long).max)
  361. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  362. stream: Optional[bool] = False
  363. stream_options: Optional[StreamOptions] = None
  364. suffix: Optional[str] = None
  365. temperature: Optional[float] = 1.0
  366. top_p: Optional[float] = 1.0
  367. user: Optional[str] = None
  368. # doc: begin-completion-sampling-params
  369. use_beam_search: Optional[bool] = False
  370. top_k: Optional[int] = -1
  371. min_p: Optional[float] = 0.0
  372. top_a: Optional[float] = 0.0
  373. tfs: Optional[float] = 1.0
  374. eta_cutoff: Optional[float] = 0.0
  375. epsilon_cutoff: Optional[float] = 0.0
  376. typical_p: Optional[float] = 1.0
  377. smoothing_factor: Optional[float] = 0.0
  378. smoothing_curve: Optional[float] = 1.0
  379. repetition_penalty: Optional[float] = 1.0
  380. no_repeat_ngram_size: Optional[int] = 0
  381. length_penalty: Optional[float] = 1.0
  382. early_stopping: Optional[bool] = False
  383. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  384. ignore_eos: Optional[bool] = False
  385. min_tokens: Optional[int] = 0
  386. skip_special_tokens: Optional[bool] = True
  387. spaces_between_special_tokens: Optional[bool] = True
  388. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  389. allowed_token_ids: Optional[List[int]] = None
  390. include_stop_str_in_output: Optional[bool] = False
  391. add_special_tokens: Optional[bool] = False
  392. temperature_last: Optional[bool] = False
  393. prompt_logprobs: Optional[int] = None
  394. xtc_threshold: Optional[float] = 0.1
  395. xtc_probability: Optional[float] = 0.0
  396. dry_multiplier: Optional[float] = 0
  397. dry_base: Optional[float] = 1.75
  398. dry_allowed_length: Optional[int] = 2
  399. dry_sequence_breakers: Optional[List[str]] = Field(
  400. default=["\n", ":", "\"", "*"])
  401. dry_range: Optional[int] = Field(
  402. default=0,
  403. validation_alias=AliasChoices("dry_range",
  404. "dry_penalty_last_n"))
  405. dynatemp_min: Optional[float] = 0.0
  406. dynatemp_max: Optional[float] = 0.0
  407. dynatemp_exponent: Optional[float] = 1.0
  408. nsigma: Optional[float] = 0.0
  409. skew: Optional[float] = 0.0
  410. custom_token_bans: Optional[List[int]] = None
  411. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  412. default=[],
  413. validation_alias=AliasChoices("sampler_priority",
  414. "sampler_order"))
  415. # doc: end-completion-sampling-params
  416. # doc: begin-completion-extra-params
  417. response_format: Optional[ResponseFormat] = Field(
  418. default=None,
  419. description=
  420. ("Similar to chat completion, this parameter specifies the format of "
  421. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  422. "supported."),
  423. )
  424. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  425. default=None,
  426. description=("If specified, the output will follow the JSON schema."),
  427. )
  428. guided_regex: Optional[str] = Field(
  429. default=None,
  430. description=(
  431. "If specified, the output will follow the regex pattern."),
  432. )
  433. guided_choice: Optional[List[str]] = Field(
  434. default=None,
  435. description=(
  436. "If specified, the output will be exactly one of the choices."),
  437. )
  438. guided_grammar: Optional[str] = Field(
  439. default=None,
  440. description=(
  441. "If specified, the output will follow the context free grammar."),
  442. )
  443. guided_decoding_backend: Optional[str] = Field(
  444. default=None,
  445. description=(
  446. "If specified, will override the default guided decoding backend "
  447. "of the server for this specific request. If set, must be one of "
  448. "'outlines' / 'lm-format-enforcer'"))
  449. guided_whitespace_pattern: Optional[str] = Field(
  450. default=None,
  451. description=(
  452. "If specified, will override the default whitespace pattern "
  453. "for guided json decoding."))
  454. # doc: end-completion-extra-params
  455. def to_sampling_params(
  456. self, tokenizer: PreTrainedTokenizer,
  457. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  458. default_max_tokens: int) -> SamplingParams:
  459. max_tokens = self.max_tokens
  460. if max_tokens is None:
  461. max_tokens = default_max_tokens
  462. echo_without_generation = self.echo and self.max_tokens == 0
  463. logits_processors = get_logits_processors(
  464. logit_bias=self.logit_bias,
  465. allowed_token_ids=self.allowed_token_ids,
  466. tokenizer=tokenizer,
  467. )
  468. if guided_decode_logits_processor:
  469. logits_processors.append(guided_decode_logits_processor)
  470. dry_sequence_breaker_ids = []
  471. if self.dry_sequence_breakers:
  472. for s in self.dry_sequence_breakers:
  473. s = bytes(s, "utf-8").decode("unicode_escape")
  474. token_id = tokenizer.encode(f'a{s}')[-1]
  475. dry_sequence_breaker_ids.append(token_id)
  476. return SamplingParams(
  477. n=self.n,
  478. best_of=self.best_of,
  479. presence_penalty=self.presence_penalty,
  480. frequency_penalty=self.frequency_penalty,
  481. repetition_penalty=self.repetition_penalty,
  482. no_repeat_ngram_size=self.no_repeat_ngram_size,
  483. temperature=self.temperature,
  484. top_p=self.top_p,
  485. top_k=self.top_k,
  486. min_p=self.min_p,
  487. top_a=self.top_a,
  488. tfs=self.tfs,
  489. eta_cutoff=self.eta_cutoff,
  490. epsilon_cutoff=self.epsilon_cutoff,
  491. typical_p=self.typical_p,
  492. smoothing_factor=self.smoothing_factor,
  493. smoothing_curve=self.smoothing_curve,
  494. seed=self.seed,
  495. stop=self.stop,
  496. stop_token_ids=self.stop_token_ids,
  497. ignore_eos=self.ignore_eos,
  498. max_tokens=max_tokens if not echo_without_generation else 1,
  499. min_tokens=self.min_tokens,
  500. logprobs=self.logprobs,
  501. prompt_logprobs=self.prompt_logprobs
  502. if self.prompt_logprobs else self.logprobs if self.echo else None,
  503. use_beam_search=self.use_beam_search,
  504. early_stopping=self.early_stopping,
  505. skip_special_tokens=self.skip_special_tokens,
  506. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  507. include_stop_str_in_output=self.include_stop_str_in_output,
  508. length_penalty=self.length_penalty,
  509. logits_processors=logits_processors,
  510. truncate_prompt_tokens=self.truncate_prompt_tokens,
  511. temperature_last=self.temperature_last,
  512. xtc_threshold=self.xtc_threshold,
  513. xtc_probability=self.xtc_probability,
  514. dry_multiplier=self.dry_multiplier,
  515. dry_base=self.dry_base,
  516. dry_allowed_length=self.dry_allowed_length,
  517. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  518. dry_range=self.dry_range,
  519. dynatemp_min=self.dynatemp_min,
  520. dynatemp_max=self.dynatemp_max,
  521. dynatemp_exponent=self.dynatemp_exponent,
  522. nsigma=self.nsigma,
  523. skew=self.skew,
  524. custom_token_bans=self.custom_token_bans,
  525. sampler_priority=self.sampler_priority,
  526. )
  527. @model_validator(mode="before")
  528. @classmethod
  529. def check_guided_decoding_count(cls, data):
  530. guide_count = sum([
  531. "guided_json" in data and data["guided_json"] is not None,
  532. "guided_regex" in data and data["guided_regex"] is not None,
  533. "guided_choice" in data and data["guided_choice"] is not None
  534. ])
  535. if guide_count > 1:
  536. raise ValueError(
  537. "You can only use one kind of guided decoding "
  538. "('guided_json', 'guided_regex' or 'guided_choice').")
  539. return data
  540. @model_validator(mode="before")
  541. @classmethod
  542. def check_logprobs(cls, data):
  543. if "logprobs" in data and data[
  544. "logprobs"] is not None and not data["logprobs"] >= 0:
  545. raise ValueError("if passed, `logprobs` must be a positive value.")
  546. return data
  547. @model_validator(mode="before")
  548. @classmethod
  549. def validate_stream_options(cls, data):
  550. if data.get("stream_options") and not data.get("stream"):
  551. raise ValueError(
  552. "Stream options can only be defined when stream is True.")
  553. return data
  554. @model_validator(mode='before')
  555. @classmethod
  556. def parse_dry_sequence_breakers(cls, data):
  557. if 'dry_sequence_breakers' in data:
  558. breakers = data['dry_sequence_breakers']
  559. if isinstance(breakers, str):
  560. try:
  561. # Try to parse as JSON string
  562. data['dry_sequence_breakers'] = json.loads(breakers)
  563. except json.JSONDecodeError as e:
  564. raise ValueError(f"Invalid JSON for dry_sequence_breakers:"
  565. f" {e}") from e
  566. # Validate that we now have a list of strings
  567. is_list = isinstance(data['dry_sequence_breakers'], list)
  568. all_strings = all(
  569. isinstance(x, str)
  570. for x in data['dry_sequence_breakers']
  571. )
  572. if not is_list or not all_strings:
  573. raise ValueError(
  574. "dry_sequence_breakers must be a list of strings or a "
  575. "JSON string representing a list of strings"
  576. )
  577. return data
  578. class EmbeddingRequest(OpenAIBaseModel):
  579. # Ordered by official OpenAI API documentation
  580. # https://platform.openai.com/docs/api-reference/embeddings
  581. model: str
  582. input: Union[List[int], List[List[int]], str, List[str]]
  583. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  584. dimensions: Optional[int] = None
  585. user: Optional[str] = None
  586. # doc: begin-embedding-pooling-params
  587. additional_data: Optional[Any] = None
  588. # doc: end-embedding-pooling-params
  589. def to_pooling_params(self):
  590. return PoolingParams(additional_data=self.additional_data)
  591. class CompletionLogProbs(OpenAIBaseModel):
  592. text_offset: List[int] = Field(default_factory=list)
  593. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  594. tokens: List[str] = Field(default_factory=list)
  595. top_logprobs: List[Optional[Dict[str,
  596. float]]] = Field(default_factory=list)
  597. class CompletionResponseChoice(OpenAIBaseModel):
  598. index: int
  599. text: str
  600. logprobs: Optional[CompletionLogProbs] = None
  601. finish_reason: Optional[str] = None
  602. stop_reason: Optional[Union[int, str]] = Field(
  603. default=None,
  604. description=(
  605. "The stop string or token id that caused the completion "
  606. "to stop, None if the completion finished for some other reason "
  607. "including encountering the EOS token"),
  608. )
  609. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  610. class CompletionResponse(OpenAIBaseModel):
  611. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  612. object: str = "text_completion"
  613. created: int = Field(default_factory=lambda: int(time.time()))
  614. model: str
  615. choices: List[CompletionResponseChoice]
  616. usage: UsageInfo
  617. class CompletionResponseStreamChoice(OpenAIBaseModel):
  618. index: int
  619. text: str
  620. logprobs: Optional[CompletionLogProbs] = None
  621. finish_reason: Optional[str] = None
  622. stop_reason: Optional[Union[int, str]] = Field(
  623. default=None,
  624. description=(
  625. "The stop string or token id that caused the completion "
  626. "to stop, None if the completion finished for some other reason "
  627. "including encountering the EOS token"),
  628. )
  629. class CompletionStreamResponse(OpenAIBaseModel):
  630. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  631. object: str = "text_completion"
  632. created: int = Field(default_factory=lambda: int(time.time()))
  633. model: str
  634. choices: List[CompletionResponseStreamChoice]
  635. usage: Optional[UsageInfo] = Field(default=None)
  636. class EmbeddingResponseData(OpenAIBaseModel):
  637. index: int
  638. object: str = "embedding"
  639. embedding: Union[List[float], str]
  640. class EmbeddingResponse(OpenAIBaseModel):
  641. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  642. object: str = "list"
  643. created: int = Field(default_factory=lambda: int(time.time()))
  644. model: str
  645. data: List[EmbeddingResponseData]
  646. usage: UsageInfo
  647. class FunctionCall(OpenAIBaseModel):
  648. name: str
  649. arguments: str
  650. class ToolCall(OpenAIBaseModel):
  651. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  652. type: Literal["function"] = "function"
  653. function: FunctionCall
  654. class ChatMessage(OpenAIBaseModel):
  655. role: str
  656. content: str
  657. tool_calls: List[ToolCall] = Field(default_factory=list)
  658. class ChatCompletionLogProb(OpenAIBaseModel):
  659. token: str
  660. logprob: float = -9999.0
  661. bytes: Optional[List[int]] = None
  662. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  663. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  664. class ChatCompletionLogProbs(OpenAIBaseModel):
  665. content: Optional[List[ChatCompletionLogProbsContent]] = None
  666. class ChatCompletionResponseChoice(OpenAIBaseModel):
  667. index: int
  668. message: ChatMessage
  669. logprobs: Optional[ChatCompletionLogProbs] = None
  670. finish_reason: Optional[str] = None
  671. stop_reason: Optional[Union[int, str]] = None
  672. class ChatCompletionResponse(OpenAIBaseModel):
  673. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  674. object: Literal["chat.completion"] = "chat.completion"
  675. created: int = Field(default_factory=lambda: int(time.time()))
  676. model: str
  677. choices: List[ChatCompletionResponseChoice]
  678. usage: UsageInfo
  679. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  680. class DeltaMessage(OpenAIBaseModel):
  681. role: Optional[str] = None
  682. content: Optional[str] = None
  683. tool_calls: List[ToolCall] = Field(default_factory=list)
  684. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  685. index: int
  686. delta: DeltaMessage
  687. logprobs: Optional[ChatCompletionLogProbs] = None
  688. finish_reason: Optional[str] = None
  689. stop_reason: Optional[Union[int, str]] = None
  690. class ChatCompletionStreamResponse(OpenAIBaseModel):
  691. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  692. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  693. created: int = Field(default_factory=lambda: int(time.time()))
  694. model: str
  695. choices: List[ChatCompletionResponseStreamChoice]
  696. usage: Optional[UsageInfo] = Field(default=None)
  697. class BatchRequestInput(OpenAIBaseModel):
  698. """
  699. The per-line object of the batch input file.
  700. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  701. """
  702. # A developer-provided per-request id that will be used to match outputs to
  703. # inputs. Must be unique for each request in a batch.
  704. custom_id: str
  705. # The HTTP method to be used for the request. Currently only POST is
  706. # supported.
  707. method: str
  708. # The OpenAI API relative URL to be used for the request. Currently
  709. # /v1/chat/completions is supported.
  710. url: str
  711. # The parameters of the request.
  712. body: Union[ChatCompletionRequest, EmbeddingRequest]
  713. class BatchResponseData(OpenAIBaseModel):
  714. # HTTP status code of the response.
  715. status_code: int = 200
  716. # An unique identifier for the API request.
  717. request_id: str
  718. # The body of the response.
  719. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  720. class BatchRequestOutput(OpenAIBaseModel):
  721. """
  722. The per-line object of the batch output and error files
  723. """
  724. id: str
  725. # A developer-provided per-request id that will be used to match outputs to
  726. # inputs.
  727. custom_id: str
  728. response: Optional[BatchResponseData]
  729. # For requests that failed with a non-HTTP error, this will contain more
  730. # information on the cause of the failure.
  731. error: Optional[Any]
  732. class TokenizeCompletionRequest(OpenAIBaseModel):
  733. model: str
  734. prompt: str
  735. add_special_tokens: bool = Field(default=True)
  736. class TokenizeChatRequest(OpenAIBaseModel):
  737. model: str
  738. messages: List[ChatCompletionMessageParam]
  739. add_generation_prompt: bool = Field(default=True)
  740. add_special_tokens: bool = Field(default=False)
  741. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  742. class TokenizeResponse(OpenAIBaseModel):
  743. tokens: List[int]
  744. count: int
  745. max_model_len: int
  746. class DetokenizeRequest(OpenAIBaseModel):
  747. model: Optional[str]
  748. tokens: List[int]
  749. class DetokenizeResponse(OpenAIBaseModel):
  750. prompt: str
  751. # ========== KoboldAI ========== #
  752. class KAIGenerationInputSchema(BaseModel):
  753. genkey: Optional[str] = None
  754. prompt: str
  755. n: Optional[int] = 1
  756. max_context_length: int
  757. max_length: int
  758. rep_pen: Optional[float] = 1.0
  759. top_k: Optional[int] = 0
  760. top_a: Optional[float] = 0.0
  761. top_p: Optional[float] = 1.0
  762. min_p: Optional[float] = 0.0
  763. tfs: Optional[float] = 1.0
  764. eps_cutoff: Optional[float] = 0.0
  765. eta_cutoff: Optional[float] = 0.0
  766. typical: Optional[float] = 1.0
  767. temperature: Optional[float] = 1.0
  768. dynatemp_range: Optional[float] = 0.0
  769. dynatemp_exponent: Optional[float] = 1.0
  770. smoothing_factor: Optional[float] = 0.0
  771. smoothing_curve: Optional[float] = 1.0
  772. xtc_threshold: Optional[float] = 0.1
  773. xtc_probability: Optional[float] = 0.0
  774. use_default_badwordsids: Optional[bool] = None
  775. quiet: Optional[bool] = None
  776. # pylint: disable=unexpected-keyword-arg
  777. sampler_seed: Optional[int] = None
  778. stop_sequence: Optional[List[str]] = None
  779. include_stop_str_in_output: Optional[bool] = False
  780. @model_validator(mode='before')
  781. def check_context(cls, values): # pylint: disable=no-self-argument
  782. assert values.get("max_length") <= values.get(
  783. "max_context_length"
  784. ), "max_length must not be larger than max_context_length"
  785. return values