protocol.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Dict, List, Literal, Optional, Union
  5. from pydantic import (AliasChoices, BaseModel, Field, conint, model_validator,
  6. root_validator)
  7. from aphrodite.common.sampling_params import SamplingParams
  8. from aphrodite.common.utils import random_uuid
  9. from aphrodite.common.logits_processor import BiasLogitsProcessor
  10. class ErrorResponse(BaseModel):
  11. object: str = "error"
  12. message: str
  13. type: str
  14. param: Optional[str] = None
  15. code: int
  16. class ModelPermission(BaseModel):
  17. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  18. object: str = "model_permission"
  19. created: int = Field(default_factory=lambda: int(time.time()))
  20. allow_create_engine: bool = False
  21. allow_sampling: bool = True
  22. allow_logprobs: bool = True
  23. allow_search_indices: bool = False
  24. allow_view: bool = True
  25. allow_fine_tuning: bool = False
  26. organization: str = "*"
  27. group: Optional[str] = None
  28. is_blocking: bool = False
  29. class ModelCard(BaseModel):
  30. id: str
  31. object: str = "model"
  32. created: int = Field(default_factory=lambda: int(time.time()))
  33. owned_by: str = "pygmalionai"
  34. root: Optional[str] = None
  35. parent: Optional[str] = None
  36. permission: List[ModelPermission] = Field(default_factory=list)
  37. class ModelList(BaseModel):
  38. object: str = "list"
  39. data: List[ModelCard] = Field(default_factory=list)
  40. class UsageInfo(BaseModel):
  41. prompt_tokens: int = 0
  42. total_tokens: int = 0
  43. completion_tokens: Optional[int] = 0
  44. class ResponseFormat(BaseModel):
  45. # type must be "json_object" or "text"
  46. type: str = Literal["text", "json_object"]
  47. class ChatCompletionRequest(BaseModel):
  48. model: str
  49. messages: List[Dict[str, str]]
  50. temperature: Optional[float] = 0.7
  51. top_p: Optional[float] = 1.0
  52. tfs: Optional[float] = 1.0
  53. eta_cutoff: Optional[float] = 0.0
  54. epsilon_cutoff: Optional[float] = 0.0
  55. typical_p: Optional[float] = 1.0
  56. n: Optional[int] = 1
  57. max_tokens: Optional[int] = None
  58. min_tokens: Optional[int] = 0
  59. seed: Optional[int] = None
  60. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  61. include_stop_str_in_output: Optional[bool] = False
  62. stream: Optional[bool] = False
  63. logprobs: Optional[bool] = False
  64. top_logprobs: Optional[int] = None
  65. presence_penalty: Optional[float] = 0.0
  66. frequency_penalty: Optional[float] = 0.0
  67. repetition_penalty: Optional[float] = 1.0
  68. logit_bias: Optional[Dict[str, float]] = None
  69. user: Optional[str] = None
  70. best_of: Optional[int] = None
  71. top_k: Optional[int] = -1
  72. top_a: Optional[float] = 0.0
  73. min_p: Optional[float] = 0.0
  74. mirostat_mode: Optional[int] = 0
  75. mirostat_tau: Optional[float] = 0.0
  76. mirostat_eta: Optional[float] = 0.0
  77. dynatemp_min: Optional[float] = 0.0
  78. dynatemp_max: Optional[float] = 0.0
  79. dynatemp_exponent: Optional[float] = 1.0
  80. smoothing_factor: Optional[float] = 0.0
  81. smoothing_curve: Optional[float] = 1.0
  82. ignore_eos: Optional[bool] = False
  83. use_beam_search: Optional[bool] = False
  84. prompt_logprobs: Optional[int] = None
  85. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  86. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  87. skip_special_tokens: Optional[bool] = True
  88. spaces_between_special_tokens: Optional[bool] = True
  89. add_generation_prompt: Optional[bool] = True
  90. echo: Optional[bool] = False
  91. length_penalty: Optional[float] = 1.0
  92. guided_json: Optional[Union[str, dict, BaseModel]] = None
  93. guided_regex: Optional[str] = None
  94. guided_choice: Optional[List[str]] = None
  95. guided_grammar: Optional[str] = None
  96. response_format: Optional[ResponseFormat] = None
  97. guided_decoding_backend: Optional[str] = Field(
  98. default="outlines",
  99. description=(
  100. "If specified, will override the default guided decoding backend "
  101. "of the server for this specific request. If set, must be either "
  102. "'outlines' / 'lm-format-enforcer'"))
  103. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  104. if self.logprobs and not self.top_logprobs:
  105. raise ValueError("Top logprobs must be set when logprobs is.")
  106. if self.top_k == 0:
  107. self.top_k = -1
  108. logits_processors = []
  109. if self.logit_bias:
  110. biases = {
  111. int(tok): max(-100, min(float(bias), 100))
  112. for tok, bias in self.logit_bias.items()
  113. if 0 < int(tok) < vocab_size
  114. }
  115. logits_processors.append(BiasLogitsProcessor(biases))
  116. return SamplingParams(
  117. n=self.n,
  118. max_tokens=self.max_tokens,
  119. min_tokens=self.min_tokens,
  120. logprobs=self.top_logprobs if self.logprobs else None,
  121. prompt_logprobs=self.top_logprobs if self.echo else None,
  122. temperature=self.temperature,
  123. top_p=self.top_p,
  124. tfs=self.tfs,
  125. eta_cutoff=self.eta_cutoff,
  126. epsilon_cutoff=self.epsilon_cutoff,
  127. typical_p=self.typical_p,
  128. presence_penalty=self.presence_penalty,
  129. frequency_penalty=self.frequency_penalty,
  130. repetition_penalty=self.repetition_penalty,
  131. top_k=self.top_k,
  132. top_a=self.top_a,
  133. min_p=self.min_p,
  134. mirostat_mode=self.mirostat_mode,
  135. mirostat_tau=self.mirostat_tau,
  136. mirostat_eta=self.mirostat_eta,
  137. dynatemp_min=self.dynatemp_min,
  138. dynatemp_max=self.dynatemp_max,
  139. dynatemp_exponent=self.dynatemp_exponent,
  140. smoothing_factor=self.smoothing_factor,
  141. smoothing_curve=self.smoothing_curve,
  142. ignore_eos=self.ignore_eos,
  143. use_beam_search=self.use_beam_search,
  144. stop_token_ids=self.stop_token_ids,
  145. custom_token_bans=self.custom_token_bans,
  146. skip_special_tokens=self.skip_special_tokens,
  147. spaces_between_special_tokens=self.spaces_between_special_tokens,
  148. stop=self.stop,
  149. best_of=self.best_of,
  150. include_stop_str_in_output=self.include_stop_str_in_output,
  151. seed=self.seed,
  152. logits_processors=logits_processors,
  153. )
  154. @model_validator(mode="before")
  155. @classmethod
  156. def check_guided_decoding_count(cls, data):
  157. guide_count = sum([
  158. "guided_json" in data and data["guided_json"] is not None,
  159. "guided_regex" in data and data["guided_regex"] is not None,
  160. "guided_choice" in data and data["guided_choice"] is not None
  161. ])
  162. if guide_count > 1:
  163. raise ValueError(
  164. "You can only use one kind of guided decoding "
  165. "('guided_json', 'guided_regex' or 'guided_choice').")
  166. return data
  167. class CompletionRequest(BaseModel):
  168. model: str
  169. # a string, array of strings, array of tokens, or array of token arrays
  170. prompt: Union[List[int], List[List[int]], str, List[str]]
  171. suffix: Optional[str] = None
  172. max_tokens: Optional[int] = 16
  173. min_tokens: Optional[int] = 0
  174. temperature: Optional[float] = 1.0
  175. top_p: Optional[float] = 1.0
  176. tfs: Optional[float] = 1.0
  177. eta_cutoff: Optional[float] = 0.0
  178. epsilon_cutoff: Optional[float] = 0.0
  179. typical_p: Optional[float] = 1.0
  180. n: Optional[int] = 1
  181. stream: Optional[bool] = False
  182. logprobs: Optional[int] = None
  183. echo: Optional[bool] = False
  184. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  185. seed: Optional[int] = None
  186. include_stop_str_in_output: Optional[bool] = False
  187. presence_penalty: Optional[float] = 0.0
  188. frequency_penalty: Optional[float] = 0.0
  189. repetition_penalty: Optional[float] = 1.0
  190. best_of: Optional[int] = None
  191. logit_bias: Optional[Dict[str, float]] = None
  192. user: Optional[str] = None
  193. top_k: Optional[int] = -1
  194. top_a: Optional[float] = 0.0
  195. min_p: Optional[float] = 0.0
  196. mirostat_mode: Optional[int] = 0
  197. mirostat_tau: Optional[float] = 0.0
  198. mirostat_eta: Optional[float] = 0.0
  199. dynatemp_min: Optional[float] = Field(0.0,
  200. validation_alias=AliasChoices(
  201. "dynatemp_min", "dynatemp_low"),
  202. description="Aliases: dynatemp_low")
  203. dynatemp_max: Optional[float] = Field(0.0,
  204. validation_alias=AliasChoices(
  205. "dynatemp_max", "dynatemp_high"),
  206. description="Aliases: dynatemp_high")
  207. dynatemp_exponent: Optional[float] = 1.0
  208. smoothing_factor: Optional[float] = 0.0
  209. smoothing_curve: Optional[float] = 1.0
  210. ignore_eos: Optional[bool] = False
  211. use_beam_search: Optional[bool] = False
  212. logprobs: Optional[int] = None
  213. prompt_logprobs: Optional[int] = None
  214. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  215. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  216. skip_special_tokens: Optional[bool] = True
  217. spaces_between_special_tokens: Optional[bool] = True
  218. truncate_prompt_tokens: Optional[conint(ge=1)] = None
  219. grammar: Optional[str] = None
  220. length_penalty: Optional[float] = 1.0
  221. guided_json: Optional[Union[str, dict, BaseModel]] = None
  222. guided_regex: Optional[str] = None
  223. guided_choice: Optional[List[str]] = None
  224. guided_grammar: Optional[str] = None
  225. response_format: Optional[ResponseFormat] = None
  226. guided_decoding_backend: Optional[str] = Field(
  227. default="outlines",
  228. description=(
  229. "If specified, will override the default guided decoding backend "
  230. "of the server for this specific request. If set, must be one of "
  231. "'outlines' / 'lm-format-enforcer'"))
  232. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  233. echo_without_generation = self.echo and self.max_tokens == 0
  234. if self.top_k == 0:
  235. self.top_k = -1
  236. logits_processors = []
  237. if self.logit_bias:
  238. biases = {
  239. int(tok): max(-100, min(float(bias), 100))
  240. for tok, bias in self.logit_bias.items()
  241. if 0 < int(tok) < vocab_size
  242. }
  243. logits_processors.append(BiasLogitsProcessor(biases))
  244. return SamplingParams(
  245. n=self.n,
  246. max_tokens=self.max_tokens if not echo_without_generation else 1,
  247. min_tokens=self.min_tokens,
  248. temperature=self.temperature,
  249. top_p=self.top_p,
  250. tfs=self.tfs,
  251. eta_cutoff=self.eta_cutoff,
  252. epsilon_cutoff=self.epsilon_cutoff,
  253. typical_p=self.typical_p,
  254. presence_penalty=self.presence_penalty,
  255. frequency_penalty=self.frequency_penalty,
  256. repetition_penalty=self.repetition_penalty,
  257. top_k=self.top_k,
  258. top_a=self.top_a,
  259. min_p=self.min_p,
  260. mirostat_mode=self.mirostat_mode,
  261. mirostat_tau=self.mirostat_tau,
  262. mirostat_eta=self.mirostat_eta,
  263. dynatemp_min=self.dynatemp_min,
  264. dynatemp_max=self.dynatemp_max,
  265. dynatemp_exponent=self.dynatemp_exponent,
  266. smoothing_factor=self.smoothing_factor,
  267. smoothing_curve=self.smoothing_curve,
  268. ignore_eos=self.ignore_eos,
  269. use_beam_search=self.use_beam_search,
  270. logprobs=self.logprobs,
  271. prompt_logprobs=self.prompt_logprobs if self.echo else None,
  272. stop_token_ids=self.stop_token_ids,
  273. custom_token_bans=self.custom_token_bans,
  274. skip_special_tokens=self.skip_special_tokens,
  275. spaces_between_special_tokens=self.spaces_between_special_tokens,
  276. stop=self.stop,
  277. best_of=self.best_of,
  278. include_stop_str_in_output=self.include_stop_str_in_output,
  279. seed=self.seed,
  280. logits_processors=logits_processors,
  281. truncate_prompt_tokens=self.truncate_prompt_tokens,
  282. )
  283. @model_validator(mode="before")
  284. @classmethod
  285. def check_guided_decoding_count(cls, data):
  286. guide_count = sum([
  287. "guided_json" in data and data["guided_json"] is not None,
  288. "guided_regex" in data and data["guided_regex"] is not None,
  289. "guided_choice" in data and data["guided_choice"] is not None
  290. ])
  291. if guide_count > 1:
  292. raise ValueError(
  293. "You can only use one kind of guided decoding "
  294. "('guided_json', 'guided_regex' or 'guided_choice').")
  295. return data
  296. class LogProbs(BaseModel):
  297. text_offset: List[int] = Field(default_factory=list)
  298. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  299. tokens: List[str] = Field(default_factory=list)
  300. top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
  301. class CompletionResponseChoice(BaseModel):
  302. index: int
  303. text: str
  304. logprobs: Optional[LogProbs] = None
  305. finish_reason: Optional[Literal["stop", "length"]] = None
  306. stop_reason: Union[None, int, str] = Field(
  307. default=None,
  308. description=(
  309. "The stop string or token id that caused the completion "
  310. "to stop, None if the completion finished for some other reason "
  311. "including encountering the EOS token"),
  312. )
  313. class CompletionResponse(BaseModel):
  314. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  315. object: str = "text_completion"
  316. created: int = Field(default_factory=lambda: int(time.time()))
  317. model: str
  318. choices: List[CompletionResponseChoice]
  319. usage: UsageInfo
  320. class CompletionResponseStreamChoice(BaseModel):
  321. index: int
  322. text: str
  323. logprobs: Optional[LogProbs] = None
  324. finish_reason: Optional[Literal["stop", "length"]] = None
  325. stop_reason: Union[None, int, str] = Field(
  326. default=None,
  327. description=(
  328. "The stop string or token id that caused the completion "
  329. "to stop, None if the completion finished for some other reason "
  330. "including encountering the EOS token"),
  331. )
  332. class CompletionStreamResponse(BaseModel):
  333. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  334. object: str = "text_completion"
  335. created: int = Field(default_factory=lambda: int(time.time()))
  336. model: str
  337. choices: List[CompletionResponseStreamChoice]
  338. usage: Optional[UsageInfo] = Field(default=None)
  339. class ChatMessage(BaseModel):
  340. role: str
  341. content: str
  342. class ChatCompletionResponseChoice(BaseModel):
  343. index: int
  344. message: ChatMessage
  345. logprobs: Optional[LogProbs] = None
  346. finish_reason: Optional[Literal["stop", "length"]] = None
  347. stop_reason: Union[None, int, str] = None
  348. class ChatCompletionResponse(BaseModel):
  349. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  350. object: str = "chat.completion"
  351. created: int = Field(default_factory=lambda: int(time.time()))
  352. model: str
  353. choices: List[ChatCompletionResponseChoice]
  354. usage: UsageInfo
  355. class DeltaMessage(BaseModel):
  356. role: Optional[str] = None
  357. content: Optional[str] = None
  358. class ChatCompletionResponseStreamChoice(BaseModel):
  359. index: int
  360. delta: DeltaMessage
  361. logprobs: Optional[LogProbs] = None
  362. finish_reason: Optional[Literal["stop", "length"]] = None
  363. stop_reason: Union[None, int, str] = None
  364. class ChatCompletionStreamResponse(BaseModel):
  365. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  366. object: str = "chat.completion.chunk"
  367. created: int = Field(default_factory=lambda: int(time.time()))
  368. model: str
  369. choices: List[ChatCompletionResponseStreamChoice]
  370. usage: Optional[UsageInfo] = Field(default=None)
  371. logprobs: Optional[LogProbs] = None
  372. class EmbeddingsRequest(BaseModel):
  373. input: List[str] = Field(
  374. ..., description="List of input texts to generate embeddings for.")
  375. encoding_format: str = Field(
  376. "float",
  377. description="Encoding format for the embeddings. "
  378. "Can be 'float' or 'base64'.")
  379. model: Optional[str] = Field(
  380. None,
  381. description="Name of the embedding model to use. "
  382. "If not provided, the default model will be used.")
  383. class EmbeddingObject(BaseModel):
  384. object: str = Field("embedding", description="Type of the object.")
  385. embedding: List[float] = Field(
  386. ..., description="Embedding values as a list of floats.")
  387. index: int = Field(
  388. ...,
  389. description="Index of the input text corresponding to "
  390. "the embedding.")
  391. class EmbeddingsResponse(BaseModel):
  392. object: str = Field("list", description="Type of the response object.")
  393. data: List[EmbeddingObject] = Field(
  394. ..., description="List of embedding objects.")
  395. model: str = Field(..., description="Name of the embedding model used.")
  396. usage: UsageInfo = Field(..., description="Information about token usage.")
  397. class Prompt(BaseModel):
  398. prompt: str
  399. # ========== KoboldAI ========== #
  400. class KoboldSamplingParams(BaseModel):
  401. n: int = Field(1, alias="n")
  402. best_of: Optional[int] = Field(None, alias="best_of")
  403. presence_penalty: float = Field(0.0, alias="presence_penalty")
  404. frequency_penalty: float = Field(0.0, alias="rep_pen")
  405. temperature: float = Field(1.0, alias="temperature")
  406. dynatemp_range: Optional[float] = 0.0
  407. dynatemp_exponent: Optional[float] = 1.0
  408. smoothing_factor: Optional[float] = 0.0
  409. smoothing_curve: Optional[float] = 1.0
  410. top_p: float = Field(1.0, alias="top_p")
  411. top_k: float = Field(-1, alias="top_k")
  412. min_p: float = Field(0.0, alias="min_p")
  413. top_a: float = Field(0.0, alias="top_a")
  414. tfs: float = Field(1.0, alias="tfs")
  415. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  416. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  417. typical_p: float = Field(1.0, alias="typical_p")
  418. use_beam_search: bool = Field(False, alias="use_beam_search")
  419. length_penalty: float = Field(1.0, alias="length_penalty")
  420. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  421. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  422. include_stop_str_in_output: Optional[bool] = False
  423. ignore_eos: bool = Field(False, alias="ignore_eos")
  424. max_tokens: int = Field(16, alias="max_length")
  425. logprobs: Optional[int] = Field(None, alias="logprobs")
  426. custom_token_bans: Optional[List[int]] = Field(None,
  427. alias="custom_token_bans")
  428. @root_validator(pre=False, skip_on_failure=True)
  429. def validate_best_of(cls, values): # pylint: disable=no-self-argument
  430. best_of = values.get("best_of")
  431. n = values.get("n")
  432. if best_of is not None and (best_of <= 0 or best_of > n):
  433. raise ValueError(
  434. "best_of must be a positive integer less than or equal to n")
  435. return values
  436. class KAIGenerationInputSchema(BaseModel):
  437. genkey: Optional[str] = None
  438. prompt: str
  439. n: Optional[int] = 1
  440. max_context_length: int
  441. max_length: int
  442. rep_pen: Optional[float] = 1.0
  443. rep_pen_range: Optional[int] = None
  444. rep_pen_slope: Optional[float] = None
  445. top_k: Optional[int] = 0
  446. top_a: Optional[float] = 0.0
  447. top_p: Optional[float] = 1.0
  448. min_p: Optional[float] = 0.0
  449. tfs: Optional[float] = 1.0
  450. eps_cutoff: Optional[float] = 0.0
  451. eta_cutoff: Optional[float] = 0.0
  452. typical: Optional[float] = 1.0
  453. temperature: Optional[float] = 1.0
  454. dynatemp_range: Optional[float] = 0.0
  455. dynatemp_exponent: Optional[float] = 1.0
  456. smoothing_factor: Optional[float] = 0.0
  457. smoothing_curve: Optional[float] = 1.0
  458. use_memory: Optional[bool] = None
  459. use_story: Optional[bool] = None
  460. use_authors_note: Optional[bool] = None
  461. use_world_info: Optional[bool] = None
  462. use_userscripts: Optional[bool] = None
  463. soft_prompt: Optional[str] = None
  464. disable_output_formatting: Optional[bool] = None
  465. frmtrmblln: Optional[bool] = None
  466. frmtrmspch: Optional[bool] = None
  467. singleline: Optional[bool] = None
  468. use_default_badwordsids: Optional[bool] = None
  469. mirostat: Optional[int] = 0
  470. mirostat_tau: Optional[float] = 0.0
  471. mirostat_eta: Optional[float] = 0.0
  472. disable_input_formatting: Optional[bool] = None
  473. frmtadsnsp: Optional[bool] = None
  474. quiet: Optional[bool] = None
  475. # pylint: disable=unexpected-keyword-arg
  476. sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
  477. sampler_seed: Optional[int] = None
  478. sampler_full_determinism: Optional[bool] = None
  479. stop_sequence: Optional[List[str]] = None
  480. include_stop_str_in_output: Optional[bool] = False
  481. @root_validator(pre=False, skip_on_failure=True)
  482. def check_context(cls, values): # pylint: disable=no-self-argument
  483. assert values.get("max_length") <= values.get(
  484. "max_context_length"
  485. ), "max_length must not be larger than max_context_length"
  486. return values