protocol.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import json
  4. import time
  5. from typing import Any, Dict, List, Literal, Optional, Union
  6. import torch
  7. from openai.types.chat import ChatCompletionContentPartParam
  8. from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
  9. model_validator)
  10. from transformers import PreTrainedTokenizer
  11. from typing_extensions import Annotated, Required, TypedDict
  12. from aphrodite.common.pooling_params import PoolingParams
  13. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  14. SamplingParams)
  15. from aphrodite.common.sequence import Logprob
  16. from aphrodite.common.utils import random_uuid
  17. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  18. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  19. class CustomChatCompletionMessageParam(TypedDict, total=False):
  20. """Enables custom roles in the Chat Completion API."""
  21. role: Required[str]
  22. """The role of the message's author."""
  23. content: Union[str, List[ChatCompletionContentPartParam]]
  24. """The contents of the message."""
  25. name: str
  26. """An optional name for the participant.
  27. Provides the model information to differentiate between participants of the
  28. same role.
  29. """
  30. tool_call_id: Optional[str]
  31. tool_calls: Optional[List[dict]]
  32. class OpenAIBaseModel(BaseModel):
  33. model_config = ConfigDict(extra="ignore")
  34. class ErrorResponse(OpenAIBaseModel):
  35. object: str = "error"
  36. message: str
  37. type: str
  38. param: Optional[str] = None
  39. code: int
  40. class ModelPermission(OpenAIBaseModel):
  41. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  42. object: str = "model_permission"
  43. created: int = Field(default_factory=lambda: int(time.time()))
  44. allow_create_engine: bool = False
  45. allow_sampling: bool = True
  46. allow_logprobs: bool = True
  47. allow_search_indices: bool = False
  48. allow_view: bool = True
  49. allow_fine_tuning: bool = False
  50. organization: str = "*"
  51. group: Optional[str] = None
  52. is_blocking: bool = False
  53. class ModelCard(OpenAIBaseModel):
  54. id: str
  55. object: str = "model"
  56. created: int = Field(default_factory=lambda: int(time.time()))
  57. owned_by: str = "pygmalionai"
  58. root: Optional[str] = None
  59. parent: Optional[str] = None
  60. max_model_len: Optional[int] = None
  61. permission: List[ModelPermission] = Field(default_factory=list)
  62. class ModelList(OpenAIBaseModel):
  63. object: str = "list"
  64. data: List[ModelCard] = Field(default_factory=list)
  65. class UsageInfo(OpenAIBaseModel):
  66. prompt_tokens: int = 0
  67. total_tokens: int = 0
  68. completion_tokens: Optional[int] = 0
  69. class JsonSchemaResponseFormat(OpenAIBaseModel):
  70. name: str
  71. description: Optional[str] = None
  72. # schema is the field in openai but that causes conflicts with pydantic so
  73. # instead use json_schema with an alias
  74. json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
  75. strict: Optional[bool] = None
  76. class ResponseFormat(OpenAIBaseModel):
  77. # type must be "json_schema", "json_object" or "text"
  78. type: Literal["text", "json_object", "json_schema"]
  79. json_schema: Optional[JsonSchemaResponseFormat] = None
  80. class StreamOptions(OpenAIBaseModel):
  81. include_usage: Optional[bool] = True
  82. continuous_usage_stats: Optional[bool] = True
  83. class FunctionDefinition(OpenAIBaseModel):
  84. name: str
  85. description: Optional[str] = None
  86. parameters: Optional[Dict[str, Any]] = None
  87. class ChatCompletionToolsParam(OpenAIBaseModel):
  88. type: Literal["function"] = "function"
  89. function: FunctionDefinition
  90. class ChatCompletionNamedFunction(OpenAIBaseModel):
  91. name: str
  92. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  93. function: ChatCompletionNamedFunction
  94. type: Literal["function"] = "function"
  95. class ChatCompletionRequest(OpenAIBaseModel):
  96. # Ordered by official OpenAI API documentation
  97. # https://platform.openai.com/docs/api-reference/chat/create
  98. messages: List[ChatCompletionMessageParam]
  99. model: str
  100. frequency_penalty: Optional[float] = 0.0
  101. logit_bias: Optional[Dict[str, float]] = None
  102. logprobs: Optional[bool] = False
  103. top_logprobs: Optional[int] = 0
  104. max_tokens: Optional[int] = None
  105. n: Optional[int] = 1
  106. presence_penalty: Optional[float] = 0.0
  107. response_format: Optional[ResponseFormat] = None
  108. seed: Optional[int] = Field(None,
  109. ge=torch.iinfo(torch.long).min,
  110. le=torch.iinfo(torch.long).max)
  111. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  112. stream: Optional[bool] = False
  113. stream_options: Optional[StreamOptions] = None
  114. temperature: Optional[float] = 0.7
  115. top_p: Optional[float] = 1.0
  116. tools: Optional[List[ChatCompletionToolsParam]] = None
  117. tool_choice: Optional[Union[Literal["none"], Literal["auto"],
  118. ChatCompletionNamedToolChoiceParam]] = "none"
  119. # NOTE this will be ignored by Aphrodite - the model determines the behavior
  120. parallel_tool_calls: Optional[bool] = False
  121. user: Optional[str] = None
  122. # doc: begin-chat-completion-sampling-params
  123. best_of: Optional[int] = None
  124. use_beam_search: Optional[bool] = False
  125. top_k: Optional[int] = -1
  126. min_p: Optional[float] = 0.0
  127. top_a: Optional[float] = 0.0
  128. tfs: Optional[float] = 1.0
  129. eta_cutoff: Optional[float] = 0.0
  130. epsilon_cutoff: Optional[float] = 0.0
  131. typical_p: Optional[float] = 1.0
  132. smoothing_factor: Optional[float] = 0.0
  133. smoothing_curve: Optional[float] = 1.0
  134. repetition_penalty: Optional[float] = 1.0
  135. no_repeat_ngram_size: Optional[int] = 0
  136. length_penalty: Optional[float] = 1.0
  137. early_stopping: Optional[bool] = False
  138. ignore_eos: Optional[bool] = False
  139. min_tokens: Optional[int] = 0
  140. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  141. skip_special_tokens: Optional[bool] = True
  142. spaces_between_special_tokens: Optional[bool] = True
  143. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  144. temperature_last: Optional[bool] = False
  145. prompt_logprobs: Optional[int] = None
  146. xtc_threshold: Optional[float] = 0.1
  147. xtc_probability: Optional[float] = 0.0
  148. dry_multiplier: Optional[float] = 0
  149. dry_base: Optional[float] = 1.75
  150. dry_allowed_length: Optional[int] = 2
  151. dry_sequence_breakers: Optional[List[str]] = Field(
  152. default=["\n", ":", "\"", "*"])
  153. dry_range: Optional[int] = Field(
  154. default=0,
  155. validation_alias=AliasChoices("dry_range",
  156. "dry_penalty_last_n"))
  157. dynatemp_min: Optional[float] = 0.0
  158. dynatemp_max: Optional[float] = 0.0
  159. dynatemp_exponent: Optional[float] = 1.0
  160. nsigma: Optional[float] = 0.0
  161. skew: Optional[float] = 0.0
  162. custom_token_bans: Optional[List[int]] = None
  163. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  164. default=[],
  165. validation_alias=AliasChoices("sampler_priority",
  166. "sampler_order"))
  167. # doc: end-chat-completion-sampling-params
  168. # doc: begin-chat-completion-extra-params
  169. echo: Optional[bool] = Field(
  170. default=False,
  171. description=(
  172. "If true, the new message will be prepended with the last message "
  173. "if they belong to the same role."),
  174. )
  175. add_generation_prompt: Optional[bool] = Field(
  176. default=True,
  177. description=
  178. ("If true, the generation prompt will be added to the chat template. "
  179. "This is a parameter used by chat template in tokenizer config of the "
  180. "model."),
  181. )
  182. add_special_tokens: Optional[bool] = Field(
  183. default=False,
  184. description=(
  185. "If true, special tokens (e.g. BOS) will be added to the prompt "
  186. "on top of what is added by the chat template. "
  187. "For most models, the chat template takes care of adding the "
  188. "special tokens so this should be set to False (as is the "
  189. "default)."),
  190. )
  191. documents: Optional[List[Dict[str, str]]] = Field(
  192. default=None,
  193. description=
  194. ("A list of dicts representing documents that will be accessible to "
  195. "the model if it is performing RAG (retrieval-augmented generation)."
  196. " If the template does not support RAG, this argument will have no "
  197. "effect. We recommend that each document should be a dict containing "
  198. "\"title\" and \"text\" keys."),
  199. )
  200. chat_template: Optional[str] = Field(
  201. default=None,
  202. description=(
  203. "A Jinja template to use for this conversion. "
  204. "As of transformers v4.44, default chat template is no longer "
  205. "allowed, so you must provide a chat template if the tokenizer "
  206. "does not define one."),
  207. )
  208. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  209. default=None,
  210. description=("Additional kwargs to pass to the template renderer. "
  211. "Will be accessible by the chat template."),
  212. )
  213. include_stop_str_in_output: Optional[bool] = Field(
  214. default=False,
  215. description=(
  216. "Whether to include the stop string in the output. "
  217. "This is only applied when the stop or stop_token_ids is set."),
  218. )
  219. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  220. default=None,
  221. description=("If specified, the output will follow the JSON schema."),
  222. )
  223. guided_regex: Optional[str] = Field(
  224. default=None,
  225. description=(
  226. "If specified, the output will follow the regex pattern."),
  227. )
  228. guided_choice: Optional[List[str]] = Field(
  229. default=None,
  230. description=(
  231. "If specified, the output will be exactly one of the choices."),
  232. )
  233. guided_grammar: Optional[str] = Field(
  234. default=None,
  235. description=(
  236. "If specified, the output will follow the context free grammar."),
  237. )
  238. guided_decoding_backend: Optional[str] = Field(
  239. default=None,
  240. description=(
  241. "If specified, will override the default guided decoding backend "
  242. "of the server for this specific request. If set, must be either "
  243. "'outlines' / 'lm-format-enforcer'"))
  244. guided_whitespace_pattern: Optional[str] = Field(
  245. default=None,
  246. description=(
  247. "If specified, will override the default whitespace pattern "
  248. "for guided json decoding."))
  249. # doc: end-chat-completion-extra-params
  250. def to_sampling_params(
  251. self, tokenizer: PreTrainedTokenizer,
  252. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  253. default_max_tokens: int) -> SamplingParams:
  254. max_tokens = self.max_tokens
  255. if max_tokens is None:
  256. max_tokens = default_max_tokens
  257. # We now allow logprobs being true without top_logrobs.
  258. logits_processors = get_logits_processors(
  259. logit_bias=self.logit_bias,
  260. allowed_token_ids=None,
  261. tokenizer=tokenizer,
  262. )
  263. if guided_decode_logits_processor:
  264. logits_processors.append(guided_decode_logits_processor)
  265. dry_sequence_breaker_ids = []
  266. if self.dry_sequence_breakers:
  267. for s in self.dry_sequence_breakers:
  268. token_id = tokenizer.encode(f'a{s}')[-1]
  269. dry_sequence_breaker_ids.append(token_id)
  270. return SamplingParams(
  271. n=self.n,
  272. presence_penalty=self.presence_penalty,
  273. frequency_penalty=self.frequency_penalty,
  274. repetition_penalty=self.repetition_penalty,
  275. no_repeat_ngram_size=self.no_repeat_ngram_size,
  276. temperature=self.temperature,
  277. top_p=self.top_p,
  278. min_p=self.min_p,
  279. seed=self.seed,
  280. stop=self.stop,
  281. stop_token_ids=self.stop_token_ids,
  282. max_tokens=max_tokens,
  283. min_tokens=self.min_tokens,
  284. logprobs=self.top_logprobs if self.logprobs else None,
  285. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  286. (self.top_logprobs if self.echo else None),
  287. best_of=self.best_of,
  288. top_k=self.top_k,
  289. top_a=self.top_a,
  290. tfs=self.tfs,
  291. eta_cutoff=self.eta_cutoff,
  292. epsilon_cutoff=self.epsilon_cutoff,
  293. typical_p=self.typical_p,
  294. smoothing_factor=self.smoothing_factor,
  295. smoothing_curve=self.smoothing_curve,
  296. ignore_eos=self.ignore_eos,
  297. use_beam_search=self.use_beam_search,
  298. early_stopping=self.early_stopping,
  299. skip_special_tokens=self.skip_special_tokens,
  300. spaces_between_special_tokens=self.spaces_between_special_tokens,
  301. include_stop_str_in_output=self.include_stop_str_in_output,
  302. length_penalty=self.length_penalty,
  303. logits_processors=logits_processors,
  304. temperature_last=self.temperature_last,
  305. xtc_threshold=self.xtc_threshold,
  306. xtc_probability=self.xtc_probability,
  307. dry_multiplier=self.dry_multiplier,
  308. dry_base=self.dry_base,
  309. dry_allowed_length=self.dry_allowed_length,
  310. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  311. dry_range=self.dry_range,
  312. dynatemp_min=self.dynatemp_min,
  313. dynatemp_max=self.dynatemp_max,
  314. dynatemp_exponent=self.dynatemp_exponent,
  315. nsigma=self.nsigma,
  316. skew=self.skew,
  317. custom_token_bans=self.custom_token_bans,
  318. sampler_priority=self.sampler_priority,
  319. )
  320. @model_validator(mode='before')
  321. @classmethod
  322. def validate_stream_options(cls, values):
  323. if (values.get('stream_options') is not None
  324. and not values.get('stream')):
  325. raise ValueError(
  326. "stream_options can only be set if stream is true")
  327. return values
  328. @model_validator(mode="before")
  329. @classmethod
  330. def check_guided_decoding_count(cls, data):
  331. if isinstance(data, ValueError):
  332. raise data
  333. guide_count = sum([
  334. "guided_json" in data and data["guided_json"] is not None,
  335. "guided_regex" in data and data["guided_regex"] is not None,
  336. "guided_choice" in data and data["guided_choice"] is not None
  337. ])
  338. # you can only use one kind of guided decoding
  339. if guide_count > 1:
  340. raise ValueError(
  341. "You can only use one kind of guided decoding "
  342. "('guided_json', 'guided_regex' or 'guided_choice').")
  343. # you can only either use guided decoding or tools, not both
  344. if guide_count > 1 and data.get("tool_choice",
  345. "none") not in ("none", "auto"):
  346. raise ValueError(
  347. "You can only either use guided decoding or tools, not both.")
  348. return data
  349. @model_validator(mode="before")
  350. @classmethod
  351. def check_tool_usage(cls, data):
  352. # if "tool_choice" is not specified but tools are provided,
  353. # default to "auto" tool_choice
  354. if "tool_choice" not in data and "tools" in data:
  355. data["tool_choice"] = "auto"
  356. # if "tool_choice" is specified -- validation
  357. if "tool_choice" in data:
  358. # ensure that if "tool choice" is specified, tools are present
  359. if "tools" not in data or data["tools"] is None:
  360. raise ValueError(
  361. "When using `tool_choice`, `tools` must be set.")
  362. # make sure that tool choice is either a named tool
  363. # OR that it's set to "auto"
  364. if data["tool_choice"] != "auto" and not isinstance(
  365. data["tool_choice"], dict):
  366. raise ValueError(
  367. "`tool_choice` must either be a named tool or \"auto\". "
  368. "`tool_choice=\"none\" is not supported.")
  369. # ensure that if "tool_choice" is specified as an object,
  370. # it matches a valid tool
  371. if isinstance(data["tool_choice"], dict):
  372. valid_tool = False
  373. specified_function = data["tool_choice"]["function"]
  374. if not specified_function:
  375. raise ValueError(
  376. "Incorrectly formatted `tool_choice`. Should be like "
  377. "`{\"type\": \"function\","
  378. " \"function\": {\"name\": \"my_function\"}}`")
  379. specified_function_name = specified_function["name"]
  380. if not specified_function_name:
  381. raise ValueError(
  382. "Incorrectly formatted `tool_choice`. Should be like "
  383. "`{\"type\": \"function\", "
  384. "\"function\": {\"name\": \"my_function\"}}`")
  385. for tool in data["tools"]:
  386. if tool["function"]["name"] == specified_function_name:
  387. valid_tool = True
  388. break
  389. if not valid_tool:
  390. raise ValueError(
  391. "The tool specified in `tool_choice` does not match any"
  392. " of the specified `tools`")
  393. return data
  394. @model_validator(mode="before")
  395. @classmethod
  396. def check_logprobs(cls, data):
  397. if "top_logprobs" in data and data["top_logprobs"] is not None:
  398. if "logprobs" not in data or data["logprobs"] is False:
  399. raise ValueError(
  400. "when using `top_logprobs`, `logprobs` must be set to true."
  401. )
  402. elif data["top_logprobs"] < 0:
  403. raise ValueError(
  404. "`top_logprobs` must be a value a positive value.")
  405. return data
  406. class CompletionRequest(OpenAIBaseModel):
  407. # Ordered by official OpenAI API documentation
  408. # https://platform.openai.com/docs/api-reference/completions/create
  409. model: str
  410. prompt: Union[List[int], List[List[int]], str, List[str]]
  411. best_of: Optional[int] = None
  412. echo: Optional[bool] = False
  413. frequency_penalty: Optional[float] = 0.0
  414. logit_bias: Optional[Dict[str, float]] = None
  415. logprobs: Optional[int] = None
  416. max_tokens: Optional[int] = 16
  417. n: int = 1
  418. presence_penalty: Optional[float] = 0.0
  419. seed: Optional[int] = Field(None,
  420. ge=torch.iinfo(torch.long).min,
  421. le=torch.iinfo(torch.long).max)
  422. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  423. stream: Optional[bool] = False
  424. stream_options: Optional[StreamOptions] = None
  425. suffix: Optional[str] = None
  426. temperature: Optional[float] = 1.0
  427. top_p: Optional[float] = 1.0
  428. user: Optional[str] = None
  429. # doc: begin-completion-sampling-params
  430. use_beam_search: Optional[bool] = False
  431. top_k: Optional[int] = -1
  432. min_p: Optional[float] = 0.0
  433. top_a: Optional[float] = 0.0
  434. tfs: Optional[float] = 1.0
  435. eta_cutoff: Optional[float] = 0.0
  436. epsilon_cutoff: Optional[float] = 0.0
  437. typical_p: Optional[float] = 1.0
  438. smoothing_factor: Optional[float] = 0.0
  439. smoothing_curve: Optional[float] = 1.0
  440. repetition_penalty: Optional[float] = 1.0
  441. no_repeat_ngram_size: Optional[int] = 0
  442. length_penalty: Optional[float] = 1.0
  443. early_stopping: Optional[bool] = False
  444. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  445. ignore_eos: Optional[bool] = False
  446. min_tokens: Optional[int] = 0
  447. skip_special_tokens: Optional[bool] = True
  448. spaces_between_special_tokens: Optional[bool] = True
  449. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  450. allowed_token_ids: Optional[List[int]] = None
  451. include_stop_str_in_output: Optional[bool] = False
  452. add_special_tokens: Optional[bool] = False
  453. temperature_last: Optional[bool] = False
  454. prompt_logprobs: Optional[int] = None
  455. xtc_threshold: Optional[float] = 0.1
  456. xtc_probability: Optional[float] = 0.0
  457. dry_multiplier: Optional[float] = 0
  458. dry_base: Optional[float] = 1.75
  459. dry_allowed_length: Optional[int] = 2
  460. dry_sequence_breakers: Optional[List[str]] = Field(
  461. default=["\n", ":", "\"", "*"])
  462. dry_range: Optional[int] = Field(
  463. default=0,
  464. validation_alias=AliasChoices("dry_range",
  465. "dry_penalty_last_n"))
  466. dynatemp_min: Optional[float] = 0.0
  467. dynatemp_max: Optional[float] = 0.0
  468. dynatemp_exponent: Optional[float] = 1.0
  469. nsigma: Optional[float] = 0.0
  470. skew: Optional[float] = 0.0
  471. custom_token_bans: Optional[List[int]] = None
  472. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  473. default=[],
  474. validation_alias=AliasChoices("sampler_priority",
  475. "sampler_order"))
  476. # doc: end-completion-sampling-params
  477. # doc: begin-completion-extra-params
  478. response_format: Optional[ResponseFormat] = Field(
  479. default=None,
  480. description=
  481. ("Similar to chat completion, this parameter specifies the format of "
  482. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  483. "supported."),
  484. )
  485. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  486. default=None,
  487. description=("If specified, the output will follow the JSON schema."),
  488. )
  489. guided_regex: Optional[str] = Field(
  490. default=None,
  491. description=(
  492. "If specified, the output will follow the regex pattern."),
  493. )
  494. guided_choice: Optional[List[str]] = Field(
  495. default=None,
  496. description=(
  497. "If specified, the output will be exactly one of the choices."),
  498. )
  499. guided_grammar: Optional[str] = Field(
  500. default=None,
  501. description=(
  502. "If specified, the output will follow the context free grammar."),
  503. )
  504. guided_decoding_backend: Optional[str] = Field(
  505. default=None,
  506. description=(
  507. "If specified, will override the default guided decoding backend "
  508. "of the server for this specific request. If set, must be one of "
  509. "'outlines' / 'lm-format-enforcer'"))
  510. guided_whitespace_pattern: Optional[str] = Field(
  511. default=None,
  512. description=(
  513. "If specified, will override the default whitespace pattern "
  514. "for guided json decoding."))
  515. # doc: end-completion-extra-params
  516. def to_sampling_params(
  517. self, tokenizer: PreTrainedTokenizer,
  518. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  519. default_max_tokens: int) -> SamplingParams:
  520. max_tokens = self.max_tokens
  521. if max_tokens is None:
  522. max_tokens = default_max_tokens
  523. echo_without_generation = self.echo and self.max_tokens == 0
  524. logits_processors = get_logits_processors(
  525. logit_bias=self.logit_bias,
  526. allowed_token_ids=self.allowed_token_ids,
  527. tokenizer=tokenizer,
  528. )
  529. if guided_decode_logits_processor:
  530. logits_processors.append(guided_decode_logits_processor)
  531. dry_sequence_breaker_ids = []
  532. if self.dry_sequence_breakers:
  533. for s in self.dry_sequence_breakers:
  534. s = bytes(s, "utf-8").decode("unicode_escape")
  535. token_id = tokenizer.encode(f'a{s}')[-1]
  536. dry_sequence_breaker_ids.append(token_id)
  537. return SamplingParams(
  538. n=self.n,
  539. best_of=self.best_of,
  540. presence_penalty=self.presence_penalty,
  541. frequency_penalty=self.frequency_penalty,
  542. repetition_penalty=self.repetition_penalty,
  543. no_repeat_ngram_size=self.no_repeat_ngram_size,
  544. temperature=self.temperature,
  545. top_p=self.top_p,
  546. top_k=self.top_k,
  547. min_p=self.min_p,
  548. top_a=self.top_a,
  549. tfs=self.tfs,
  550. eta_cutoff=self.eta_cutoff,
  551. epsilon_cutoff=self.epsilon_cutoff,
  552. typical_p=self.typical_p,
  553. smoothing_factor=self.smoothing_factor,
  554. smoothing_curve=self.smoothing_curve,
  555. seed=self.seed,
  556. stop=self.stop,
  557. stop_token_ids=self.stop_token_ids,
  558. ignore_eos=self.ignore_eos,
  559. max_tokens=max_tokens if not echo_without_generation else 1,
  560. min_tokens=self.min_tokens,
  561. logprobs=self.logprobs,
  562. prompt_logprobs=self.prompt_logprobs
  563. if self.prompt_logprobs else self.logprobs if self.echo else None,
  564. use_beam_search=self.use_beam_search,
  565. early_stopping=self.early_stopping,
  566. skip_special_tokens=self.skip_special_tokens,
  567. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  568. include_stop_str_in_output=self.include_stop_str_in_output,
  569. length_penalty=self.length_penalty,
  570. logits_processors=logits_processors,
  571. truncate_prompt_tokens=self.truncate_prompt_tokens,
  572. temperature_last=self.temperature_last,
  573. xtc_threshold=self.xtc_threshold,
  574. xtc_probability=self.xtc_probability,
  575. dry_multiplier=self.dry_multiplier,
  576. dry_base=self.dry_base,
  577. dry_allowed_length=self.dry_allowed_length,
  578. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  579. dry_range=self.dry_range,
  580. dynatemp_min=self.dynatemp_min,
  581. dynatemp_max=self.dynatemp_max,
  582. dynatemp_exponent=self.dynatemp_exponent,
  583. nsigma=self.nsigma,
  584. skew=self.skew,
  585. custom_token_bans=self.custom_token_bans,
  586. sampler_priority=self.sampler_priority,
  587. )
  588. @model_validator(mode="before")
  589. @classmethod
  590. def check_guided_decoding_count(cls, data):
  591. guide_count = sum([
  592. "guided_json" in data and data["guided_json"] is not None,
  593. "guided_regex" in data and data["guided_regex"] is not None,
  594. "guided_choice" in data and data["guided_choice"] is not None
  595. ])
  596. if guide_count > 1:
  597. raise ValueError(
  598. "You can only use one kind of guided decoding "
  599. "('guided_json', 'guided_regex' or 'guided_choice').")
  600. return data
  601. @model_validator(mode="before")
  602. @classmethod
  603. def check_logprobs(cls, data):
  604. if "logprobs" in data and data[
  605. "logprobs"] is not None and not data["logprobs"] >= 0:
  606. raise ValueError("if passed, `logprobs` must be a positive value.")
  607. return data
  608. @model_validator(mode="before")
  609. @classmethod
  610. def validate_stream_options(cls, data):
  611. if data.get("stream_options") and not data.get("stream"):
  612. raise ValueError(
  613. "Stream options can only be defined when stream is True.")
  614. return data
  615. @model_validator(mode='before')
  616. @classmethod
  617. def parse_dry_sequence_breakers(cls, data):
  618. if 'dry_sequence_breakers' in data:
  619. breakers = data['dry_sequence_breakers']
  620. if isinstance(breakers, str):
  621. try:
  622. # Try to parse as JSON string
  623. data['dry_sequence_breakers'] = json.loads(breakers)
  624. except json.JSONDecodeError as e:
  625. raise ValueError(f"Invalid JSON for dry_sequence_breakers:"
  626. f" {e}") from e
  627. # Validate that we now have a list of strings
  628. is_list = isinstance(data['dry_sequence_breakers'], list)
  629. all_strings = all(
  630. isinstance(x, str)
  631. for x in data['dry_sequence_breakers']
  632. )
  633. if not is_list or not all_strings:
  634. raise ValueError(
  635. "dry_sequence_breakers must be a list of strings or a "
  636. "JSON string representing a list of strings"
  637. )
  638. return data
  639. class EmbeddingRequest(OpenAIBaseModel):
  640. # Ordered by official OpenAI API documentation
  641. # https://platform.openai.com/docs/api-reference/embeddings
  642. model: str
  643. input: Union[List[int], List[List[int]], str, List[str]]
  644. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  645. dimensions: Optional[int] = None
  646. user: Optional[str] = None
  647. # doc: begin-embedding-pooling-params
  648. additional_data: Optional[Any] = None
  649. # doc: end-embedding-pooling-params
  650. def to_pooling_params(self):
  651. return PoolingParams(additional_data=self.additional_data)
  652. class CompletionLogProbs(OpenAIBaseModel):
  653. text_offset: List[int] = Field(default_factory=list)
  654. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  655. tokens: List[str] = Field(default_factory=list)
  656. top_logprobs: List[Optional[Dict[str,
  657. float]]] = Field(default_factory=list)
  658. class CompletionResponseChoice(OpenAIBaseModel):
  659. index: int
  660. text: str
  661. logprobs: Optional[CompletionLogProbs] = None
  662. finish_reason: Optional[str] = None
  663. stop_reason: Optional[Union[int, str]] = Field(
  664. default=None,
  665. description=(
  666. "The stop string or token id that caused the completion "
  667. "to stop, None if the completion finished for some other reason "
  668. "including encountering the EOS token"),
  669. )
  670. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  671. class CompletionResponse(OpenAIBaseModel):
  672. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  673. object: str = "text_completion"
  674. created: int = Field(default_factory=lambda: int(time.time()))
  675. model: str
  676. choices: List[CompletionResponseChoice]
  677. usage: UsageInfo
  678. class CompletionResponseStreamChoice(OpenAIBaseModel):
  679. index: int
  680. text: str
  681. logprobs: Optional[CompletionLogProbs] = None
  682. finish_reason: Optional[str] = None
  683. stop_reason: Optional[Union[int, str]] = Field(
  684. default=None,
  685. description=(
  686. "The stop string or token id that caused the completion "
  687. "to stop, None if the completion finished for some other reason "
  688. "including encountering the EOS token"),
  689. )
  690. class CompletionStreamResponse(OpenAIBaseModel):
  691. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  692. object: str = "text_completion"
  693. created: int = Field(default_factory=lambda: int(time.time()))
  694. model: str
  695. choices: List[CompletionResponseStreamChoice]
  696. usage: Optional[UsageInfo] = Field(default=None)
  697. class EmbeddingResponseData(OpenAIBaseModel):
  698. index: int
  699. object: str = "embedding"
  700. embedding: Union[List[float], str]
  701. class EmbeddingResponse(OpenAIBaseModel):
  702. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  703. object: str = "list"
  704. created: int = Field(default_factory=lambda: int(time.time()))
  705. model: str
  706. data: List[EmbeddingResponseData]
  707. usage: UsageInfo
  708. class FunctionCall(OpenAIBaseModel):
  709. name: str
  710. arguments: str
  711. class ToolCall(OpenAIBaseModel):
  712. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  713. type: Literal["function"] = "function"
  714. function: FunctionCall
  715. class DeltaFunctionCall(BaseModel):
  716. name: Optional[str] = None
  717. arguments: Optional[str] = None
  718. # a tool call delta where everything is optional
  719. class DeltaToolCall(OpenAIBaseModel):
  720. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  721. type: Literal["function"] = "function"
  722. index: int
  723. function: Optional[DeltaFunctionCall] = None
  724. class ExtractedToolCallInformation(BaseModel):
  725. # indicate if tools were called
  726. tools_called: bool
  727. # extracted tool calls
  728. tool_calls: List[ToolCall]
  729. # content - per OpenAI spec, content AND tool calls can be returned rarely
  730. # But some models will do this intentionally
  731. content: Optional[str] = None
  732. class ChatMessage(OpenAIBaseModel):
  733. role: str
  734. content: Optional[str] = None
  735. tool_calls: List[ToolCall] = Field(default_factory=list)
  736. class ChatCompletionLogProb(OpenAIBaseModel):
  737. token: str
  738. logprob: float = -9999.0
  739. bytes: Optional[List[int]] = None
  740. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  741. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  742. class ChatCompletionLogProbs(OpenAIBaseModel):
  743. content: Optional[List[ChatCompletionLogProbsContent]] = None
  744. class ChatCompletionResponseChoice(OpenAIBaseModel):
  745. index: int
  746. message: ChatMessage
  747. logprobs: Optional[ChatCompletionLogProbs] = None
  748. # per OpenAI spec this is the default
  749. finish_reason: Optional[str] = "stop"
  750. # not part of the OpenAI spec but included in Aphrodite for legacy reasons
  751. stop_reason: Optional[Union[int, str]] = None
  752. class ChatCompletionResponse(OpenAIBaseModel):
  753. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  754. object: Literal["chat.completion"] = "chat.completion"
  755. created: int = Field(default_factory=lambda: int(time.time()))
  756. model: str
  757. choices: List[ChatCompletionResponseChoice]
  758. usage: UsageInfo
  759. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  760. class DeltaMessage(OpenAIBaseModel):
  761. role: Optional[str] = None
  762. content: Optional[str] = None
  763. tool_calls: List[DeltaToolCall] = Field(default_factory=list)
  764. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  765. index: int
  766. delta: DeltaMessage
  767. logprobs: Optional[ChatCompletionLogProbs] = None
  768. finish_reason: Optional[str] = None
  769. stop_reason: Optional[Union[int, str]] = None
  770. class ChatCompletionStreamResponse(OpenAIBaseModel):
  771. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  772. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  773. created: int = Field(default_factory=lambda: int(time.time()))
  774. model: str
  775. choices: List[ChatCompletionResponseStreamChoice]
  776. usage: Optional[UsageInfo] = Field(default=None)
  777. class BatchRequestInput(OpenAIBaseModel):
  778. """
  779. The per-line object of the batch input file.
  780. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  781. """
  782. # A developer-provided per-request id that will be used to match outputs to
  783. # inputs. Must be unique for each request in a batch.
  784. custom_id: str
  785. # The HTTP method to be used for the request. Currently only POST is
  786. # supported.
  787. method: str
  788. # The OpenAI API relative URL to be used for the request. Currently
  789. # /v1/chat/completions is supported.
  790. url: str
  791. # The parameters of the request.
  792. body: Union[ChatCompletionRequest, EmbeddingRequest]
  793. class BatchResponseData(OpenAIBaseModel):
  794. # HTTP status code of the response.
  795. status_code: int = 200
  796. # An unique identifier for the API request.
  797. request_id: str
  798. # The body of the response.
  799. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  800. class BatchRequestOutput(OpenAIBaseModel):
  801. """
  802. The per-line object of the batch output and error files
  803. """
  804. id: str
  805. # A developer-provided per-request id that will be used to match outputs to
  806. # inputs.
  807. custom_id: str
  808. response: Optional[BatchResponseData]
  809. # For requests that failed with a non-HTTP error, this will contain more
  810. # information on the cause of the failure.
  811. error: Optional[Any]
  812. class TokenizeCompletionRequest(OpenAIBaseModel):
  813. model: str
  814. prompt: str
  815. add_special_tokens: bool = Field(default=True)
  816. class TokenizeChatRequest(OpenAIBaseModel):
  817. model: str
  818. messages: List[ChatCompletionMessageParam]
  819. add_generation_prompt: bool = Field(default=True)
  820. add_special_tokens: bool = Field(default=False)
  821. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  822. class TokenizeResponse(OpenAIBaseModel):
  823. tokens: List[int]
  824. count: int
  825. max_model_len: int
  826. class DetokenizeRequest(OpenAIBaseModel):
  827. model: Optional[str]
  828. tokens: List[int]
  829. class DetokenizeResponse(OpenAIBaseModel):
  830. prompt: str
  831. # ========== KoboldAI ========== #
  832. class KAIGenerationInputSchema(BaseModel):
  833. genkey: Optional[str] = None
  834. prompt: str
  835. n: Optional[int] = 1
  836. max_context_length: int
  837. max_length: int
  838. rep_pen: Optional[float] = 1.0
  839. top_k: Optional[int] = 0
  840. top_a: Optional[float] = 0.0
  841. top_p: Optional[float] = 1.0
  842. min_p: Optional[float] = 0.0
  843. tfs: Optional[float] = 1.0
  844. eps_cutoff: Optional[float] = 0.0
  845. eta_cutoff: Optional[float] = 0.0
  846. typical: Optional[float] = 1.0
  847. temperature: Optional[float] = 1.0
  848. dynatemp_range: Optional[float] = 0.0
  849. dynatemp_exponent: Optional[float] = 1.0
  850. smoothing_factor: Optional[float] = 0.0
  851. smoothing_curve: Optional[float] = 1.0
  852. xtc_threshold: Optional[float] = 0.1
  853. xtc_probability: Optional[float] = 0.0
  854. use_default_badwordsids: Optional[bool] = None
  855. quiet: Optional[bool] = None
  856. # pylint: disable=unexpected-keyword-arg
  857. sampler_seed: Optional[int] = None
  858. stop_sequence: Optional[List[str]] = None
  859. include_stop_str_in_output: Optional[bool] = False
  860. @model_validator(mode='before')
  861. def check_context(cls, values): # pylint: disable=no-self-argument
  862. assert values.get("max_length") <= values.get(
  863. "max_context_length"
  864. ), "max_length must not be larger than max_context_length"
  865. return values