protocol.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Any, Dict, List, Literal, Optional, Union
  5. import torch
  6. from pydantic import BaseModel, ConfigDict, Field, model_validator
  7. from transformers import PreTrainedTokenizer
  8. from typing_extensions import Annotated
  9. from aphrodite.common.pooling_params import PoolingParams
  10. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  11. SamplingParams)
  12. from aphrodite.common.sequence import Logprob
  13. from aphrodite.common.utils import random_uuid
  14. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  15. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  16. class OpenAIBaseModel(BaseModel):
  17. model_config = ConfigDict(extra="ignore")
  18. class ErrorResponse(OpenAIBaseModel):
  19. object: str = "error"
  20. message: str
  21. type: str
  22. param: Optional[str] = None
  23. code: int
  24. class ModelPermission(OpenAIBaseModel):
  25. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  26. object: str = "model_permission"
  27. created: int = Field(default_factory=lambda: int(time.time()))
  28. allow_create_engine: bool = False
  29. allow_sampling: bool = True
  30. allow_logprobs: bool = True
  31. allow_search_indices: bool = False
  32. allow_view: bool = True
  33. allow_fine_tuning: bool = False
  34. organization: str = "*"
  35. group: Optional[str] = None
  36. is_blocking: bool = False
  37. class ModelCard(OpenAIBaseModel):
  38. id: str
  39. object: str = "model"
  40. created: int = Field(default_factory=lambda: int(time.time()))
  41. owned_by: str = "pygmalionai"
  42. root: Optional[str] = None
  43. parent: Optional[str] = None
  44. max_model_len: Optional[int] = None
  45. permission: List[ModelPermission] = Field(default_factory=list)
  46. class ModelList(OpenAIBaseModel):
  47. object: str = "list"
  48. data: List[ModelCard] = Field(default_factory=list)
  49. class UsageInfo(OpenAIBaseModel):
  50. prompt_tokens: int = 0
  51. total_tokens: int = 0
  52. completion_tokens: Optional[int] = 0
  53. class ResponseFormat(OpenAIBaseModel):
  54. # type must be "json_object" or "text"
  55. type: Literal["text", "json_object"]
  56. class StreamOptions(OpenAIBaseModel):
  57. include_usage: Optional[bool] = True
  58. continuous_usage_stats: Optional[bool] = True
  59. class FunctionDefinition(OpenAIBaseModel):
  60. name: str
  61. description: Optional[str] = None
  62. parameters: Optional[Dict[str, Any]] = None
  63. class ChatCompletionToolsParam(OpenAIBaseModel):
  64. type: Literal["function"] = "function"
  65. function: FunctionDefinition
  66. class ChatCompletionNamedFunction(OpenAIBaseModel):
  67. name: str
  68. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  69. function: ChatCompletionNamedFunction
  70. type: Literal["function"] = "function"
  71. class ChatCompletionRequest(OpenAIBaseModel):
  72. # Ordered by official OpenAI API documentation
  73. # https://platform.openai.com/docs/api-reference/chat/create
  74. messages: List[ChatCompletionMessageParam]
  75. model: str
  76. frequency_penalty: Optional[float] = 0.0
  77. logit_bias: Optional[Dict[str, float]] = None
  78. logprobs: Optional[bool] = False
  79. top_logprobs: Optional[int] = 0
  80. max_tokens: Optional[int] = None
  81. n: Optional[int] = 1
  82. presence_penalty: Optional[float] = 0.0
  83. response_format: Optional[ResponseFormat] = None
  84. seed: Optional[int] = Field(None,
  85. ge=torch.iinfo(torch.long).min,
  86. le=torch.iinfo(torch.long).max)
  87. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  88. stream: Optional[bool] = False
  89. stream_options: Optional[StreamOptions] = None
  90. temperature: Optional[float] = 0.7
  91. top_p: Optional[float] = 1.0
  92. tools: Optional[List[ChatCompletionToolsParam]] = None
  93. tool_choice: Optional[Union[Literal["none"],
  94. ChatCompletionNamedToolChoiceParam]] = "none"
  95. user: Optional[str] = None
  96. # doc: begin-chat-completion-sampling-params
  97. best_of: Optional[int] = None
  98. use_beam_search: Optional[bool] = False
  99. top_k: Optional[int] = -1
  100. min_p: Optional[float] = 0.0
  101. top_a: Optional[float] = 0.0
  102. tfs: Optional[float] = 1.0
  103. eta_cutoff: Optional[float] = 0.0
  104. epsilon_cutoff: Optional[float] = 0.0
  105. typical_p: Optional[float] = 1.0
  106. smoothing_factor: Optional[float] = 0.0
  107. smoothing_curve: Optional[float] = 1.0
  108. repetition_penalty: Optional[float] = 1.0
  109. length_penalty: Optional[float] = 1.0
  110. early_stopping: Optional[bool] = False
  111. ignore_eos: Optional[bool] = False
  112. min_tokens: Optional[int] = 0
  113. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  114. skip_special_tokens: Optional[bool] = True
  115. spaces_between_special_tokens: Optional[bool] = True
  116. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  117. temperature_last: Optional[bool] = False
  118. prompt_logprobs: Optional[int] = None
  119. xtc_threshold: Optional[float] = 0.1
  120. xtc_probability: Optional[float] = 0.0
  121. dynatemp_min: Optional[float] = 0.0
  122. dynatemp_max: Optional[float] = 0.0
  123. dynatemp_exponent: Optional[float] = 1.0
  124. nsigma: Optional[float] = 0.0
  125. custom_token_bans: Optional[List[int]] = None
  126. # doc: end-chat-completion-sampling-params
  127. # doc: begin-chat-completion-extra-params
  128. echo: Optional[bool] = Field(
  129. default=False,
  130. description=(
  131. "If true, the new message will be prepended with the last message "
  132. "if they belong to the same role."),
  133. )
  134. add_generation_prompt: Optional[bool] = Field(
  135. default=True,
  136. description=
  137. ("If true, the generation prompt will be added to the chat template. "
  138. "This is a parameter used by chat template in tokenizer config of the "
  139. "model."),
  140. )
  141. add_special_tokens: Optional[bool] = Field(
  142. default=False,
  143. description=(
  144. "If true, special tokens (e.g. BOS) will be added to the prompt "
  145. "on top of what is added by the chat template. "
  146. "For most models, the chat template takes care of adding the "
  147. "special tokens so this should be set to False (as is the "
  148. "default)."),
  149. )
  150. documents: Optional[List[Dict[str, str]]] = Field(
  151. default=None,
  152. description=
  153. ("A list of dicts representing documents that will be accessible to "
  154. "the model if it is performing RAG (retrieval-augmented generation)."
  155. " If the template does not support RAG, this argument will have no "
  156. "effect. We recommend that each document should be a dict containing "
  157. "\"title\" and \"text\" keys."),
  158. )
  159. chat_template: Optional[str] = Field(
  160. default=None,
  161. description=(
  162. "A Jinja template to use for this conversion. "
  163. "As of transformers v4.44, default chat template is no longer "
  164. "allowed, so you must provide a chat template if the tokenizer "
  165. "does not define one."),
  166. )
  167. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  168. default=None,
  169. description=("Additional kwargs to pass to the template renderer. "
  170. "Will be accessible by the chat template."),
  171. )
  172. include_stop_str_in_output: Optional[bool] = Field(
  173. default=False,
  174. description=(
  175. "Whether to include the stop string in the output. "
  176. "This is only applied when the stop or stop_token_ids is set."),
  177. )
  178. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  179. default=None,
  180. description=("If specified, the output will follow the JSON schema."),
  181. )
  182. guided_regex: Optional[str] = Field(
  183. default=None,
  184. description=(
  185. "If specified, the output will follow the regex pattern."),
  186. )
  187. guided_choice: Optional[List[str]] = Field(
  188. default=None,
  189. description=(
  190. "If specified, the output will be exactly one of the choices."),
  191. )
  192. guided_grammar: Optional[str] = Field(
  193. default=None,
  194. description=(
  195. "If specified, the output will follow the context free grammar."),
  196. )
  197. guided_decoding_backend: Optional[str] = Field(
  198. default=None,
  199. description=(
  200. "If specified, will override the default guided decoding backend "
  201. "of the server for this specific request. If set, must be either "
  202. "'outlines' / 'lm-format-enforcer'"))
  203. guided_whitespace_pattern: Optional[str] = Field(
  204. default=None,
  205. description=(
  206. "If specified, will override the default whitespace pattern "
  207. "for guided json decoding."))
  208. # doc: end-chat-completion-extra-params
  209. def to_sampling_params(
  210. self, tokenizer: PreTrainedTokenizer,
  211. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  212. default_max_tokens: int) -> SamplingParams:
  213. max_tokens = self.max_tokens
  214. if max_tokens is None:
  215. max_tokens = default_max_tokens
  216. # We now allow logprobs being true without top_logrobs.
  217. logits_processors = get_logits_processors(
  218. logit_bias=self.logit_bias,
  219. allowed_token_ids=None,
  220. tokenizer=tokenizer,
  221. )
  222. if guided_decode_logits_processor:
  223. logits_processors.append(guided_decode_logits_processor)
  224. return SamplingParams(
  225. n=self.n,
  226. presence_penalty=self.presence_penalty,
  227. frequency_penalty=self.frequency_penalty,
  228. repetition_penalty=self.repetition_penalty,
  229. temperature=self.temperature,
  230. top_p=self.top_p,
  231. min_p=self.min_p,
  232. seed=self.seed,
  233. stop=self.stop,
  234. stop_token_ids=self.stop_token_ids,
  235. max_tokens=max_tokens,
  236. min_tokens=self.min_tokens,
  237. logprobs=self.top_logprobs if self.logprobs else None,
  238. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  239. (self.top_logprobs if self.echo else None),
  240. best_of=self.best_of,
  241. top_k=self.top_k,
  242. top_a=self.top_a,
  243. tfs=self.tfs,
  244. eta_cutoff=self.eta_cutoff,
  245. epsilon_cutoff=self.epsilon_cutoff,
  246. typical_p=self.typical_p,
  247. smoothing_factor=self.smoothing_factor,
  248. smoothing_curve=self.smoothing_curve,
  249. ignore_eos=self.ignore_eos,
  250. use_beam_search=self.use_beam_search,
  251. early_stopping=self.early_stopping,
  252. skip_special_tokens=self.skip_special_tokens,
  253. spaces_between_special_tokens=self.spaces_between_special_tokens,
  254. include_stop_str_in_output=self.include_stop_str_in_output,
  255. length_penalty=self.length_penalty,
  256. logits_processors=logits_processors,
  257. temperature_last=self.temperature_last,
  258. xtc_threshold=self.xtc_threshold,
  259. xtc_probability=self.xtc_probability,
  260. dynatemp_min=self.dynatemp_min,
  261. dynatemp_max=self.dynatemp_max,
  262. dynatemp_exponent=self.dynatemp_exponent,
  263. nsigma=self.nsigma,
  264. custom_token_bans=self.custom_token_bans,
  265. )
  266. @model_validator(mode='before')
  267. @classmethod
  268. def validate_stream_options(cls, values):
  269. if (values.get('stream_options') is not None
  270. and not values.get('stream')):
  271. raise ValueError(
  272. "stream_options can only be set if stream is true")
  273. return values
  274. @model_validator(mode="before")
  275. @classmethod
  276. def check_guided_decoding_count(cls, data):
  277. guide_count = sum([
  278. "guided_json" in data and data["guided_json"] is not None,
  279. "guided_regex" in data and data["guided_regex"] is not None,
  280. "guided_choice" in data and data["guided_choice"] is not None
  281. ])
  282. # you can only use one kind of guided decoding
  283. if guide_count > 1:
  284. raise ValueError(
  285. "You can only use one kind of guided decoding "
  286. "('guided_json', 'guided_regex' or 'guided_choice').")
  287. # you can only either use guided decoding or tools, not both
  288. if guide_count > 1 and "tool_choice" in data and data[
  289. "tool_choice"] != "none":
  290. raise ValueError(
  291. "You can only either use guided decoding or tools, not both.")
  292. return data
  293. @model_validator(mode="before")
  294. @classmethod
  295. def check_tool_choice(cls, data):
  296. if "tool_choice" in data and data["tool_choice"] != "none":
  297. if not isinstance(data["tool_choice"], dict):
  298. raise ValueError("Currently only named tools are supported.")
  299. if "tools" not in data or data["tools"] is None:
  300. raise ValueError(
  301. "When using `tool_choice`, `tools` must be set.")
  302. return data
  303. @model_validator(mode="before")
  304. @classmethod
  305. def check_logprobs(cls, data):
  306. if "top_logprobs" in data and data["top_logprobs"] is not None:
  307. if "logprobs" not in data or data["logprobs"] is False:
  308. raise ValueError(
  309. "when using `top_logprobs`, `logprobs` must be set to true."
  310. )
  311. elif data["top_logprobs"] < 0:
  312. raise ValueError(
  313. "`top_logprobs` must be a value a positive value.")
  314. return data
  315. class CompletionRequest(OpenAIBaseModel):
  316. # Ordered by official OpenAI API documentation
  317. # https://platform.openai.com/docs/api-reference/completions/create
  318. model: str
  319. prompt: Union[List[int], List[List[int]], str, List[str]]
  320. best_of: Optional[int] = None
  321. echo: Optional[bool] = False
  322. frequency_penalty: Optional[float] = 0.0
  323. logit_bias: Optional[Dict[str, float]] = None
  324. logprobs: Optional[int] = None
  325. max_tokens: Optional[int] = 16
  326. n: int = 1
  327. presence_penalty: Optional[float] = 0.0
  328. seed: Optional[int] = Field(None,
  329. ge=torch.iinfo(torch.long).min,
  330. le=torch.iinfo(torch.long).max)
  331. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  332. stream: Optional[bool] = False
  333. stream_options: Optional[StreamOptions] = None
  334. suffix: Optional[str] = None
  335. temperature: Optional[float] = 1.0
  336. top_p: Optional[float] = 1.0
  337. user: Optional[str] = None
  338. # doc: begin-completion-sampling-params
  339. use_beam_search: Optional[bool] = False
  340. top_k: Optional[int] = -1
  341. min_p: Optional[float] = 0.0
  342. top_a: Optional[float] = 0.0
  343. tfs: Optional[float] = 1.0
  344. eta_cutoff: Optional[float] = 0.0
  345. epsilon_cutoff: Optional[float] = 0.0
  346. typical_p: Optional[float] = 1.0
  347. smoothing_factor: Optional[float] = 0.0
  348. smoothing_curve: Optional[float] = 1.0
  349. repetition_penalty: Optional[float] = 1.0
  350. length_penalty: Optional[float] = 1.0
  351. early_stopping: Optional[bool] = False
  352. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  353. ignore_eos: Optional[bool] = False
  354. min_tokens: Optional[int] = 0
  355. skip_special_tokens: Optional[bool] = True
  356. spaces_between_special_tokens: Optional[bool] = True
  357. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  358. allowed_token_ids: Optional[List[int]] = None
  359. include_stop_str_in_output: Optional[bool] = False
  360. add_special_tokens: Optional[bool] = False
  361. temperature_last: Optional[bool] = False
  362. prompt_logprobs: Optional[int] = None
  363. xtc_threshold: Optional[float] = 0.1
  364. xtc_probability: Optional[float] = 0.0
  365. dynatemp_min: Optional[float] = 0.0
  366. dynatemp_max: Optional[float] = 0.0
  367. dynatemp_exponent: Optional[float] = 1.0
  368. nsigma: Optional[float] = 0.0
  369. custom_token_bans: Optional[List[int]] = None
  370. # doc: end-completion-sampling-params
  371. # doc: begin-completion-extra-params
  372. response_format: Optional[ResponseFormat] = Field(
  373. default=None,
  374. description=
  375. ("Similar to chat completion, this parameter specifies the format of "
  376. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  377. "supported."),
  378. )
  379. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  380. default=None,
  381. description=("If specified, the output will follow the JSON schema."),
  382. )
  383. guided_regex: Optional[str] = Field(
  384. default=None,
  385. description=(
  386. "If specified, the output will follow the regex pattern."),
  387. )
  388. guided_choice: Optional[List[str]] = Field(
  389. default=None,
  390. description=(
  391. "If specified, the output will be exactly one of the choices."),
  392. )
  393. guided_grammar: Optional[str] = Field(
  394. default=None,
  395. description=(
  396. "If specified, the output will follow the context free grammar."),
  397. )
  398. guided_decoding_backend: Optional[str] = Field(
  399. default=None,
  400. description=(
  401. "If specified, will override the default guided decoding backend "
  402. "of the server for this specific request. If set, must be one of "
  403. "'outlines' / 'lm-format-enforcer'"))
  404. guided_whitespace_pattern: Optional[str] = Field(
  405. default=None,
  406. description=(
  407. "If specified, will override the default whitespace pattern "
  408. "for guided json decoding."))
  409. # doc: end-completion-extra-params
  410. def to_sampling_params(
  411. self, tokenizer: PreTrainedTokenizer,
  412. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  413. default_max_tokens: int) -> SamplingParams:
  414. max_tokens = self.max_tokens
  415. if max_tokens is None:
  416. max_tokens = default_max_tokens
  417. echo_without_generation = self.echo and self.max_tokens == 0
  418. logits_processors = get_logits_processors(
  419. logit_bias=self.logit_bias,
  420. allowed_token_ids=self.allowed_token_ids,
  421. tokenizer=tokenizer,
  422. )
  423. if guided_decode_logits_processor:
  424. logits_processors.append(guided_decode_logits_processor)
  425. return SamplingParams(
  426. n=self.n,
  427. best_of=self.best_of,
  428. presence_penalty=self.presence_penalty,
  429. frequency_penalty=self.frequency_penalty,
  430. repetition_penalty=self.repetition_penalty,
  431. temperature=self.temperature,
  432. top_p=self.top_p,
  433. top_k=self.top_k,
  434. min_p=self.min_p,
  435. top_a=self.top_a,
  436. tfs=self.tfs,
  437. eta_cutoff=self.eta_cutoff,
  438. epsilon_cutoff=self.epsilon_cutoff,
  439. typical_p=self.typical_p,
  440. smoothing_factor=self.smoothing_factor,
  441. smoothing_curve=self.smoothing_curve,
  442. seed=self.seed,
  443. stop=self.stop,
  444. stop_token_ids=self.stop_token_ids,
  445. ignore_eos=self.ignore_eos,
  446. max_tokens=max_tokens if not echo_without_generation else 1,
  447. min_tokens=self.min_tokens,
  448. logprobs=self.logprobs,
  449. prompt_logprobs=self.prompt_logprobs
  450. if self.prompt_logprobs else self.logprobs if self.echo else None,
  451. use_beam_search=self.use_beam_search,
  452. early_stopping=self.early_stopping,
  453. skip_special_tokens=self.skip_special_tokens,
  454. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  455. include_stop_str_in_output=self.include_stop_str_in_output,
  456. length_penalty=self.length_penalty,
  457. logits_processors=logits_processors,
  458. truncate_prompt_tokens=self.truncate_prompt_tokens,
  459. temperature_last=self.temperature_last,
  460. xtc_threshold=self.xtc_threshold,
  461. xtc_probability=self.xtc_probability,
  462. dynatemp_min=self.dynatemp_min,
  463. dynatemp_max=self.dynatemp_max,
  464. dynatemp_exponent=self.dynatemp_exponent,
  465. nsigma=self.nsigma,
  466. custom_token_bans=self.custom_token_bans,
  467. )
  468. @model_validator(mode="before")
  469. @classmethod
  470. def check_guided_decoding_count(cls, data):
  471. guide_count = sum([
  472. "guided_json" in data and data["guided_json"] is not None,
  473. "guided_regex" in data and data["guided_regex"] is not None,
  474. "guided_choice" in data and data["guided_choice"] is not None
  475. ])
  476. if guide_count > 1:
  477. raise ValueError(
  478. "You can only use one kind of guided decoding "
  479. "('guided_json', 'guided_regex' or 'guided_choice').")
  480. return data
  481. @model_validator(mode="before")
  482. @classmethod
  483. def check_logprobs(cls, data):
  484. if "logprobs" in data and data[
  485. "logprobs"] is not None and not data["logprobs"] >= 0:
  486. raise ValueError("if passed, `logprobs` must be a positive value.")
  487. return data
  488. @model_validator(mode="before")
  489. @classmethod
  490. def validate_stream_options(cls, data):
  491. if data.get("stream_options") and not data.get("stream"):
  492. raise ValueError(
  493. "Stream options can only be defined when stream is True.")
  494. return data
  495. class EmbeddingRequest(OpenAIBaseModel):
  496. # Ordered by official OpenAI API documentation
  497. # https://platform.openai.com/docs/api-reference/embeddings
  498. model: str
  499. input: Union[List[int], List[List[int]], str, List[str]]
  500. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  501. dimensions: Optional[int] = None
  502. user: Optional[str] = None
  503. # doc: begin-embedding-pooling-params
  504. additional_data: Optional[Any] = None
  505. # doc: end-embedding-pooling-params
  506. def to_pooling_params(self):
  507. return PoolingParams(additional_data=self.additional_data)
  508. class CompletionLogProbs(OpenAIBaseModel):
  509. text_offset: List[int] = Field(default_factory=list)
  510. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  511. tokens: List[str] = Field(default_factory=list)
  512. top_logprobs: List[Optional[Dict[str,
  513. float]]] = Field(default_factory=list)
  514. class CompletionResponseChoice(OpenAIBaseModel):
  515. index: int
  516. text: str
  517. logprobs: Optional[CompletionLogProbs] = None
  518. finish_reason: Optional[str] = None
  519. stop_reason: Optional[Union[int, str]] = Field(
  520. default=None,
  521. description=(
  522. "The stop string or token id that caused the completion "
  523. "to stop, None if the completion finished for some other reason "
  524. "including encountering the EOS token"),
  525. )
  526. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  527. class CompletionResponse(OpenAIBaseModel):
  528. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  529. object: str = "text_completion"
  530. created: int = Field(default_factory=lambda: int(time.time()))
  531. model: str
  532. choices: List[CompletionResponseChoice]
  533. usage: UsageInfo
  534. class CompletionResponseStreamChoice(OpenAIBaseModel):
  535. index: int
  536. text: str
  537. logprobs: Optional[CompletionLogProbs] = None
  538. finish_reason: Optional[str] = None
  539. stop_reason: Optional[Union[int, str]] = Field(
  540. default=None,
  541. description=(
  542. "The stop string or token id that caused the completion "
  543. "to stop, None if the completion finished for some other reason "
  544. "including encountering the EOS token"),
  545. )
  546. class CompletionStreamResponse(OpenAIBaseModel):
  547. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  548. object: str = "text_completion"
  549. created: int = Field(default_factory=lambda: int(time.time()))
  550. model: str
  551. choices: List[CompletionResponseStreamChoice]
  552. usage: Optional[UsageInfo] = Field(default=None)
  553. class EmbeddingResponseData(OpenAIBaseModel):
  554. index: int
  555. object: str = "embedding"
  556. embedding: Union[List[float], str]
  557. class EmbeddingResponse(OpenAIBaseModel):
  558. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  559. object: str = "list"
  560. created: int = Field(default_factory=lambda: int(time.time()))
  561. model: str
  562. data: List[EmbeddingResponseData]
  563. usage: UsageInfo
  564. class FunctionCall(OpenAIBaseModel):
  565. name: str
  566. arguments: str
  567. class ToolCall(OpenAIBaseModel):
  568. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  569. type: Literal["function"] = "function"
  570. function: FunctionCall
  571. class ChatMessage(OpenAIBaseModel):
  572. role: str
  573. content: str
  574. tool_calls: List[ToolCall] = Field(default_factory=list)
  575. class ChatCompletionLogProb(OpenAIBaseModel):
  576. token: str
  577. logprob: float = -9999.0
  578. bytes: Optional[List[int]] = None
  579. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  580. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  581. class ChatCompletionLogProbs(OpenAIBaseModel):
  582. content: Optional[List[ChatCompletionLogProbsContent]] = None
  583. class ChatCompletionResponseChoice(OpenAIBaseModel):
  584. index: int
  585. message: ChatMessage
  586. logprobs: Optional[ChatCompletionLogProbs] = None
  587. finish_reason: Optional[str] = None
  588. stop_reason: Optional[Union[int, str]] = None
  589. class ChatCompletionResponse(OpenAIBaseModel):
  590. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  591. object: Literal["chat.completion"] = "chat.completion"
  592. created: int = Field(default_factory=lambda: int(time.time()))
  593. model: str
  594. choices: List[ChatCompletionResponseChoice]
  595. usage: UsageInfo
  596. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  597. class DeltaMessage(OpenAIBaseModel):
  598. role: Optional[str] = None
  599. content: Optional[str] = None
  600. tool_calls: List[ToolCall] = Field(default_factory=list)
  601. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  602. index: int
  603. delta: DeltaMessage
  604. logprobs: Optional[ChatCompletionLogProbs] = None
  605. finish_reason: Optional[str] = None
  606. stop_reason: Optional[Union[int, str]] = None
  607. class ChatCompletionStreamResponse(OpenAIBaseModel):
  608. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  609. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  610. created: int = Field(default_factory=lambda: int(time.time()))
  611. model: str
  612. choices: List[ChatCompletionResponseStreamChoice]
  613. usage: Optional[UsageInfo] = Field(default=None)
  614. class BatchRequestInput(OpenAIBaseModel):
  615. """
  616. The per-line object of the batch input file.
  617. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  618. """
  619. # A developer-provided per-request id that will be used to match outputs to
  620. # inputs. Must be unique for each request in a batch.
  621. custom_id: str
  622. # The HTTP method to be used for the request. Currently only POST is
  623. # supported.
  624. method: str
  625. # The OpenAI API relative URL to be used for the request. Currently
  626. # /v1/chat/completions is supported.
  627. url: str
  628. # The parameters of the request.
  629. body: Union[ChatCompletionRequest, EmbeddingRequest]
  630. class BatchResponseData(OpenAIBaseModel):
  631. # HTTP status code of the response.
  632. status_code: int = 200
  633. # An unique identifier for the API request.
  634. request_id: str
  635. # The body of the response.
  636. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  637. class BatchRequestOutput(OpenAIBaseModel):
  638. """
  639. The per-line object of the batch output and error files
  640. """
  641. id: str
  642. # A developer-provided per-request id that will be used to match outputs to
  643. # inputs.
  644. custom_id: str
  645. response: Optional[BatchResponseData]
  646. # For requests that failed with a non-HTTP error, this will contain more
  647. # information on the cause of the failure.
  648. error: Optional[Any]
  649. class TokenizeCompletionRequest(OpenAIBaseModel):
  650. model: str
  651. prompt: str
  652. add_special_tokens: bool = Field(default=True)
  653. class TokenizeChatRequest(OpenAIBaseModel):
  654. model: str
  655. messages: List[ChatCompletionMessageParam]
  656. add_generation_prompt: bool = Field(default=True)
  657. add_special_tokens: bool = Field(default=False)
  658. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  659. class TokenizeResponse(OpenAIBaseModel):
  660. tokens: List[int]
  661. count: int
  662. max_model_len: int
  663. class DetokenizeRequest(OpenAIBaseModel):
  664. model: Optional[str]
  665. tokens: List[int]
  666. class DetokenizeResponse(OpenAIBaseModel):
  667. prompt: str
  668. # ========== KoboldAI ========== #
  669. class KAIGenerationInputSchema(BaseModel):
  670. genkey: Optional[str] = None
  671. prompt: str
  672. n: Optional[int] = 1
  673. max_context_length: int
  674. max_length: int
  675. rep_pen: Optional[float] = 1.0
  676. top_k: Optional[int] = 0
  677. top_a: Optional[float] = 0.0
  678. top_p: Optional[float] = 1.0
  679. min_p: Optional[float] = 0.0
  680. tfs: Optional[float] = 1.0
  681. eps_cutoff: Optional[float] = 0.0
  682. eta_cutoff: Optional[float] = 0.0
  683. typical: Optional[float] = 1.0
  684. temperature: Optional[float] = 1.0
  685. dynatemp_range: Optional[float] = 0.0
  686. dynatemp_exponent: Optional[float] = 1.0
  687. smoothing_factor: Optional[float] = 0.0
  688. smoothing_curve: Optional[float] = 1.0
  689. xtc_threshold: Optional[float] = 0.1
  690. xtc_probability: Optional[float] = 0.0
  691. use_default_badwordsids: Optional[bool] = None
  692. quiet: Optional[bool] = None
  693. # pylint: disable=unexpected-keyword-arg
  694. sampler_seed: Optional[int] = None
  695. stop_sequence: Optional[List[str]] = None
  696. include_stop_str_in_output: Optional[bool] = False
  697. @model_validator(mode='before')
  698. def check_context(cls, values): # pylint: disable=no-self-argument
  699. assert values.get("max_length") <= values.get(
  700. "max_context_length"
  701. ), "max_length must not be larger than max_context_length"
  702. return values