protocol.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import json
  4. import time
  5. from typing import Any, Dict, List, Literal, Optional, Union
  6. import torch
  7. from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
  8. model_validator)
  9. from transformers import PreTrainedTokenizer
  10. from typing_extensions import Annotated
  11. from aphrodite.common.pooling_params import PoolingParams
  12. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  13. SamplingParams)
  14. from aphrodite.common.sequence import Logprob
  15. from aphrodite.common.utils import random_uuid
  16. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  17. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  18. class OpenAIBaseModel(BaseModel):
  19. model_config = ConfigDict(extra="ignore")
  20. class ErrorResponse(OpenAIBaseModel):
  21. object: str = "error"
  22. message: str
  23. type: str
  24. param: Optional[str] = None
  25. code: int
  26. class ModelPermission(OpenAIBaseModel):
  27. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  28. object: str = "model_permission"
  29. created: int = Field(default_factory=lambda: int(time.time()))
  30. allow_create_engine: bool = False
  31. allow_sampling: bool = True
  32. allow_logprobs: bool = True
  33. allow_search_indices: bool = False
  34. allow_view: bool = True
  35. allow_fine_tuning: bool = False
  36. organization: str = "*"
  37. group: Optional[str] = None
  38. is_blocking: bool = False
  39. class ModelCard(OpenAIBaseModel):
  40. id: str
  41. object: str = "model"
  42. created: int = Field(default_factory=lambda: int(time.time()))
  43. owned_by: str = "pygmalionai"
  44. root: Optional[str] = None
  45. parent: Optional[str] = None
  46. max_model_len: Optional[int] = None
  47. permission: List[ModelPermission] = Field(default_factory=list)
  48. class ModelList(OpenAIBaseModel):
  49. object: str = "list"
  50. data: List[ModelCard] = Field(default_factory=list)
  51. class UsageInfo(OpenAIBaseModel):
  52. prompt_tokens: int = 0
  53. total_tokens: int = 0
  54. completion_tokens: Optional[int] = 0
  55. class JsonSchemaResponseFormat(OpenAIBaseModel):
  56. name: str
  57. description: Optional[str] = None
  58. # schema is the field in openai but that causes conflicts with pydantic so
  59. # instead use json_schema with an alias
  60. json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
  61. strict: Optional[bool] = None
  62. class ResponseFormat(OpenAIBaseModel):
  63. # type must be "json_schema", "json_object" or "text"
  64. type: Literal["text", "json_object", "json_schema"]
  65. json_schema: Optional[JsonSchemaResponseFormat] = None
  66. class StreamOptions(OpenAIBaseModel):
  67. include_usage: Optional[bool] = True
  68. continuous_usage_stats: Optional[bool] = True
  69. class FunctionDefinition(OpenAIBaseModel):
  70. name: str
  71. description: Optional[str] = None
  72. parameters: Optional[Dict[str, Any]] = None
  73. class ChatCompletionToolsParam(OpenAIBaseModel):
  74. type: Literal["function"] = "function"
  75. function: FunctionDefinition
  76. class ChatCompletionNamedFunction(OpenAIBaseModel):
  77. name: str
  78. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  79. function: ChatCompletionNamedFunction
  80. type: Literal["function"] = "function"
  81. class ChatCompletionRequest(OpenAIBaseModel):
  82. # Ordered by official OpenAI API documentation
  83. # https://platform.openai.com/docs/api-reference/chat/create
  84. messages: List[ChatCompletionMessageParam]
  85. model: str
  86. frequency_penalty: Optional[float] = 0.0
  87. logit_bias: Optional[Dict[str, float]] = None
  88. logprobs: Optional[bool] = False
  89. top_logprobs: Optional[int] = 0
  90. max_tokens: Optional[int] = None
  91. n: Optional[int] = 1
  92. presence_penalty: Optional[float] = 0.0
  93. response_format: Optional[ResponseFormat] = None
  94. seed: Optional[int] = Field(None,
  95. ge=torch.iinfo(torch.long).min,
  96. le=torch.iinfo(torch.long).max)
  97. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  98. stream: Optional[bool] = False
  99. stream_options: Optional[StreamOptions] = None
  100. temperature: Optional[float] = 0.7
  101. top_p: Optional[float] = 1.0
  102. tools: Optional[List[ChatCompletionToolsParam]] = None
  103. tool_choice: Optional[Union[Literal["none"],
  104. ChatCompletionNamedToolChoiceParam]] = "none"
  105. user: Optional[str] = None
  106. # doc: begin-chat-completion-sampling-params
  107. best_of: Optional[int] = None
  108. use_beam_search: Optional[bool] = False
  109. top_k: Optional[int] = -1
  110. min_p: Optional[float] = 0.0
  111. top_a: Optional[float] = 0.0
  112. tfs: Optional[float] = 1.0
  113. eta_cutoff: Optional[float] = 0.0
  114. epsilon_cutoff: Optional[float] = 0.0
  115. typical_p: Optional[float] = 1.0
  116. smoothing_factor: Optional[float] = 0.0
  117. smoothing_curve: Optional[float] = 1.0
  118. repetition_penalty: Optional[float] = 1.0
  119. no_repeat_ngram_size: Optional[int] = 0
  120. length_penalty: Optional[float] = 1.0
  121. early_stopping: Optional[bool] = False
  122. ignore_eos: Optional[bool] = False
  123. min_tokens: Optional[int] = 0
  124. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  125. skip_special_tokens: Optional[bool] = True
  126. spaces_between_special_tokens: Optional[bool] = True
  127. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  128. temperature_last: Optional[bool] = False
  129. prompt_logprobs: Optional[int] = None
  130. xtc_threshold: Optional[float] = 0.1
  131. xtc_probability: Optional[float] = 0.0
  132. dry_multiplier: Optional[float] = 0
  133. dry_base: Optional[float] = 1.75
  134. dry_allowed_length: Optional[int] = 2
  135. dry_sequence_breakers: Optional[List[str]] = Field(
  136. default=["\n", ":", "\"", "*"])
  137. dry_range: Optional[int] = Field(
  138. default=0,
  139. validation_alias=AliasChoices("dry_range",
  140. "dry_penalty_last_n"))
  141. dynatemp_min: Optional[float] = 0.0
  142. dynatemp_max: Optional[float] = 0.0
  143. dynatemp_exponent: Optional[float] = 1.0
  144. nsigma: Optional[float] = 0.0
  145. skew: Optional[float] = 0.0
  146. custom_token_bans: Optional[List[int]] = None
  147. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  148. default=[],
  149. validation_alias=AliasChoices("sampler_priority",
  150. "sampler_order"))
  151. # doc: end-chat-completion-sampling-params
  152. # doc: begin-chat-completion-extra-params
  153. echo: Optional[bool] = Field(
  154. default=False,
  155. description=(
  156. "If true, the new message will be prepended with the last message "
  157. "if they belong to the same role."),
  158. )
  159. add_generation_prompt: Optional[bool] = Field(
  160. default=True,
  161. description=
  162. ("If true, the generation prompt will be added to the chat template. "
  163. "This is a parameter used by chat template in tokenizer config of the "
  164. "model."),
  165. )
  166. add_special_tokens: Optional[bool] = Field(
  167. default=False,
  168. description=(
  169. "If true, special tokens (e.g. BOS) will be added to the prompt "
  170. "on top of what is added by the chat template. "
  171. "For most models, the chat template takes care of adding the "
  172. "special tokens so this should be set to False (as is the "
  173. "default)."),
  174. )
  175. documents: Optional[List[Dict[str, str]]] = Field(
  176. default=None,
  177. description=
  178. ("A list of dicts representing documents that will be accessible to "
  179. "the model if it is performing RAG (retrieval-augmented generation)."
  180. " If the template does not support RAG, this argument will have no "
  181. "effect. We recommend that each document should be a dict containing "
  182. "\"title\" and \"text\" keys."),
  183. )
  184. chat_template: Optional[str] = Field(
  185. default=None,
  186. description=(
  187. "A Jinja template to use for this conversion. "
  188. "As of transformers v4.44, default chat template is no longer "
  189. "allowed, so you must provide a chat template if the tokenizer "
  190. "does not define one."),
  191. )
  192. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  193. default=None,
  194. description=("Additional kwargs to pass to the template renderer. "
  195. "Will be accessible by the chat template."),
  196. )
  197. include_stop_str_in_output: Optional[bool] = Field(
  198. default=False,
  199. description=(
  200. "Whether to include the stop string in the output. "
  201. "This is only applied when the stop or stop_token_ids is set."),
  202. )
  203. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  204. default=None,
  205. description=("If specified, the output will follow the JSON schema."),
  206. )
  207. guided_regex: Optional[str] = Field(
  208. default=None,
  209. description=(
  210. "If specified, the output will follow the regex pattern."),
  211. )
  212. guided_choice: Optional[List[str]] = Field(
  213. default=None,
  214. description=(
  215. "If specified, the output will be exactly one of the choices."),
  216. )
  217. guided_grammar: Optional[str] = Field(
  218. default=None,
  219. description=(
  220. "If specified, the output will follow the context free grammar."),
  221. )
  222. guided_decoding_backend: Optional[str] = Field(
  223. default=None,
  224. description=(
  225. "If specified, will override the default guided decoding backend "
  226. "of the server for this specific request. If set, must be either "
  227. "'outlines' / 'lm-format-enforcer'"))
  228. guided_whitespace_pattern: Optional[str] = Field(
  229. default=None,
  230. description=(
  231. "If specified, will override the default whitespace pattern "
  232. "for guided json decoding."))
  233. # doc: end-chat-completion-extra-params
  234. def to_sampling_params(
  235. self, tokenizer: PreTrainedTokenizer,
  236. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  237. default_max_tokens: int) -> SamplingParams:
  238. max_tokens = self.max_tokens
  239. if max_tokens is None:
  240. max_tokens = default_max_tokens
  241. # We now allow logprobs being true without top_logrobs.
  242. logits_processors = get_logits_processors(
  243. logit_bias=self.logit_bias,
  244. allowed_token_ids=None,
  245. tokenizer=tokenizer,
  246. )
  247. if guided_decode_logits_processor:
  248. logits_processors.append(guided_decode_logits_processor)
  249. dry_sequence_breaker_ids = []
  250. if self.dry_sequence_breakers:
  251. for s in self.dry_sequence_breakers:
  252. token_id = tokenizer.encode(f'a{s}')[-1]
  253. dry_sequence_breaker_ids.append(token_id)
  254. return SamplingParams(
  255. n=self.n,
  256. presence_penalty=self.presence_penalty,
  257. frequency_penalty=self.frequency_penalty,
  258. repetition_penalty=self.repetition_penalty,
  259. no_repeat_ngram_size=self.no_repeat_ngram_size,
  260. temperature=self.temperature,
  261. top_p=self.top_p,
  262. min_p=self.min_p,
  263. seed=self.seed,
  264. stop=self.stop,
  265. stop_token_ids=self.stop_token_ids,
  266. max_tokens=max_tokens,
  267. min_tokens=self.min_tokens,
  268. logprobs=self.top_logprobs if self.logprobs else None,
  269. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  270. (self.top_logprobs if self.echo else None),
  271. best_of=self.best_of,
  272. top_k=self.top_k,
  273. top_a=self.top_a,
  274. tfs=self.tfs,
  275. eta_cutoff=self.eta_cutoff,
  276. epsilon_cutoff=self.epsilon_cutoff,
  277. typical_p=self.typical_p,
  278. smoothing_factor=self.smoothing_factor,
  279. smoothing_curve=self.smoothing_curve,
  280. ignore_eos=self.ignore_eos,
  281. use_beam_search=self.use_beam_search,
  282. early_stopping=self.early_stopping,
  283. skip_special_tokens=self.skip_special_tokens,
  284. spaces_between_special_tokens=self.spaces_between_special_tokens,
  285. include_stop_str_in_output=self.include_stop_str_in_output,
  286. length_penalty=self.length_penalty,
  287. logits_processors=logits_processors,
  288. temperature_last=self.temperature_last,
  289. xtc_threshold=self.xtc_threshold,
  290. xtc_probability=self.xtc_probability,
  291. dry_multiplier=self.dry_multiplier,
  292. dry_base=self.dry_base,
  293. dry_allowed_length=self.dry_allowed_length,
  294. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  295. dry_range=self.dry_range,
  296. dynatemp_min=self.dynatemp_min,
  297. dynatemp_max=self.dynatemp_max,
  298. dynatemp_exponent=self.dynatemp_exponent,
  299. nsigma=self.nsigma,
  300. skew=self.skew,
  301. custom_token_bans=self.custom_token_bans,
  302. sampler_priority=self.sampler_priority,
  303. )
  304. @model_validator(mode='before')
  305. @classmethod
  306. def validate_stream_options(cls, values):
  307. if (values.get('stream_options') is not None
  308. and not values.get('stream')):
  309. raise ValueError(
  310. "stream_options can only be set if stream is true")
  311. return values
  312. @model_validator(mode="before")
  313. @classmethod
  314. def check_guided_decoding_count(cls, data):
  315. guide_count = sum([
  316. "guided_json" in data and data["guided_json"] is not None,
  317. "guided_regex" in data and data["guided_regex"] is not None,
  318. "guided_choice" in data and data["guided_choice"] is not None
  319. ])
  320. # you can only use one kind of guided decoding
  321. if guide_count > 1:
  322. raise ValueError(
  323. "You can only use one kind of guided decoding "
  324. "('guided_json', 'guided_regex' or 'guided_choice').")
  325. # you can only either use guided decoding or tools, not both
  326. if guide_count > 1 and "tool_choice" in data and data[
  327. "tool_choice"] != "none":
  328. raise ValueError(
  329. "You can only either use guided decoding or tools, not both.")
  330. return data
  331. @model_validator(mode="before")
  332. @classmethod
  333. def check_tool_choice(cls, data):
  334. if "tool_choice" in data and data["tool_choice"] != "none":
  335. if not isinstance(data["tool_choice"], dict):
  336. raise ValueError("Currently only named tools are supported.")
  337. if "tools" not in data or data["tools"] is None:
  338. raise ValueError(
  339. "When using `tool_choice`, `tools` must be set.")
  340. return data
  341. @model_validator(mode="before")
  342. @classmethod
  343. def check_logprobs(cls, data):
  344. if "top_logprobs" in data and data["top_logprobs"] is not None:
  345. if "logprobs" not in data or data["logprobs"] is False:
  346. raise ValueError(
  347. "when using `top_logprobs`, `logprobs` must be set to true."
  348. )
  349. elif data["top_logprobs"] < 0:
  350. raise ValueError(
  351. "`top_logprobs` must be a value a positive value.")
  352. return data
  353. class CompletionRequest(OpenAIBaseModel):
  354. # Ordered by official OpenAI API documentation
  355. # https://platform.openai.com/docs/api-reference/completions/create
  356. model: str
  357. prompt: Union[List[int], List[List[int]], str, List[str]]
  358. best_of: Optional[int] = None
  359. echo: Optional[bool] = False
  360. frequency_penalty: Optional[float] = 0.0
  361. logit_bias: Optional[Dict[str, float]] = None
  362. logprobs: Optional[int] = None
  363. max_tokens: Optional[int] = 16
  364. n: int = 1
  365. presence_penalty: Optional[float] = 0.0
  366. seed: Optional[int] = Field(None,
  367. ge=torch.iinfo(torch.long).min,
  368. le=torch.iinfo(torch.long).max)
  369. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  370. stream: Optional[bool] = False
  371. stream_options: Optional[StreamOptions] = None
  372. suffix: Optional[str] = None
  373. temperature: Optional[float] = 1.0
  374. top_p: Optional[float] = 1.0
  375. user: Optional[str] = None
  376. # doc: begin-completion-sampling-params
  377. use_beam_search: Optional[bool] = False
  378. top_k: Optional[int] = -1
  379. min_p: Optional[float] = 0.0
  380. top_a: Optional[float] = 0.0
  381. tfs: Optional[float] = 1.0
  382. eta_cutoff: Optional[float] = 0.0
  383. epsilon_cutoff: Optional[float] = 0.0
  384. typical_p: Optional[float] = 1.0
  385. smoothing_factor: Optional[float] = 0.0
  386. smoothing_curve: Optional[float] = 1.0
  387. repetition_penalty: Optional[float] = 1.0
  388. no_repeat_ngram_size: Optional[int] = 0
  389. length_penalty: Optional[float] = 1.0
  390. early_stopping: Optional[bool] = False
  391. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  392. ignore_eos: Optional[bool] = False
  393. min_tokens: Optional[int] = 0
  394. skip_special_tokens: Optional[bool] = True
  395. spaces_between_special_tokens: Optional[bool] = True
  396. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  397. allowed_token_ids: Optional[List[int]] = None
  398. include_stop_str_in_output: Optional[bool] = False
  399. add_special_tokens: Optional[bool] = False
  400. temperature_last: Optional[bool] = False
  401. prompt_logprobs: Optional[int] = None
  402. xtc_threshold: Optional[float] = 0.1
  403. xtc_probability: Optional[float] = 0.0
  404. dry_multiplier: Optional[float] = 0
  405. dry_base: Optional[float] = 1.75
  406. dry_allowed_length: Optional[int] = 2
  407. dry_sequence_breakers: Optional[List[str]] = Field(
  408. default=["\n", ":", "\"", "*"])
  409. dry_range: Optional[int] = Field(
  410. default=0,
  411. validation_alias=AliasChoices("dry_range",
  412. "dry_penalty_last_n"))
  413. dynatemp_min: Optional[float] = 0.0
  414. dynatemp_max: Optional[float] = 0.0
  415. dynatemp_exponent: Optional[float] = 1.0
  416. nsigma: Optional[float] = 0.0
  417. skew: Optional[float] = 0.0
  418. custom_token_bans: Optional[List[int]] = None
  419. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  420. default=[],
  421. validation_alias=AliasChoices("sampler_priority",
  422. "sampler_order"))
  423. # doc: end-completion-sampling-params
  424. # doc: begin-completion-extra-params
  425. response_format: Optional[ResponseFormat] = Field(
  426. default=None,
  427. description=
  428. ("Similar to chat completion, this parameter specifies the format of "
  429. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  430. "supported."),
  431. )
  432. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  433. default=None,
  434. description=("If specified, the output will follow the JSON schema."),
  435. )
  436. guided_regex: Optional[str] = Field(
  437. default=None,
  438. description=(
  439. "If specified, the output will follow the regex pattern."),
  440. )
  441. guided_choice: Optional[List[str]] = Field(
  442. default=None,
  443. description=(
  444. "If specified, the output will be exactly one of the choices."),
  445. )
  446. guided_grammar: Optional[str] = Field(
  447. default=None,
  448. description=(
  449. "If specified, the output will follow the context free grammar."),
  450. )
  451. guided_decoding_backend: Optional[str] = Field(
  452. default=None,
  453. description=(
  454. "If specified, will override the default guided decoding backend "
  455. "of the server for this specific request. If set, must be one of "
  456. "'outlines' / 'lm-format-enforcer'"))
  457. guided_whitespace_pattern: Optional[str] = Field(
  458. default=None,
  459. description=(
  460. "If specified, will override the default whitespace pattern "
  461. "for guided json decoding."))
  462. # doc: end-completion-extra-params
  463. def to_sampling_params(
  464. self, tokenizer: PreTrainedTokenizer,
  465. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  466. default_max_tokens: int) -> SamplingParams:
  467. max_tokens = self.max_tokens
  468. if max_tokens is None:
  469. max_tokens = default_max_tokens
  470. echo_without_generation = self.echo and self.max_tokens == 0
  471. logits_processors = get_logits_processors(
  472. logit_bias=self.logit_bias,
  473. allowed_token_ids=self.allowed_token_ids,
  474. tokenizer=tokenizer,
  475. )
  476. if guided_decode_logits_processor:
  477. logits_processors.append(guided_decode_logits_processor)
  478. dry_sequence_breaker_ids = []
  479. if self.dry_sequence_breakers:
  480. for s in self.dry_sequence_breakers:
  481. s = bytes(s, "utf-8").decode("unicode_escape")
  482. token_id = tokenizer.encode(f'a{s}')[-1]
  483. dry_sequence_breaker_ids.append(token_id)
  484. return SamplingParams(
  485. n=self.n,
  486. best_of=self.best_of,
  487. presence_penalty=self.presence_penalty,
  488. frequency_penalty=self.frequency_penalty,
  489. repetition_penalty=self.repetition_penalty,
  490. no_repeat_ngram_size=self.no_repeat_ngram_size,
  491. temperature=self.temperature,
  492. top_p=self.top_p,
  493. top_k=self.top_k,
  494. min_p=self.min_p,
  495. top_a=self.top_a,
  496. tfs=self.tfs,
  497. eta_cutoff=self.eta_cutoff,
  498. epsilon_cutoff=self.epsilon_cutoff,
  499. typical_p=self.typical_p,
  500. smoothing_factor=self.smoothing_factor,
  501. smoothing_curve=self.smoothing_curve,
  502. seed=self.seed,
  503. stop=self.stop,
  504. stop_token_ids=self.stop_token_ids,
  505. ignore_eos=self.ignore_eos,
  506. max_tokens=max_tokens if not echo_without_generation else 1,
  507. min_tokens=self.min_tokens,
  508. logprobs=self.logprobs,
  509. prompt_logprobs=self.prompt_logprobs
  510. if self.prompt_logprobs else self.logprobs if self.echo else None,
  511. use_beam_search=self.use_beam_search,
  512. early_stopping=self.early_stopping,
  513. skip_special_tokens=self.skip_special_tokens,
  514. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  515. include_stop_str_in_output=self.include_stop_str_in_output,
  516. length_penalty=self.length_penalty,
  517. logits_processors=logits_processors,
  518. truncate_prompt_tokens=self.truncate_prompt_tokens,
  519. temperature_last=self.temperature_last,
  520. xtc_threshold=self.xtc_threshold,
  521. xtc_probability=self.xtc_probability,
  522. dry_multiplier=self.dry_multiplier,
  523. dry_base=self.dry_base,
  524. dry_allowed_length=self.dry_allowed_length,
  525. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  526. dry_range=self.dry_range,
  527. dynatemp_min=self.dynatemp_min,
  528. dynatemp_max=self.dynatemp_max,
  529. dynatemp_exponent=self.dynatemp_exponent,
  530. nsigma=self.nsigma,
  531. skew=self.skew,
  532. custom_token_bans=self.custom_token_bans,
  533. sampler_priority=self.sampler_priority,
  534. )
  535. @model_validator(mode="before")
  536. @classmethod
  537. def check_guided_decoding_count(cls, data):
  538. guide_count = sum([
  539. "guided_json" in data and data["guided_json"] is not None,
  540. "guided_regex" in data and data["guided_regex"] is not None,
  541. "guided_choice" in data and data["guided_choice"] is not None
  542. ])
  543. if guide_count > 1:
  544. raise ValueError(
  545. "You can only use one kind of guided decoding "
  546. "('guided_json', 'guided_regex' or 'guided_choice').")
  547. return data
  548. @model_validator(mode="before")
  549. @classmethod
  550. def check_logprobs(cls, data):
  551. if "logprobs" in data and data[
  552. "logprobs"] is not None and not data["logprobs"] >= 0:
  553. raise ValueError("if passed, `logprobs` must be a positive value.")
  554. return data
  555. @model_validator(mode="before")
  556. @classmethod
  557. def validate_stream_options(cls, data):
  558. if data.get("stream_options") and not data.get("stream"):
  559. raise ValueError(
  560. "Stream options can only be defined when stream is True.")
  561. return data
  562. @model_validator(mode='before')
  563. @classmethod
  564. def parse_dry_sequence_breakers(cls, data):
  565. if 'dry_sequence_breakers' in data:
  566. breakers = data['dry_sequence_breakers']
  567. if isinstance(breakers, str):
  568. try:
  569. # Try to parse as JSON string
  570. data['dry_sequence_breakers'] = json.loads(breakers)
  571. except json.JSONDecodeError as e:
  572. raise ValueError(f"Invalid JSON for dry_sequence_breakers:"
  573. f" {e}") from e
  574. # Validate that we now have a list of strings
  575. is_list = isinstance(data['dry_sequence_breakers'], list)
  576. all_strings = all(
  577. isinstance(x, str)
  578. for x in data['dry_sequence_breakers']
  579. )
  580. if not is_list or not all_strings:
  581. raise ValueError(
  582. "dry_sequence_breakers must be a list of strings or a "
  583. "JSON string representing a list of strings"
  584. )
  585. return data
  586. class EmbeddingRequest(OpenAIBaseModel):
  587. # Ordered by official OpenAI API documentation
  588. # https://platform.openai.com/docs/api-reference/embeddings
  589. model: str
  590. input: Union[List[int], List[List[int]], str, List[str]]
  591. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  592. dimensions: Optional[int] = None
  593. user: Optional[str] = None
  594. # doc: begin-embedding-pooling-params
  595. additional_data: Optional[Any] = None
  596. # doc: end-embedding-pooling-params
  597. def to_pooling_params(self):
  598. return PoolingParams(additional_data=self.additional_data)
  599. class CompletionLogProbs(OpenAIBaseModel):
  600. text_offset: List[int] = Field(default_factory=list)
  601. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  602. tokens: List[str] = Field(default_factory=list)
  603. top_logprobs: List[Optional[Dict[str,
  604. float]]] = Field(default_factory=list)
  605. class CompletionResponseChoice(OpenAIBaseModel):
  606. index: int
  607. text: str
  608. logprobs: Optional[CompletionLogProbs] = None
  609. finish_reason: Optional[str] = None
  610. stop_reason: Optional[Union[int, str]] = Field(
  611. default=None,
  612. description=(
  613. "The stop string or token id that caused the completion "
  614. "to stop, None if the completion finished for some other reason "
  615. "including encountering the EOS token"),
  616. )
  617. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  618. class CompletionResponse(OpenAIBaseModel):
  619. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  620. object: str = "text_completion"
  621. created: int = Field(default_factory=lambda: int(time.time()))
  622. model: str
  623. choices: List[CompletionResponseChoice]
  624. usage: UsageInfo
  625. class CompletionResponseStreamChoice(OpenAIBaseModel):
  626. index: int
  627. text: str
  628. logprobs: Optional[CompletionLogProbs] = None
  629. finish_reason: Optional[str] = None
  630. stop_reason: Optional[Union[int, str]] = Field(
  631. default=None,
  632. description=(
  633. "The stop string or token id that caused the completion "
  634. "to stop, None if the completion finished for some other reason "
  635. "including encountering the EOS token"),
  636. )
  637. class CompletionStreamResponse(OpenAIBaseModel):
  638. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  639. object: str = "text_completion"
  640. created: int = Field(default_factory=lambda: int(time.time()))
  641. model: str
  642. choices: List[CompletionResponseStreamChoice]
  643. usage: Optional[UsageInfo] = Field(default=None)
  644. class EmbeddingResponseData(OpenAIBaseModel):
  645. index: int
  646. object: str = "embedding"
  647. embedding: Union[List[float], str]
  648. class EmbeddingResponse(OpenAIBaseModel):
  649. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  650. object: str = "list"
  651. created: int = Field(default_factory=lambda: int(time.time()))
  652. model: str
  653. data: List[EmbeddingResponseData]
  654. usage: UsageInfo
  655. class FunctionCall(OpenAIBaseModel):
  656. name: str
  657. arguments: str
  658. class ToolCall(OpenAIBaseModel):
  659. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  660. type: Literal["function"] = "function"
  661. function: FunctionCall
  662. class ChatMessage(OpenAIBaseModel):
  663. role: str
  664. content: str
  665. tool_calls: List[ToolCall] = Field(default_factory=list)
  666. class ChatCompletionLogProb(OpenAIBaseModel):
  667. token: str
  668. logprob: float = -9999.0
  669. bytes: Optional[List[int]] = None
  670. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  671. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  672. class ChatCompletionLogProbs(OpenAIBaseModel):
  673. content: Optional[List[ChatCompletionLogProbsContent]] = None
  674. class ChatCompletionResponseChoice(OpenAIBaseModel):
  675. index: int
  676. message: ChatMessage
  677. logprobs: Optional[ChatCompletionLogProbs] = None
  678. finish_reason: Optional[str] = None
  679. stop_reason: Optional[Union[int, str]] = None
  680. class ChatCompletionResponse(OpenAIBaseModel):
  681. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  682. object: Literal["chat.completion"] = "chat.completion"
  683. created: int = Field(default_factory=lambda: int(time.time()))
  684. model: str
  685. choices: List[ChatCompletionResponseChoice]
  686. usage: UsageInfo
  687. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  688. class DeltaMessage(OpenAIBaseModel):
  689. role: Optional[str] = None
  690. content: Optional[str] = None
  691. tool_calls: List[ToolCall] = Field(default_factory=list)
  692. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  693. index: int
  694. delta: DeltaMessage
  695. logprobs: Optional[ChatCompletionLogProbs] = None
  696. finish_reason: Optional[str] = None
  697. stop_reason: Optional[Union[int, str]] = None
  698. class ChatCompletionStreamResponse(OpenAIBaseModel):
  699. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  700. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  701. created: int = Field(default_factory=lambda: int(time.time()))
  702. model: str
  703. choices: List[ChatCompletionResponseStreamChoice]
  704. usage: Optional[UsageInfo] = Field(default=None)
  705. class BatchRequestInput(OpenAIBaseModel):
  706. """
  707. The per-line object of the batch input file.
  708. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  709. """
  710. # A developer-provided per-request id that will be used to match outputs to
  711. # inputs. Must be unique for each request in a batch.
  712. custom_id: str
  713. # The HTTP method to be used for the request. Currently only POST is
  714. # supported.
  715. method: str
  716. # The OpenAI API relative URL to be used for the request. Currently
  717. # /v1/chat/completions is supported.
  718. url: str
  719. # The parameters of the request.
  720. body: Union[ChatCompletionRequest, EmbeddingRequest]
  721. class BatchResponseData(OpenAIBaseModel):
  722. # HTTP status code of the response.
  723. status_code: int = 200
  724. # An unique identifier for the API request.
  725. request_id: str
  726. # The body of the response.
  727. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  728. class BatchRequestOutput(OpenAIBaseModel):
  729. """
  730. The per-line object of the batch output and error files
  731. """
  732. id: str
  733. # A developer-provided per-request id that will be used to match outputs to
  734. # inputs.
  735. custom_id: str
  736. response: Optional[BatchResponseData]
  737. # For requests that failed with a non-HTTP error, this will contain more
  738. # information on the cause of the failure.
  739. error: Optional[Any]
  740. class TokenizeCompletionRequest(OpenAIBaseModel):
  741. model: str
  742. prompt: str
  743. add_special_tokens: bool = Field(default=True)
  744. class TokenizeChatRequest(OpenAIBaseModel):
  745. model: str
  746. messages: List[ChatCompletionMessageParam]
  747. add_generation_prompt: bool = Field(default=True)
  748. add_special_tokens: bool = Field(default=False)
  749. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  750. class TokenizeResponse(OpenAIBaseModel):
  751. tokens: List[int]
  752. count: int
  753. max_model_len: int
  754. class DetokenizeRequest(OpenAIBaseModel):
  755. model: Optional[str]
  756. tokens: List[int]
  757. class DetokenizeResponse(OpenAIBaseModel):
  758. prompt: str
  759. # ========== KoboldAI ========== #
  760. class KAIGenerationInputSchema(BaseModel):
  761. genkey: Optional[str] = None
  762. prompt: str
  763. n: Optional[int] = 1
  764. max_context_length: int
  765. max_length: int
  766. rep_pen: Optional[float] = 1.0
  767. top_k: Optional[int] = 0
  768. top_a: Optional[float] = 0.0
  769. top_p: Optional[float] = 1.0
  770. min_p: Optional[float] = 0.0
  771. tfs: Optional[float] = 1.0
  772. eps_cutoff: Optional[float] = 0.0
  773. eta_cutoff: Optional[float] = 0.0
  774. typical: Optional[float] = 1.0
  775. temperature: Optional[float] = 1.0
  776. dynatemp_range: Optional[float] = 0.0
  777. dynatemp_exponent: Optional[float] = 1.0
  778. smoothing_factor: Optional[float] = 0.0
  779. smoothing_curve: Optional[float] = 1.0
  780. xtc_threshold: Optional[float] = 0.1
  781. xtc_probability: Optional[float] = 0.0
  782. use_default_badwordsids: Optional[bool] = None
  783. quiet: Optional[bool] = None
  784. # pylint: disable=unexpected-keyword-arg
  785. sampler_seed: Optional[int] = None
  786. stop_sequence: Optional[List[str]] = None
  787. include_stop_str_in_output: Optional[bool] = False
  788. @model_validator(mode='before')
  789. def check_context(cls, values): # pylint: disable=no-self-argument
  790. assert values.get("max_length") <= values.get(
  791. "max_context_length"
  792. ), "max_length must not be larger than max_context_length"
  793. return values