protocol.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Dict, List, Literal, Optional, Union
  5. import torch
  6. from pydantic import (AliasChoices, BaseModel, Field, conint, model_validator,
  7. root_validator)
  8. from aphrodite.common.sampling_params import SamplingParams
  9. from aphrodite.common.utils import random_uuid
  10. class ErrorResponse(BaseModel):
  11. object: str = "error"
  12. message: str
  13. type: str
  14. param: Optional[str] = None
  15. code: int
  16. class ModelPermission(BaseModel):
  17. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  18. object: str = "model_permission"
  19. created: int = Field(default_factory=lambda: int(time.time()))
  20. allow_create_engine: bool = False
  21. allow_sampling: bool = True
  22. allow_logprobs: bool = True
  23. allow_search_indices: bool = False
  24. allow_view: bool = True
  25. allow_fine_tuning: bool = False
  26. organization: str = "*"
  27. group: Optional[str] = None
  28. is_blocking: bool = False
  29. class ModelCard(BaseModel):
  30. id: str
  31. object: str = "model"
  32. created: int = Field(default_factory=lambda: int(time.time()))
  33. owned_by: str = "pygmalionai"
  34. root: Optional[str] = None
  35. parent: Optional[str] = None
  36. permission: List[ModelPermission] = Field(default_factory=list)
  37. class ModelList(BaseModel):
  38. object: str = "list"
  39. data: List[ModelCard] = Field(default_factory=list)
  40. class UsageInfo(BaseModel):
  41. prompt_tokens: int = 0
  42. total_tokens: int = 0
  43. completion_tokens: Optional[int] = 0
  44. class ResponseFormat(BaseModel):
  45. # type must be "json_object" or "text"
  46. type: str = Literal["text", "json_object"]
  47. class ChatCompletionRequest(BaseModel):
  48. model: str
  49. messages: List[Dict[str, str]]
  50. temperature: Optional[float] = 0.7
  51. top_p: Optional[float] = 1.0
  52. tfs: Optional[float] = 1.0
  53. eta_cutoff: Optional[float] = 0.0
  54. epsilon_cutoff: Optional[float] = 0.0
  55. typical_p: Optional[float] = 1.0
  56. n: Optional[int] = 1
  57. max_tokens: Optional[int] = None
  58. min_tokens: Optional[int] = 0
  59. seed: Optional[int] = None
  60. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  61. include_stop_str_in_output: Optional[bool] = False
  62. stream: Optional[bool] = False
  63. logprobs: Optional[bool] = False
  64. top_logprobs: Optional[int] = None
  65. presence_penalty: Optional[float] = 0.0
  66. frequency_penalty: Optional[float] = 0.0
  67. repetition_penalty: Optional[float] = 1.0
  68. logit_bias: Optional[Dict[str, float]] = None
  69. user: Optional[str] = None
  70. best_of: Optional[int] = None
  71. top_k: Optional[int] = -1
  72. top_a: Optional[float] = 0.0
  73. min_p: Optional[float] = 0.0
  74. mirostat_mode: Optional[int] = 0
  75. mirostat_tau: Optional[float] = 0.0
  76. mirostat_eta: Optional[float] = 0.0
  77. dynatemp_min: Optional[float] = 0.0
  78. dynatemp_max: Optional[float] = 0.0
  79. dynatemp_exponent: Optional[float] = 1.0
  80. smoothing_factor: Optional[float] = 0.0
  81. smoothing_curve: Optional[float] = 1.0
  82. ignore_eos: Optional[bool] = False
  83. use_beam_search: Optional[bool] = False
  84. prompt_logprobs: Optional[int] = None
  85. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  86. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  87. skip_special_tokens: Optional[bool] = True
  88. spaces_between_special_tokens: Optional[bool] = True
  89. add_generation_prompt: Optional[bool] = True
  90. echo: Optional[bool] = False
  91. length_penalty: Optional[float] = 1.0
  92. guided_json: Optional[Union[str, dict, BaseModel]] = None
  93. guided_regex: Optional[str] = None
  94. guided_choice: Optional[List[str]] = None
  95. guided_grammar: Optional[str] = None
  96. response_format: Optional[ResponseFormat] = None
  97. def to_sampling_params(self) -> SamplingParams:
  98. if self.logprobs and not self.top_logprobs:
  99. raise ValueError("Top logprobs must be set when logprobs is.")
  100. if self.top_k == 0:
  101. self.top_k = -1
  102. logits_processors = None
  103. if self.logit_bias:
  104. def logit_bias_logits_processor(
  105. token_ids: List[int],
  106. logits: torch.Tensor) -> torch.Tensor:
  107. for token_id, bias in self.logit_bias.items():
  108. # Clamp the bias between -100 and 100 per OpenAI API spec
  109. bias = min(100, max(-100, bias))
  110. logits[int(token_id)] += bias
  111. return logits
  112. logits_processors = [logit_bias_logits_processor]
  113. return SamplingParams(
  114. n=self.n,
  115. max_tokens=self.max_tokens,
  116. min_tokens=self.min_tokens,
  117. logprobs=self.top_logprobs if self.logprobs else None,
  118. prompt_logprobs=self.top_logprobs if self.echo else None,
  119. temperature=self.temperature,
  120. top_p=self.top_p,
  121. tfs=self.tfs,
  122. eta_cutoff=self.eta_cutoff,
  123. epsilon_cutoff=self.epsilon_cutoff,
  124. typical_p=self.typical_p,
  125. presence_penalty=self.presence_penalty,
  126. frequency_penalty=self.frequency_penalty,
  127. repetition_penalty=self.repetition_penalty,
  128. top_k=self.top_k,
  129. top_a=self.top_a,
  130. min_p=self.min_p,
  131. mirostat_mode=self.mirostat_mode,
  132. mirostat_tau=self.mirostat_tau,
  133. mirostat_eta=self.mirostat_eta,
  134. dynatemp_min=self.dynatemp_min,
  135. dynatemp_max=self.dynatemp_max,
  136. dynatemp_exponent=self.dynatemp_exponent,
  137. smoothing_factor=self.smoothing_factor,
  138. smoothing_curve=self.smoothing_curve,
  139. ignore_eos=self.ignore_eos,
  140. use_beam_search=self.use_beam_search,
  141. stop_token_ids=self.stop_token_ids,
  142. custom_token_bans=self.custom_token_bans,
  143. skip_special_tokens=self.skip_special_tokens,
  144. spaces_between_special_tokens=self.spaces_between_special_tokens,
  145. stop=self.stop,
  146. best_of=self.best_of,
  147. include_stop_str_in_output=self.include_stop_str_in_output,
  148. seed=self.seed,
  149. logits_processors=logits_processors,
  150. )
  151. @model_validator(mode="before")
  152. @classmethod
  153. def check_guided_decoding_count(cls, data):
  154. guide_count = sum([
  155. "guided_json" in data and data["guided_json"] is not None,
  156. "guided_regex" in data and data["guided_regex"] is not None,
  157. "guided_choice" in data and data["guided_choice"] is not None
  158. ])
  159. if guide_count > 1:
  160. raise ValueError(
  161. "You can only use one kind of guided decoding "
  162. "('guided_json', 'guided_regex' or 'guided_choice').")
  163. return data
  164. class CompletionRequest(BaseModel):
  165. model: str
  166. # a string, array of strings, array of tokens, or array of token arrays
  167. prompt: Union[List[int], List[List[int]], str, List[str]]
  168. suffix: Optional[str] = None
  169. max_tokens: Optional[int] = 16
  170. min_tokens: Optional[int] = 0
  171. temperature: Optional[float] = 1.0
  172. top_p: Optional[float] = 1.0
  173. tfs: Optional[float] = 1.0
  174. eta_cutoff: Optional[float] = 0.0
  175. epsilon_cutoff: Optional[float] = 0.0
  176. typical_p: Optional[float] = 1.0
  177. n: Optional[int] = 1
  178. stream: Optional[bool] = False
  179. logprobs: Optional[int] = None
  180. echo: Optional[bool] = False
  181. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  182. seed: Optional[int] = None
  183. include_stop_str_in_output: Optional[bool] = False
  184. presence_penalty: Optional[float] = 0.0
  185. frequency_penalty: Optional[float] = 0.0
  186. repetition_penalty: Optional[float] = 1.0
  187. best_of: Optional[int] = None
  188. logit_bias: Optional[Dict[str, float]] = None
  189. user: Optional[str] = None
  190. top_k: Optional[int] = -1
  191. top_a: Optional[float] = 0.0
  192. min_p: Optional[float] = 0.0
  193. mirostat_mode: Optional[int] = 0
  194. mirostat_tau: Optional[float] = 0.0
  195. mirostat_eta: Optional[float] = 0.0
  196. dynatemp_min: Optional[float] = Field(0.0,
  197. validation_alias=AliasChoices(
  198. "dynatemp_min", "dynatemp_low"),
  199. description="Aliases: dynatemp_low")
  200. dynatemp_max: Optional[float] = Field(0.0,
  201. validation_alias=AliasChoices(
  202. "dynatemp_max", "dynatemp_high"),
  203. description="Aliases: dynatemp_high")
  204. dynatemp_exponent: Optional[float] = 1.0
  205. smoothing_factor: Optional[float] = 0.0
  206. smoothing_curve: Optional[float] = 1.0
  207. ignore_eos: Optional[bool] = False
  208. use_beam_search: Optional[bool] = False
  209. logprobs: Optional[int] = None
  210. prompt_logprobs: Optional[int] = None
  211. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  212. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  213. skip_special_tokens: Optional[bool] = True
  214. spaces_between_special_tokens: Optional[bool] = True
  215. truncate_prompt_tokens: Optional[conint(ge=1)] = None
  216. grammar: Optional[str] = None
  217. length_penalty: Optional[float] = 1.0
  218. guided_json: Optional[Union[str, dict, BaseModel]] = None
  219. guided_regex: Optional[str] = None
  220. guided_choice: Optional[List[str]] = None
  221. guided_grammar: Optional[str] = None
  222. response_format: Optional[ResponseFormat] = None
  223. def to_sampling_params(self) -> SamplingParams:
  224. echo_without_generation = self.echo and self.max_tokens == 0
  225. if self.top_k == 0:
  226. self.top_k = -1
  227. logits_processors = None
  228. if self.logit_bias:
  229. def logit_bias_logits_processor(
  230. token_ids: List[int],
  231. logits: torch.Tensor) -> torch.Tensor:
  232. for token_id, bias in self.logit_bias.items():
  233. bias = min(100, max(-100, bias))
  234. logits[int(token_id)] += bias
  235. return logits
  236. logits_processors = [logit_bias_logits_processor]
  237. return SamplingParams(
  238. n=self.n,
  239. max_tokens=self.max_tokens if not echo_without_generation else 1,
  240. min_tokens=self.min_tokens,
  241. temperature=self.temperature,
  242. top_p=self.top_p,
  243. tfs=self.tfs,
  244. eta_cutoff=self.eta_cutoff,
  245. epsilon_cutoff=self.epsilon_cutoff,
  246. typical_p=self.typical_p,
  247. presence_penalty=self.presence_penalty,
  248. frequency_penalty=self.frequency_penalty,
  249. repetition_penalty=self.repetition_penalty,
  250. top_k=self.top_k,
  251. top_a=self.top_a,
  252. min_p=self.min_p,
  253. mirostat_mode=self.mirostat_mode,
  254. mirostat_tau=self.mirostat_tau,
  255. mirostat_eta=self.mirostat_eta,
  256. dynatemp_min=self.dynatemp_min,
  257. dynatemp_max=self.dynatemp_max,
  258. dynatemp_exponent=self.dynatemp_exponent,
  259. smoothing_factor=self.smoothing_factor,
  260. smoothing_curve=self.smoothing_curve,
  261. ignore_eos=self.ignore_eos,
  262. use_beam_search=self.use_beam_search,
  263. logprobs=self.logprobs,
  264. prompt_logprobs=self.prompt_logprobs if self.echo else None,
  265. stop_token_ids=self.stop_token_ids,
  266. custom_token_bans=self.custom_token_bans,
  267. skip_special_tokens=self.skip_special_tokens,
  268. spaces_between_special_tokens=self.spaces_between_special_tokens,
  269. stop=self.stop,
  270. best_of=self.best_of,
  271. include_stop_str_in_output=self.include_stop_str_in_output,
  272. seed=self.seed,
  273. logits_processors=logits_processors,
  274. truncate_prompt_tokens=self.truncate_prompt_tokens,
  275. )
  276. @model_validator(mode="before")
  277. @classmethod
  278. def check_guided_decoding_count(cls, data):
  279. guide_count = sum([
  280. "guided_json" in data and data["guided_json"] is not None,
  281. "guided_regex" in data and data["guided_regex"] is not None,
  282. "guided_choice" in data and data["guided_choice"] is not None
  283. ])
  284. if guide_count > 1:
  285. raise ValueError(
  286. "You can only use one kind of guided decoding "
  287. "('guided_json', 'guided_regex' or 'guided_choice').")
  288. return data
  289. class LogProbs(BaseModel):
  290. text_offset: List[int] = Field(default_factory=list)
  291. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  292. tokens: List[str] = Field(default_factory=list)
  293. top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
  294. class CompletionResponseChoice(BaseModel):
  295. index: int
  296. text: str
  297. logprobs: Optional[LogProbs] = None
  298. finish_reason: Optional[Literal["stop", "length"]] = None
  299. stop_reason: Union[None, int, str] = Field(
  300. default=None,
  301. description=(
  302. "The stop string or token id that caused the completion "
  303. "to stop, None if the completion finished for some other reason "
  304. "including encountering the EOS token"),
  305. )
  306. class CompletionResponse(BaseModel):
  307. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  308. object: str = "text_completion"
  309. created: int = Field(default_factory=lambda: int(time.time()))
  310. model: str
  311. choices: List[CompletionResponseChoice]
  312. usage: UsageInfo
  313. class CompletionResponseStreamChoice(BaseModel):
  314. index: int
  315. text: str
  316. logprobs: Optional[LogProbs] = None
  317. finish_reason: Optional[Literal["stop", "length"]] = None
  318. stop_reason: Union[None, int, str] = Field(
  319. default=None,
  320. description=(
  321. "The stop string or token id that caused the completion "
  322. "to stop, None if the completion finished for some other reason "
  323. "including encountering the EOS token"),
  324. )
  325. class CompletionStreamResponse(BaseModel):
  326. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  327. object: str = "text_completion"
  328. created: int = Field(default_factory=lambda: int(time.time()))
  329. model: str
  330. choices: List[CompletionResponseStreamChoice]
  331. usage: Optional[UsageInfo] = Field(default=None)
  332. class ChatMessage(BaseModel):
  333. role: str
  334. content: str
  335. class ChatCompletionResponseChoice(BaseModel):
  336. index: int
  337. message: ChatMessage
  338. logprobs: Optional[LogProbs] = None
  339. finish_reason: Optional[Literal["stop", "length"]] = None
  340. stop_reason: Union[None, int, str] = None
  341. class ChatCompletionResponse(BaseModel):
  342. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  343. object: str = "chat.completion"
  344. created: int = Field(default_factory=lambda: int(time.time()))
  345. model: str
  346. choices: List[ChatCompletionResponseChoice]
  347. usage: UsageInfo
  348. class DeltaMessage(BaseModel):
  349. role: Optional[str] = None
  350. content: Optional[str] = None
  351. class ChatCompletionResponseStreamChoice(BaseModel):
  352. index: int
  353. delta: DeltaMessage
  354. logprobs: Optional[LogProbs] = None
  355. finish_reason: Optional[Literal["stop", "length"]] = None
  356. stop_reason: Union[None, int, str] = None
  357. class ChatCompletionStreamResponse(BaseModel):
  358. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  359. object: str = "chat.completion.chunk"
  360. created: int = Field(default_factory=lambda: int(time.time()))
  361. model: str
  362. choices: List[ChatCompletionResponseStreamChoice]
  363. usage: Optional[UsageInfo] = Field(default=None)
  364. logprobs: Optional[LogProbs] = None
  365. class EmbeddingsRequest(BaseModel):
  366. input: List[str] = Field(
  367. ..., description="List of input texts to generate embeddings for.")
  368. encoding_format: str = Field(
  369. "float",
  370. description="Encoding format for the embeddings. "
  371. "Can be 'float' or 'base64'.")
  372. model: Optional[str] = Field(
  373. None,
  374. description="Name of the embedding model to use. "
  375. "If not provided, the default model will be used.")
  376. class EmbeddingObject(BaseModel):
  377. object: str = Field("embedding", description="Type of the object.")
  378. embedding: List[float] = Field(
  379. ..., description="Embedding values as a list of floats.")
  380. index: int = Field(
  381. ...,
  382. description="Index of the input text corresponding to "
  383. "the embedding.")
  384. class EmbeddingsResponse(BaseModel):
  385. object: str = Field("list", description="Type of the response object.")
  386. data: List[EmbeddingObject] = Field(
  387. ..., description="List of embedding objects.")
  388. model: str = Field(..., description="Name of the embedding model used.")
  389. usage: UsageInfo = Field(..., description="Information about token usage.")
  390. class Prompt(BaseModel):
  391. prompt: str
  392. # ========== KoboldAI ========== #
  393. class KoboldSamplingParams(BaseModel):
  394. n: int = Field(1, alias="n")
  395. best_of: Optional[int] = Field(None, alias="best_of")
  396. presence_penalty: float = Field(0.0, alias="presence_penalty")
  397. frequency_penalty: float = Field(0.0, alias="rep_pen")
  398. temperature: float = Field(1.0, alias="temperature")
  399. dynatemp_range: Optional[float] = 0.0
  400. dynatemp_exponent: Optional[float] = 1.0
  401. smoothing_factor: Optional[float] = 0.0
  402. smoothing_curve: Optional[float] = 1.0
  403. top_p: float = Field(1.0, alias="top_p")
  404. top_k: float = Field(-1, alias="top_k")
  405. min_p: float = Field(0.0, alias="min_p")
  406. top_a: float = Field(0.0, alias="top_a")
  407. tfs: float = Field(1.0, alias="tfs")
  408. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  409. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  410. typical_p: float = Field(1.0, alias="typical_p")
  411. use_beam_search: bool = Field(False, alias="use_beam_search")
  412. length_penalty: float = Field(1.0, alias="length_penalty")
  413. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  414. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  415. include_stop_str_in_output: Optional[bool] = False
  416. ignore_eos: bool = Field(False, alias="ignore_eos")
  417. max_tokens: int = Field(16, alias="max_length")
  418. logprobs: Optional[int] = Field(None, alias="logprobs")
  419. custom_token_bans: Optional[List[int]] = Field(None,
  420. alias="custom_token_bans")
  421. @root_validator(pre=False, skip_on_failure=True)
  422. def validate_best_of(cls, values): # pylint: disable=no-self-argument
  423. best_of = values.get("best_of")
  424. n = values.get("n")
  425. if best_of is not None and (best_of <= 0 or best_of > n):
  426. raise ValueError(
  427. "best_of must be a positive integer less than or equal to n")
  428. return values
  429. class KAIGenerationInputSchema(BaseModel):
  430. genkey: Optional[str] = None
  431. prompt: str
  432. n: Optional[int] = 1
  433. max_context_length: int
  434. max_length: int
  435. rep_pen: Optional[float] = 1.0
  436. rep_pen_range: Optional[int] = None
  437. rep_pen_slope: Optional[float] = None
  438. top_k: Optional[int] = 0
  439. top_a: Optional[float] = 0.0
  440. top_p: Optional[float] = 1.0
  441. min_p: Optional[float] = 0.0
  442. tfs: Optional[float] = 1.0
  443. eps_cutoff: Optional[float] = 0.0
  444. eta_cutoff: Optional[float] = 0.0
  445. typical: Optional[float] = 1.0
  446. temperature: Optional[float] = 1.0
  447. dynatemp_range: Optional[float] = 0.0
  448. dynatemp_exponent: Optional[float] = 1.0
  449. smoothing_factor: Optional[float] = 0.0
  450. smoothing_curve: Optional[float] = 1.0
  451. use_memory: Optional[bool] = None
  452. use_story: Optional[bool] = None
  453. use_authors_note: Optional[bool] = None
  454. use_world_info: Optional[bool] = None
  455. use_userscripts: Optional[bool] = None
  456. soft_prompt: Optional[str] = None
  457. disable_output_formatting: Optional[bool] = None
  458. frmtrmblln: Optional[bool] = None
  459. frmtrmspch: Optional[bool] = None
  460. singleline: Optional[bool] = None
  461. use_default_badwordsids: Optional[bool] = None
  462. mirostat: Optional[int] = 0
  463. mirostat_tau: Optional[float] = 0.0
  464. mirostat_eta: Optional[float] = 0.0
  465. disable_input_formatting: Optional[bool] = None
  466. frmtadsnsp: Optional[bool] = None
  467. quiet: Optional[bool] = None
  468. # pylint: disable=unexpected-keyword-arg
  469. sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
  470. sampler_seed: Optional[int] = None
  471. sampler_full_determinism: Optional[bool] = None
  472. stop_sequence: Optional[List[str]] = None
  473. include_stop_str_in_output: Optional[bool] = False
  474. @root_validator(pre=False, skip_on_failure=True)
  475. def check_context(cls, values): # pylint: disable=no-self-argument
  476. assert values.get("max_length") <= values.get(
  477. "max_context_length"
  478. ), "max_length must not be larger than max_context_length"
  479. return values