protocol.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import json
  4. import time
  5. from typing import Any, Dict, List, Literal, Optional, Union
  6. import torch
  7. from openai.types.chat import ChatCompletionContentPartParam
  8. from pydantic import (AliasChoices, BaseModel, ConfigDict, Field,
  9. model_validator)
  10. from transformers import PreTrainedTokenizer
  11. from typing_extensions import Annotated, Required, TypedDict
  12. from aphrodite.common.pooling_params import PoolingParams
  13. from aphrodite.common.sampling_params import (LogitsProcessorFunc,
  14. RequestOutputKind,
  15. SamplingParams)
  16. from aphrodite.common.sequence import Logprob
  17. from aphrodite.common.utils import random_uuid
  18. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  19. from aphrodite.endpoints.openai.logits_processors import get_logits_processors
  20. class CustomChatCompletionMessageParam(TypedDict, total=False):
  21. """Enables custom roles in the Chat Completion API."""
  22. role: Required[str]
  23. """The role of the message's author."""
  24. content: Union[str, List[ChatCompletionContentPartParam]]
  25. """The contents of the message."""
  26. name: str
  27. """An optional name for the participant.
  28. Provides the model information to differentiate between participants of the
  29. same role.
  30. """
  31. tool_call_id: Optional[str]
  32. tool_calls: Optional[List[dict]]
  33. class OpenAIBaseModel(BaseModel):
  34. model_config = ConfigDict(extra="ignore")
  35. class ErrorResponse(OpenAIBaseModel):
  36. object: str = "error"
  37. message: str
  38. type: str
  39. param: Optional[str] = None
  40. code: int
  41. class ModelPermission(OpenAIBaseModel):
  42. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  43. object: str = "model_permission"
  44. created: int = Field(default_factory=lambda: int(time.time()))
  45. allow_create_engine: bool = False
  46. allow_sampling: bool = True
  47. allow_logprobs: bool = True
  48. allow_search_indices: bool = False
  49. allow_view: bool = True
  50. allow_fine_tuning: bool = False
  51. organization: str = "*"
  52. group: Optional[str] = None
  53. is_blocking: bool = False
  54. class ModelCard(OpenAIBaseModel):
  55. id: str
  56. object: str = "model"
  57. created: int = Field(default_factory=lambda: int(time.time()))
  58. owned_by: str = "pygmalionai"
  59. root: Optional[str] = None
  60. parent: Optional[str] = None
  61. max_model_len: Optional[int] = None
  62. permission: List[ModelPermission] = Field(default_factory=list)
  63. class ModelList(OpenAIBaseModel):
  64. object: str = "list"
  65. data: List[ModelCard] = Field(default_factory=list)
  66. class UsageInfo(OpenAIBaseModel):
  67. prompt_tokens: int = 0
  68. total_tokens: int = 0
  69. completion_tokens: Optional[int] = 0
  70. class JsonSchemaResponseFormat(OpenAIBaseModel):
  71. name: str
  72. description: Optional[str] = None
  73. # schema is the field in openai but that causes conflicts with pydantic so
  74. # instead use json_schema with an alias
  75. json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema')
  76. strict: Optional[bool] = None
  77. class ResponseFormat(OpenAIBaseModel):
  78. # type must be "json_schema", "json_object" or "text"
  79. type: Literal["text", "json_object", "json_schema"]
  80. json_schema: Optional[JsonSchemaResponseFormat] = None
  81. class StreamOptions(OpenAIBaseModel):
  82. include_usage: Optional[bool] = True
  83. continuous_usage_stats: Optional[bool] = True
  84. class FunctionDefinition(OpenAIBaseModel):
  85. name: str
  86. description: Optional[str] = None
  87. parameters: Optional[Dict[str, Any]] = None
  88. class ChatCompletionToolsParam(OpenAIBaseModel):
  89. type: Literal["function"] = "function"
  90. function: FunctionDefinition
  91. class ChatCompletionNamedFunction(OpenAIBaseModel):
  92. name: str
  93. class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
  94. function: ChatCompletionNamedFunction
  95. type: Literal["function"] = "function"
  96. class ChatCompletionRequest(OpenAIBaseModel):
  97. # Ordered by official OpenAI API documentation
  98. # https://platform.openai.com/docs/api-reference/chat/create
  99. messages: List[ChatCompletionMessageParam]
  100. model: str
  101. frequency_penalty: Optional[float] = 0.0
  102. logit_bias: Optional[Dict[str, float]] = None
  103. logprobs: Optional[bool] = False
  104. top_logprobs: Optional[int] = 0
  105. max_tokens: Optional[int] = None
  106. n: Optional[int] = 1
  107. presence_penalty: Optional[float] = 0.0
  108. response_format: Optional[ResponseFormat] = None
  109. seed: Optional[int] = Field(None,
  110. ge=torch.iinfo(torch.long).min,
  111. le=torch.iinfo(torch.long).max)
  112. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  113. stream: Optional[bool] = False
  114. stream_options: Optional[StreamOptions] = None
  115. temperature: Optional[float] = 0.7
  116. top_p: Optional[float] = 1.0
  117. tools: Optional[List[ChatCompletionToolsParam]] = None
  118. tool_choice: Optional[Union[Literal["none"], Literal["auto"],
  119. ChatCompletionNamedToolChoiceParam]] = "none"
  120. # NOTE this will be ignored by Aphrodite - the model determines the behavior
  121. parallel_tool_calls: Optional[bool] = False
  122. user: Optional[str] = None
  123. # doc: begin-chat-completion-sampling-params
  124. best_of: Optional[int] = None
  125. use_beam_search: Optional[bool] = False
  126. top_k: Optional[int] = -1
  127. min_p: Optional[float] = 0.0
  128. top_a: Optional[float] = 0.0
  129. tfs: Optional[float] = 1.0
  130. eta_cutoff: Optional[float] = 0.0
  131. epsilon_cutoff: Optional[float] = 0.0
  132. typical_p: Optional[float] = 1.0
  133. smoothing_factor: Optional[float] = 0.0
  134. smoothing_curve: Optional[float] = 1.0
  135. repetition_penalty: Optional[float] = 1.0
  136. no_repeat_ngram_size: Optional[int] = 0
  137. length_penalty: Optional[float] = 1.0
  138. early_stopping: Optional[bool] = False
  139. ignore_eos: Optional[bool] = False
  140. min_tokens: Optional[int] = 0
  141. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  142. skip_special_tokens: Optional[bool] = True
  143. spaces_between_special_tokens: Optional[bool] = True
  144. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  145. temperature_last: Optional[bool] = False
  146. prompt_logprobs: Optional[int] = None
  147. xtc_threshold: Optional[float] = 0.1
  148. xtc_probability: Optional[float] = 0.0
  149. dry_multiplier: Optional[float] = 0
  150. dry_base: Optional[float] = 1.75
  151. dry_allowed_length: Optional[int] = 2
  152. dry_sequence_breakers: Optional[List[str]] = Field(
  153. default=["\n", ":", "\"", "*"])
  154. dry_range: Optional[int] = Field(
  155. default=0,
  156. validation_alias=AliasChoices("dry_range",
  157. "dry_penalty_last_n"))
  158. dry_max_ngram: Optional[int] = 12
  159. dry_max_occurrences: Optional[int] = 8
  160. dry_early_exit_match_len: Optional[int] = 8
  161. dynatemp_min: Optional[float] = 0.0
  162. dynatemp_max: Optional[float] = 0.0
  163. dynatemp_exponent: Optional[float] = 1.0
  164. nsigma: Optional[float] = 0.0
  165. skew: Optional[float] = 0.0
  166. custom_token_bans: Optional[List[int]] = None
  167. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  168. default=[],
  169. validation_alias=AliasChoices("sampler_priority",
  170. "sampler_order"))
  171. # doc: end-chat-completion-sampling-params
  172. # doc: begin-chat-completion-extra-params
  173. echo: Optional[bool] = Field(
  174. default=False,
  175. description=(
  176. "If true, the new message will be prepended with the last message "
  177. "if they belong to the same role."),
  178. )
  179. add_generation_prompt: Optional[bool] = Field(
  180. default=True,
  181. description=
  182. ("If true, the generation prompt will be added to the chat template. "
  183. "This is a parameter used by chat template in tokenizer config of the "
  184. "model."),
  185. )
  186. add_special_tokens: Optional[bool] = Field(
  187. default=False,
  188. description=(
  189. "If true, special tokens (e.g. BOS) will be added to the prompt "
  190. "on top of what is added by the chat template. "
  191. "For most models, the chat template takes care of adding the "
  192. "special tokens so this should be set to False (as is the "
  193. "default)."),
  194. )
  195. documents: Optional[List[Dict[str, str]]] = Field(
  196. default=None,
  197. description=
  198. ("A list of dicts representing documents that will be accessible to "
  199. "the model if it is performing RAG (retrieval-augmented generation)."
  200. " If the template does not support RAG, this argument will have no "
  201. "effect. We recommend that each document should be a dict containing "
  202. "\"title\" and \"text\" keys."),
  203. )
  204. chat_template: Optional[str] = Field(
  205. default=None,
  206. description=(
  207. "A Jinja template to use for this conversion. "
  208. "As of transformers v4.44, default chat template is no longer "
  209. "allowed, so you must provide a chat template if the tokenizer "
  210. "does not define one."),
  211. )
  212. chat_template_kwargs: Optional[Dict[str, Any]] = Field(
  213. default=None,
  214. description=("Additional kwargs to pass to the template renderer. "
  215. "Will be accessible by the chat template."),
  216. )
  217. include_stop_str_in_output: Optional[bool] = Field(
  218. default=False,
  219. description=(
  220. "Whether to include the stop string in the output. "
  221. "This is only applied when the stop or stop_token_ids is set."),
  222. )
  223. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  224. default=None,
  225. description=("If specified, the output will follow the JSON schema."),
  226. )
  227. guided_regex: Optional[str] = Field(
  228. default=None,
  229. description=(
  230. "If specified, the output will follow the regex pattern."),
  231. )
  232. guided_choice: Optional[List[str]] = Field(
  233. default=None,
  234. description=(
  235. "If specified, the output will be exactly one of the choices."),
  236. )
  237. guided_grammar: Optional[str] = Field(
  238. default=None,
  239. description=(
  240. "If specified, the output will follow the context free grammar."),
  241. )
  242. guided_decoding_backend: Optional[str] = Field(
  243. default=None,
  244. description=(
  245. "If specified, will override the default guided decoding backend "
  246. "of the server for this specific request. If set, must be either "
  247. "'outlines' / 'lm-format-enforcer'"))
  248. guided_whitespace_pattern: Optional[str] = Field(
  249. default=None,
  250. description=(
  251. "If specified, will override the default whitespace pattern "
  252. "for guided json decoding."))
  253. # doc: end-chat-completion-extra-params
  254. def to_sampling_params(
  255. self, tokenizer: PreTrainedTokenizer,
  256. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  257. default_max_tokens: int) -> SamplingParams:
  258. max_tokens = self.max_tokens
  259. if max_tokens is None:
  260. max_tokens = default_max_tokens
  261. # We now allow logprobs being true without top_logrobs.
  262. logits_processors = get_logits_processors(
  263. logit_bias=self.logit_bias,
  264. allowed_token_ids=None,
  265. tokenizer=tokenizer,
  266. )
  267. if guided_decode_logits_processor:
  268. logits_processors.append(guided_decode_logits_processor)
  269. dry_sequence_breaker_ids = []
  270. if self.dry_sequence_breakers:
  271. for s in self.dry_sequence_breakers:
  272. token_id = tokenizer.encode(f'a{s}')[-1]
  273. dry_sequence_breaker_ids.append(token_id)
  274. return SamplingParams(
  275. n=self.n,
  276. presence_penalty=self.presence_penalty,
  277. frequency_penalty=self.frequency_penalty,
  278. repetition_penalty=self.repetition_penalty,
  279. no_repeat_ngram_size=self.no_repeat_ngram_size,
  280. temperature=self.temperature,
  281. top_p=self.top_p,
  282. min_p=self.min_p,
  283. seed=self.seed,
  284. stop=self.stop,
  285. stop_token_ids=self.stop_token_ids,
  286. max_tokens=max_tokens,
  287. min_tokens=self.min_tokens,
  288. logprobs=self.top_logprobs if self.logprobs else None,
  289. prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else
  290. (self.top_logprobs if self.echo else None),
  291. best_of=self.best_of,
  292. top_k=self.top_k,
  293. top_a=self.top_a,
  294. tfs=self.tfs,
  295. eta_cutoff=self.eta_cutoff,
  296. epsilon_cutoff=self.epsilon_cutoff,
  297. typical_p=self.typical_p,
  298. smoothing_factor=self.smoothing_factor,
  299. smoothing_curve=self.smoothing_curve,
  300. ignore_eos=self.ignore_eos,
  301. use_beam_search=self.use_beam_search,
  302. early_stopping=self.early_stopping,
  303. skip_special_tokens=self.skip_special_tokens,
  304. spaces_between_special_tokens=self.spaces_between_special_tokens,
  305. include_stop_str_in_output=self.include_stop_str_in_output,
  306. length_penalty=self.length_penalty,
  307. logits_processors=logits_processors,
  308. temperature_last=self.temperature_last,
  309. xtc_threshold=self.xtc_threshold,
  310. xtc_probability=self.xtc_probability,
  311. dry_multiplier=self.dry_multiplier,
  312. dry_base=self.dry_base,
  313. dry_allowed_length=self.dry_allowed_length,
  314. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  315. dry_range=self.dry_range,
  316. dry_max_ngram=self.dry_max_ngram,
  317. dry_max_occurrences=self.dry_max_occurrences,
  318. dry_early_exit_match_len=self.dry_early_exit_match_len,
  319. dynatemp_min=self.dynatemp_min,
  320. dynatemp_max=self.dynatemp_max,
  321. dynatemp_exponent=self.dynatemp_exponent,
  322. nsigma=self.nsigma,
  323. skew=self.skew,
  324. custom_token_bans=self.custom_token_bans,
  325. sampler_priority=self.sampler_priority,
  326. output_kind=RequestOutputKind.DELTA if self.stream \
  327. else RequestOutputKind.FINAL_ONLY,
  328. )
  329. @model_validator(mode='before')
  330. @classmethod
  331. def validate_stream_options(cls, values):
  332. if (values.get('stream_options') is not None
  333. and not values.get('stream')):
  334. raise ValueError(
  335. "stream_options can only be set if stream is true")
  336. return values
  337. @model_validator(mode="before")
  338. @classmethod
  339. def check_guided_decoding_count(cls, data):
  340. if isinstance(data, ValueError):
  341. raise data
  342. guide_count = sum([
  343. "guided_json" in data and data["guided_json"] is not None,
  344. "guided_regex" in data and data["guided_regex"] is not None,
  345. "guided_choice" in data and data["guided_choice"] is not None
  346. ])
  347. # you can only use one kind of guided decoding
  348. if guide_count > 1:
  349. raise ValueError(
  350. "You can only use one kind of guided decoding "
  351. "('guided_json', 'guided_regex' or 'guided_choice').")
  352. # you can only either use guided decoding or tools, not both
  353. if guide_count > 1 and data.get("tool_choice",
  354. "none") not in ("none", "auto"):
  355. raise ValueError(
  356. "You can only either use guided decoding or tools, not both.")
  357. return data
  358. @model_validator(mode="before")
  359. @classmethod
  360. def check_tool_usage(cls, data):
  361. # if "tool_choice" is not specified but tools are provided,
  362. # default to "auto" tool_choice
  363. if "tool_choice" not in data and "tools" in data:
  364. data["tool_choice"] = "auto"
  365. # if "tool_choice" is specified -- validation
  366. if "tool_choice" in data:
  367. # ensure that if "tool choice" is specified, tools are present
  368. if "tools" not in data or data["tools"] is None:
  369. raise ValueError(
  370. "When using `tool_choice`, `tools` must be set.")
  371. # make sure that tool choice is either a named tool
  372. # OR that it's set to "auto"
  373. if data["tool_choice"] != "auto" and not isinstance(
  374. data["tool_choice"], dict):
  375. raise ValueError(
  376. "`tool_choice` must either be a named tool or \"auto\". "
  377. "`tool_choice=\"none\" is not supported.")
  378. # ensure that if "tool_choice" is specified as an object,
  379. # it matches a valid tool
  380. if isinstance(data["tool_choice"], dict):
  381. valid_tool = False
  382. specified_function = data["tool_choice"]["function"]
  383. if not specified_function:
  384. raise ValueError(
  385. "Incorrectly formatted `tool_choice`. Should be like "
  386. "`{\"type\": \"function\","
  387. " \"function\": {\"name\": \"my_function\"}}`")
  388. specified_function_name = specified_function["name"]
  389. if not specified_function_name:
  390. raise ValueError(
  391. "Incorrectly formatted `tool_choice`. Should be like "
  392. "`{\"type\": \"function\", "
  393. "\"function\": {\"name\": \"my_function\"}}`")
  394. for tool in data["tools"]:
  395. if tool["function"]["name"] == specified_function_name:
  396. valid_tool = True
  397. break
  398. if not valid_tool:
  399. raise ValueError(
  400. "The tool specified in `tool_choice` does not match any"
  401. " of the specified `tools`")
  402. return data
  403. @model_validator(mode="before")
  404. @classmethod
  405. def check_logprobs(cls, data):
  406. if "top_logprobs" in data and data["top_logprobs"] is not None:
  407. if "logprobs" not in data or data["logprobs"] is False:
  408. raise ValueError(
  409. "when using `top_logprobs`, `logprobs` must be set to true."
  410. )
  411. elif data["top_logprobs"] < 0:
  412. raise ValueError(
  413. "`top_logprobs` must be a value a positive value.")
  414. return data
  415. class CompletionRequest(OpenAIBaseModel):
  416. # Ordered by official OpenAI API documentation
  417. # https://platform.openai.com/docs/api-reference/completions/create
  418. model: str
  419. prompt: Union[List[int], List[List[int]], str, List[str]]
  420. best_of: Optional[int] = None
  421. echo: Optional[bool] = False
  422. frequency_penalty: Optional[float] = 0.0
  423. logit_bias: Optional[Dict[str, float]] = None
  424. logprobs: Optional[int] = None
  425. max_tokens: Optional[int] = 16
  426. n: int = 1
  427. presence_penalty: Optional[float] = 0.0
  428. seed: Optional[int] = Field(None,
  429. ge=torch.iinfo(torch.long).min,
  430. le=torch.iinfo(torch.long).max)
  431. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  432. stream: Optional[bool] = False
  433. stream_options: Optional[StreamOptions] = None
  434. suffix: Optional[str] = None
  435. temperature: Optional[float] = 1.0
  436. top_p: Optional[float] = 1.0
  437. user: Optional[str] = None
  438. # doc: begin-completion-sampling-params
  439. use_beam_search: Optional[bool] = False
  440. top_k: Optional[int] = -1
  441. min_p: Optional[float] = 0.0
  442. top_a: Optional[float] = 0.0
  443. tfs: Optional[float] = 1.0
  444. eta_cutoff: Optional[float] = 0.0
  445. epsilon_cutoff: Optional[float] = 0.0
  446. typical_p: Optional[float] = 1.0
  447. smoothing_factor: Optional[float] = 0.0
  448. smoothing_curve: Optional[float] = 1.0
  449. repetition_penalty: Optional[float] = 1.0
  450. no_repeat_ngram_size: Optional[int] = 0
  451. length_penalty: Optional[float] = 1.0
  452. early_stopping: Optional[bool] = False
  453. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  454. ignore_eos: Optional[bool] = False
  455. min_tokens: Optional[int] = 0
  456. skip_special_tokens: Optional[bool] = True
  457. spaces_between_special_tokens: Optional[bool] = True
  458. truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
  459. allowed_token_ids: Optional[List[int]] = None
  460. include_stop_str_in_output: Optional[bool] = False
  461. add_special_tokens: Optional[bool] = False
  462. temperature_last: Optional[bool] = False
  463. prompt_logprobs: Optional[int] = None
  464. xtc_threshold: Optional[float] = 0.1
  465. xtc_probability: Optional[float] = 0.0
  466. dry_multiplier: Optional[float] = 0
  467. dry_base: Optional[float] = 1.75
  468. dry_allowed_length: Optional[int] = 2
  469. dry_sequence_breakers: Optional[List[str]] = Field(
  470. default=["\n", ":", "\"", "*"])
  471. dry_range: Optional[int] = Field(
  472. default=0,
  473. validation_alias=AliasChoices("dry_range",
  474. "dry_penalty_last_n"))
  475. dry_max_ngram: Optional[int] = 12
  476. dry_max_occurrences: Optional[int] = 8
  477. dry_early_exit_match_len: Optional[int] = 8
  478. dynatemp_min: Optional[float] = 0.0
  479. dynatemp_max: Optional[float] = 0.0
  480. dynatemp_exponent: Optional[float] = 1.0
  481. nsigma: Optional[float] = 0.0
  482. skew: Optional[float] = 0.0
  483. custom_token_bans: Optional[List[int]] = None
  484. sampler_priority: Optional[Union[List[int], List[str]]] = Field(
  485. default=[],
  486. validation_alias=AliasChoices("sampler_priority",
  487. "sampler_order"))
  488. # doc: end-completion-sampling-params
  489. # doc: begin-completion-extra-params
  490. response_format: Optional[ResponseFormat] = Field(
  491. default=None,
  492. description=
  493. ("Similar to chat completion, this parameter specifies the format of "
  494. "output. Only {'type': 'json_object'} or {'type': 'text' } is "
  495. "supported."),
  496. )
  497. guided_json: Optional[Union[str, dict, BaseModel]] = Field(
  498. default=None,
  499. description=("If specified, the output will follow the JSON schema."),
  500. )
  501. guided_regex: Optional[str] = Field(
  502. default=None,
  503. description=(
  504. "If specified, the output will follow the regex pattern."),
  505. )
  506. guided_choice: Optional[List[str]] = Field(
  507. default=None,
  508. description=(
  509. "If specified, the output will be exactly one of the choices."),
  510. )
  511. guided_grammar: Optional[str] = Field(
  512. default=None,
  513. description=(
  514. "If specified, the output will follow the context free grammar."),
  515. )
  516. guided_decoding_backend: Optional[str] = Field(
  517. default=None,
  518. description=(
  519. "If specified, will override the default guided decoding backend "
  520. "of the server for this specific request. If set, must be one of "
  521. "'outlines' / 'lm-format-enforcer'"))
  522. guided_whitespace_pattern: Optional[str] = Field(
  523. default=None,
  524. description=(
  525. "If specified, will override the default whitespace pattern "
  526. "for guided json decoding."))
  527. # doc: end-completion-extra-params
  528. def to_sampling_params(
  529. self, tokenizer: PreTrainedTokenizer,
  530. guided_decode_logits_processor: Optional[LogitsProcessorFunc],
  531. default_max_tokens: int) -> SamplingParams:
  532. max_tokens = self.max_tokens
  533. if max_tokens is None:
  534. max_tokens = default_max_tokens
  535. echo_without_generation = self.echo and self.max_tokens == 0
  536. logits_processors = get_logits_processors(
  537. logit_bias=self.logit_bias,
  538. allowed_token_ids=self.allowed_token_ids,
  539. tokenizer=tokenizer,
  540. )
  541. if guided_decode_logits_processor:
  542. logits_processors.append(guided_decode_logits_processor)
  543. dry_sequence_breaker_ids = []
  544. if self.dry_sequence_breakers:
  545. for s in self.dry_sequence_breakers:
  546. s = bytes(s, "utf-8").decode("unicode_escape")
  547. token_id = tokenizer.encode(f'a{s}')[-1]
  548. dry_sequence_breaker_ids.append(token_id)
  549. return SamplingParams(
  550. n=self.n,
  551. best_of=self.best_of,
  552. presence_penalty=self.presence_penalty,
  553. frequency_penalty=self.frequency_penalty,
  554. repetition_penalty=self.repetition_penalty,
  555. no_repeat_ngram_size=self.no_repeat_ngram_size,
  556. temperature=self.temperature,
  557. top_p=self.top_p,
  558. top_k=self.top_k,
  559. min_p=self.min_p,
  560. top_a=self.top_a,
  561. tfs=self.tfs,
  562. eta_cutoff=self.eta_cutoff,
  563. epsilon_cutoff=self.epsilon_cutoff,
  564. typical_p=self.typical_p,
  565. smoothing_factor=self.smoothing_factor,
  566. smoothing_curve=self.smoothing_curve,
  567. seed=self.seed,
  568. stop=self.stop,
  569. stop_token_ids=self.stop_token_ids,
  570. ignore_eos=self.ignore_eos,
  571. max_tokens=max_tokens if not echo_without_generation else 1,
  572. min_tokens=self.min_tokens,
  573. logprobs=self.logprobs,
  574. prompt_logprobs=self.prompt_logprobs
  575. if self.prompt_logprobs else self.logprobs if self.echo else None,
  576. use_beam_search=self.use_beam_search,
  577. early_stopping=self.early_stopping,
  578. skip_special_tokens=self.skip_special_tokens,
  579. spaces_between_special_tokens=(self.spaces_between_special_tokens),
  580. include_stop_str_in_output=self.include_stop_str_in_output,
  581. length_penalty=self.length_penalty,
  582. logits_processors=logits_processors,
  583. truncate_prompt_tokens=self.truncate_prompt_tokens,
  584. temperature_last=self.temperature_last,
  585. xtc_threshold=self.xtc_threshold,
  586. xtc_probability=self.xtc_probability,
  587. dry_multiplier=self.dry_multiplier,
  588. dry_base=self.dry_base,
  589. dry_allowed_length=self.dry_allowed_length,
  590. dry_sequence_breaker_ids=dry_sequence_breaker_ids,
  591. dry_range=self.dry_range,
  592. dry_max_ngram=self.dry_max_ngram,
  593. dry_max_occurrences=self.dry_max_occurrences,
  594. dry_early_exit_match_len=self.dry_early_exit_match_len,
  595. dynatemp_min=self.dynatemp_min,
  596. dynatemp_max=self.dynatemp_max,
  597. dynatemp_exponent=self.dynatemp_exponent,
  598. nsigma=self.nsigma,
  599. skew=self.skew,
  600. custom_token_bans=self.custom_token_bans,
  601. sampler_priority=self.sampler_priority,
  602. output_kind=RequestOutputKind.DELTA if self.stream \
  603. else RequestOutputKind.FINAL_ONLY,
  604. )
  605. @model_validator(mode="before")
  606. @classmethod
  607. def check_guided_decoding_count(cls, data):
  608. guide_count = sum([
  609. "guided_json" in data and data["guided_json"] is not None,
  610. "guided_regex" in data and data["guided_regex"] is not None,
  611. "guided_choice" in data and data["guided_choice"] is not None
  612. ])
  613. if guide_count > 1:
  614. raise ValueError(
  615. "You can only use one kind of guided decoding "
  616. "('guided_json', 'guided_regex' or 'guided_choice').")
  617. return data
  618. @model_validator(mode="before")
  619. @classmethod
  620. def check_logprobs(cls, data):
  621. if "logprobs" in data and data[
  622. "logprobs"] is not None and not data["logprobs"] >= 0:
  623. raise ValueError("if passed, `logprobs` must be a positive value.")
  624. return data
  625. @model_validator(mode="before")
  626. @classmethod
  627. def validate_stream_options(cls, data):
  628. if data.get("stream_options") and not data.get("stream"):
  629. raise ValueError(
  630. "Stream options can only be defined when stream is True.")
  631. return data
  632. @model_validator(mode='before')
  633. @classmethod
  634. def parse_dry_sequence_breakers(cls, data):
  635. if 'dry_sequence_breakers' in data:
  636. breakers = data['dry_sequence_breakers']
  637. if isinstance(breakers, str):
  638. try:
  639. # Try to parse as JSON string
  640. data['dry_sequence_breakers'] = json.loads(breakers)
  641. except json.JSONDecodeError as e:
  642. raise ValueError(f"Invalid JSON for dry_sequence_breakers:"
  643. f" {e}") from e
  644. # Validate that we now have a list of strings
  645. is_list = isinstance(data['dry_sequence_breakers'], list)
  646. all_strings = all(
  647. isinstance(x, str)
  648. for x in data['dry_sequence_breakers']
  649. )
  650. if not is_list or not all_strings:
  651. raise ValueError(
  652. "dry_sequence_breakers must be a list of strings or a "
  653. "JSON string representing a list of strings"
  654. )
  655. return data
  656. class EmbeddingRequest(OpenAIBaseModel):
  657. # Ordered by official OpenAI API documentation
  658. # https://platform.openai.com/docs/api-reference/embeddings
  659. model: str
  660. input: Union[List[int], List[List[int]], str, List[str]]
  661. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  662. dimensions: Optional[int] = None
  663. user: Optional[str] = None
  664. # doc: begin-embedding-pooling-params
  665. additional_data: Optional[Any] = None
  666. # doc: end-embedding-pooling-params
  667. def to_pooling_params(self):
  668. return PoolingParams(additional_data=self.additional_data)
  669. class CompletionLogProbs(OpenAIBaseModel):
  670. text_offset: List[int] = Field(default_factory=list)
  671. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  672. tokens: List[str] = Field(default_factory=list)
  673. top_logprobs: List[Optional[Dict[str,
  674. float]]] = Field(default_factory=list)
  675. class CompletionResponseChoice(OpenAIBaseModel):
  676. index: int
  677. text: str
  678. logprobs: Optional[CompletionLogProbs] = None
  679. finish_reason: Optional[str] = None
  680. stop_reason: Optional[Union[int, str]] = Field(
  681. default=None,
  682. description=(
  683. "The stop string or token id that caused the completion "
  684. "to stop, None if the completion finished for some other reason "
  685. "including encountering the EOS token"),
  686. )
  687. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  688. class CompletionResponse(OpenAIBaseModel):
  689. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  690. object: str = "text_completion"
  691. created: int = Field(default_factory=lambda: int(time.time()))
  692. model: str
  693. choices: List[CompletionResponseChoice]
  694. usage: UsageInfo
  695. class CompletionResponseStreamChoice(OpenAIBaseModel):
  696. index: int
  697. text: str
  698. logprobs: Optional[CompletionLogProbs] = None
  699. finish_reason: Optional[str] = None
  700. stop_reason: Optional[Union[int, str]] = Field(
  701. default=None,
  702. description=(
  703. "The stop string or token id that caused the completion "
  704. "to stop, None if the completion finished for some other reason "
  705. "including encountering the EOS token"),
  706. )
  707. class CompletionStreamResponse(OpenAIBaseModel):
  708. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  709. object: str = "text_completion"
  710. created: int = Field(default_factory=lambda: int(time.time()))
  711. model: str
  712. choices: List[CompletionResponseStreamChoice]
  713. usage: Optional[UsageInfo] = Field(default=None)
  714. class EmbeddingResponseData(OpenAIBaseModel):
  715. index: int
  716. object: str = "embedding"
  717. embedding: Union[List[float], str]
  718. class EmbeddingResponse(OpenAIBaseModel):
  719. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  720. object: str = "list"
  721. created: int = Field(default_factory=lambda: int(time.time()))
  722. model: str
  723. data: List[EmbeddingResponseData]
  724. usage: UsageInfo
  725. class FunctionCall(OpenAIBaseModel):
  726. name: str
  727. arguments: str
  728. class ToolCall(OpenAIBaseModel):
  729. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  730. type: Literal["function"] = "function"
  731. function: FunctionCall
  732. class DeltaFunctionCall(BaseModel):
  733. name: Optional[str] = None
  734. arguments: Optional[str] = None
  735. # a tool call delta where everything is optional
  736. class DeltaToolCall(OpenAIBaseModel):
  737. id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}")
  738. type: Literal["function"] = "function"
  739. index: int
  740. function: Optional[DeltaFunctionCall] = None
  741. class ExtractedToolCallInformation(BaseModel):
  742. # indicate if tools were called
  743. tools_called: bool
  744. # extracted tool calls
  745. tool_calls: List[ToolCall]
  746. # content - per OpenAI spec, content AND tool calls can be returned rarely
  747. # But some models will do this intentionally
  748. content: Optional[str] = None
  749. class ChatMessage(OpenAIBaseModel):
  750. role: str
  751. content: Optional[str] = None
  752. tool_calls: List[ToolCall] = Field(default_factory=list)
  753. class ChatCompletionLogProb(OpenAIBaseModel):
  754. token: str
  755. logprob: float = -9999.0
  756. bytes: Optional[List[int]] = None
  757. class ChatCompletionLogProbsContent(ChatCompletionLogProb):
  758. top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
  759. class ChatCompletionLogProbs(OpenAIBaseModel):
  760. content: Optional[List[ChatCompletionLogProbsContent]] = None
  761. class ChatCompletionResponseChoice(OpenAIBaseModel):
  762. index: int
  763. message: ChatMessage
  764. logprobs: Optional[ChatCompletionLogProbs] = None
  765. # per OpenAI spec this is the default
  766. finish_reason: Optional[str] = "stop"
  767. # not part of the OpenAI spec but included in Aphrodite for legacy reasons
  768. stop_reason: Optional[Union[int, str]] = None
  769. class ChatCompletionResponse(OpenAIBaseModel):
  770. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  771. object: Literal["chat.completion"] = "chat.completion"
  772. created: int = Field(default_factory=lambda: int(time.time()))
  773. model: str
  774. choices: List[ChatCompletionResponseChoice]
  775. usage: UsageInfo
  776. prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None
  777. class DeltaMessage(OpenAIBaseModel):
  778. role: Optional[str] = None
  779. content: Optional[str] = None
  780. tool_calls: List[DeltaToolCall] = Field(default_factory=list)
  781. class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
  782. index: int
  783. delta: DeltaMessage
  784. logprobs: Optional[ChatCompletionLogProbs] = None
  785. finish_reason: Optional[str] = None
  786. stop_reason: Optional[Union[int, str]] = None
  787. class ChatCompletionStreamResponse(OpenAIBaseModel):
  788. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  789. object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
  790. created: int = Field(default_factory=lambda: int(time.time()))
  791. model: str
  792. choices: List[ChatCompletionResponseStreamChoice]
  793. usage: Optional[UsageInfo] = Field(default=None)
  794. class BatchRequestInput(OpenAIBaseModel):
  795. """
  796. The per-line object of the batch input file.
  797. NOTE: Currently only the `/v1/chat/completions` endpoint is supported.
  798. """
  799. # A developer-provided per-request id that will be used to match outputs to
  800. # inputs. Must be unique for each request in a batch.
  801. custom_id: str
  802. # The HTTP method to be used for the request. Currently only POST is
  803. # supported.
  804. method: str
  805. # The OpenAI API relative URL to be used for the request. Currently
  806. # /v1/chat/completions is supported.
  807. url: str
  808. # The parameters of the request.
  809. body: Union[ChatCompletionRequest, EmbeddingRequest]
  810. class BatchResponseData(OpenAIBaseModel):
  811. # HTTP status code of the response.
  812. status_code: int = 200
  813. # An unique identifier for the API request.
  814. request_id: str
  815. # The body of the response.
  816. body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None
  817. class BatchRequestOutput(OpenAIBaseModel):
  818. """
  819. The per-line object of the batch output and error files
  820. """
  821. id: str
  822. # A developer-provided per-request id that will be used to match outputs to
  823. # inputs.
  824. custom_id: str
  825. response: Optional[BatchResponseData]
  826. # For requests that failed with a non-HTTP error, this will contain more
  827. # information on the cause of the failure.
  828. error: Optional[Any]
  829. class TokenizeCompletionRequest(OpenAIBaseModel):
  830. model: str
  831. prompt: str
  832. add_special_tokens: bool = Field(default=True)
  833. class TokenizeChatRequest(OpenAIBaseModel):
  834. model: str
  835. messages: List[ChatCompletionMessageParam]
  836. add_generation_prompt: bool = Field(default=True)
  837. add_special_tokens: bool = Field(default=False)
  838. TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
  839. class TokenizeResponse(OpenAIBaseModel):
  840. tokens: List[int]
  841. count: int
  842. max_model_len: int
  843. class DetokenizeRequest(OpenAIBaseModel):
  844. model: Optional[str]
  845. tokens: List[int]
  846. class DetokenizeResponse(OpenAIBaseModel):
  847. prompt: str
  848. # ========== KoboldAI ========== #
  849. class KAIGenerationInputSchema(BaseModel):
  850. genkey: Optional[str] = None
  851. prompt: str
  852. n: Optional[int] = 1
  853. max_context_length: int
  854. max_length: int
  855. rep_pen: Optional[float] = 1.0
  856. top_k: Optional[int] = 0
  857. top_a: Optional[float] = 0.0
  858. top_p: Optional[float] = 1.0
  859. min_p: Optional[float] = 0.0
  860. tfs: Optional[float] = 1.0
  861. eps_cutoff: Optional[float] = 0.0
  862. eta_cutoff: Optional[float] = 0.0
  863. typical: Optional[float] = 1.0
  864. temperature: Optional[float] = 1.0
  865. dynatemp_range: Optional[float] = 0.0
  866. dynatemp_exponent: Optional[float] = 1.0
  867. smoothing_factor: Optional[float] = 0.0
  868. smoothing_curve: Optional[float] = 1.0
  869. xtc_threshold: Optional[float] = 0.1
  870. xtc_probability: Optional[float] = 0.0
  871. use_default_badwordsids: Optional[bool] = None
  872. quiet: Optional[bool] = None
  873. # pylint: disable=unexpected-keyword-arg
  874. sampler_seed: Optional[int] = None
  875. stop_sequence: Optional[List[str]] = None
  876. include_stop_str_in_output: Optional[bool] = False
  877. @model_validator(mode='before')
  878. def check_context(cls, values): # pylint: disable=no-self-argument
  879. assert values.get("max_length") <= values.get(
  880. "max_context_length"
  881. ), "max_length must not be larger than max_context_length"
  882. return values