protocol.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Dict, List, Literal, Optional, Union
  5. from pydantic import (AliasChoices, BaseModel, Field, conint, model_validator,
  6. root_validator)
  7. from aphrodite.common.sampling_params import SamplingParams
  8. from aphrodite.common.utils import random_uuid
  9. from aphrodite.common.logits_processor import BiasLogitsProcessor
  10. class ErrorResponse(BaseModel):
  11. object: str = "error"
  12. message: str
  13. type: str
  14. param: Optional[str] = None
  15. code: int
  16. class ModelPermission(BaseModel):
  17. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  18. object: str = "model_permission"
  19. created: int = Field(default_factory=lambda: int(time.time()))
  20. allow_create_engine: bool = False
  21. allow_sampling: bool = True
  22. allow_logprobs: bool = True
  23. allow_search_indices: bool = False
  24. allow_view: bool = True
  25. allow_fine_tuning: bool = False
  26. organization: str = "*"
  27. group: Optional[str] = None
  28. is_blocking: bool = False
  29. class ModelCard(BaseModel):
  30. id: str
  31. object: str = "model"
  32. created: int = Field(default_factory=lambda: int(time.time()))
  33. owned_by: str = "pygmalionai"
  34. root: Optional[str] = None
  35. parent: Optional[str] = None
  36. permission: List[ModelPermission] = Field(default_factory=list)
  37. class ModelList(BaseModel):
  38. object: str = "list"
  39. data: List[ModelCard] = Field(default_factory=list)
  40. class UsageInfo(BaseModel):
  41. prompt_tokens: int = 0
  42. total_tokens: int = 0
  43. completion_tokens: Optional[int] = 0
  44. class ResponseFormat(BaseModel):
  45. # type must be "json_object" or "text"
  46. type: str = Literal["text", "json_object"]
  47. class ChatCompletionRequest(BaseModel):
  48. model: str
  49. # support list type in messages.content
  50. messages: List[Dict[str, Union[str, List[Dict[str, str]]]]]
  51. temperature: Optional[float] = 0.7
  52. top_p: Optional[float] = 1.0
  53. tfs: Optional[float] = 1.0
  54. eta_cutoff: Optional[float] = 0.0
  55. epsilon_cutoff: Optional[float] = 0.0
  56. typical_p: Optional[float] = 1.0
  57. n: Optional[int] = 1
  58. max_tokens: Optional[int] = None
  59. min_tokens: Optional[int] = 0
  60. seed: Optional[int] = None
  61. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  62. include_stop_str_in_output: Optional[bool] = False
  63. stream: Optional[bool] = False
  64. logprobs: Optional[bool] = False
  65. top_logprobs: Optional[int] = None
  66. presence_penalty: Optional[float] = 0.0
  67. frequency_penalty: Optional[float] = 0.0
  68. repetition_penalty: Optional[float] = 1.0
  69. logit_bias: Optional[Dict[str, float]] = None
  70. user: Optional[str] = None
  71. best_of: Optional[int] = None
  72. top_k: Optional[int] = -1
  73. top_a: Optional[float] = 0.0
  74. min_p: Optional[float] = 0.0
  75. mirostat_mode: Optional[int] = 0
  76. mirostat_tau: Optional[float] = 0.0
  77. mirostat_eta: Optional[float] = 0.0
  78. dynatemp_min: Optional[float] = 0.0
  79. dynatemp_max: Optional[float] = 0.0
  80. dynatemp_exponent: Optional[float] = 1.0
  81. smoothing_factor: Optional[float] = 0.0
  82. smoothing_curve: Optional[float] = 1.0
  83. ignore_eos: Optional[bool] = False
  84. use_beam_search: Optional[bool] = False
  85. prompt_logprobs: Optional[int] = None
  86. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  87. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  88. skip_special_tokens: Optional[bool] = True
  89. spaces_between_special_tokens: Optional[bool] = True
  90. add_generation_prompt: Optional[bool] = True
  91. echo: Optional[bool] = False
  92. length_penalty: Optional[float] = 1.0
  93. guided_json: Optional[Union[str, dict, BaseModel]] = None
  94. guided_regex: Optional[str] = None
  95. guided_choice: Optional[List[str]] = None
  96. guided_grammar: Optional[str] = None
  97. response_format: Optional[ResponseFormat] = None
  98. guided_decoding_backend: Optional[str] = Field(
  99. default="outlines",
  100. description=(
  101. "If specified, will override the default guided decoding backend "
  102. "of the server for this specific request. If set, must be either "
  103. "'outlines' / 'lm-format-enforcer'"))
  104. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  105. if self.logprobs and not self.top_logprobs:
  106. raise ValueError("Top logprobs must be set when logprobs is.")
  107. if self.top_k == 0:
  108. self.top_k = -1
  109. logits_processors = []
  110. if self.logit_bias:
  111. biases = {
  112. int(tok): max(-100, min(float(bias), 100))
  113. for tok, bias in self.logit_bias.items()
  114. if 0 < int(tok) < vocab_size
  115. }
  116. logits_processors.append(BiasLogitsProcessor(biases))
  117. return SamplingParams(
  118. n=self.n,
  119. max_tokens=self.max_tokens,
  120. min_tokens=self.min_tokens,
  121. logprobs=self.top_logprobs if self.logprobs else None,
  122. prompt_logprobs=self.top_logprobs if self.echo else None,
  123. temperature=self.temperature,
  124. top_p=self.top_p,
  125. tfs=self.tfs,
  126. eta_cutoff=self.eta_cutoff,
  127. epsilon_cutoff=self.epsilon_cutoff,
  128. typical_p=self.typical_p,
  129. presence_penalty=self.presence_penalty,
  130. frequency_penalty=self.frequency_penalty,
  131. repetition_penalty=self.repetition_penalty,
  132. top_k=self.top_k,
  133. top_a=self.top_a,
  134. min_p=self.min_p,
  135. mirostat_mode=self.mirostat_mode,
  136. mirostat_tau=self.mirostat_tau,
  137. mirostat_eta=self.mirostat_eta,
  138. dynatemp_min=self.dynatemp_min,
  139. dynatemp_max=self.dynatemp_max,
  140. dynatemp_exponent=self.dynatemp_exponent,
  141. smoothing_factor=self.smoothing_factor,
  142. smoothing_curve=self.smoothing_curve,
  143. ignore_eos=self.ignore_eos,
  144. use_beam_search=self.use_beam_search,
  145. stop_token_ids=self.stop_token_ids,
  146. custom_token_bans=self.custom_token_bans,
  147. skip_special_tokens=self.skip_special_tokens,
  148. spaces_between_special_tokens=self.spaces_between_special_tokens,
  149. stop=self.stop,
  150. best_of=self.best_of,
  151. include_stop_str_in_output=self.include_stop_str_in_output,
  152. seed=self.seed,
  153. logits_processors=logits_processors,
  154. )
  155. @model_validator(mode="before")
  156. @classmethod
  157. def check_guided_decoding_count(cls, data):
  158. guide_count = sum([
  159. "guided_json" in data and data["guided_json"] is not None,
  160. "guided_regex" in data and data["guided_regex"] is not None,
  161. "guided_choice" in data and data["guided_choice"] is not None
  162. ])
  163. if guide_count > 1:
  164. raise ValueError(
  165. "You can only use one kind of guided decoding "
  166. "('guided_json', 'guided_regex' or 'guided_choice').")
  167. return data
  168. class CompletionRequest(BaseModel):
  169. model: str
  170. # a string, array of strings, array of tokens, or array of token arrays
  171. prompt: Union[List[int], List[List[int]], str, List[str]]
  172. suffix: Optional[str] = None
  173. max_tokens: Optional[int] = 16
  174. min_tokens: Optional[int] = 0
  175. temperature: Optional[float] = 1.0
  176. top_p: Optional[float] = 1.0
  177. tfs: Optional[float] = 1.0
  178. eta_cutoff: Optional[float] = 0.0
  179. epsilon_cutoff: Optional[float] = 0.0
  180. typical_p: Optional[float] = 1.0
  181. n: Optional[int] = 1
  182. stream: Optional[bool] = False
  183. logprobs: Optional[int] = None
  184. echo: Optional[bool] = False
  185. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  186. seed: Optional[int] = None
  187. include_stop_str_in_output: Optional[bool] = False
  188. presence_penalty: Optional[float] = 0.0
  189. frequency_penalty: Optional[float] = 0.0
  190. repetition_penalty: Optional[float] = 1.0
  191. best_of: Optional[int] = None
  192. logit_bias: Optional[Dict[str, float]] = None
  193. user: Optional[str] = None
  194. top_k: Optional[int] = -1
  195. top_a: Optional[float] = 0.0
  196. min_p: Optional[float] = 0.0
  197. mirostat_mode: Optional[int] = 0
  198. mirostat_tau: Optional[float] = 0.0
  199. mirostat_eta: Optional[float] = 0.0
  200. dynatemp_min: Optional[float] = Field(0.0,
  201. validation_alias=AliasChoices(
  202. "dynatemp_min", "dynatemp_low"),
  203. description="Aliases: dynatemp_low")
  204. dynatemp_max: Optional[float] = Field(0.0,
  205. validation_alias=AliasChoices(
  206. "dynatemp_max", "dynatemp_high"),
  207. description="Aliases: dynatemp_high")
  208. dynatemp_exponent: Optional[float] = 1.0
  209. smoothing_factor: Optional[float] = 0.0
  210. smoothing_curve: Optional[float] = 1.0
  211. ignore_eos: Optional[bool] = False
  212. use_beam_search: Optional[bool] = False
  213. logprobs: Optional[int] = None
  214. prompt_logprobs: Optional[int] = None
  215. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  216. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  217. skip_special_tokens: Optional[bool] = True
  218. spaces_between_special_tokens: Optional[bool] = True
  219. truncate_prompt_tokens: Optional[conint(ge=1)] = None
  220. grammar: Optional[str] = None
  221. length_penalty: Optional[float] = 1.0
  222. guided_json: Optional[Union[str, dict, BaseModel]] = None
  223. guided_regex: Optional[str] = None
  224. guided_choice: Optional[List[str]] = None
  225. guided_grammar: Optional[str] = None
  226. response_format: Optional[ResponseFormat] = None
  227. guided_decoding_backend: Optional[str] = Field(
  228. default="outlines",
  229. description=(
  230. "If specified, will override the default guided decoding backend "
  231. "of the server for this specific request. If set, must be one of "
  232. "'outlines' / 'lm-format-enforcer'"))
  233. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  234. echo_without_generation = self.echo and self.max_tokens == 0
  235. if self.top_k == 0:
  236. self.top_k = -1
  237. logits_processors = []
  238. if self.logit_bias:
  239. biases = {
  240. int(tok): max(-100, min(float(bias), 100))
  241. for tok, bias in self.logit_bias.items()
  242. if 0 < int(tok) < vocab_size
  243. }
  244. logits_processors.append(BiasLogitsProcessor(biases))
  245. return SamplingParams(
  246. n=self.n,
  247. max_tokens=self.max_tokens if not echo_without_generation else 1,
  248. min_tokens=self.min_tokens,
  249. temperature=self.temperature,
  250. top_p=self.top_p,
  251. tfs=self.tfs,
  252. eta_cutoff=self.eta_cutoff,
  253. epsilon_cutoff=self.epsilon_cutoff,
  254. typical_p=self.typical_p,
  255. presence_penalty=self.presence_penalty,
  256. frequency_penalty=self.frequency_penalty,
  257. repetition_penalty=self.repetition_penalty,
  258. top_k=self.top_k,
  259. top_a=self.top_a,
  260. min_p=self.min_p,
  261. mirostat_mode=self.mirostat_mode,
  262. mirostat_tau=self.mirostat_tau,
  263. mirostat_eta=self.mirostat_eta,
  264. dynatemp_min=self.dynatemp_min,
  265. dynatemp_max=self.dynatemp_max,
  266. dynatemp_exponent=self.dynatemp_exponent,
  267. smoothing_factor=self.smoothing_factor,
  268. smoothing_curve=self.smoothing_curve,
  269. ignore_eos=self.ignore_eos,
  270. use_beam_search=self.use_beam_search,
  271. logprobs=self.logprobs,
  272. prompt_logprobs=self.prompt_logprobs if self.echo else None,
  273. stop_token_ids=self.stop_token_ids,
  274. custom_token_bans=self.custom_token_bans,
  275. skip_special_tokens=self.skip_special_tokens,
  276. spaces_between_special_tokens=self.spaces_between_special_tokens,
  277. stop=self.stop,
  278. best_of=self.best_of,
  279. include_stop_str_in_output=self.include_stop_str_in_output,
  280. seed=self.seed,
  281. logits_processors=logits_processors,
  282. truncate_prompt_tokens=self.truncate_prompt_tokens,
  283. )
  284. @model_validator(mode="before")
  285. @classmethod
  286. def check_guided_decoding_count(cls, data):
  287. guide_count = sum([
  288. "guided_json" in data and data["guided_json"] is not None,
  289. "guided_regex" in data and data["guided_regex"] is not None,
  290. "guided_choice" in data and data["guided_choice"] is not None
  291. ])
  292. if guide_count > 1:
  293. raise ValueError(
  294. "You can only use one kind of guided decoding "
  295. "('guided_json', 'guided_regex' or 'guided_choice').")
  296. return data
  297. class LogProbs(BaseModel):
  298. text_offset: List[int] = Field(default_factory=list)
  299. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  300. tokens: List[str] = Field(default_factory=list)
  301. top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
  302. class CompletionResponseChoice(BaseModel):
  303. index: int
  304. text: str
  305. logprobs: Optional[LogProbs] = None
  306. finish_reason: Optional[Literal["stop", "length"]] = None
  307. stop_reason: Union[None, int, str] = Field(
  308. default=None,
  309. description=(
  310. "The stop string or token id that caused the completion "
  311. "to stop, None if the completion finished for some other reason "
  312. "including encountering the EOS token"),
  313. )
  314. class CompletionResponse(BaseModel):
  315. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  316. object: str = "text_completion"
  317. created: int = Field(default_factory=lambda: int(time.time()))
  318. model: str
  319. choices: List[CompletionResponseChoice]
  320. usage: UsageInfo
  321. class CompletionResponseStreamChoice(BaseModel):
  322. index: int
  323. text: str
  324. logprobs: Optional[LogProbs] = None
  325. finish_reason: Optional[Literal["stop", "length"]] = None
  326. stop_reason: Union[None, int, str] = Field(
  327. default=None,
  328. description=(
  329. "The stop string or token id that caused the completion "
  330. "to stop, None if the completion finished for some other reason "
  331. "including encountering the EOS token"),
  332. )
  333. class CompletionStreamResponse(BaseModel):
  334. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  335. object: str = "text_completion"
  336. created: int = Field(default_factory=lambda: int(time.time()))
  337. model: str
  338. choices: List[CompletionResponseStreamChoice]
  339. usage: Optional[UsageInfo] = Field(default=None)
  340. class ChatMessage(BaseModel):
  341. role: str
  342. content: str
  343. class ChatCompletionResponseChoice(BaseModel):
  344. index: int
  345. message: ChatMessage
  346. logprobs: Optional[LogProbs] = None
  347. finish_reason: Optional[Literal["stop", "length"]] = None
  348. stop_reason: Union[None, int, str] = None
  349. class ChatCompletionResponse(BaseModel):
  350. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  351. object: str = "chat.completion"
  352. created: int = Field(default_factory=lambda: int(time.time()))
  353. model: str
  354. choices: List[ChatCompletionResponseChoice]
  355. usage: UsageInfo
  356. class DeltaMessage(BaseModel):
  357. role: Optional[str] = None
  358. content: Optional[str] = None
  359. class ChatCompletionResponseStreamChoice(BaseModel):
  360. index: int
  361. delta: DeltaMessage
  362. logprobs: Optional[LogProbs] = None
  363. finish_reason: Optional[Literal["stop", "length"]] = None
  364. stop_reason: Union[None, int, str] = None
  365. class ChatCompletionStreamResponse(BaseModel):
  366. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  367. object: str = "chat.completion.chunk"
  368. created: int = Field(default_factory=lambda: int(time.time()))
  369. model: str
  370. choices: List[ChatCompletionResponseStreamChoice]
  371. usage: Optional[UsageInfo] = Field(default=None)
  372. logprobs: Optional[LogProbs] = None
  373. class EmbeddingsRequest(BaseModel):
  374. input: List[str] = Field(
  375. ..., description="List of input texts to generate embeddings for.")
  376. encoding_format: str = Field(
  377. "float",
  378. description="Encoding format for the embeddings. "
  379. "Can be 'float' or 'base64'.")
  380. model: Optional[str] = Field(
  381. None,
  382. description="Name of the embedding model to use. "
  383. "If not provided, the default model will be used.")
  384. class EmbeddingObject(BaseModel):
  385. object: str = Field("embedding", description="Type of the object.")
  386. embedding: List[float] = Field(
  387. ..., description="Embedding values as a list of floats.")
  388. index: int = Field(
  389. ...,
  390. description="Index of the input text corresponding to "
  391. "the embedding.")
  392. class EmbeddingsResponse(BaseModel):
  393. object: str = Field("list", description="Type of the response object.")
  394. data: List[EmbeddingObject] = Field(
  395. ..., description="List of embedding objects.")
  396. model: str = Field(..., description="Name of the embedding model used.")
  397. usage: UsageInfo = Field(..., description="Information about token usage.")
  398. class Prompt(BaseModel):
  399. prompt: str
  400. # ========== KoboldAI ========== #
  401. class KoboldSamplingParams(BaseModel):
  402. n: int = Field(1, alias="n")
  403. best_of: Optional[int] = Field(None, alias="best_of")
  404. presence_penalty: float = Field(0.0, alias="presence_penalty")
  405. frequency_penalty: float = Field(0.0, alias="rep_pen")
  406. temperature: float = Field(1.0, alias="temperature")
  407. dynatemp_range: Optional[float] = 0.0
  408. dynatemp_exponent: Optional[float] = 1.0
  409. smoothing_factor: Optional[float] = 0.0
  410. smoothing_curve: Optional[float] = 1.0
  411. top_p: float = Field(1.0, alias="top_p")
  412. top_k: float = Field(-1, alias="top_k")
  413. min_p: float = Field(0.0, alias="min_p")
  414. top_a: float = Field(0.0, alias="top_a")
  415. tfs: float = Field(1.0, alias="tfs")
  416. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  417. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  418. typical_p: float = Field(1.0, alias="typical_p")
  419. use_beam_search: bool = Field(False, alias="use_beam_search")
  420. length_penalty: float = Field(1.0, alias="length_penalty")
  421. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  422. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  423. include_stop_str_in_output: Optional[bool] = False
  424. ignore_eos: bool = Field(False, alias="ignore_eos")
  425. max_tokens: int = Field(16, alias="max_length")
  426. logprobs: Optional[int] = Field(None, alias="logprobs")
  427. custom_token_bans: Optional[List[int]] = Field(None,
  428. alias="custom_token_bans")
  429. @root_validator(pre=False, skip_on_failure=True)
  430. def validate_best_of(cls, values): # pylint: disable=no-self-argument
  431. best_of = values.get("best_of")
  432. n = values.get("n")
  433. if best_of is not None and (best_of <= 0 or best_of > n):
  434. raise ValueError(
  435. "best_of must be a positive integer less than or equal to n")
  436. return values
  437. class KAIGenerationInputSchema(BaseModel):
  438. genkey: Optional[str] = None
  439. prompt: str
  440. n: Optional[int] = 1
  441. max_context_length: int
  442. max_length: int
  443. rep_pen: Optional[float] = 1.0
  444. rep_pen_range: Optional[int] = None
  445. rep_pen_slope: Optional[float] = None
  446. top_k: Optional[int] = 0
  447. top_a: Optional[float] = 0.0
  448. top_p: Optional[float] = 1.0
  449. min_p: Optional[float] = 0.0
  450. tfs: Optional[float] = 1.0
  451. eps_cutoff: Optional[float] = 0.0
  452. eta_cutoff: Optional[float] = 0.0
  453. typical: Optional[float] = 1.0
  454. temperature: Optional[float] = 1.0
  455. dynatemp_range: Optional[float] = 0.0
  456. dynatemp_exponent: Optional[float] = 1.0
  457. smoothing_factor: Optional[float] = 0.0
  458. smoothing_curve: Optional[float] = 1.0
  459. use_memory: Optional[bool] = None
  460. use_story: Optional[bool] = None
  461. use_authors_note: Optional[bool] = None
  462. use_world_info: Optional[bool] = None
  463. use_userscripts: Optional[bool] = None
  464. soft_prompt: Optional[str] = None
  465. disable_output_formatting: Optional[bool] = None
  466. frmtrmblln: Optional[bool] = None
  467. frmtrmspch: Optional[bool] = None
  468. singleline: Optional[bool] = None
  469. use_default_badwordsids: Optional[bool] = None
  470. mirostat: Optional[int] = 0
  471. mirostat_tau: Optional[float] = 0.0
  472. mirostat_eta: Optional[float] = 0.0
  473. disable_input_formatting: Optional[bool] = None
  474. frmtadsnsp: Optional[bool] = None
  475. quiet: Optional[bool] = None
  476. # pylint: disable=unexpected-keyword-arg
  477. sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
  478. sampler_seed: Optional[int] = None
  479. sampler_full_determinism: Optional[bool] = None
  480. stop_sequence: Optional[List[str]] = None
  481. include_stop_str_in_output: Optional[bool] = False
  482. @root_validator(pre=False, skip_on_failure=True)
  483. def check_context(cls, values): # pylint: disable=no-self-argument
  484. assert values.get("max_length") <= values.get(
  485. "max_context_length"
  486. ), "max_length must not be larger than max_context_length"
  487. return values