protocol.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import json
  4. import time
  5. from typing import Any, Dict, List, Literal, Optional, Union
  6. import torch
  7. from openai.types.chat import ChatCompletionContentPartParam
  8. from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
  9. model_validator)
  10. from transformers import PreTrainedTokenizer
  11. from typing_extensions import Annotated, Required, TypedDict
  12. from aphrodite.common.pooling_params import PoolingParams
  13. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  14. RequestOutputKind,
  15. SamplingParams)
  16. from aphrodite.common.sequence import Logprob
  17. from aphrodite.common.utils import random_uuid
  18. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  19. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  20. class CustomChatCompletionMessageParam(TypedDict, total=False):
  21. """Enables custom roles in the Chat Completion API."""
  22. role: Required[str]
  23. """The role of the message's author."""
  24. content: Union[str, List[ChatCompletionContentPartParam]]
  25. """The contents of the message."""
  26. name: str
  27. """An optional name for the participant.
  28. Provides the model information to differentiate between participants of the
  29. same role.
  30. """
  31. tool_call_id: Optional[str]
  32. tool_calls: Optional[List[dict]]
  33. class OpenAIBaseModel(BaseModel):
  34. model_config = ConfigDict(extra="ignore")
  35. class ErrorResponse(OpenAIBaseModel):
  36. object: str = "error"
  37. message: str
  38. type: str
  39. param: Optional[str] = None
  40. code: int
  41. class ModelPermission(OpenAIBaseModel):
  42. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  43. object: str = "model_permission"
  44. created: int = Field(default_factory=lambda: int(time.time()))
  45. allow_create_engine: bool = False
  46. allow_sampling: bool = True
  47. allow_logprobs: bool = True
  48. allow_search_indices: bool = False
  49. allow_view: bool = True
  50. allow_fine_tuning: bool = False
  51. organization: str = "*"
  52. group: Optional[str] = None
  53. is_blocking: bool = False
  54. class ModelCard(OpenAIBaseModel):
  55. id: str
  56. object: str = "model"
  57. created: int = Field(default_factory=lambda: int(time.time()))
  58. owned_by: str = "pygmalionai"
  59. root: Optional[str] = None
  60. parent: Optional[str] = None
  61. max_model_len: Optional[int] = None
  62. permission: List[ModelPermission] = Field(default_factory=list)
  63. class ModelList(OpenAIBaseModel):
  64. object: str = "list"
  65. data: List[ModelCard] = Field(default_factory=list)
  66. class UsageInfo(OpenAIBaseModel):
  67. prompt_tokens: int = 0
  68. total_tokens: int = 0
  69. completion_tokens: Optional[int] = 0
  70. class JsonSchemaResponseFormat(OpenAIBaseModel):
  71. name: str
  72. description: Optional[str] = None
  73. # schema is the field in openai but that causes conflicts with pydantic so
  74. # instead use json_schema with an alias
  75. json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
  76. strict: Optional[bool] = None
  77. class ResponseFormat(OpenAIBaseModel):
  78. # type must be "json_schema", "json_object" or "text"
  79. type: Literal["text", "json_object", "json_schema"]
  80. json_schema: Optional[JsonSchemaResponseFormat] = None
  81. class StreamOptions(OpenAIBaseModel):
  82. include_usage: Optional[bool] = True
  83. continuous_usage_stats: Optional[bool] = True
  84. class FunctionDefinition(OpenAIBaseModel):
  85. name: str
  86. description: Optional[str] = None
  87. parameters: Optional[Dict[str, Any]] = None
  88. class ChatCompletionToolsParam(OpenAIBaseModel):
  89. type: Literal["function"] = "function"
  90. function: FunctionDefinition
  91. class ChatCompletionNamedFunction(OpenAIBaseModel):
  92. name: str
  93. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  94. function: ChatCompletionNamedFunction
  95. type: Literal["function"] = "function"
  96. class ChatCompletionRequest(OpenAIBaseModel):
  97. # Ordered by official OpenAI API documentation
  98. # https://platform.openai.com/docs/api-reference/chat/create
  99. messages: List[ChatCompletionMessageParam]
  100. model: str
  101. frequency_penalty: Optional[float] = 0.0
  102. logit_bias: Optional[Dict[str, float]] = None
  103. logprobs: Optional[bool] = False
  104. top_logprobs: Optional[int] = 0
  105. max_tokens: Optional[int] = None
  106. n: Optional[int] = 1
  107. presence_penalty: Optional[float] = 0.0
  108. response_format: Optional[ResponseFormat] = None
  109. seed: Optional[int] = Field(None,
  110. ge=torch.iinfo(torch.long).min,
  111. le=torch.iinfo(torch.long).max)
  112. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  113. stream: Optional[bool] = False
  114. stream_options: Optional[StreamOptions] = None
  115. temperature: Optional[float] = 0.7
  116. top_p: Optional[float] = 1.0
  117. tools: Optional[List[ChatCompletionToolsParam]] = None
  118. tool_choice: Optional[Union[Literal["none"], Literal["auto"],
  119. ChatCompletionNamedToolChoiceParam]] = "none"
  120. # NOTE this will be ignored by Aphrodite - the model determines the behavior
  121. parallel_tool_calls: Optional[bool] = False
  122. user: Optional[str] = None
  123. # doc: begin-chat-completion-sampling-params
  124. best_of: Optional[int] = None
  125. use_beam_search: Optional[bool] = False
  126. top_k: Optional[int] = -1
  127. min_p: Optional[float] = 0.0
  128. top_a: Optional[float] = 0.0
  129. tfs: Optional[float] = 1.0
  130. eta_cutoff: Optional[float] = 0.0
  131. epsilon_cutoff: Optional[float] = 0.0
  132. typical_p: Optional[float] = 1.0
  133. smoothing_factor: Optional[float] = 0.0
  134. smoothing_curve: Optional[float] = 1.0
  135. repetition_penalty: Optional[float] = 1.0
  136. no_repeat_ngram_size: Optional[int] = 0
  137. length_penalty: Optional[float] = 1.0
  138. early_stopping: Optional[bool] = False
  139. ignore_eos: Optional[bool] = False
  140. min_tokens: Optional[int] = 0
  141. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  142. skip_special_tokens: Optional[bool] = True
  143. spaces_between_special_tokens: Optional[bool] = True
  144. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  145. temperature_last: Optional[bool] = False
  146. prompt_logprobs: Optional[int] = None
  147. xtc_threshold: Optional[float] = 0.1
  148. xtc_probability: Optional[float] = 0.0
  149. dry_multiplier: Optional[float] = 0
  150. dry_base: Optional[float] = 1.75
  151. dry_allowed_length: Optional[int] = 2
  152. dry_sequence_breakers: Optional[List[str]] = Field(
  153. default=["\n", ":", "\"", "*"])
  154. dry_range: Optional[int] = Field(
  155. default=0,
  156. validation_alias=AliasChoices("dry_range",
  157. "dry_penalty_last_n"))
  158. dynatemp_min: Optional[float] = 0.0
  159. dynatemp_max: Optional[float] = 0.0
  160. dynatemp_exponent: Optional[float] = 1.0
  161. nsigma: Optional[float] = 0.0
  162. skew: Optional[float] = 0.0
  163. custom_token_bans: Optional[List[int]] = None
  164. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  165. default=[],
  166. validation_alias=AliasChoices("sampler_priority",
  167. "sampler_order"))
  168. # doc: end-chat-completion-sampling-params
  169. # doc: begin-chat-completion-extra-params
  170. echo: Optional[bool] = Field(
  171. default=False,
  172. description=(
  173. "If true, the new message will be prepended with the last message "
  174. "if they belong to the same role."),
  175. )
  176. add_generation_prompt: Optional[bool] = Field(
  177. default=True,
  178. description=
  179. ("If true, the generation prompt will be added to the chat template. "
  180. "This is a parameter used by chat template in tokenizer config of the "
  181. "model."),
  182. )
  183. add_special_tokens: Optional[bool] = Field(
  184. default=False,
  185. description=(
  186. "If true, special tokens (e.g. BOS) will be added to the prompt "
  187. "on top of what is added by the chat template. "
  188. "For most models, the chat template takes care of adding the "
  189. "special tokens so this should be set to False (as is the "
  190. "default)."),
  191. )
  192. documents: Optional[List[Dict[str, str]]] = Field(
  193. default=None,
  194. description=
  195. ("A list of dicts representing documents that will be accessible to "
  196. "the model if it is performing RAG (retrieval-augmented generation)."
  197. " If the template does not support RAG, this argument will have no "
  198. "effect. We recommend that each document should be a dict containing "
  199. "\"title\" and \"text\" keys."),
  200. )
  201. chat_template: Optional[str] = Field(
  202. default=None,
  203. description=(
  204. "A Jinja template to use for this conversion. "
  205. "As of transformers v4.44, default chat template is no longer "
  206. "allowed, so you must provide a chat template if the tokenizer "
  207. "does not define one."),
  208. )
  209. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  210. default=None,
  211. description=("Additional kwargs to pass to the template renderer. "
  212. "Will be accessible by the chat template."),
  213. )
  214. include_stop_str_in_output: Optional[bool] = Field(
  215. default=False,
  216. description=(
  217. "Whether to include the stop string in the output. "
  218. "This is only applied when the stop or stop_token_ids is set."),
  219. )
  220. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  221. default=None,
  222. description=("If specified, the output will follow the JSON schema."),
  223. )
  224. guided_regex: Optional[str] = Field(
  225. default=None,
  226. description=(
  227. "If specified, the output will follow the regex pattern."),
  228. )
  229. guided_choice: Optional[List[str]] = Field(
  230. default=None,
  231. description=(
  232. "If specified, the output will be exactly one of the choices."),
  233. )
  234. guided_grammar: Optional[str] = Field(
  235. default=None,
  236. description=(
  237. "If specified, the output will follow the context free grammar."),
  238. )
  239. guided_decoding_backend: Optional[str] = Field(
  240. default=None,
  241. description=(
  242. "If specified, will override the default guided decoding backend "
  243. "of the server for this specific request. If set, must be either "
  244. "'outlines' / 'lm-format-enforcer'"))
  245. guided_whitespace_pattern: Optional[str] = Field(
  246. default=None,
  247. description=(
  248. "If specified, will override the default whitespace pattern "
  249. "for guided json decoding."))
  250. # doc: end-chat-completion-extra-params
  251. def to_sampling_params(
  252. self, tokenizer: PreTrainedTokenizer,
  253. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  254. default_max_tokens: int) -> SamplingParams:
  255. max_tokens = self.max_tokens
  256. if max_tokens is None:
  257. max_tokens = default_max_tokens
  258. # We now allow logprobs being true without top_logrobs.
  259. logits_processors = get_logits_processors(
  260. logit_bias=self.logit_bias,
  261. allowed_token_ids=None,
  262. tokenizer=tokenizer,
  263. )
  264. if guided_decode_logits_processor:
  265. logits_processors.append(guided_decode_logits_processor)
  266. dry_sequence_breaker_ids = []
  267. if self.dry_sequence_breakers:
  268. for s in self.dry_sequence_breakers:
  269. token_id = tokenizer.encode(f'a{s}')[-1]
  270. dry_sequence_breaker_ids.append(token_id)
  271. return SamplingParams(
  272. n=self.n,
  273. presence_penalty=self.presence_penalty,
  274. frequency_penalty=self.frequency_penalty,
  275. repetition_penalty=self.repetition_penalty,
  276. no_repeat_ngram_size=self.no_repeat_ngram_size,
  277. temperature=self.temperature,
  278. top_p=self.top_p,
  279. min_p=self.min_p,
  280. seed=self.seed,
  281. stop=self.stop,
  282. stop_token_ids=self.stop_token_ids,
  283. max_tokens=max_tokens,
  284. min_tokens=self.min_tokens,
  285. logprobs=self.top_logprobs if self.logprobs else None,
  286. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  287. (self.top_logprobs if self.echo else None),
  288. best_of=self.best_of,
  289. top_k=self.top_k,
  290. top_a=self.top_a,
  291. tfs=self.tfs,
  292. eta_cutoff=self.eta_cutoff,
  293. epsilon_cutoff=self.epsilon_cutoff,
  294. typical_p=self.typical_p,
  295. smoothing_factor=self.smoothing_factor,
  296. smoothing_curve=self.smoothing_curve,
  297. ignore_eos=self.ignore_eos,
  298. use_beam_search=self.use_beam_search,
  299. early_stopping=self.early_stopping,
  300. skip_special_tokens=self.skip_special_tokens,
  301. spaces_between_special_tokens=self.spaces_between_special_tokens,
  302. include_stop_str_in_output=self.include_stop_str_in_output,
  303. length_penalty=self.length_penalty,
  304. logits_processors=logits_processors,
  305. temperature_last=self.temperature_last,
  306. xtc_threshold=self.xtc_threshold,
  307. xtc_probability=self.xtc_probability,
  308. dry_multiplier=self.dry_multiplier,
  309. dry_base=self.dry_base,
  310. dry_allowed_length=self.dry_allowed_length,
  311. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  312. dry_range=self.dry_range,
  313. dynatemp_min=self.dynatemp_min,
  314. dynatemp_max=self.dynatemp_max,
  315. dynatemp_exponent=self.dynatemp_exponent,
  316. nsigma=self.nsigma,
  317. skew=self.skew,
  318. custom_token_bans=self.custom_token_bans,
  319. sampler_priority=self.sampler_priority,
  320. output_kind=RequestOutputKind.DELTA if self.stream \
  321. else RequestOutputKind.FINAL_ONLY,
  322. )
  323. @model_validator(mode='before')
  324. @classmethod
  325. def validate_stream_options(cls, values):
  326. if (values.get('stream_options') is not None
  327. and not values.get('stream')):
  328. raise ValueError(
  329. "stream_options can only be set if stream is true")
  330. return values
  331. @model_validator(mode="before")
  332. @classmethod
  333. def check_guided_decoding_count(cls, data):
  334. if isinstance(data, ValueError):
  335. raise data
  336. guide_count = sum([
  337. "guided_json" in data and data["guided_json"] is not None,
  338. "guided_regex" in data and data["guided_regex"] is not None,
  339. "guided_choice" in data and data["guided_choice"] is not None
  340. ])
  341. # you can only use one kind of guided decoding
  342. if guide_count > 1:
  343. raise ValueError(
  344. "You can only use one kind of guided decoding "
  345. "('guided_json', 'guided_regex' or 'guided_choice').")
  346. # you can only either use guided decoding or tools, not both
  347. if guide_count > 1 and data.get("tool_choice",
  348. "none") not in ("none", "auto"):
  349. raise ValueError(
  350. "You can only either use guided decoding or tools, not both.")
  351. return data
  352. @model_validator(mode="before")
  353. @classmethod
  354. def check_tool_usage(cls, data):
  355. # if "tool_choice" is not specified but tools are provided,
  356. # default to "auto" tool_choice
  357. if "tool_choice" not in data and "tools" in data:
  358. data["tool_choice"] = "auto"
  359. # if "tool_choice" is specified -- validation
  360. if "tool_choice" in data:
  361. # ensure that if "tool choice" is specified, tools are present
  362. if "tools" not in data or data["tools"] is None:
  363. raise ValueError(
  364. "When using `tool_choice`, `tools` must be set.")
  365. # make sure that tool choice is either a named tool
  366. # OR that it's set to "auto"
  367. if data["tool_choice"] != "auto" and not isinstance(
  368. data["tool_choice"], dict):
  369. raise ValueError(
  370. "`tool_choice` must either be a named tool or \"auto\". "
  371. "`tool_choice=\"none\" is not supported.")
  372. # ensure that if "tool_choice" is specified as an object,
  373. # it matches a valid tool
  374. if isinstance(data["tool_choice"], dict):
  375. valid_tool = False
  376. specified_function = data["tool_choice"]["function"]
  377. if not specified_function:
  378. raise ValueError(
  379. "Incorrectly formatted `tool_choice`. Should be like "
  380. "`{\"type\": \"function\","
  381. " \"function\": {\"name\": \"my_function\"}}`")
  382. specified_function_name = specified_function["name"]
  383. if not specified_function_name:
  384. raise ValueError(
  385. "Incorrectly formatted `tool_choice`. Should be like "
  386. "`{\"type\": \"function\", "
  387. "\"function\": {\"name\": \"my_function\"}}`")
  388. for tool in data["tools"]:
  389. if tool["function"]["name"] == specified_function_name:
  390. valid_tool = True
  391. break
  392. if not valid_tool:
  393. raise ValueError(
  394. "The tool specified in `tool_choice` does not match any"
  395. " of the specified `tools`")
  396. return data
  397. @model_validator(mode="before")
  398. @classmethod
  399. def check_logprobs(cls, data):
  400. if "top_logprobs" in data and data["top_logprobs"] is not None:
  401. if "logprobs" not in data or data["logprobs"] is False:
  402. raise ValueError(
  403. "when using `top_logprobs`, `logprobs` must be set to true."
  404. )
  405. elif data["top_logprobs"] < 0:
  406. raise ValueError(
  407. "`top_logprobs` must be a value a positive value.")
  408. return data
  409. class CompletionRequest(OpenAIBaseModel):
  410. # Ordered by official OpenAI API documentation
  411. # https://platform.openai.com/docs/api-reference/completions/create
  412. model: str
  413. prompt: Union[List[int], List[List[int]], str, List[str]]
  414. best_of: Optional[int] = None
  415. echo: Optional[bool] = False
  416. frequency_penalty: Optional[float] = 0.0
  417. logit_bias: Optional[Dict[str, float]] = None
  418. logprobs: Optional[int] = None
  419. max_tokens: Optional[int] = 16
  420. n: int = 1
  421. presence_penalty: Optional[float] = 0.0
  422. seed: Optional[int] = Field(None,
  423. ge=torch.iinfo(torch.long).min,
  424. le=torch.iinfo(torch.long).max)
  425. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  426. stream: Optional[bool] = False
  427. stream_options: Optional[StreamOptions] = None
  428. suffix: Optional[str] = None
  429. temperature: Optional[float] = 1.0
  430. top_p: Optional[float] = 1.0
  431. user: Optional[str] = None
  432. # doc: begin-completion-sampling-params
  433. use_beam_search: Optional[bool] = False
  434. top_k: Optional[int] = -1
  435. min_p: Optional[float] = 0.0
  436. top_a: Optional[float] = 0.0
  437. tfs: Optional[float] = 1.0
  438. eta_cutoff: Optional[float] = 0.0
  439. epsilon_cutoff: Optional[float] = 0.0
  440. typical_p: Optional[float] = 1.0
  441. smoothing_factor: Optional[float] = 0.0
  442. smoothing_curve: Optional[float] = 1.0
  443. repetition_penalty: Optional[float] = 1.0
  444. no_repeat_ngram_size: Optional[int] = 0
  445. length_penalty: Optional[float] = 1.0
  446. early_stopping: Optional[bool] = False
  447. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  448. ignore_eos: Optional[bool] = False
  449. min_tokens: Optional[int] = 0
  450. skip_special_tokens: Optional[bool] = True
  451. spaces_between_special_tokens: Optional[bool] = True
  452. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  453. allowed_token_ids: Optional[List[int]] = None
  454. include_stop_str_in_output: Optional[bool] = False
  455. add_special_tokens: Optional[bool] = False
  456. temperature_last: Optional[bool] = False
  457. prompt_logprobs: Optional[int] = None
  458. xtc_threshold: Optional[float] = 0.1
  459. xtc_probability: Optional[float] = 0.0
  460. dry_multiplier: Optional[float] = 0
  461. dry_base: Optional[float] = 1.75
  462. dry_allowed_length: Optional[int] = 2
  463. dry_sequence_breakers: Optional[List[str]] = Field(
  464. default=["\n", ":", "\"", "*"])
  465. dry_range: Optional[int] = Field(
  466. default=0,
  467. validation_alias=AliasChoices("dry_range",
  468. "dry_penalty_last_n"))
  469. dynatemp_min: Optional[float] = 0.0
  470. dynatemp_max: Optional[float] = 0.0
  471. dynatemp_exponent: Optional[float] = 1.0
  472. nsigma: Optional[float] = 0.0
  473. skew: Optional[float] = 0.0
  474. custom_token_bans: Optional[List[int]] = None
  475. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  476. default=[],
  477. validation_alias=AliasChoices("sampler_priority",
  478. "sampler_order"))
  479. # doc: end-completion-sampling-params
  480. # doc: begin-completion-extra-params
  481. response_format: Optional[ResponseFormat] = Field(
  482. default=None,
  483. description=
  484. ("Similar to chat completion, this parameter specifies the format of "
  485. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  486. "supported."),
  487. )
  488. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  489. default=None,
  490. description=("If specified, the output will follow the JSON schema."),
  491. )
  492. guided_regex: Optional[str] = Field(
  493. default=None,
  494. description=(
  495. "If specified, the output will follow the regex pattern."),
  496. )
  497. guided_choice: Optional[List[str]] = Field(
  498. default=None,
  499. description=(
  500. "If specified, the output will be exactly one of the choices."),
  501. )
  502. guided_grammar: Optional[str] = Field(
  503. default=None,
  504. description=(
  505. "If specified, the output will follow the context free grammar."),
  506. )
  507. guided_decoding_backend: Optional[str] = Field(
  508. default=None,
  509. description=(
  510. "If specified, will override the default guided decoding backend "
  511. "of the server for this specific request. If set, must be one of "
  512. "'outlines' / 'lm-format-enforcer'"))
  513. guided_whitespace_pattern: Optional[str] = Field(
  514. default=None,
  515. description=(
  516. "If specified, will override the default whitespace pattern "
  517. "for guided json decoding."))
  518. # doc: end-completion-extra-params
  519. def to_sampling_params(
  520. self, tokenizer: PreTrainedTokenizer,
  521. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  522. default_max_tokens: int) -> SamplingParams:
  523. max_tokens = self.max_tokens
  524. if max_tokens is None:
  525. max_tokens = default_max_tokens
  526. echo_without_generation = self.echo and self.max_tokens == 0
  527. logits_processors = get_logits_processors(
  528. logit_bias=self.logit_bias,
  529. allowed_token_ids=self.allowed_token_ids,
  530. tokenizer=tokenizer,
  531. )
  532. if guided_decode_logits_processor:
  533. logits_processors.append(guided_decode_logits_processor)
  534. dry_sequence_breaker_ids = []
  535. if self.dry_sequence_breakers:
  536. for s in self.dry_sequence_breakers:
  537. s = bytes(s, "utf-8").decode("unicode_escape")
  538. token_id = tokenizer.encode(f'a{s}')[-1]
  539. dry_sequence_breaker_ids.append(token_id)
  540. return SamplingParams(
  541. n=self.n,
  542. best_of=self.best_of,
  543. presence_penalty=self.presence_penalty,
  544. frequency_penalty=self.frequency_penalty,
  545. repetition_penalty=self.repetition_penalty,
  546. no_repeat_ngram_size=self.no_repeat_ngram_size,
  547. temperature=self.temperature,
  548. top_p=self.top_p,
  549. top_k=self.top_k,
  550. min_p=self.min_p,
  551. top_a=self.top_a,
  552. tfs=self.tfs,
  553. eta_cutoff=self.eta_cutoff,
  554. epsilon_cutoff=self.epsilon_cutoff,
  555. typical_p=self.typical_p,
  556. smoothing_factor=self.smoothing_factor,
  557. smoothing_curve=self.smoothing_curve,
  558. seed=self.seed,
  559. stop=self.stop,
  560. stop_token_ids=self.stop_token_ids,
  561. ignore_eos=self.ignore_eos,
  562. max_tokens=max_tokens if not echo_without_generation else 1,
  563. min_tokens=self.min_tokens,
  564. logprobs=self.logprobs,
  565. prompt_logprobs=self.prompt_logprobs
  566. if self.prompt_logprobs else self.logprobs if self.echo else None,
  567. use_beam_search=self.use_beam_search,
  568. early_stopping=self.early_stopping,
  569. skip_special_tokens=self.skip_special_tokens,
  570. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  571. include_stop_str_in_output=self.include_stop_str_in_output,
  572. length_penalty=self.length_penalty,
  573. logits_processors=logits_processors,
  574. truncate_prompt_tokens=self.truncate_prompt_tokens,
  575. temperature_last=self.temperature_last,
  576. xtc_threshold=self.xtc_threshold,
  577. xtc_probability=self.xtc_probability,
  578. dry_multiplier=self.dry_multiplier,
  579. dry_base=self.dry_base,
  580. dry_allowed_length=self.dry_allowed_length,
  581. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  582. dry_range=self.dry_range,
  583. dynatemp_min=self.dynatemp_min,
  584. dynatemp_max=self.dynatemp_max,
  585. dynatemp_exponent=self.dynatemp_exponent,
  586. nsigma=self.nsigma,
  587. skew=self.skew,
  588. custom_token_bans=self.custom_token_bans,
  589. sampler_priority=self.sampler_priority,
  590. output_kind=RequestOutputKind.DELTA if self.stream \
  591. else RequestOutputKind.FINAL_ONLY,
  592. )
  593. @model_validator(mode="before")
  594. @classmethod
  595. def check_guided_decoding_count(cls, data):
  596. guide_count = sum([
  597. "guided_json" in data and data["guided_json"] is not None,
  598. "guided_regex" in data and data["guided_regex"] is not None,
  599. "guided_choice" in data and data["guided_choice"] is not None
  600. ])
  601. if guide_count > 1:
  602. raise ValueError(
  603. "You can only use one kind of guided decoding "
  604. "('guided_json', 'guided_regex' or 'guided_choice').")
  605. return data
  606. @model_validator(mode="before")
  607. @classmethod
  608. def check_logprobs(cls, data):
  609. if "logprobs" in data and data[
  610. "logprobs"] is not None and not data["logprobs"] >= 0:
  611. raise ValueError("if passed, `logprobs` must be a positive value.")
  612. return data
  613. @model_validator(mode="before")
  614. @classmethod
  615. def validate_stream_options(cls, data):
  616. if data.get("stream_options") and not data.get("stream"):
  617. raise ValueError(
  618. "Stream options can only be defined when stream is True.")
  619. return data
  620. @model_validator(mode='before')
  621. @classmethod
  622. def parse_dry_sequence_breakers(cls, data):
  623. if 'dry_sequence_breakers' in data:
  624. breakers = data['dry_sequence_breakers']
  625. if isinstance(breakers, str):
  626. try:
  627. # Try to parse as JSON string
  628. data['dry_sequence_breakers'] = json.loads(breakers)
  629. except json.JSONDecodeError as e:
  630. raise ValueError(f"Invalid JSON for dry_sequence_breakers:"
  631. f" {e}") from e
  632. # Validate that we now have a list of strings
  633. is_list = isinstance(data['dry_sequence_breakers'], list)
  634. all_strings = all(
  635. isinstance(x, str)
  636. for x in data['dry_sequence_breakers']
  637. )
  638. if not is_list or not all_strings:
  639. raise ValueError(
  640. "dry_sequence_breakers must be a list of strings or a "
  641. "JSON string representing a list of strings"
  642. )
  643. return data
  644. class EmbeddingRequest(OpenAIBaseModel):
  645. # Ordered by official OpenAI API documentation
  646. # https://platform.openai.com/docs/api-reference/embeddings
  647. model: str
  648. input: Union[List[int], List[List[int]], str, List[str]]
  649. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  650. dimensions: Optional[int] = None
  651. user: Optional[str] = None
  652. # doc: begin-embedding-pooling-params
  653. additional_data: Optional[Any] = None
  654. # doc: end-embedding-pooling-params
  655. def to_pooling_params(self):
  656. return PoolingParams(additional_data=self.additional_data)
  657. class CompletionLogProbs(OpenAIBaseModel):
  658. text_offset: List[int] = Field(default_factory=list)
  659. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  660. tokens: List[str] = Field(default_factory=list)
  661. top_logprobs: List[Optional[Dict[str,
  662. float]]] = Field(default_factory=list)
  663. class CompletionResponseChoice(OpenAIBaseModel):
  664. index: int
  665. text: str
  666. logprobs: Optional[CompletionLogProbs] = None
  667. finish_reason: Optional[str] = None
  668. stop_reason: Optional[Union[int, str]] = Field(
  669. default=None,
  670. description=(
  671. "The stop string or token id that caused the completion "
  672. "to stop, None if the completion finished for some other reason "
  673. "including encountering the EOS token"),
  674. )
  675. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  676. class CompletionResponse(OpenAIBaseModel):
  677. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  678. object: str = "text_completion"
  679. created: int = Field(default_factory=lambda: int(time.time()))
  680. model: str
  681. choices: List[CompletionResponseChoice]
  682. usage: UsageInfo
  683. class CompletionResponseStreamChoice(OpenAIBaseModel):
  684. index: int
  685. text: str
  686. logprobs: Optional[CompletionLogProbs] = None
  687. finish_reason: Optional[str] = None
  688. stop_reason: Optional[Union[int, str]] = Field(
  689. default=None,
  690. description=(
  691. "The stop string or token id that caused the completion "
  692. "to stop, None if the completion finished for some other reason "
  693. "including encountering the EOS token"),
  694. )
  695. class CompletionStreamResponse(OpenAIBaseModel):
  696. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  697. object: str = "text_completion"
  698. created: int = Field(default_factory=lambda: int(time.time()))
  699. model: str
  700. choices: List[CompletionResponseStreamChoice]
  701. usage: Optional[UsageInfo] = Field(default=None)
  702. class EmbeddingResponseData(OpenAIBaseModel):
  703. index: int
  704. object: str = "embedding"
  705. embedding: Union[List[float], str]
  706. class EmbeddingResponse(OpenAIBaseModel):
  707. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  708. object: str = "list"
  709. created: int = Field(default_factory=lambda: int(time.time()))
  710. model: str
  711. data: List[EmbeddingResponseData]
  712. usage: UsageInfo
  713. class FunctionCall(OpenAIBaseModel):
  714. name: str
  715. arguments: str
  716. class ToolCall(OpenAIBaseModel):
  717. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  718. type: Literal["function"] = "function"
  719. function: FunctionCall
  720. class DeltaFunctionCall(BaseModel):
  721. name: Optional[str] = None
  722. arguments: Optional[str] = None
  723. # a tool call delta where everything is optional
  724. class DeltaToolCall(OpenAIBaseModel):
  725. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  726. type: Literal["function"] = "function"
  727. index: int
  728. function: Optional[DeltaFunctionCall] = None
  729. class ExtractedToolCallInformation(BaseModel):
  730. # indicate if tools were called
  731. tools_called: bool
  732. # extracted tool calls
  733. tool_calls: List[ToolCall]
  734. # content - per OpenAI spec, content AND tool calls can be returned rarely
  735. # But some models will do this intentionally
  736. content: Optional[str] = None
  737. class ChatMessage(OpenAIBaseModel):
  738. role: str
  739. content: Optional[str] = None
  740. tool_calls: List[ToolCall] = Field(default_factory=list)
  741. class ChatCompletionLogProb(OpenAIBaseModel):
  742. token: str
  743. logprob: float = -9999.0
  744. bytes: Optional[List[int]] = None
  745. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  746. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  747. class ChatCompletionLogProbs(OpenAIBaseModel):
  748. content: Optional[List[ChatCompletionLogProbsContent]] = None
  749. class ChatCompletionResponseChoice(OpenAIBaseModel):
  750. index: int
  751. message: ChatMessage
  752. logprobs: Optional[ChatCompletionLogProbs] = None
  753. # per OpenAI spec this is the default
  754. finish_reason: Optional[str] = "stop"
  755. # not part of the OpenAI spec but included in Aphrodite for legacy reasons
  756. stop_reason: Optional[Union[int, str]] = None
  757. class ChatCompletionResponse(OpenAIBaseModel):
  758. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  759. object: Literal["chat.completion"] = "chat.completion"
  760. created: int = Field(default_factory=lambda: int(time.time()))
  761. model: str
  762. choices: List[ChatCompletionResponseChoice]
  763. usage: UsageInfo
  764. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  765. class DeltaMessage(OpenAIBaseModel):
  766. role: Optional[str] = None
  767. content: Optional[str] = None
  768. tool_calls: List[DeltaToolCall] = Field(default_factory=list)
  769. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  770. index: int
  771. delta: DeltaMessage
  772. logprobs: Optional[ChatCompletionLogProbs] = None
  773. finish_reason: Optional[str] = None
  774. stop_reason: Optional[Union[int, str]] = None
  775. class ChatCompletionStreamResponse(OpenAIBaseModel):
  776. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  777. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  778. created: int = Field(default_factory=lambda: int(time.time()))
  779. model: str
  780. choices: List[ChatCompletionResponseStreamChoice]
  781. usage: Optional[UsageInfo] = Field(default=None)
  782. class BatchRequestInput(OpenAIBaseModel):
  783. """
  784. The per-line object of the batch input file.
  785. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  786. """
  787. # A developer-provided per-request id that will be used to match outputs to
  788. # inputs. Must be unique for each request in a batch.
  789. custom_id: str
  790. # The HTTP method to be used for the request. Currently only POST is
  791. # supported.
  792. method: str
  793. # The OpenAI API relative URL to be used for the request. Currently
  794. # /v1/chat/completions is supported.
  795. url: str
  796. # The parameters of the request.
  797. body: Union[ChatCompletionRequest, EmbeddingRequest]
  798. class BatchResponseData(OpenAIBaseModel):
  799. # HTTP status code of the response.
  800. status_code: int = 200
  801. # An unique identifier for the API request.
  802. request_id: str
  803. # The body of the response.
  804. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  805. class BatchRequestOutput(OpenAIBaseModel):
  806. """
  807. The per-line object of the batch output and error files
  808. """
  809. id: str
  810. # A developer-provided per-request id that will be used to match outputs to
  811. # inputs.
  812. custom_id: str
  813. response: Optional[BatchResponseData]
  814. # For requests that failed with a non-HTTP error, this will contain more
  815. # information on the cause of the failure.
  816. error: Optional[Any]
  817. class TokenizeCompletionRequest(OpenAIBaseModel):
  818. model: str
  819. prompt: str
  820. add_special_tokens: bool = Field(default=True)
  821. class TokenizeChatRequest(OpenAIBaseModel):
  822. model: str
  823. messages: List[ChatCompletionMessageParam]
  824. add_generation_prompt: bool = Field(default=True)
  825. add_special_tokens: bool = Field(default=False)
  826. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  827. class TokenizeResponse(OpenAIBaseModel):
  828. tokens: List[int]
  829. count: int
  830. max_model_len: int
  831. class DetokenizeRequest(OpenAIBaseModel):
  832. model: Optional[str]
  833. tokens: List[int]
  834. class DetokenizeResponse(OpenAIBaseModel):
  835. prompt: str
  836. # ========== KoboldAI ========== #
  837. class KAIGenerationInputSchema(BaseModel):
  838. genkey: Optional[str] = None
  839. prompt: str
  840. n: Optional[int] = 1
  841. max_context_length: int
  842. max_length: int
  843. rep_pen: Optional[float] = 1.0
  844. top_k: Optional[int] = 0
  845. top_a: Optional[float] = 0.0
  846. top_p: Optional[float] = 1.0
  847. min_p: Optional[float] = 0.0
  848. tfs: Optional[float] = 1.0
  849. eps_cutoff: Optional[float] = 0.0
  850. eta_cutoff: Optional[float] = 0.0
  851. typical: Optional[float] = 1.0
  852. temperature: Optional[float] = 1.0
  853. dynatemp_range: Optional[float] = 0.0
  854. dynatemp_exponent: Optional[float] = 1.0
  855. smoothing_factor: Optional[float] = 0.0
  856. smoothing_curve: Optional[float] = 1.0
  857. xtc_threshold: Optional[float] = 0.1
  858. xtc_probability: Optional[float] = 0.0
  859. use_default_badwordsids: Optional[bool] = None
  860. quiet: Optional[bool] = None
  861. # pylint: disable=unexpected-keyword-arg
  862. sampler_seed: Optional[int] = None
  863. stop_sequence: Optional[List[str]] = None
  864. include_stop_str_in_output: Optional[bool] = False
  865. @model_validator(mode='before')
  866. def check_context(cls, values): # pylint: disable=no-self-argument
  867. assert values.get("max_length") <= values.get(
  868. "max_context_length"
  869. ), "max_length must not be larger than max_context_length"
  870. return values