1
0

protocol.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Any, Dict, List, Literal, Optional, Union
  5. import torch
  6. from pydantic import BaseModel, ConfigDict, Field, model_validator
  7. from transformers import PreTrainedTokenizer
  8. from typing_extensions import Annotated
  9. from aphrodite.common.pooling_params import PoolingParams
  10. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  11. SamplingParams)
  12. from aphrodite.common.sequence import Logprob
  13. from aphrodite.common.utils import random_uuid
  14. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  15. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  16. class OpenAIBaseModel(BaseModel):
  17. model_config = ConfigDict(extra="ignore")
  18. class ErrorResponse(OpenAIBaseModel):
  19. object: str = "error"
  20. message: str
  21. type: str
  22. param: Optional[str] = None
  23. code: int
  24. class ModelPermission(OpenAIBaseModel):
  25. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  26. object: str = "model_permission"
  27. created: int = Field(default_factory=lambda: int(time.time()))
  28. allow_create_engine: bool = False
  29. allow_sampling: bool = True
  30. allow_logprobs: bool = True
  31. allow_search_indices: bool = False
  32. allow_view: bool = True
  33. allow_fine_tuning: bool = False
  34. organization: str = "*"
  35. group: Optional[str] = None
  36. is_blocking: bool = False
  37. class ModelCard(OpenAIBaseModel):
  38. id: str
  39. object: str = "model"
  40. created: int = Field(default_factory=lambda: int(time.time()))
  41. owned_by: str = "pygmalionai"
  42. root: Optional[str] = None
  43. parent: Optional[str] = None
  44. max_model_len: Optional[int] = None
  45. permission: List[ModelPermission] = Field(default_factory=list)
  46. class ModelList(OpenAIBaseModel):
  47. object: str = "list"
  48. data: List[ModelCard] = Field(default_factory=list)
  49. class UsageInfo(OpenAIBaseModel):
  50. prompt_tokens: int = 0
  51. total_tokens: int = 0
  52. completion_tokens: Optional[int] = 0
  53. class ResponseFormat(OpenAIBaseModel):
  54. # type must be "json_object" or "text"
  55. type: Literal["text", "json_object"]
  56. class StreamOptions(OpenAIBaseModel):
  57. include_usage: Optional[bool] = True
  58. continuous_usage_stats: Optional[bool] = True
  59. class FunctionDefinition(OpenAIBaseModel):
  60. name: str
  61. description: Optional[str] = None
  62. parameters: Optional[Dict[str, Any]] = None
  63. class ChatCompletionToolsParam(OpenAIBaseModel):
  64. type: Literal["function"] = "function"
  65. function: FunctionDefinition
  66. class ChatCompletionNamedFunction(OpenAIBaseModel):
  67. name: str
  68. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  69. function: ChatCompletionNamedFunction
  70. type: Literal["function"] = "function"
  71. class ChatCompletionRequest(OpenAIBaseModel):
  72. # Ordered by official OpenAI API documentation
  73. # https://platform.openai.com/docs/api-reference/chat/create
  74. messages: List[ChatCompletionMessageParam]
  75. model: str
  76. frequency_penalty: Optional[float] = 0.0
  77. logit_bias: Optional[Dict[str, float]] = None
  78. logprobs: Optional[bool] = False
  79. top_logprobs: Optional[int] = 0
  80. max_tokens: Optional[int] = None
  81. n: Optional[int] = 1
  82. presence_penalty: Optional[float] = 0.0
  83. response_format: Optional[ResponseFormat] = None
  84. seed: Optional[int] = Field(None,
  85. ge=torch.iinfo(torch.long).min,
  86. le=torch.iinfo(torch.long).max)
  87. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  88. stream: Optional[bool] = False
  89. stream_options: Optional[StreamOptions] = None
  90. temperature: Optional[float] = 0.7
  91. top_p: Optional[float] = 1.0
  92. tools: Optional[List[ChatCompletionToolsParam]] = None
  93. tool_choice: Optional[Union[Literal["none"],
  94. ChatCompletionNamedToolChoiceParam]] = "none"
  95. user: Optional[str] = None
  96. # doc: begin-chat-completion-sampling-params
  97. best_of: Optional[int] = None
  98. use_beam_search: Optional[bool] = False
  99. top_k: Optional[int] = -1
  100. min_p: Optional[float] = 0.0
  101. top_a: Optional[float] = 0.0
  102. tfs: Optional[float] = 1.0
  103. eta_cutoff: Optional[float] = 0.0
  104. epsilon_cutoff: Optional[float] = 0.0
  105. typical_p: Optional[float] = 1.0
  106. smoothing_factor: Optional[float] = 0.0
  107. smoothing_curve: Optional[float] = 1.0
  108. repetition_penalty: Optional[float] = 1.0
  109. length_penalty: Optional[float] = 1.0
  110. early_stopping: Optional[bool] = False
  111. ignore_eos: Optional[bool] = False
  112. min_tokens: Optional[int] = 0
  113. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  114. skip_special_tokens: Optional[bool] = True
  115. spaces_between_special_tokens: Optional[bool] = True
  116. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  117. temperature_last: Optional[bool] = False
  118. prompt_logprobs: Optional[int] = None
  119. xtc_threshold: Optional[float] = 0.1
  120. xtc_probability: Optional[float] = 0.0
  121. dynatemp_min: Optional[float] = 0.0
  122. dynatemp_max: Optional[float] = 0.0
  123. dynatemp_exponent: Optional[float] = 1.0
  124. custom_token_bans: Optional[List[int]] = None
  125. # doc: end-chat-completion-sampling-params
  126. # doc: begin-chat-completion-extra-params
  127. echo: Optional[bool] = Field(
  128. default=False,
  129. description=(
  130. "If true, the new message will be prepended with the last message "
  131. "if they belong to the same role."),
  132. )
  133. add_generation_prompt: Optional[bool] = Field(
  134. default=True,
  135. description=
  136. ("If true, the generation prompt will be added to the chat template. "
  137. "This is a parameter used by chat template in tokenizer config of the "
  138. "model."),
  139. )
  140. add_special_tokens: Optional[bool] = Field(
  141. default=False,
  142. description=(
  143. "If true, special tokens (e.g. BOS) will be added to the prompt "
  144. "on top of what is added by the chat template. "
  145. "For most models, the chat template takes care of adding the "
  146. "special tokens so this should be set to False (as is the "
  147. "default)."),
  148. )
  149. documents: Optional[List[Dict[str, str]]] = Field(
  150. default=None,
  151. description=
  152. ("A list of dicts representing documents that will be accessible to "
  153. "the model if it is performing RAG (retrieval-augmented generation)."
  154. " If the template does not support RAG, this argument will have no "
  155. "effect. We recommend that each document should be a dict containing "
  156. "\"title\" and \"text\" keys."),
  157. )
  158. chat_template: Optional[str] = Field(
  159. default=None,
  160. description=(
  161. "A Jinja template to use for this conversion. "
  162. "As of transformers v4.44, default chat template is no longer "
  163. "allowed, so you must provide a chat template if the tokenizer "
  164. "does not define one."),
  165. )
  166. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  167. default=None,
  168. description=("Additional kwargs to pass to the template renderer. "
  169. "Will be accessible by the chat template."),
  170. )
  171. include_stop_str_in_output: Optional[bool] = Field(
  172. default=False,
  173. description=(
  174. "Whether to include the stop string in the output. "
  175. "This is only applied when the stop or stop_token_ids is set."),
  176. )
  177. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  178. default=None,
  179. description=("If specified, the output will follow the JSON schema."),
  180. )
  181. guided_regex: Optional[str] = Field(
  182. default=None,
  183. description=(
  184. "If specified, the output will follow the regex pattern."),
  185. )
  186. guided_choice: Optional[List[str]] = Field(
  187. default=None,
  188. description=(
  189. "If specified, the output will be exactly one of the choices."),
  190. )
  191. guided_grammar: Optional[str] = Field(
  192. default=None,
  193. description=(
  194. "If specified, the output will follow the context free grammar."),
  195. )
  196. guided_decoding_backend: Optional[str] = Field(
  197. default=None,
  198. description=(
  199. "If specified, will override the default guided decoding backend "
  200. "of the server for this specific request. If set, must be either "
  201. "'outlines' / 'lm-format-enforcer'"))
  202. guided_whitespace_pattern: Optional[str] = Field(
  203. default=None,
  204. description=(
  205. "If specified, will override the default whitespace pattern "
  206. "for guided json decoding."))
  207. # doc: end-chat-completion-extra-params
  208. def to_sampling_params(
  209. self, tokenizer: PreTrainedTokenizer,
  210. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  211. default_max_tokens: int) -> SamplingParams:
  212. max_tokens = self.max_tokens
  213. if max_tokens is None:
  214. max_tokens = default_max_tokens
  215. # We now allow logprobs being true without top_logrobs.
  216. logits_processors = get_logits_processors(
  217. logit_bias=self.logit_bias,
  218. allowed_token_ids=None,
  219. tokenizer=tokenizer,
  220. )
  221. if guided_decode_logits_processor:
  222. logits_processors.append(guided_decode_logits_processor)
  223. return SamplingParams(
  224. n=self.n,
  225. presence_penalty=self.presence_penalty,
  226. frequency_penalty=self.frequency_penalty,
  227. repetition_penalty=self.repetition_penalty,
  228. temperature=self.temperature,
  229. top_p=self.top_p,
  230. min_p=self.min_p,
  231. seed=self.seed,
  232. stop=self.stop,
  233. stop_token_ids=self.stop_token_ids,
  234. max_tokens=max_tokens,
  235. min_tokens=self.min_tokens,
  236. logprobs=self.top_logprobs if self.logprobs else None,
  237. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  238. (self.top_logprobs if self.echo else None),
  239. best_of=self.best_of,
  240. top_k=self.top_k,
  241. top_a=self.top_a,
  242. tfs=self.tfs,
  243. eta_cutoff=self.eta_cutoff,
  244. epsilon_cutoff=self.epsilon_cutoff,
  245. typical_p=self.typical_p,
  246. smoothing_factor=self.smoothing_factor,
  247. smoothing_curve=self.smoothing_curve,
  248. ignore_eos=self.ignore_eos,
  249. use_beam_search=self.use_beam_search,
  250. early_stopping=self.early_stopping,
  251. skip_special_tokens=self.skip_special_tokens,
  252. spaces_between_special_tokens=self.spaces_between_special_tokens,
  253. include_stop_str_in_output=self.include_stop_str_in_output,
  254. length_penalty=self.length_penalty,
  255. logits_processors=logits_processors,
  256. temperature_last=self.temperature_last,
  257. xtc_threshold=self.xtc_threshold,
  258. xtc_probability=self.xtc_probability,
  259. dynatemp_min=self.dynatemp_min,
  260. dynatemp_max=self.dynatemp_max,
  261. dynatemp_exponent=self.dynatemp_exponent,
  262. custom_token_bans=self.custom_token_bans,
  263. )
  264. @model_validator(mode='before')
  265. @classmethod
  266. def validate_stream_options(cls, values):
  267. if (values.get('stream_options') is not None
  268. and not values.get('stream')):
  269. raise ValueError(
  270. "stream_options can only be set if stream is true")
  271. return values
  272. @model_validator(mode="before")
  273. @classmethod
  274. def check_guided_decoding_count(cls, data):
  275. guide_count = sum([
  276. "guided_json" in data and data["guided_json"] is not None,
  277. "guided_regex" in data and data["guided_regex"] is not None,
  278. "guided_choice" in data and data["guided_choice"] is not None
  279. ])
  280. # you can only use one kind of guided decoding
  281. if guide_count > 1:
  282. raise ValueError(
  283. "You can only use one kind of guided decoding "
  284. "('guided_json', 'guided_regex' or 'guided_choice').")
  285. # you can only either use guided decoding or tools, not both
  286. if guide_count > 1 and "tool_choice" in data and data[
  287. "tool_choice"] != "none":
  288. raise ValueError(
  289. "You can only either use guided decoding or tools, not both.")
  290. return data
  291. @model_validator(mode="before")
  292. @classmethod
  293. def check_tool_choice(cls, data):
  294. if "tool_choice" in data and data["tool_choice"] != "none":
  295. if not isinstance(data["tool_choice"], dict):
  296. raise ValueError("Currently only named tools are supported.")
  297. if "tools" not in data or data["tools"] is None:
  298. raise ValueError(
  299. "When using `tool_choice`, `tools` must be set.")
  300. return data
  301. @model_validator(mode="before")
  302. @classmethod
  303. def check_logprobs(cls, data):
  304. if "top_logprobs" in data and data["top_logprobs"] is not None:
  305. if "logprobs" not in data or data["logprobs"] is False:
  306. raise ValueError(
  307. "when using `top_logprobs`, `logprobs` must be set to true."
  308. )
  309. elif data["top_logprobs"] < 0:
  310. raise ValueError(
  311. "`top_logprobs` must be a value a positive value.")
  312. return data
  313. class CompletionRequest(OpenAIBaseModel):
  314. # Ordered by official OpenAI API documentation
  315. # https://platform.openai.com/docs/api-reference/completions/create
  316. model: str
  317. prompt: Union[List[int], List[List[int]], str, List[str]]
  318. best_of: Optional[int] = None
  319. echo: Optional[bool] = False
  320. frequency_penalty: Optional[float] = 0.0
  321. logit_bias: Optional[Dict[str, float]] = None
  322. logprobs: Optional[int] = None
  323. max_tokens: Optional[int] = 16
  324. n: int = 1
  325. presence_penalty: Optional[float] = 0.0
  326. seed: Optional[int] = Field(None,
  327. ge=torch.iinfo(torch.long).min,
  328. le=torch.iinfo(torch.long).max)
  329. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  330. stream: Optional[bool] = False
  331. stream_options: Optional[StreamOptions] = None
  332. suffix: Optional[str] = None
  333. temperature: Optional[float] = 1.0
  334. top_p: Optional[float] = 1.0
  335. user: Optional[str] = None
  336. # doc: begin-completion-sampling-params
  337. use_beam_search: Optional[bool] = False
  338. top_k: Optional[int] = -1
  339. min_p: Optional[float] = 0.0
  340. top_a: Optional[float] = 0.0
  341. tfs: Optional[float] = 1.0
  342. eta_cutoff: Optional[float] = 0.0
  343. epsilon_cutoff: Optional[float] = 0.0
  344. typical_p: Optional[float] = 1.0
  345. smoothing_factor: Optional[float] = 0.0
  346. smoothing_curve: Optional[float] = 1.0
  347. repetition_penalty: Optional[float] = 1.0
  348. length_penalty: Optional[float] = 1.0
  349. early_stopping: Optional[bool] = False
  350. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  351. ignore_eos: Optional[bool] = False
  352. min_tokens: Optional[int] = 0
  353. skip_special_tokens: Optional[bool] = True
  354. spaces_between_special_tokens: Optional[bool] = True
  355. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  356. allowed_token_ids: Optional[List[int]] = None
  357. include_stop_str_in_output: Optional[bool] = False
  358. add_special_tokens: Optional[bool] = False
  359. temperature_last: Optional[bool] = False
  360. prompt_logprobs: Optional[int] = None
  361. xtc_threshold: Optional[float] = 0.1
  362. xtc_probability: Optional[float] = 0.0
  363. kl_threshold: Optional[float] = 0.0
  364. jsd_threshold: Optional[float] = 0.0
  365. min_typical_p: Optional[float] = 1.0
  366. max_typical_p: Optional[float] = 1.0
  367. dynatemp_min: Optional[float] = 0.0
  368. dynatemp_max: Optional[float] = 0.0
  369. dynatemp_exponent: Optional[float] = 1.0
  370. custom_token_bans: Optional[List[int]] = None
  371. # doc: end-completion-sampling-params
  372. # doc: begin-completion-extra-params
  373. response_format: Optional[ResponseFormat] = Field(
  374. default=None,
  375. description=
  376. ("Similar to chat completion, this parameter specifies the format of "
  377. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  378. "supported."),
  379. )
  380. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  381. default=None,
  382. description=("If specified, the output will follow the JSON schema."),
  383. )
  384. guided_regex: Optional[str] = Field(
  385. default=None,
  386. description=(
  387. "If specified, the output will follow the regex pattern."),
  388. )
  389. guided_choice: Optional[List[str]] = Field(
  390. default=None,
  391. description=(
  392. "If specified, the output will be exactly one of the choices."),
  393. )
  394. guided_grammar: Optional[str] = Field(
  395. default=None,
  396. description=(
  397. "If specified, the output will follow the context free grammar."),
  398. )
  399. guided_decoding_backend: Optional[str] = Field(
  400. default=None,
  401. description=(
  402. "If specified, will override the default guided decoding backend "
  403. "of the server for this specific request. If set, must be one of "
  404. "'outlines' / 'lm-format-enforcer'"))
  405. guided_whitespace_pattern: Optional[str] = Field(
  406. default=None,
  407. description=(
  408. "If specified, will override the default whitespace pattern "
  409. "for guided json decoding."))
  410. # doc: end-completion-extra-params
  411. def to_sampling_params(
  412. self, tokenizer: PreTrainedTokenizer,
  413. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  414. default_max_tokens: int) -> SamplingParams:
  415. max_tokens = self.max_tokens
  416. if max_tokens is None:
  417. max_tokens = default_max_tokens
  418. echo_without_generation = self.echo and self.max_tokens == 0
  419. logits_processors = get_logits_processors(
  420. logit_bias=self.logit_bias,
  421. allowed_token_ids=self.allowed_token_ids,
  422. tokenizer=tokenizer,
  423. )
  424. if guided_decode_logits_processor:
  425. logits_processors.append(guided_decode_logits_processor)
  426. return SamplingParams(
  427. n=self.n,
  428. best_of=self.best_of,
  429. presence_penalty=self.presence_penalty,
  430. frequency_penalty=self.frequency_penalty,
  431. repetition_penalty=self.repetition_penalty,
  432. temperature=self.temperature,
  433. top_p=self.top_p,
  434. top_k=self.top_k,
  435. min_p=self.min_p,
  436. top_a=self.top_a,
  437. tfs=self.tfs,
  438. eta_cutoff=self.eta_cutoff,
  439. epsilon_cutoff=self.epsilon_cutoff,
  440. typical_p=self.typical_p,
  441. smoothing_factor=self.smoothing_factor,
  442. smoothing_curve=self.smoothing_curve,
  443. seed=self.seed,
  444. stop=self.stop,
  445. stop_token_ids=self.stop_token_ids,
  446. ignore_eos=self.ignore_eos,
  447. max_tokens=max_tokens if not echo_without_generation else 1,
  448. min_tokens=self.min_tokens,
  449. logprobs=self.logprobs,
  450. prompt_logprobs=self.prompt_logprobs
  451. if self.prompt_logprobs else self.logprobs if self.echo else None,
  452. use_beam_search=self.use_beam_search,
  453. early_stopping=self.early_stopping,
  454. skip_special_tokens=self.skip_special_tokens,
  455. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  456. include_stop_str_in_output=self.include_stop_str_in_output,
  457. length_penalty=self.length_penalty,
  458. logits_processors=logits_processors,
  459. truncate_prompt_tokens=self.truncate_prompt_tokens,
  460. temperature_last=self.temperature_last,
  461. xtc_threshold=self.xtc_threshold,
  462. xtc_probability=self.xtc_probability,
  463. kl_threshold=self.kl_threshold,
  464. jsd_threshold=self.jsd_threshold,
  465. min_typical_p=self.min_typical_p,
  466. max_typical_p=self.max_typical_p,
  467. dynatemp_min=self.dynatemp_min,
  468. dynatemp_max=self.dynatemp_max,
  469. dynatemp_exponent=self.dynatemp_exponent,
  470. custom_token_bans=self.custom_token_bans,
  471. )
  472. @model_validator(mode="before")
  473. @classmethod
  474. def check_guided_decoding_count(cls, data):
  475. guide_count = sum([
  476. "guided_json" in data and data["guided_json"] is not None,
  477. "guided_regex" in data and data["guided_regex"] is not None,
  478. "guided_choice" in data and data["guided_choice"] is not None
  479. ])
  480. if guide_count > 1:
  481. raise ValueError(
  482. "You can only use one kind of guided decoding "
  483. "('guided_json', 'guided_regex' or 'guided_choice').")
  484. return data
  485. @model_validator(mode="before")
  486. @classmethod
  487. def check_logprobs(cls, data):
  488. if "logprobs" in data and data[
  489. "logprobs"] is not None and not data["logprobs"] >= 0:
  490. raise ValueError("if passed, `logprobs` must be a positive value.")
  491. return data
  492. @model_validator(mode="before")
  493. @classmethod
  494. def validate_stream_options(cls, data):
  495. if data.get("stream_options") and not data.get("stream"):
  496. raise ValueError(
  497. "Stream options can only be defined when stream is True.")
  498. return data
  499. class EmbeddingRequest(OpenAIBaseModel):
  500. # Ordered by official OpenAI API documentation
  501. # https://platform.openai.com/docs/api-reference/embeddings
  502. model: str
  503. input: Union[List[int], List[List[int]], str, List[str]]
  504. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  505. dimensions: Optional[int] = None
  506. user: Optional[str] = None
  507. # doc: begin-embedding-pooling-params
  508. additional_data: Optional[Any] = None
  509. # doc: end-embedding-pooling-params
  510. def to_pooling_params(self):
  511. return PoolingParams(additional_data=self.additional_data)
  512. class CompletionLogProbs(OpenAIBaseModel):
  513. text_offset: List[int] = Field(default_factory=list)
  514. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  515. tokens: List[str] = Field(default_factory=list)
  516. top_logprobs: List[Optional[Dict[str,
  517. float]]] = Field(default_factory=list)
  518. class CompletionResponseChoice(OpenAIBaseModel):
  519. index: int
  520. text: str
  521. logprobs: Optional[CompletionLogProbs] = None
  522. finish_reason: Optional[str] = None
  523. stop_reason: Optional[Union[int, str]] = Field(
  524. default=None,
  525. description=(
  526. "The stop string or token id that caused the completion "
  527. "to stop, None if the completion finished for some other reason "
  528. "including encountering the EOS token"),
  529. )
  530. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  531. class CompletionResponse(OpenAIBaseModel):
  532. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  533. object: str = "text_completion"
  534. created: int = Field(default_factory=lambda: int(time.time()))
  535. model: str
  536. choices: List[CompletionResponseChoice]
  537. usage: UsageInfo
  538. class CompletionResponseStreamChoice(OpenAIBaseModel):
  539. index: int
  540. text: str
  541. logprobs: Optional[CompletionLogProbs] = None
  542. finish_reason: Optional[str] = None
  543. stop_reason: Optional[Union[int, str]] = Field(
  544. default=None,
  545. description=(
  546. "The stop string or token id that caused the completion "
  547. "to stop, None if the completion finished for some other reason "
  548. "including encountering the EOS token"),
  549. )
  550. class CompletionStreamResponse(OpenAIBaseModel):
  551. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  552. object: str = "text_completion"
  553. created: int = Field(default_factory=lambda: int(time.time()))
  554. model: str
  555. choices: List[CompletionResponseStreamChoice]
  556. usage: Optional[UsageInfo] = Field(default=None)
  557. class EmbeddingResponseData(OpenAIBaseModel):
  558. index: int
  559. object: str = "embedding"
  560. embedding: Union[List[float], str]
  561. class EmbeddingResponse(OpenAIBaseModel):
  562. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  563. object: str = "list"
  564. created: int = Field(default_factory=lambda: int(time.time()))
  565. model: str
  566. data: List[EmbeddingResponseData]
  567. usage: UsageInfo
  568. class FunctionCall(OpenAIBaseModel):
  569. name: str
  570. arguments: str
  571. class ToolCall(OpenAIBaseModel):
  572. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  573. type: Literal["function"] = "function"
  574. function: FunctionCall
  575. class ChatMessage(OpenAIBaseModel):
  576. role: str
  577. content: str
  578. tool_calls: List[ToolCall] = Field(default_factory=list)
  579. class ChatCompletionLogProb(OpenAIBaseModel):
  580. token: str
  581. logprob: float = -9999.0
  582. bytes: Optional[List[int]] = None
  583. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  584. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  585. class ChatCompletionLogProbs(OpenAIBaseModel):
  586. content: Optional[List[ChatCompletionLogProbsContent]] = None
  587. class ChatCompletionResponseChoice(OpenAIBaseModel):
  588. index: int
  589. message: ChatMessage
  590. logprobs: Optional[ChatCompletionLogProbs] = None
  591. finish_reason: Optional[str] = None
  592. stop_reason: Optional[Union[int, str]] = None
  593. class ChatCompletionResponse(OpenAIBaseModel):
  594. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  595. object: Literal["chat.completion"] = "chat.completion"
  596. created: int = Field(default_factory=lambda: int(time.time()))
  597. model: str
  598. choices: List[ChatCompletionResponseChoice]
  599. usage: UsageInfo
  600. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  601. class DeltaMessage(OpenAIBaseModel):
  602. role: Optional[str] = None
  603. content: Optional[str] = None
  604. tool_calls: List[ToolCall] = Field(default_factory=list)
  605. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  606. index: int
  607. delta: DeltaMessage
  608. logprobs: Optional[ChatCompletionLogProbs] = None
  609. finish_reason: Optional[str] = None
  610. stop_reason: Optional[Union[int, str]] = None
  611. class ChatCompletionStreamResponse(OpenAIBaseModel):
  612. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  613. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  614. created: int = Field(default_factory=lambda: int(time.time()))
  615. model: str
  616. choices: List[ChatCompletionResponseStreamChoice]
  617. usage: Optional[UsageInfo] = Field(default=None)
  618. class BatchRequestInput(OpenAIBaseModel):
  619. """
  620. The per-line object of the batch input file.
  621. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  622. """
  623. # A developer-provided per-request id that will be used to match outputs to
  624. # inputs. Must be unique for each request in a batch.
  625. custom_id: str
  626. # The HTTP method to be used for the request. Currently only POST is
  627. # supported.
  628. method: str
  629. # The OpenAI API relative URL to be used for the request. Currently
  630. # /v1/chat/completions is supported.
  631. url: str
  632. # The parameters of the request.
  633. body: Union[ChatCompletionRequest, EmbeddingRequest]
  634. class BatchResponseData(OpenAIBaseModel):
  635. # HTTP status code of the response.
  636. status_code: int = 200
  637. # An unique identifier for the API request.
  638. request_id: str
  639. # The body of the response.
  640. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  641. class BatchRequestOutput(OpenAIBaseModel):
  642. """
  643. The per-line object of the batch output and error files
  644. """
  645. id: str
  646. # A developer-provided per-request id that will be used to match outputs to
  647. # inputs.
  648. custom_id: str
  649. response: Optional[BatchResponseData]
  650. # For requests that failed with a non-HTTP error, this will contain more
  651. # information on the cause of the failure.
  652. error: Optional[Any]
  653. class TokenizeCompletionRequest(OpenAIBaseModel):
  654. model: str
  655. prompt: str
  656. add_special_tokens: bool = Field(default=True)
  657. class TokenizeChatRequest(OpenAIBaseModel):
  658. model: str
  659. messages: List[ChatCompletionMessageParam]
  660. add_generation_prompt: bool = Field(default=True)
  661. add_special_tokens: bool = Field(default=False)
  662. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  663. class TokenizeResponse(OpenAIBaseModel):
  664. tokens: List[int]
  665. count: int
  666. max_model_len: int
  667. class DetokenizeRequest(OpenAIBaseModel):
  668. model: Optional[str]
  669. tokens: List[int]
  670. class DetokenizeResponse(OpenAIBaseModel):
  671. prompt: str
  672. # ========== KoboldAI ========== #
  673. class KAIGenerationInputSchema(BaseModel):
  674. genkey: Optional[str] = None
  675. prompt: str
  676. n: Optional[int] = 1
  677. max_context_length: int
  678. max_length: int
  679. rep_pen: Optional[float] = 1.0
  680. top_k: Optional[int] = 0
  681. top_a: Optional[float] = 0.0
  682. top_p: Optional[float] = 1.0
  683. min_p: Optional[float] = 0.0
  684. tfs: Optional[float] = 1.0
  685. eps_cutoff: Optional[float] = 0.0
  686. eta_cutoff: Optional[float] = 0.0
  687. typical: Optional[float] = 1.0
  688. temperature: Optional[float] = 1.0
  689. dynatemp_range: Optional[float] = 0.0
  690. dynatemp_exponent: Optional[float] = 1.0
  691. smoothing_factor: Optional[float] = 0.0
  692. smoothing_curve: Optional[float] = 1.0
  693. xtc_threshold: Optional[float] = 0.1
  694. xtc_probability: Optional[float] = 0.0
  695. use_default_badwordsids: Optional[bool] = None
  696. quiet: Optional[bool] = None
  697. # pylint: disable=unexpected-keyword-arg
  698. sampler_seed: Optional[int] = None
  699. stop_sequence: Optional[List[str]] = None
  700. include_stop_str_in_output: Optional[bool] = False
  701. @model_validator(mode='before')
  702. def check_context(cls, values): # pylint: disable=no-self-argument
  703. assert values.get("max_length") <= values.get(
  704. "max_context_length"
  705. ), "max_length must not be larger than max_context_length"
  706. return values