protocol.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Any, Dict, List, Literal, Optional, Union
  5. import torch
  6. from pydantic import (BaseModel, ConfigDict, Field, model_validator,
  7. root_validator)
  8. from transformers import PreTrainedTokenizer
  9. from typing_extensions import Annotated
  10. from aphrodite.common.pooling_params import PoolingParams
  11. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  12. SamplingParams)
  13. from aphrodite.common.sequence import Logprob
  14. from aphrodite.common.utils import random_uuid
  15. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  16. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  17. class OpenAIBaseModel(BaseModel):
  18. model_config = ConfigDict(extra="ignore")
  19. class ErrorResponse(OpenAIBaseModel):
  20. object: str = "error"
  21. message: str
  22. type: str
  23. param: Optional[str] = None
  24. code: int
  25. class ModelPermission(OpenAIBaseModel):
  26. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  27. object: str = "model_permission"
  28. created: int = Field(default_factory=lambda: int(time.time()))
  29. allow_create_engine: bool = False
  30. allow_sampling: bool = True
  31. allow_logprobs: bool = True
  32. allow_search_indices: bool = False
  33. allow_view: bool = True
  34. allow_fine_tuning: bool = False
  35. organization: str = "*"
  36. group: Optional[str] = None
  37. is_blocking: bool = False
  38. class ModelCard(OpenAIBaseModel):
  39. id: str
  40. object: str = "model"
  41. created: int = Field(default_factory=lambda: int(time.time()))
  42. owned_by: str = "pygmalionai"
  43. root: Optional[str] = None
  44. parent: Optional[str] = None
  45. max_model_len: Optional[int] = None
  46. permission: List[ModelPermission] = Field(default_factory=list)
  47. class ModelList(OpenAIBaseModel):
  48. object: str = "list"
  49. data: List[ModelCard] = Field(default_factory=list)
  50. class UsageInfo(OpenAIBaseModel):
  51. prompt_tokens: int = 0
  52. total_tokens: int = 0
  53. completion_tokens: Optional[int] = 0
  54. class ResponseFormat(OpenAIBaseModel):
  55. # type must be "json_object" or "text"
  56. type: Literal["text", "json_object"]
  57. class StreamOptions(OpenAIBaseModel):
  58. include_usage: Optional[bool] = True
  59. continuous_usage_stats: Optional[bool] = True
  60. class FunctionDefinition(OpenAIBaseModel):
  61. name: str
  62. description: Optional[str] = None
  63. parameters: Optional[Dict[str, Any]] = None
  64. class ChatCompletionToolsParam(OpenAIBaseModel):
  65. type: Literal["function"] = "function"
  66. function: FunctionDefinition
  67. class ChatCompletionNamedFunction(OpenAIBaseModel):
  68. name: str
  69. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  70. function: ChatCompletionNamedFunction
  71. type: Literal["function"] = "function"
  72. class ChatCompletionRequest(OpenAIBaseModel):
  73. # Ordered by official OpenAI API documentation
  74. # https://platform.openai.com/docs/api-reference/chat/create
  75. messages: List[ChatCompletionMessageParam]
  76. model: str
  77. frequency_penalty: Optional[float] = 0.0
  78. logit_bias: Optional[Dict[str, float]] = None
  79. logprobs: Optional[bool] = False
  80. top_logprobs: Optional[int] = 0
  81. max_tokens: Optional[int] = None
  82. n: Optional[int] = 1
  83. presence_penalty: Optional[float] = 0.0
  84. response_format: Optional[ResponseFormat] = None
  85. seed: Optional[int] = Field(None,
  86. ge=torch.iinfo(torch.long).min,
  87. le=torch.iinfo(torch.long).max)
  88. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  89. stream: Optional[bool] = False
  90. stream_options: Optional[StreamOptions] = None
  91. temperature: Optional[float] = 0.7
  92. top_p: Optional[float] = 1.0
  93. tools: Optional[List[ChatCompletionToolsParam]] = None
  94. tool_choice: Optional[Union[Literal["none"],
  95. ChatCompletionNamedToolChoiceParam]] = "none"
  96. user: Optional[str] = None
  97. # doc: begin-chat-completion-sampling-params
  98. best_of: Optional[int] = None
  99. use_beam_search: Optional[bool] = False
  100. top_k: Optional[int] = -1
  101. min_p: Optional[float] = 0.0
  102. top_a: Optional[float] = 0.0
  103. tfs: Optional[float] = 1.0
  104. eta_cutoff: Optional[float] = 0.0
  105. epsilon_cutoff: Optional[float] = 0.0
  106. typical_p: Optional[float] = 1.0
  107. smoothing_factor: Optional[float] = 0.0
  108. smoothing_curve: Optional[float] = 1.0
  109. repetition_penalty: Optional[float] = 1.0
  110. length_penalty: Optional[float] = 1.0
  111. early_stopping: Optional[bool] = False
  112. ignore_eos: Optional[bool] = False
  113. min_tokens: Optional[int] = 0
  114. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  115. skip_special_tokens: Optional[bool] = True
  116. spaces_between_special_tokens: Optional[bool] = True
  117. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  118. temperature_last: Optional[bool] = False
  119. prompt_logprobs: Optional[int] = None
  120. # doc: end-chat-completion-sampling-params
  121. # doc: begin-chat-completion-extra-params
  122. echo: Optional[bool] = Field(
  123. default=False,
  124. description=(
  125. "If true, the new message will be prepended with the last message "
  126. "if they belong to the same role."),
  127. )
  128. add_generation_prompt: Optional[bool] = Field(
  129. default=True,
  130. description=
  131. ("If true, the generation prompt will be added to the chat template. "
  132. "This is a parameter used by chat template in tokenizer config of the "
  133. "model."),
  134. )
  135. add_special_tokens: Optional[bool] = Field(
  136. default=False,
  137. description=(
  138. "If true, special tokens (e.g. BOS) will be added to the prompt "
  139. "on top of what is added by the chat template. "
  140. "For most models, the chat template takes care of adding the "
  141. "special tokens so this should be set to False (as is the "
  142. "default)."),
  143. )
  144. documents: Optional[List[Dict[str, str]]] = Field(
  145. default=None,
  146. description=
  147. ("A list of dicts representing documents that will be accessible to "
  148. "the model if it is performing RAG (retrieval-augmented generation)."
  149. " If the template does not support RAG, this argument will have no "
  150. "effect. We recommend that each document should be a dict containing "
  151. "\"title\" and \"text\" keys."),
  152. )
  153. chat_template: Optional[str] = Field(
  154. default=None,
  155. description=(
  156. "A Jinja template to use for this conversion. "
  157. "As of transformers v4.44, default chat template is no longer "
  158. "allowed, so you must provide a chat template if the tokenizer "
  159. "does not define one."),
  160. )
  161. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  162. default=None,
  163. description=("Additional kwargs to pass to the template renderer. "
  164. "Will be accessible by the chat template."),
  165. )
  166. include_stop_str_in_output: Optional[bool] = Field(
  167. default=False,
  168. description=(
  169. "Whether to include the stop string in the output. "
  170. "This is only applied when the stop or stop_token_ids is set."),
  171. )
  172. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  173. default=None,
  174. description=("If specified, the output will follow the JSON schema."),
  175. )
  176. guided_regex: Optional[str] = Field(
  177. default=None,
  178. description=(
  179. "If specified, the output will follow the regex pattern."),
  180. )
  181. guided_choice: Optional[List[str]] = Field(
  182. default=None,
  183. description=(
  184. "If specified, the output will be exactly one of the choices."),
  185. )
  186. guided_grammar: Optional[str] = Field(
  187. default=None,
  188. description=(
  189. "If specified, the output will follow the context free grammar."),
  190. )
  191. guided_decoding_backend: Optional[str] = Field(
  192. default=None,
  193. description=(
  194. "If specified, will override the default guided decoding backend "
  195. "of the server for this specific request. If set, must be either "
  196. "'outlines' / 'lm-format-enforcer'"))
  197. guided_whitespace_pattern: Optional[str] = Field(
  198. default=None,
  199. description=(
  200. "If specified, will override the default whitespace pattern "
  201. "for guided json decoding."))
  202. # doc: end-chat-completion-extra-params
  203. def to_sampling_params(
  204. self, tokenizer: PreTrainedTokenizer,
  205. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  206. default_max_tokens: int) -> SamplingParams:
  207. max_tokens = self.max_tokens
  208. if max_tokens is None:
  209. max_tokens = default_max_tokens
  210. # We now allow logprobs being true without top_logrobs.
  211. logits_processors = get_logits_processors(
  212. logit_bias=self.logit_bias,
  213. allowed_token_ids=None,
  214. tokenizer=tokenizer,
  215. )
  216. if guided_decode_logits_processor:
  217. logits_processors.append(guided_decode_logits_processor)
  218. return SamplingParams(
  219. n=self.n,
  220. presence_penalty=self.presence_penalty,
  221. frequency_penalty=self.frequency_penalty,
  222. repetition_penalty=self.repetition_penalty,
  223. temperature=self.temperature,
  224. top_p=self.top_p,
  225. min_p=self.min_p,
  226. seed=self.seed,
  227. stop=self.stop,
  228. stop_token_ids=self.stop_token_ids,
  229. max_tokens=max_tokens,
  230. min_tokens=self.min_tokens,
  231. logprobs=self.top_logprobs if self.logprobs else None,
  232. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  233. (self.top_logprobs if self.echo else None),
  234. best_of=self.best_of,
  235. top_k=self.top_k,
  236. top_a=self.top_a,
  237. tfs=self.tfs,
  238. eta_cutoff=self.eta_cutoff,
  239. epsilon_cutoff=self.epsilon_cutoff,
  240. typical_p=self.typical_p,
  241. smoothing_factor=self.smoothing_factor,
  242. smoothing_curve=self.smoothing_curve,
  243. ignore_eos=self.ignore_eos,
  244. use_beam_search=self.use_beam_search,
  245. early_stopping=self.early_stopping,
  246. skip_special_tokens=self.skip_special_tokens,
  247. spaces_between_special_tokens=self.spaces_between_special_tokens,
  248. include_stop_str_in_output=self.include_stop_str_in_output,
  249. length_penalty=self.length_penalty,
  250. logits_processors=logits_processors,
  251. temperature_last=self.temperature_last,
  252. )
  253. @model_validator(mode='before')
  254. @classmethod
  255. def validate_stream_options(cls, values):
  256. if (values.get('stream_options') is not None
  257. and not values.get('stream')):
  258. raise ValueError(
  259. "stream_options can only be set if stream is true")
  260. return values
  261. @model_validator(mode="before")
  262. @classmethod
  263. def check_guided_decoding_count(cls, data):
  264. guide_count = sum([
  265. "guided_json" in data and data["guided_json"] is not None,
  266. "guided_regex" in data and data["guided_regex"] is not None,
  267. "guided_choice" in data and data["guided_choice"] is not None
  268. ])
  269. # you can only use one kind of guided decoding
  270. if guide_count > 1:
  271. raise ValueError(
  272. "You can only use one kind of guided decoding "
  273. "('guided_json', 'guided_regex' or 'guided_choice').")
  274. # you can only either use guided decoding or tools, not both
  275. if guide_count > 1 and "tool_choice" in data and data[
  276. "tool_choice"] != "none":
  277. raise ValueError(
  278. "You can only either use guided decoding or tools, not both.")
  279. return data
  280. @model_validator(mode="before")
  281. @classmethod
  282. def check_tool_choice(cls, data):
  283. if "tool_choice" in data and data["tool_choice"] != "none":
  284. if not isinstance(data["tool_choice"], dict):
  285. raise ValueError("Currently only named tools are supported.")
  286. if "tools" not in data or data["tools"] is None:
  287. raise ValueError(
  288. "When using `tool_choice`, `tools` must be set.")
  289. return data
  290. @model_validator(mode="before")
  291. @classmethod
  292. def check_logprobs(cls, data):
  293. if "top_logprobs" in data and data["top_logprobs"] is not None:
  294. if "logprobs" not in data or data["logprobs"] is False:
  295. raise ValueError(
  296. "when using `top_logprobs`, `logprobs` must be set to true."
  297. )
  298. elif data["top_logprobs"] < 0:
  299. raise ValueError(
  300. "`top_logprobs` must be a value a positive value.")
  301. return data
  302. class CompletionRequest(OpenAIBaseModel):
  303. # Ordered by official OpenAI API documentation
  304. # https://platform.openai.com/docs/api-reference/completions/create
  305. model: str
  306. prompt: Union[List[int], List[List[int]], str, List[str]]
  307. best_of: Optional[int] = None
  308. echo: Optional[bool] = False
  309. frequency_penalty: Optional[float] = 0.0
  310. logit_bias: Optional[Dict[str, float]] = None
  311. logprobs: Optional[int] = None
  312. max_tokens: Optional[int] = 16
  313. n: int = 1
  314. presence_penalty: Optional[float] = 0.0
  315. seed: Optional[int] = Field(None,
  316. ge=torch.iinfo(torch.long).min,
  317. le=torch.iinfo(torch.long).max)
  318. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  319. stream: Optional[bool] = False
  320. stream_options: Optional[StreamOptions] = None
  321. suffix: Optional[str] = None
  322. temperature: Optional[float] = 1.0
  323. top_p: Optional[float] = 1.0
  324. user: Optional[str] = None
  325. # doc: begin-completion-sampling-params
  326. use_beam_search: Optional[bool] = False
  327. top_k: Optional[int] = -1
  328. min_p: Optional[float] = 0.0
  329. top_a: Optional[float] = 0.0
  330. tfs: Optional[float] = 1.0
  331. eta_cutoff: Optional[float] = 0.0
  332. epsilon_cutoff: Optional[float] = 0.0
  333. typical_p: Optional[float] = 1.0
  334. smoothing_factor: Optional[float] = 0.0
  335. smoothing_curve: Optional[float] = 1.0
  336. repetition_penalty: Optional[float] = 1.0
  337. dry_multiplier: Optional[float] = 0.0
  338. dry_base: Optional[float] = 1.75
  339. dry_allowed_length: Optional[int] = 2
  340. dry_sequence_breakers: Optional[List[str]] = Field(default_factory=list)
  341. length_penalty: Optional[float] = 1.0
  342. early_stopping: Optional[bool] = False
  343. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  344. ignore_eos: Optional[bool] = False
  345. min_tokens: Optional[int] = 0
  346. skip_special_tokens: Optional[bool] = True
  347. spaces_between_special_tokens: Optional[bool] = True
  348. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  349. allowed_token_ids: Optional[List[int]] = None
  350. include_stop_str_in_output: Optional[bool] = False
  351. add_special_tokens: Optional[bool] = False
  352. temperature_last: Optional[bool] = False
  353. prompt_logprobs: Optional[int] = None
  354. # doc: end-completion-sampling-params
  355. # doc: begin-completion-extra-params
  356. response_format: Optional[ResponseFormat] = Field(
  357. default=None,
  358. description=
  359. ("Similar to chat completion, this parameter specifies the format of "
  360. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  361. "supported."),
  362. )
  363. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  364. default=None,
  365. description=("If specified, the output will follow the JSON schema."),
  366. )
  367. guided_regex: Optional[str] = Field(
  368. default=None,
  369. description=(
  370. "If specified, the output will follow the regex pattern."),
  371. )
  372. guided_choice: Optional[List[str]] = Field(
  373. default=None,
  374. description=(
  375. "If specified, the output will be exactly one of the choices."),
  376. )
  377. guided_grammar: Optional[str] = Field(
  378. default=None,
  379. description=(
  380. "If specified, the output will follow the context free grammar."),
  381. )
  382. guided_decoding_backend: Optional[str] = Field(
  383. default=None,
  384. description=(
  385. "If specified, will override the default guided decoding backend "
  386. "of the server for this specific request. If set, must be one of "
  387. "'outlines' / 'lm-format-enforcer'"))
  388. guided_whitespace_pattern: Optional[str] = Field(
  389. default=None,
  390. description=(
  391. "If specified, will override the default whitespace pattern "
  392. "for guided json decoding."))
  393. # doc: end-completion-extra-params
  394. def _tokenize_dry_sequence_breakers(self, tokenizer: PreTrainedTokenizer):
  395. if not self.dry_sequence_breakers:
  396. return []
  397. tokenized_breakers = []
  398. for breaker in self.dry_sequence_breakers:
  399. tokenized_breaker = tokenizer.encode(breaker, add_special_tokens=False)
  400. tokenized_breakers.extend(tokenized_breaker)
  401. return tokenized_breakers
  402. def to_sampling_params(
  403. self, tokenizer: PreTrainedTokenizer,
  404. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  405. default_max_tokens: int) -> SamplingParams:
  406. max_tokens = self.max_tokens
  407. if max_tokens is None:
  408. max_tokens = default_max_tokens
  409. echo_without_generation = self.echo and self.max_tokens == 0
  410. logits_processors = get_logits_processors(
  411. logit_bias=self.logit_bias,
  412. allowed_token_ids=self.allowed_token_ids,
  413. tokenizer=tokenizer,
  414. )
  415. if guided_decode_logits_processor:
  416. logits_processors.append(guided_decode_logits_processor)
  417. tokenized_dry_sequence_breakers = self._tokenize_dry_sequence_breakers(tokenizer)
  418. return SamplingParams(
  419. n=self.n,
  420. best_of=self.best_of,
  421. presence_penalty=self.presence_penalty,
  422. frequency_penalty=self.frequency_penalty,
  423. repetition_penalty=self.repetition_penalty,
  424. temperature=self.temperature,
  425. top_p=self.top_p,
  426. top_k=self.top_k,
  427. min_p=self.min_p,
  428. top_a=self.top_a,
  429. tfs=self.tfs,
  430. eta_cutoff=self.eta_cutoff,
  431. epsilon_cutoff=self.epsilon_cutoff,
  432. typical_p=self.typical_p,
  433. smoothing_factor=self.smoothing_factor,
  434. smoothing_curve=self.smoothing_curve,
  435. seed=self.seed,
  436. stop=self.stop,
  437. stop_token_ids=self.stop_token_ids,
  438. ignore_eos=self.ignore_eos,
  439. max_tokens=max_tokens if not echo_without_generation else 1,
  440. min_tokens=self.min_tokens,
  441. logprobs=self.logprobs,
  442. prompt_logprobs=self.prompt_logprobs
  443. if self.prompt_logprobs else self.logprobs if self.echo else None,
  444. use_beam_search=self.use_beam_search,
  445. early_stopping=self.early_stopping,
  446. skip_special_tokens=self.skip_special_tokens,
  447. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  448. include_stop_str_in_output=self.include_stop_str_in_output,
  449. length_penalty=self.length_penalty,
  450. logits_processors=logits_processors,
  451. truncate_prompt_tokens=self.truncate_prompt_tokens,
  452. temperature_last=self.temperature_last,
  453. dry_multiplier=self.dry_multiplier,
  454. dry_base=self.dry_base,
  455. dry_allowed_length=self.dry_allowed_length,
  456. dry_sequence_breakers=tokenized_dry_sequence_breakers,
  457. )
  458. @model_validator(mode="before")
  459. @classmethod
  460. def check_guided_decoding_count(cls, data):
  461. guide_count = sum([
  462. "guided_json" in data and data["guided_json"] is not None,
  463. "guided_regex" in data and data["guided_regex"] is not None,
  464. "guided_choice" in data and data["guided_choice"] is not None
  465. ])
  466. if guide_count > 1:
  467. raise ValueError(
  468. "You can only use one kind of guided decoding "
  469. "('guided_json', 'guided_regex' or 'guided_choice').")
  470. return data
  471. @model_validator(mode="before")
  472. @classmethod
  473. def check_logprobs(cls, data):
  474. if "logprobs" in data and data[
  475. "logprobs"] is not None and not data["logprobs"] >= 0:
  476. raise ValueError("if passed, `logprobs` must be a positive value.")
  477. return data
  478. @model_validator(mode="before")
  479. @classmethod
  480. def validate_stream_options(cls, data):
  481. if data.get("stream_options") and not data.get("stream"):
  482. raise ValueError(
  483. "Stream options can only be defined when stream is True.")
  484. return data
  485. class EmbeddingRequest(OpenAIBaseModel):
  486. # Ordered by official OpenAI API documentation
  487. # https://platform.openai.com/docs/api-reference/embeddings
  488. model: str
  489. input: Union[List[int], List[List[int]], str, List[str]]
  490. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  491. dimensions: Optional[int] = None
  492. user: Optional[str] = None
  493. # doc: begin-embedding-pooling-params
  494. additional_data: Optional[Any] = None
  495. # doc: end-embedding-pooling-params
  496. def to_pooling_params(self):
  497. return PoolingParams(additional_data=self.additional_data)
  498. class CompletionLogProbs(OpenAIBaseModel):
  499. text_offset: List[int] = Field(default_factory=list)
  500. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  501. tokens: List[str] = Field(default_factory=list)
  502. top_logprobs: List[Optional[Dict[str,
  503. float]]] = Field(default_factory=list)
  504. class CompletionResponseChoice(OpenAIBaseModel):
  505. index: int
  506. text: str
  507. logprobs: Optional[CompletionLogProbs] = None
  508. finish_reason: Optional[str] = None
  509. stop_reason: Optional[Union[int, str]] = Field(
  510. default=None,
  511. description=(
  512. "The stop string or token id that caused the completion "
  513. "to stop, None if the completion finished for some other reason "
  514. "including encountering the EOS token"),
  515. )
  516. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  517. class CompletionResponse(OpenAIBaseModel):
  518. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  519. object: str = "text_completion"
  520. created: int = Field(default_factory=lambda: int(time.time()))
  521. model: str
  522. choices: List[CompletionResponseChoice]
  523. usage: UsageInfo
  524. class CompletionResponseStreamChoice(OpenAIBaseModel):
  525. index: int
  526. text: str
  527. logprobs: Optional[CompletionLogProbs] = None
  528. finish_reason: Optional[str] = None
  529. stop_reason: Optional[Union[int, str]] = Field(
  530. default=None,
  531. description=(
  532. "The stop string or token id that caused the completion "
  533. "to stop, None if the completion finished for some other reason "
  534. "including encountering the EOS token"),
  535. )
  536. class CompletionStreamResponse(OpenAIBaseModel):
  537. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  538. object: str = "text_completion"
  539. created: int = Field(default_factory=lambda: int(time.time()))
  540. model: str
  541. choices: List[CompletionResponseStreamChoice]
  542. usage: Optional[UsageInfo] = Field(default=None)
  543. class EmbeddingResponseData(OpenAIBaseModel):
  544. index: int
  545. object: str = "embedding"
  546. embedding: Union[List[float], str]
  547. class EmbeddingResponse(OpenAIBaseModel):
  548. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  549. object: str = "list"
  550. created: int = Field(default_factory=lambda: int(time.time()))
  551. model: str
  552. data: List[EmbeddingResponseData]
  553. usage: UsageInfo
  554. class FunctionCall(OpenAIBaseModel):
  555. name: str
  556. arguments: str
  557. class ToolCall(OpenAIBaseModel):
  558. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  559. type: Literal["function"] = "function"
  560. function: FunctionCall
  561. class ChatMessage(OpenAIBaseModel):
  562. role: str
  563. content: str
  564. tool_calls: List[ToolCall] = Field(default_factory=list)
  565. class ChatCompletionLogProb(OpenAIBaseModel):
  566. token: str
  567. logprob: float = -9999.0
  568. bytes: Optional[List[int]] = None
  569. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  570. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  571. class ChatCompletionLogProbs(OpenAIBaseModel):
  572. content: Optional[List[ChatCompletionLogProbsContent]] = None
  573. class ChatCompletionResponseChoice(OpenAIBaseModel):
  574. index: int
  575. message: ChatMessage
  576. logprobs: Optional[ChatCompletionLogProbs] = None
  577. finish_reason: Optional[str] = None
  578. stop_reason: Optional[Union[int, str]] = None
  579. class ChatCompletionResponse(OpenAIBaseModel):
  580. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  581. object: Literal["chat.completion"] = "chat.completion"
  582. created: int = Field(default_factory=lambda: int(time.time()))
  583. model: str
  584. choices: List[ChatCompletionResponseChoice]
  585. usage: UsageInfo
  586. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  587. class DeltaMessage(OpenAIBaseModel):
  588. role: Optional[str] = None
  589. content: Optional[str] = None
  590. tool_calls: List[ToolCall] = Field(default_factory=list)
  591. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  592. index: int
  593. delta: DeltaMessage
  594. logprobs: Optional[ChatCompletionLogProbs] = None
  595. finish_reason: Optional[str] = None
  596. stop_reason: Optional[Union[int, str]] = None
  597. class ChatCompletionStreamResponse(OpenAIBaseModel):
  598. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  599. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  600. created: int = Field(default_factory=lambda: int(time.time()))
  601. model: str
  602. choices: List[ChatCompletionResponseStreamChoice]
  603. usage: Optional[UsageInfo] = Field(default=None)
  604. class BatchRequestInput(OpenAIBaseModel):
  605. """
  606. The per-line object of the batch input file.
  607. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  608. """
  609. # A developer-provided per-request id that will be used to match outputs to
  610. # inputs. Must be unique for each request in a batch.
  611. custom_id: str
  612. # The HTTP method to be used for the request. Currently only POST is
  613. # supported.
  614. method: str
  615. # The OpenAI API relative URL to be used for the request. Currently
  616. # /v1/chat/completions is supported.
  617. url: str
  618. # The parameters of the request.
  619. body: Union[ChatCompletionRequest, EmbeddingRequest]
  620. class BatchResponseData(OpenAIBaseModel):
  621. # HTTP status code of the response.
  622. status_code: int = 200
  623. # An unique identifier for the API request.
  624. request_id: str
  625. # The body of the response.
  626. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  627. class BatchRequestOutput(OpenAIBaseModel):
  628. """
  629. The per-line object of the batch output and error files
  630. """
  631. id: str
  632. # A developer-provided per-request id that will be used to match outputs to
  633. # inputs.
  634. custom_id: str
  635. response: Optional[BatchResponseData]
  636. # For requests that failed with a non-HTTP error, this will contain more
  637. # information on the cause of the failure.
  638. error: Optional[Any]
  639. class TokenizeCompletionRequest(OpenAIBaseModel):
  640. model: str
  641. prompt: str
  642. add_special_tokens: bool = Field(default=True)
  643. class TokenizeChatRequest(OpenAIBaseModel):
  644. model: str
  645. messages: List[ChatCompletionMessageParam]
  646. add_generation_prompt: bool = Field(default=True)
  647. add_special_tokens: bool = Field(default=False)
  648. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  649. class TokenizeResponse(OpenAIBaseModel):
  650. tokens: List[int]
  651. count: int
  652. max_model_len: int
  653. class DetokenizeRequest(OpenAIBaseModel):
  654. model: Optional[str]
  655. tokens: List[int]
  656. class DetokenizeResponse(OpenAIBaseModel):
  657. prompt: str
  658. # ========== KoboldAI ========== #
  659. class KoboldSamplingParams(BaseModel):
  660. n: int = Field(1, alias="n")
  661. best_of: Optional[int] = Field(None, alias="best_of")
  662. presence_penalty: float = Field(0.0, alias="presence_penalty")
  663. frequency_penalty: float = Field(0.0, alias="rep_pen")
  664. temperature: float = Field(1.0, alias="temperature")
  665. dynatemp_range: Optional[float] = 0.0
  666. dynatemp_exponent: Optional[float] = 1.0
  667. smoothing_factor: Optional[float] = 0.0
  668. smoothing_curve: Optional[float] = 1.0
  669. top_p: float = Field(1.0, alias="top_p")
  670. top_k: float = Field(-1, alias="top_k")
  671. min_p: float = Field(0.0, alias="min_p")
  672. top_a: float = Field(0.0, alias="top_a")
  673. tfs: float = Field(1.0, alias="tfs")
  674. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  675. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  676. typical_p: float = Field(1.0, alias="typical_p")
  677. use_beam_search: bool = Field(False, alias="use_beam_search")
  678. length_penalty: float = Field(1.0, alias="length_penalty")
  679. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  680. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  681. include_stop_str_in_output: Optional[bool] = False
  682. ignore_eos: bool = Field(False, alias="ignore_eos")
  683. max_tokens: int = Field(16, alias="max_length")
  684. logprobs: Optional[int] = Field(None, alias="logprobs")
  685. custom_token_bans: Optional[List[int]] = Field(None,
  686. alias="custom_token_bans")
  687. @root_validator(pre=False, skip_on_failure=True)
  688. def validate_best_of(cls, values): # pylint: disable=no-self-argument
  689. best_of = values.get("best_of")
  690. n = values.get("n")
  691. if best_of is not None and (best_of <= 0 or best_of > n):
  692. raise ValueError(
  693. "best_of must be a positive integer less than or equal to n")
  694. return values
  695. class KAIGenerationInputSchema(BaseModel):
  696. genkey: Optional[str] = None
  697. prompt: str
  698. n: Optional[int] = 1
  699. max_context_length: int
  700. max_length: int
  701. rep_pen: Optional[float] = 1.0
  702. rep_pen_range: Optional[int] = None
  703. rep_pen_slope: Optional[float] = None
  704. top_k: Optional[int] = 0
  705. top_a: Optional[float] = 0.0
  706. top_p: Optional[float] = 1.0
  707. min_p: Optional[float] = 0.0
  708. tfs: Optional[float] = 1.0
  709. eps_cutoff: Optional[float] = 0.0
  710. eta_cutoff: Optional[float] = 0.0
  711. typical: Optional[float] = 1.0
  712. temperature: Optional[float] = 1.0
  713. dynatemp_range: Optional[float] = 0.0
  714. dynatemp_exponent: Optional[float] = 1.0
  715. smoothing_factor: Optional[float] = 0.0
  716. smoothing_curve: Optional[float] = 1.0
  717. use_memory: Optional[bool] = None
  718. use_story: Optional[bool] = None
  719. use_authors_note: Optional[bool] = None
  720. use_world_info: Optional[bool] = None
  721. use_userscripts: Optional[bool] = None
  722. soft_prompt: Optional[str] = None
  723. disable_output_formatting: Optional[bool] = None
  724. frmtrmblln: Optional[bool] = None
  725. frmtrmspch: Optional[bool] = None
  726. singleline: Optional[bool] = None
  727. use_default_badwordsids: Optional[bool] = None
  728. mirostat: Optional[int] = 0
  729. mirostat_tau: Optional[float] = 0.0
  730. mirostat_eta: Optional[float] = 0.0
  731. disable_input_formatting: Optional[bool] = None
  732. frmtadsnsp: Optional[bool] = None
  733. quiet: Optional[bool] = None
  734. # pylint: disable=unexpected-keyword-arg
  735. sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
  736. sampler_seed: Optional[int] = None
  737. sampler_full_determinism: Optional[bool] = None
  738. stop_sequence: Optional[List[str]] = None
  739. include_stop_str_in_output: Optional[bool] = False
  740. @root_validator(pre=False, skip_on_failure=True)
  741. def check_context(cls, values): # pylint: disable=no-self-argument
  742. assert values.get("max_length") <= values.get(
  743. "max_context_length"
  744. ), "max_length must not be larger than max_context_length"
  745. return values