protocol.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Any, Dict, List, Literal, Optional, Union
  5. import torch
  6. from pydantic import BaseModel, ConfigDict, Field, model_validator
  7. from transformers import PreTrainedTokenizer
  8. from typing_extensions import Annotated
  9. from aphrodite.common.pooling_params import PoolingParams
  10. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  11. SamplingParams)
  12. from aphrodite.common.sequence import Logprob
  13. from aphrodite.common.utils import random_uuid
  14. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  15. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  16. class OpenAIBaseModel(BaseModel):
  17. model_config = ConfigDict(extra="ignore")
  18. class ErrorResponse(OpenAIBaseModel):
  19. object: str = "error"
  20. message: str
  21. type: str
  22. param: Optional[str] = None
  23. code: int
  24. class ModelPermission(OpenAIBaseModel):
  25. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  26. object: str = "model_permission"
  27. created: int = Field(default_factory=lambda: int(time.time()))
  28. allow_create_engine: bool = False
  29. allow_sampling: bool = True
  30. allow_logprobs: bool = True
  31. allow_search_indices: bool = False
  32. allow_view: bool = True
  33. allow_fine_tuning: bool = False
  34. organization: str = "*"
  35. group: Optional[str] = None
  36. is_blocking: bool = False
  37. class ModelCard(OpenAIBaseModel):
  38. id: str
  39. object: str = "model"
  40. created: int = Field(default_factory=lambda: int(time.time()))
  41. owned_by: str = "pygmalionai"
  42. root: Optional[str] = None
  43. parent: Optional[str] = None
  44. max_model_len: Optional[int] = None
  45. permission: List[ModelPermission] = Field(default_factory=list)
  46. class ModelList(OpenAIBaseModel):
  47. object: str = "list"
  48. data: List[ModelCard] = Field(default_factory=list)
  49. class UsageInfo(OpenAIBaseModel):
  50. prompt_tokens: int = 0
  51. total_tokens: int = 0
  52. completion_tokens: Optional[int] = 0
  53. class ResponseFormat(OpenAIBaseModel):
  54. # type must be "json_object" or "text"
  55. type: Literal["text", "json_object"]
  56. class StreamOptions(OpenAIBaseModel):
  57. include_usage: Optional[bool] = True
  58. continuous_usage_stats: Optional[bool] = True
  59. class FunctionDefinition(OpenAIBaseModel):
  60. name: str
  61. description: Optional[str] = None
  62. parameters: Optional[Dict[str, Any]] = None
  63. class ChatCompletionToolsParam(OpenAIBaseModel):
  64. type: Literal["function"] = "function"
  65. function: FunctionDefinition
  66. class ChatCompletionNamedFunction(OpenAIBaseModel):
  67. name: str
  68. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  69. function: ChatCompletionNamedFunction
  70. type: Literal["function"] = "function"
  71. class ChatCompletionRequest(OpenAIBaseModel):
  72. # Ordered by official OpenAI API documentation
  73. # https://platform.openai.com/docs/api-reference/chat/create
  74. messages: List[ChatCompletionMessageParam]
  75. model: str
  76. frequency_penalty: Optional[float] = 0.0
  77. logit_bias: Optional[Dict[str, float]] = None
  78. logprobs: Optional[bool] = False
  79. top_logprobs: Optional[int] = 0
  80. max_tokens: Optional[int] = None
  81. n: Optional[int] = 1
  82. presence_penalty: Optional[float] = 0.0
  83. response_format: Optional[ResponseFormat] = None
  84. seed: Optional[int] = Field(None,
  85. ge=torch.iinfo(torch.long).min,
  86. le=torch.iinfo(torch.long).max)
  87. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  88. stream: Optional[bool] = False
  89. stream_options: Optional[StreamOptions] = None
  90. temperature: Optional[float] = 0.7
  91. top_p: Optional[float] = 1.0
  92. tools: Optional[List[ChatCompletionToolsParam]] = None
  93. tool_choice: Optional[Union[Literal["none"],
  94. ChatCompletionNamedToolChoiceParam]] = "none"
  95. user: Optional[str] = None
  96. # doc: begin-chat-completion-sampling-params
  97. best_of: Optional[int] = None
  98. use_beam_search: Optional[bool] = False
  99. top_k: Optional[int] = -1
  100. min_p: Optional[float] = 0.0
  101. top_a: Optional[float] = 0.0
  102. tfs: Optional[float] = 1.0
  103. eta_cutoff: Optional[float] = 0.0
  104. epsilon_cutoff: Optional[float] = 0.0
  105. typical_p: Optional[float] = 1.0
  106. smoothing_factor: Optional[float] = 0.0
  107. smoothing_curve: Optional[float] = 1.0
  108. repetition_penalty: Optional[float] = 1.0
  109. length_penalty: Optional[float] = 1.0
  110. early_stopping: Optional[bool] = False
  111. ignore_eos: Optional[bool] = False
  112. min_tokens: Optional[int] = 0
  113. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  114. skip_special_tokens: Optional[bool] = True
  115. spaces_between_special_tokens: Optional[bool] = True
  116. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  117. temperature_last: Optional[bool] = False
  118. prompt_logprobs: Optional[int] = None
  119. xtc_threshold: Optional[float] = 0.1
  120. xtc_probability: Optional[float] = 0.0
  121. dynatemp_min: Optional[float] = 0.0
  122. dynatemp_max: Optional[float] = 0.0
  123. dynatemp_exponent: Optional[float] = 1.0
  124. custom_token_bans: Optional[List[int]] = None
  125. # doc: end-chat-completion-sampling-params
  126. # doc: begin-chat-completion-extra-params
  127. echo: Optional[bool] = Field(
  128. default=False,
  129. description=(
  130. "If true, the new message will be prepended with the last message "
  131. "if they belong to the same role."),
  132. )
  133. add_generation_prompt: Optional[bool] = Field(
  134. default=True,
  135. description=
  136. ("If true, the generation prompt will be added to the chat template. "
  137. "This is a parameter used by chat template in tokenizer config of the "
  138. "model."),
  139. )
  140. add_special_tokens: Optional[bool] = Field(
  141. default=False,
  142. description=(
  143. "If true, special tokens (e.g. BOS) will be added to the prompt "
  144. "on top of what is added by the chat template. "
  145. "For most models, the chat template takes care of adding the "
  146. "special tokens so this should be set to False (as is the "
  147. "default)."),
  148. )
  149. documents: Optional[List[Dict[str, str]]] = Field(
  150. default=None,
  151. description=
  152. ("A list of dicts representing documents that will be accessible to "
  153. "the model if it is performing RAG (retrieval-augmented generation)."
  154. " If the template does not support RAG, this argument will have no "
  155. "effect. We recommend that each document should be a dict containing "
  156. "\"title\" and \"text\" keys."),
  157. )
  158. chat_template: Optional[str] = Field(
  159. default=None,
  160. description=(
  161. "A Jinja template to use for this conversion. "
  162. "As of transformers v4.44, default chat template is no longer "
  163. "allowed, so you must provide a chat template if the tokenizer "
  164. "does not define one."),
  165. )
  166. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  167. default=None,
  168. description=("Additional kwargs to pass to the template renderer. "
  169. "Will be accessible by the chat template."),
  170. )
  171. include_stop_str_in_output: Optional[bool] = Field(
  172. default=False,
  173. description=(
  174. "Whether to include the stop string in the output. "
  175. "This is only applied when the stop or stop_token_ids is set."),
  176. )
  177. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  178. default=None,
  179. description=("If specified, the output will follow the JSON schema."),
  180. )
  181. guided_regex: Optional[str] = Field(
  182. default=None,
  183. description=(
  184. "If specified, the output will follow the regex pattern."),
  185. )
  186. guided_choice: Optional[List[str]] = Field(
  187. default=None,
  188. description=(
  189. "If specified, the output will be exactly one of the choices."),
  190. )
  191. guided_grammar: Optional[str] = Field(
  192. default=None,
  193. description=(
  194. "If specified, the output will follow the context free grammar."),
  195. )
  196. guided_decoding_backend: Optional[str] = Field(
  197. default=None,
  198. description=(
  199. "If specified, will override the default guided decoding backend "
  200. "of the server for this specific request. If set, must be either "
  201. "'outlines' / 'lm-format-enforcer'"))
  202. guided_whitespace_pattern: Optional[str] = Field(
  203. default=None,
  204. description=(
  205. "If specified, will override the default whitespace pattern "
  206. "for guided json decoding."))
  207. # doc: end-chat-completion-extra-params
  208. def to_sampling_params(
  209. self, tokenizer: PreTrainedTokenizer,
  210. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  211. default_max_tokens: int) -> SamplingParams:
  212. max_tokens = self.max_tokens
  213. if max_tokens is None:
  214. max_tokens = default_max_tokens
  215. # We now allow logprobs being true without top_logrobs.
  216. logits_processors = get_logits_processors(
  217. logit_bias=self.logit_bias,
  218. allowed_token_ids=None,
  219. tokenizer=tokenizer,
  220. )
  221. if guided_decode_logits_processor:
  222. logits_processors.append(guided_decode_logits_processor)
  223. return SamplingParams(
  224. n=self.n,
  225. presence_penalty=self.presence_penalty,
  226. frequency_penalty=self.frequency_penalty,
  227. repetition_penalty=self.repetition_penalty,
  228. temperature=self.temperature,
  229. top_p=self.top_p,
  230. min_p=self.min_p,
  231. seed=self.seed,
  232. stop=self.stop,
  233. stop_token_ids=self.stop_token_ids,
  234. max_tokens=max_tokens,
  235. min_tokens=self.min_tokens,
  236. logprobs=self.top_logprobs if self.logprobs else None,
  237. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  238. (self.top_logprobs if self.echo else None),
  239. best_of=self.best_of,
  240. top_k=self.top_k,
  241. top_a=self.top_a,
  242. tfs=self.tfs,
  243. eta_cutoff=self.eta_cutoff,
  244. epsilon_cutoff=self.epsilon_cutoff,
  245. typical_p=self.typical_p,
  246. smoothing_factor=self.smoothing_factor,
  247. smoothing_curve=self.smoothing_curve,
  248. ignore_eos=self.ignore_eos,
  249. use_beam_search=self.use_beam_search,
  250. early_stopping=self.early_stopping,
  251. skip_special_tokens=self.skip_special_tokens,
  252. spaces_between_special_tokens=self.spaces_between_special_tokens,
  253. include_stop_str_in_output=self.include_stop_str_in_output,
  254. length_penalty=self.length_penalty,
  255. logits_processors=logits_processors,
  256. temperature_last=self.temperature_last,
  257. xtc_threshold=self.xtc_threshold,
  258. xtc_probability=self.xtc_probability,
  259. dynatemp_min=self.dynatemp_min,
  260. dynatemp_max=self.dynatemp_max,
  261. dynatemp_exponent=self.dynatemp_exponent,
  262. custom_token_bans=self.custom_token_bans,
  263. )
  264. @model_validator(mode='before')
  265. @classmethod
  266. def validate_stream_options(cls, values):
  267. if (values.get('stream_options') is not None
  268. and not values.get('stream')):
  269. raise ValueError(
  270. "stream_options can only be set if stream is true")
  271. return values
  272. @model_validator(mode="before")
  273. @classmethod
  274. def check_guided_decoding_count(cls, data):
  275. guide_count = sum([
  276. "guided_json" in data and data["guided_json"] is not None,
  277. "guided_regex" in data and data["guided_regex"] is not None,
  278. "guided_choice" in data and data["guided_choice"] is not None
  279. ])
  280. # you can only use one kind of guided decoding
  281. if guide_count > 1:
  282. raise ValueError(
  283. "You can only use one kind of guided decoding "
  284. "('guided_json', 'guided_regex' or 'guided_choice').")
  285. # you can only either use guided decoding or tools, not both
  286. if guide_count > 1 and "tool_choice" in data and data[
  287. "tool_choice"] != "none":
  288. raise ValueError(
  289. "You can only either use guided decoding or tools, not both.")
  290. return data
  291. @model_validator(mode="before")
  292. @classmethod
  293. def check_tool_choice(cls, data):
  294. if "tool_choice" in data and data["tool_choice"] != "none":
  295. if not isinstance(data["tool_choice"], dict):
  296. raise ValueError("Currently only named tools are supported.")
  297. if "tools" not in data or data["tools"] is None:
  298. raise ValueError(
  299. "When using `tool_choice`, `tools` must be set.")
  300. return data
  301. @model_validator(mode="before")
  302. @classmethod
  303. def check_logprobs(cls, data):
  304. if "top_logprobs" in data and data["top_logprobs"] is not None:
  305. if "logprobs" not in data or data["logprobs"] is False:
  306. raise ValueError(
  307. "when using `top_logprobs`, `logprobs` must be set to true."
  308. )
  309. elif data["top_logprobs"] < 0:
  310. raise ValueError(
  311. "`top_logprobs` must be a value a positive value.")
  312. return data
  313. class CompletionRequest(OpenAIBaseModel):
  314. # Ordered by official OpenAI API documentation
  315. # https://platform.openai.com/docs/api-reference/completions/create
  316. model: str
  317. prompt: Union[List[int], List[List[int]], str, List[str]]
  318. best_of: Optional[int] = None
  319. echo: Optional[bool] = False
  320. frequency_penalty: Optional[float] = 0.0
  321. logit_bias: Optional[Dict[str, float]] = None
  322. logprobs: Optional[int] = None
  323. max_tokens: Optional[int] = 16
  324. n: int = 1
  325. presence_penalty: Optional[float] = 0.0
  326. seed: Optional[int] = Field(None,
  327. ge=torch.iinfo(torch.long).min,
  328. le=torch.iinfo(torch.long).max)
  329. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  330. stream: Optional[bool] = False
  331. stream_options: Optional[StreamOptions] = None
  332. suffix: Optional[str] = None
  333. temperature: Optional[float] = 1.0
  334. top_p: Optional[float] = 1.0
  335. user: Optional[str] = None
  336. # doc: begin-completion-sampling-params
  337. use_beam_search: Optional[bool] = False
  338. top_k: Optional[int] = -1
  339. min_p: Optional[float] = 0.0
  340. top_a: Optional[float] = 0.0
  341. tfs: Optional[float] = 1.0
  342. eta_cutoff: Optional[float] = 0.0
  343. epsilon_cutoff: Optional[float] = 0.0
  344. typical_p: Optional[float] = 1.0
  345. smoothing_factor: Optional[float] = 0.0
  346. smoothing_curve: Optional[float] = 1.0
  347. repetition_penalty: Optional[float] = 1.0
  348. length_penalty: Optional[float] = 1.0
  349. early_stopping: Optional[bool] = False
  350. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  351. ignore_eos: Optional[bool] = False
  352. min_tokens: Optional[int] = 0
  353. skip_special_tokens: Optional[bool] = True
  354. spaces_between_special_tokens: Optional[bool] = True
  355. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  356. allowed_token_ids: Optional[List[int]] = None
  357. include_stop_str_in_output: Optional[bool] = False
  358. add_special_tokens: Optional[bool] = False
  359. temperature_last: Optional[bool] = False
  360. prompt_logprobs: Optional[int] = None
  361. xtc_threshold: Optional[float] = 0.1
  362. xtc_probability: Optional[float] = 0.0
  363. dynatemp_min: Optional[float] = 0.0
  364. dynatemp_max: Optional[float] = 0.0
  365. dynatemp_exponent: Optional[float] = 1.0
  366. custom_token_bans: Optional[List[int]] = None
  367. # doc: end-completion-sampling-params
  368. # doc: begin-completion-extra-params
  369. response_format: Optional[ResponseFormat] = Field(
  370. default=None,
  371. description=
  372. ("Similar to chat completion, this parameter specifies the format of "
  373. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  374. "supported."),
  375. )
  376. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  377. default=None,
  378. description=("If specified, the output will follow the JSON schema."),
  379. )
  380. guided_regex: Optional[str] = Field(
  381. default=None,
  382. description=(
  383. "If specified, the output will follow the regex pattern."),
  384. )
  385. guided_choice: Optional[List[str]] = Field(
  386. default=None,
  387. description=(
  388. "If specified, the output will be exactly one of the choices."),
  389. )
  390. guided_grammar: Optional[str] = Field(
  391. default=None,
  392. description=(
  393. "If specified, the output will follow the context free grammar."),
  394. )
  395. guided_decoding_backend: Optional[str] = Field(
  396. default=None,
  397. description=(
  398. "If specified, will override the default guided decoding backend "
  399. "of the server for this specific request. If set, must be one of "
  400. "'outlines' / 'lm-format-enforcer'"))
  401. guided_whitespace_pattern: Optional[str] = Field(
  402. default=None,
  403. description=(
  404. "If specified, will override the default whitespace pattern "
  405. "for guided json decoding."))
  406. # doc: end-completion-extra-params
  407. def to_sampling_params(
  408. self, tokenizer: PreTrainedTokenizer,
  409. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  410. default_max_tokens: int) -> SamplingParams:
  411. max_tokens = self.max_tokens
  412. if max_tokens is None:
  413. max_tokens = default_max_tokens
  414. echo_without_generation = self.echo and self.max_tokens == 0
  415. logits_processors = get_logits_processors(
  416. logit_bias=self.logit_bias,
  417. allowed_token_ids=self.allowed_token_ids,
  418. tokenizer=tokenizer,
  419. )
  420. if guided_decode_logits_processor:
  421. logits_processors.append(guided_decode_logits_processor)
  422. return SamplingParams(
  423. n=self.n,
  424. best_of=self.best_of,
  425. presence_penalty=self.presence_penalty,
  426. frequency_penalty=self.frequency_penalty,
  427. repetition_penalty=self.repetition_penalty,
  428. temperature=self.temperature,
  429. top_p=self.top_p,
  430. top_k=self.top_k,
  431. min_p=self.min_p,
  432. top_a=self.top_a,
  433. tfs=self.tfs,
  434. eta_cutoff=self.eta_cutoff,
  435. epsilon_cutoff=self.epsilon_cutoff,
  436. typical_p=self.typical_p,
  437. smoothing_factor=self.smoothing_factor,
  438. smoothing_curve=self.smoothing_curve,
  439. seed=self.seed,
  440. stop=self.stop,
  441. stop_token_ids=self.stop_token_ids,
  442. ignore_eos=self.ignore_eos,
  443. max_tokens=max_tokens if not echo_without_generation else 1,
  444. min_tokens=self.min_tokens,
  445. logprobs=self.logprobs,
  446. prompt_logprobs=self.prompt_logprobs
  447. if self.prompt_logprobs else self.logprobs if self.echo else None,
  448. use_beam_search=self.use_beam_search,
  449. early_stopping=self.early_stopping,
  450. skip_special_tokens=self.skip_special_tokens,
  451. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  452. include_stop_str_in_output=self.include_stop_str_in_output,
  453. length_penalty=self.length_penalty,
  454. logits_processors=logits_processors,
  455. truncate_prompt_tokens=self.truncate_prompt_tokens,
  456. temperature_last=self.temperature_last,
  457. xtc_threshold=self.xtc_threshold,
  458. xtc_probability=self.xtc_probability,
  459. dynatemp_min=self.dynatemp_min,
  460. dynatemp_max=self.dynatemp_max,
  461. dynatemp_exponent=self.dynatemp_exponent,
  462. custom_token_bans=self.custom_token_bans,
  463. )
  464. @model_validator(mode="before")
  465. @classmethod
  466. def check_guided_decoding_count(cls, data):
  467. guide_count = sum([
  468. "guided_json" in data and data["guided_json"] is not None,
  469. "guided_regex" in data and data["guided_regex"] is not None,
  470. "guided_choice" in data and data["guided_choice"] is not None
  471. ])
  472. if guide_count > 1:
  473. raise ValueError(
  474. "You can only use one kind of guided decoding "
  475. "('guided_json', 'guided_regex' or 'guided_choice').")
  476. return data
  477. @model_validator(mode="before")
  478. @classmethod
  479. def check_logprobs(cls, data):
  480. if "logprobs" in data and data[
  481. "logprobs"] is not None and not data["logprobs"] >= 0:
  482. raise ValueError("if passed, `logprobs` must be a positive value.")
  483. return data
  484. @model_validator(mode="before")
  485. @classmethod
  486. def validate_stream_options(cls, data):
  487. if data.get("stream_options") and not data.get("stream"):
  488. raise ValueError(
  489. "Stream options can only be defined when stream is True.")
  490. return data
  491. class EmbeddingRequest(OpenAIBaseModel):
  492. # Ordered by official OpenAI API documentation
  493. # https://platform.openai.com/docs/api-reference/embeddings
  494. model: str
  495. input: Union[List[int], List[List[int]], str, List[str]]
  496. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  497. dimensions: Optional[int] = None
  498. user: Optional[str] = None
  499. # doc: begin-embedding-pooling-params
  500. additional_data: Optional[Any] = None
  501. # doc: end-embedding-pooling-params
  502. def to_pooling_params(self):
  503. return PoolingParams(additional_data=self.additional_data)
  504. class CompletionLogProbs(OpenAIBaseModel):
  505. text_offset: List[int] = Field(default_factory=list)
  506. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  507. tokens: List[str] = Field(default_factory=list)
  508. top_logprobs: List[Optional[Dict[str,
  509. float]]] = Field(default_factory=list)
  510. class CompletionResponseChoice(OpenAIBaseModel):
  511. index: int
  512. text: str
  513. logprobs: Optional[CompletionLogProbs] = None
  514. finish_reason: Optional[str] = None
  515. stop_reason: Optional[Union[int, str]] = Field(
  516. default=None,
  517. description=(
  518. "The stop string or token id that caused the completion "
  519. "to stop, None if the completion finished for some other reason "
  520. "including encountering the EOS token"),
  521. )
  522. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  523. class CompletionResponse(OpenAIBaseModel):
  524. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  525. object: str = "text_completion"
  526. created: int = Field(default_factory=lambda: int(time.time()))
  527. model: str
  528. choices: List[CompletionResponseChoice]
  529. usage: UsageInfo
  530. class CompletionResponseStreamChoice(OpenAIBaseModel):
  531. index: int
  532. text: str
  533. logprobs: Optional[CompletionLogProbs] = None
  534. finish_reason: Optional[str] = None
  535. stop_reason: Optional[Union[int, str]] = Field(
  536. default=None,
  537. description=(
  538. "The stop string or token id that caused the completion "
  539. "to stop, None if the completion finished for some other reason "
  540. "including encountering the EOS token"),
  541. )
  542. class CompletionStreamResponse(OpenAIBaseModel):
  543. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  544. object: str = "text_completion"
  545. created: int = Field(default_factory=lambda: int(time.time()))
  546. model: str
  547. choices: List[CompletionResponseStreamChoice]
  548. usage: Optional[UsageInfo] = Field(default=None)
  549. class EmbeddingResponseData(OpenAIBaseModel):
  550. index: int
  551. object: str = "embedding"
  552. embedding: Union[List[float], str]
  553. class EmbeddingResponse(OpenAIBaseModel):
  554. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  555. object: str = "list"
  556. created: int = Field(default_factory=lambda: int(time.time()))
  557. model: str
  558. data: List[EmbeddingResponseData]
  559. usage: UsageInfo
  560. class FunctionCall(OpenAIBaseModel):
  561. name: str
  562. arguments: str
  563. class ToolCall(OpenAIBaseModel):
  564. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  565. type: Literal["function"] = "function"
  566. function: FunctionCall
  567. class ChatMessage(OpenAIBaseModel):
  568. role: str
  569. content: str
  570. tool_calls: List[ToolCall] = Field(default_factory=list)
  571. class ChatCompletionLogProb(OpenAIBaseModel):
  572. token: str
  573. logprob: float = -9999.0
  574. bytes: Optional[List[int]] = None
  575. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  576. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  577. class ChatCompletionLogProbs(OpenAIBaseModel):
  578. content: Optional[List[ChatCompletionLogProbsContent]] = None
  579. class ChatCompletionResponseChoice(OpenAIBaseModel):
  580. index: int
  581. message: ChatMessage
  582. logprobs: Optional[ChatCompletionLogProbs] = None
  583. finish_reason: Optional[str] = None
  584. stop_reason: Optional[Union[int, str]] = None
  585. class ChatCompletionResponse(OpenAIBaseModel):
  586. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  587. object: Literal["chat.completion"] = "chat.completion"
  588. created: int = Field(default_factory=lambda: int(time.time()))
  589. model: str
  590. choices: List[ChatCompletionResponseChoice]
  591. usage: UsageInfo
  592. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  593. class DeltaMessage(OpenAIBaseModel):
  594. role: Optional[str] = None
  595. content: Optional[str] = None
  596. tool_calls: List[ToolCall] = Field(default_factory=list)
  597. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  598. index: int
  599. delta: DeltaMessage
  600. logprobs: Optional[ChatCompletionLogProbs] = None
  601. finish_reason: Optional[str] = None
  602. stop_reason: Optional[Union[int, str]] = None
  603. class ChatCompletionStreamResponse(OpenAIBaseModel):
  604. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  605. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  606. created: int = Field(default_factory=lambda: int(time.time()))
  607. model: str
  608. choices: List[ChatCompletionResponseStreamChoice]
  609. usage: Optional[UsageInfo] = Field(default=None)
  610. class BatchRequestInput(OpenAIBaseModel):
  611. """
  612. The per-line object of the batch input file.
  613. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  614. """
  615. # A developer-provided per-request id that will be used to match outputs to
  616. # inputs. Must be unique for each request in a batch.
  617. custom_id: str
  618. # The HTTP method to be used for the request. Currently only POST is
  619. # supported.
  620. method: str
  621. # The OpenAI API relative URL to be used for the request. Currently
  622. # /v1/chat/completions is supported.
  623. url: str
  624. # The parameters of the request.
  625. body: Union[ChatCompletionRequest, EmbeddingRequest]
  626. class BatchResponseData(OpenAIBaseModel):
  627. # HTTP status code of the response.
  628. status_code: int = 200
  629. # An unique identifier for the API request.
  630. request_id: str
  631. # The body of the response.
  632. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  633. class BatchRequestOutput(OpenAIBaseModel):
  634. """
  635. The per-line object of the batch output and error files
  636. """
  637. id: str
  638. # A developer-provided per-request id that will be used to match outputs to
  639. # inputs.
  640. custom_id: str
  641. response: Optional[BatchResponseData]
  642. # For requests that failed with a non-HTTP error, this will contain more
  643. # information on the cause of the failure.
  644. error: Optional[Any]
  645. class TokenizeCompletionRequest(OpenAIBaseModel):
  646. model: str
  647. prompt: str
  648. add_special_tokens: bool = Field(default=True)
  649. class TokenizeChatRequest(OpenAIBaseModel):
  650. model: str
  651. messages: List[ChatCompletionMessageParam]
  652. add_generation_prompt: bool = Field(default=True)
  653. add_special_tokens: bool = Field(default=False)
  654. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  655. class TokenizeResponse(OpenAIBaseModel):
  656. tokens: List[int]
  657. count: int
  658. max_model_len: int
  659. class DetokenizeRequest(OpenAIBaseModel):
  660. model: Optional[str]
  661. tokens: List[int]
  662. class DetokenizeResponse(OpenAIBaseModel):
  663. prompt: str
  664. # ========== KoboldAI ========== #
  665. class KAIGenerationInputSchema(BaseModel):
  666. genkey: Optional[str] = None
  667. prompt: str
  668. n: Optional[int] = 1
  669. max_context_length: int
  670. max_length: int
  671. rep_pen: Optional[float] = 1.0
  672. top_k: Optional[int] = 0
  673. top_a: Optional[float] = 0.0
  674. top_p: Optional[float] = 1.0
  675. min_p: Optional[float] = 0.0
  676. tfs: Optional[float] = 1.0
  677. eps_cutoff: Optional[float] = 0.0
  678. eta_cutoff: Optional[float] = 0.0
  679. typical: Optional[float] = 1.0
  680. temperature: Optional[float] = 1.0
  681. dynatemp_range: Optional[float] = 0.0
  682. dynatemp_exponent: Optional[float] = 1.0
  683. smoothing_factor: Optional[float] = 0.0
  684. smoothing_curve: Optional[float] = 1.0
  685. xtc_threshold: Optional[float] = 0.1
  686. xtc_probability: Optional[float] = 0.0
  687. use_default_badwordsids: Optional[bool] = None
  688. quiet: Optional[bool] = None
  689. # pylint: disable=unexpected-keyword-arg
  690. sampler_seed: Optional[int] = None
  691. stop_sequence: Optional[List[str]] = None
  692. include_stop_str_in_output: Optional[bool] = False
  693. @model_validator(mode='before')
  694. def check_context(cls, values): # pylint: disable=no-self-argument
  695. assert values.get("max_length") <= values.get(
  696. "max_context_length"
  697. ), "max_length must not be larger than max_context_length"
  698. return values