protocol.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Dict, List, Literal, Optional, Union
  5. from pydantic import (AliasChoices, BaseModel, Field, model_validator,
  6. root_validator)
  7. from aphrodite.common.utils import random_uuid
  8. from aphrodite.common.sampling_params import SamplingParams
  9. from aphrodite.common.logits_processor import BiasLogitsProcessor
  10. class ErrorResponse(BaseModel):
  11. object: str = "error"
  12. message: str
  13. type: str
  14. param: Optional[str] = None
  15. code: int
  16. class ModelPermission(BaseModel):
  17. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  18. object: str = "model_permission"
  19. created: int = Field(default_factory=lambda: int(time.time()))
  20. allow_create_engine: bool = False
  21. allow_sampling: bool = True
  22. allow_logprobs: bool = True
  23. allow_search_indices: bool = False
  24. allow_view: bool = True
  25. allow_fine_tuning: bool = False
  26. organization: str = "*"
  27. group: Optional[str] = None
  28. is_blocking: str = False
  29. class ModelCard(BaseModel):
  30. id: str
  31. object: str = "model"
  32. created: int = Field(default_factory=lambda: int(time.time()))
  33. owned_by: str = "pygmalionai"
  34. root: Optional[str] = None
  35. parent: Optional[str] = None
  36. permission: List[ModelPermission] = Field(default_factory=list)
  37. class ModelList(BaseModel):
  38. object: str = "list"
  39. data: List[ModelCard] = Field(default_factory=list)
  40. class UsageInfo(BaseModel):
  41. prompt_tokens: int = 0
  42. total_tokens: int = 0
  43. completion_tokens: Optional[int] = 0
  44. class ChatCompletionRequest(BaseModel):
  45. model: str
  46. messages: List[Dict[str, str]]
  47. temperature: Optional[float] = 0.7
  48. top_p: Optional[float] = 1.0
  49. tfs: Optional[float] = 1.0
  50. eta_cutoff: Optional[float] = 0.0
  51. epsilon_cutoff: Optional[float] = 0.0
  52. typical_p: Optional[float] = 1.0
  53. n: Optional[int] = 1
  54. max_tokens: Optional[int] = None
  55. seed: Optional[int] = None
  56. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  57. include_stop_str_in_output: Optional[bool] = False
  58. stream: Optional[bool] = False
  59. logprobs: Optional[bool] = False
  60. top_logprobs: Optional[int] = None
  61. presence_penalty: Optional[float] = 0.0
  62. frequency_penalty: Optional[float] = 0.0
  63. repetition_penalty: Optional[float] = 1.0
  64. logit_bias: Optional[Dict[str, float]] = None
  65. user: Optional[str] = None
  66. best_of: Optional[int] = None
  67. top_k: Optional[int] = -1
  68. top_a: Optional[float] = 0.0
  69. min_p: Optional[float] = 0.0
  70. mirostat_mode: Optional[int] = 0
  71. mirostat_tau: Optional[float] = 0.0
  72. mirostat_eta: Optional[float] = 0.0
  73. dynatemp_min: Optional[float] = 0.0
  74. dynatemp_max: Optional[float] = 0.0
  75. dynatemp_exponent: Optional[float] = 1.0
  76. smoothing_factor: Optional[float] = 0.0
  77. smoothing_curve: Optional[float] = 1.0
  78. ignore_eos: Optional[bool] = False
  79. use_beam_search: Optional[bool] = False
  80. prompt_logprobs: Optional[int] = None
  81. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  82. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  83. skip_special_tokens: Optional[bool] = True
  84. spaces_between_special_tokens: Optional[bool] = True
  85. add_generation_prompt: Optional[bool] = True
  86. echo: Optional[bool] = False
  87. length_penalty: Optional[float] = 1.0
  88. guided_json: Optional[Union[str, dict, BaseModel]] = None
  89. guided_regex: Optional[str] = None
  90. guided_choice: Optional[List[str]] = None
  91. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  92. if self.logprobs and not self.top_logprobs:
  93. raise ValueError("Top logprobs must be set when logprobs is.")
  94. logits_processors = []
  95. if self.logit_bias:
  96. biases = {
  97. int(tok): min(100, max(float(bias), -100))
  98. for tok, bias in self.logit_bias.items()
  99. if 0 < int(tok) < vocab_size
  100. }
  101. logits_processors.append(BiasLogitsProcessor(biases))
  102. return SamplingParams(
  103. n=self.n,
  104. max_tokens=self.max_tokens,
  105. logprobs=self.top_logprobs if self.logprobs else None,
  106. prompt_logprobs=self.top_logprobs if self.echo else None,
  107. temperature=self.temperature,
  108. top_p=self.top_p,
  109. tfs=self.tfs,
  110. eta_cutoff=self.eta_cutoff,
  111. epsilon_cutoff=self.epsilon_cutoff,
  112. typical_p=self.typical_p,
  113. presence_penalty=self.presence_penalty,
  114. frequency_penalty=self.frequency_penalty,
  115. repetition_penalty=self.repetition_penalty,
  116. top_k=self.top_k,
  117. top_a=self.top_a,
  118. min_p=self.min_p,
  119. mirostat_mode=self.mirostat_mode,
  120. mirostat_tau=self.mirostat_tau,
  121. mirostat_eta=self.mirostat_eta,
  122. dynatemp_min=self.dynatemp_min,
  123. dynatemp_max=self.dynatemp_max,
  124. dynatemp_exponent=self.dynatemp_exponent,
  125. smoothing_factor=self.smoothing_factor,
  126. smoothing_curve=self.smoothing_curve,
  127. ignore_eos=self.ignore_eos,
  128. use_beam_search=self.use_beam_search,
  129. stop_token_ids=self.stop_token_ids,
  130. custom_token_bans=self.custom_token_bans,
  131. skip_special_tokens=self.skip_special_tokens,
  132. spaces_between_special_tokens=self.spaces_between_special_tokens,
  133. stop=self.stop,
  134. best_of=self.best_of,
  135. include_stop_str_in_output=self.include_stop_str_in_output,
  136. seed=self.seed,
  137. logits_processors=logits_processors,
  138. )
  139. @model_validator(mode="before")
  140. @classmethod
  141. def check_guided_decoding_count(cls, data):
  142. guide_count = sum([
  143. "guided_json" in data and data["guided_json"] is not None,
  144. "guided_regex" in data and data["guided_regex"] is not None,
  145. "guided_choice" in data and data["guided_choice"] is not None
  146. ])
  147. if guide_count > 1:
  148. raise ValueError(
  149. "You can only use one kind of guided decoding "
  150. "('guided_json', 'guided_regex' or 'guided_choice').")
  151. return data
  152. class CompletionRequest(BaseModel):
  153. model: str
  154. # a string, array of strings, array of tokens, or array of token arrays
  155. prompt: Union[List[int], List[List[int]], str, List[str]]
  156. suffix: Optional[str] = None
  157. max_tokens: Optional[int] = 16
  158. temperature: Optional[float] = 1.0
  159. top_p: Optional[float] = 1.0
  160. tfs: Optional[float] = 1.0
  161. eta_cutoff: Optional[float] = 0.0
  162. epsilon_cutoff: Optional[float] = 0.0
  163. typical_p: Optional[float] = 1.0
  164. n: Optional[int] = 1
  165. stream: Optional[bool] = False
  166. logprobs: Optional[int] = None
  167. echo: Optional[bool] = False
  168. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  169. seed: Optional[int] = None
  170. include_stop_str_in_output: Optional[bool] = False
  171. presence_penalty: Optional[float] = 0.0
  172. frequency_penalty: Optional[float] = 0.0
  173. repetition_penalty: Optional[float] = 1.0
  174. best_of: Optional[int] = None
  175. logit_bias: Optional[Dict[str, float]] = None
  176. user: Optional[str] = None
  177. top_k: Optional[int] = -1
  178. top_a: Optional[float] = 0.0
  179. min_p: Optional[float] = 0.0
  180. mirostat_mode: Optional[int] = 0
  181. mirostat_tau: Optional[float] = 0.0
  182. mirostat_eta: Optional[float] = 0.0
  183. dynatemp_min: Optional[float] = Field(0.0,
  184. validation_alias=AliasChoices(
  185. "dynatemp_min", "dynatemp_low"),
  186. description="Aliases: dynatemp_low")
  187. dynatemp_max: Optional[float] = Field(0.0,
  188. validation_alias=AliasChoices(
  189. "dynatemp_max", "dynatemp_high"),
  190. description="Aliases: dynatemp_high")
  191. dynatemp_exponent: Optional[float] = 1.0
  192. smoothing_factor: Optional[float] = 0.0
  193. smoothing_curve: Optional[float] = 1.0
  194. ignore_eos: Optional[bool] = False
  195. use_beam_search: Optional[bool] = False
  196. logprobs: Optional[int] = None
  197. prompt_logprobs: Optional[int] = None
  198. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  199. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  200. skip_special_tokens: Optional[bool] = True
  201. spaces_between_special_tokens: Optional[bool] = True
  202. grammar: Optional[str] = None
  203. length_penalty: Optional[float] = 1.0
  204. guided_json: Optional[Union[str, dict, BaseModel]] = None
  205. guided_regex: Optional[str] = None
  206. guided_choice: Optional[List[str]] = None
  207. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  208. echo_without_generation = self.echo and self.max_tokens == 0
  209. logits_processors = []
  210. if self.logit_bias:
  211. biases = {
  212. int(tok): min(100, max(float(bias), -100))
  213. for tok, bias in self.logit_bias.items()
  214. if 0 < int(tok) < vocab_size
  215. }
  216. logits_processors.append(BiasLogitsProcessor(biases))
  217. return SamplingParams(
  218. n=self.n,
  219. max_tokens=self.max_tokens if not echo_without_generation else 1,
  220. temperature=self.temperature,
  221. top_p=self.top_p,
  222. tfs=self.tfs,
  223. eta_cutoff=self.eta_cutoff,
  224. epsilon_cutoff=self.epsilon_cutoff,
  225. typical_p=self.typical_p,
  226. presence_penalty=self.presence_penalty,
  227. frequency_penalty=self.frequency_penalty,
  228. repetition_penalty=self.repetition_penalty,
  229. top_k=self.top_k,
  230. top_a=self.top_a,
  231. min_p=self.min_p,
  232. mirostat_mode=self.mirostat_mode,
  233. mirostat_tau=self.mirostat_tau,
  234. mirostat_eta=self.mirostat_eta,
  235. dynatemp_min=self.dynatemp_min,
  236. dynatemp_max=self.dynatemp_max,
  237. dynatemp_exponent=self.dynatemp_exponent,
  238. smoothing_factor=self.smoothing_factor,
  239. smoothing_curve=self.smoothing_curve,
  240. ignore_eos=self.ignore_eos,
  241. use_beam_search=self.use_beam_search,
  242. logprobs=self.logprobs,
  243. prompt_logprobs=self.prompt_logprobs if self.echo else None,
  244. stop_token_ids=self.stop_token_ids,
  245. custom_token_bans=self.custom_token_bans,
  246. skip_special_tokens=self.skip_special_tokens,
  247. spaces_between_special_tokens=self.spaces_between_special_tokens,
  248. stop=self.stop,
  249. best_of=self.best_of,
  250. include_stop_str_in_output=self.include_stop_str_in_output,
  251. seed=self.seed,
  252. logits_processors=logits_processors,
  253. )
  254. @model_validator(mode="before")
  255. @classmethod
  256. def check_guided_decoding_count(cls, data):
  257. guide_count = sum([
  258. "guided_json" in data and data["guided_json"] is not None,
  259. "guided_regex" in data and data["guided_regex"] is not None,
  260. "guided_choice" in data and data["guided_choice"] is not None
  261. ])
  262. if guide_count > 1:
  263. raise ValueError(
  264. "You can only use one kind of guided decoding "
  265. "('guided_json', 'guided_regex' or 'guided_choice').")
  266. return data
  267. class LogProbs(BaseModel):
  268. text_offset: List[int] = Field(default_factory=list)
  269. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  270. tokens: List[str] = Field(default_factory=list)
  271. top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
  272. class CompletionResponseChoice(BaseModel):
  273. index: int
  274. text: str
  275. logprobs: Optional[LogProbs] = None
  276. finish_reason: Optional[Literal["stop", "length"]] = None
  277. class CompletionResponse(BaseModel):
  278. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  279. object: str = "text_completion"
  280. created: int = Field(default_factory=lambda: int(time.time()))
  281. model: str
  282. choices: List[CompletionResponseChoice]
  283. usage: UsageInfo
  284. class CompletionResponseStreamChoice(BaseModel):
  285. index: int
  286. text: str
  287. logprobs: Optional[LogProbs] = None
  288. finish_reason: Optional[Literal["stop", "length"]] = None
  289. class CompletionStreamResponse(BaseModel):
  290. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  291. object: str = "text_completion"
  292. created: int = Field(default_factory=lambda: int(time.time()))
  293. model: str
  294. choices: List[CompletionResponseStreamChoice]
  295. usage: Optional[UsageInfo] = Field(default=None)
  296. class ChatMessage(BaseModel):
  297. role: str
  298. content: str
  299. class ChatCompletionResponseChoice(BaseModel):
  300. index: int
  301. message: ChatMessage
  302. logprobs: Optional[LogProbs] = None
  303. finish_reason: Optional[Literal["stop", "length"]] = None
  304. class ChatCompletionResponse(BaseModel):
  305. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  306. object: str = "chat.completion"
  307. created: int = Field(default_factory=lambda: int(time.time()))
  308. model: str
  309. choices: List[ChatCompletionResponseChoice]
  310. usage: UsageInfo
  311. class DeltaMessage(BaseModel):
  312. role: Optional[str] = None
  313. content: Optional[str] = None
  314. class ChatCompletionResponseStreamChoice(BaseModel):
  315. index: int
  316. delta: DeltaMessage
  317. logprobs: Optional[LogProbs] = None
  318. finish_reason: Optional[Literal["stop", "length"]] = None
  319. class ChatCompletionStreamResponse(BaseModel):
  320. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  321. object: str = "chat.completion.chunk"
  322. created: int = Field(default_factory=lambda: int(time.time()))
  323. model: str
  324. choices: List[ChatCompletionResponseStreamChoice]
  325. usage: Optional[UsageInfo] = Field(default=None)
  326. logprobs: Optional[LogProbs] = None
  327. class Prompt(BaseModel):
  328. prompt: str
  329. # ========== KoboldAI ========== #
  330. class KoboldSamplingParams(BaseModel):
  331. n: int = Field(1, alias="n")
  332. best_of: Optional[int] = Field(None, alias="best_of")
  333. presence_penalty: float = Field(0.0, alias="presence_penalty")
  334. frequency_penalty: float = Field(0.0, alias="rep_pen")
  335. temperature: float = Field(1.0, alias="temperature")
  336. dynatemp_range: Optional[float] = 0.0
  337. dynatemp_exponent: Optional[float] = 1.0
  338. smoothing_factor: Optional[float] = 0.0
  339. smoothing_curve: Optional[float] = 1.0
  340. top_p: float = Field(1.0, alias="top_p")
  341. top_k: float = Field(-1, alias="top_k")
  342. min_p: float = Field(0.0, alias="min_p")
  343. top_a: float = Field(0.0, alias="top_a")
  344. tfs: float = Field(1.0, alias="tfs")
  345. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  346. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  347. typical_p: float = Field(1.0, alias="typical_p")
  348. use_beam_search: bool = Field(False, alias="use_beam_search")
  349. length_penalty: float = Field(1.0, alias="length_penalty")
  350. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  351. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  352. include_stop_str_in_output: Optional[bool] = False
  353. ignore_eos: bool = Field(False, alias="ignore_eos")
  354. max_tokens: int = Field(16, alias="max_length")
  355. logprobs: Optional[int] = Field(None, alias="logprobs")
  356. custom_token_bans: Optional[List[int]] = Field(None,
  357. alias="custom_token_bans")
  358. @root_validator(pre=False, skip_on_failure=True)
  359. def validate_best_of(cls, values): # pylint: disable=no-self-argument
  360. best_of = values.get("best_of")
  361. n = values.get("n")
  362. if best_of is not None and (best_of <= 0 or best_of > n):
  363. raise ValueError(
  364. "best_of must be a positive integer less than or equal to n")
  365. return values
  366. class KAIGenerationInputSchema(BaseModel):
  367. genkey: Optional[str] = None
  368. prompt: str
  369. n: Optional[int] = 1
  370. max_context_length: int
  371. max_length: int
  372. rep_pen: Optional[float] = 1.0
  373. rep_pen_range: Optional[int] = None
  374. rep_pen_slope: Optional[float] = None
  375. top_k: Optional[int] = 0
  376. top_a: Optional[float] = 0.0
  377. top_p: Optional[float] = 1.0
  378. min_p: Optional[float] = 0.0
  379. tfs: Optional[float] = 1.0
  380. eps_cutoff: Optional[float] = 0.0
  381. eta_cutoff: Optional[float] = 0.0
  382. typical: Optional[float] = 1.0
  383. temperature: Optional[float] = 1.0
  384. dynatemp_range: Optional[float] = 0.0
  385. dynatemp_exponent: Optional[float] = 1.0
  386. smoothing_factor: Optional[float] = 0.0
  387. smoothing_curve: Optional[float] = 1.0
  388. use_memory: Optional[bool] = None
  389. use_story: Optional[bool] = None
  390. use_authors_note: Optional[bool] = None
  391. use_world_info: Optional[bool] = None
  392. use_userscripts: Optional[bool] = None
  393. soft_prompt: Optional[str] = None
  394. disable_output_formatting: Optional[bool] = None
  395. frmtrmblln: Optional[bool] = None
  396. frmtrmspch: Optional[bool] = None
  397. singleline: Optional[bool] = None
  398. use_default_badwordsids: Optional[bool] = None
  399. mirostat: Optional[int] = 0
  400. mirostat_tau: Optional[float] = 0.0
  401. mirostat_eta: Optional[float] = 0.0
  402. disable_input_formatting: Optional[bool] = None
  403. frmtadsnsp: Optional[bool] = None
  404. quiet: Optional[bool] = None
  405. # pylint: disable=unexpected-keyword-arg
  406. sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
  407. sampler_seed: Optional[int] = None
  408. sampler_full_determinism: Optional[bool] = None
  409. stop_sequence: Optional[List[str]] = None
  410. include_stop_str_in_output: Optional[bool] = False
  411. @root_validator(pre=False, skip_on_failure=True)
  412. def check_context(cls, values): # pylint: disable=no-self-argument
  413. assert values.get("max_length") <= values.get(
  414. "max_context_length"
  415. ), "max_length must not be larger than max_context_length"
  416. return values