protocol.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. # Adapted from
  2. # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
  3. import time
  4. from typing import Any, Dict, List, Literal, Optional, Union
  5. from pydantic import (AliasChoices, BaseModel, Field, conint, model_validator,
  6. root_validator)
  7. import torch
  8. from aphrodite.common.pooling_params import PoolingParams
  9. from aphrodite.common.sampling_params import SamplingParams
  10. from aphrodite.common.utils import random_uuid
  11. from aphrodite.common.logits_processor import BiasLogitsProcessor
  12. class ErrorResponse(BaseModel):
  13. object: str = "error"
  14. message: str
  15. type: str
  16. param: Optional[str] = None
  17. code: int
  18. class ModelPermission(BaseModel):
  19. id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
  20. object: str = "model_permission"
  21. created: int = Field(default_factory=lambda: int(time.time()))
  22. allow_create_engine: bool = False
  23. allow_sampling: bool = True
  24. allow_logprobs: bool = True
  25. allow_search_indices: bool = False
  26. allow_view: bool = True
  27. allow_fine_tuning: bool = False
  28. organization: str = "*"
  29. group: Optional[str] = None
  30. is_blocking: bool = False
  31. class ModelCard(BaseModel):
  32. id: str
  33. object: str = "model"
  34. created: int = Field(default_factory=lambda: int(time.time()))
  35. owned_by: str = "pygmalionai"
  36. root: Optional[str] = None
  37. parent: Optional[str] = None
  38. permission: List[ModelPermission] = Field(default_factory=list)
  39. class ModelList(BaseModel):
  40. object: str = "list"
  41. data: List[ModelCard] = Field(default_factory=list)
  42. class UsageInfo(BaseModel):
  43. prompt_tokens: int = 0
  44. total_tokens: int = 0
  45. completion_tokens: Optional[int] = 0
  46. class ResponseFormat(BaseModel):
  47. # type must be "json_object" or "text"
  48. type: str = Literal["text", "json_object"]
  49. class ChatCompletionRequest(BaseModel):
  50. model: str
  51. # support list type in messages.content
  52. messages: List[Dict[str, Union[str, List[Dict[str, str]]]]]
  53. temperature: Optional[float] = 0.7
  54. top_p: Optional[float] = 1.0
  55. tfs: Optional[float] = 1.0
  56. eta_cutoff: Optional[float] = 0.0
  57. epsilon_cutoff: Optional[float] = 0.0
  58. typical_p: Optional[float] = 1.0
  59. n: Optional[int] = 1
  60. max_tokens: Optional[int] = None
  61. min_tokens: Optional[int] = 0
  62. seed: Optional[int] = Field(None,
  63. ge=torch.iinfo(torch.long).min,
  64. le=torch.iinfo(torch.long).max)
  65. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  66. include_stop_str_in_output: Optional[bool] = False
  67. stream: Optional[bool] = False
  68. logprobs: Optional[bool] = False
  69. top_logprobs: Optional[int] = None
  70. presence_penalty: Optional[float] = 0.0
  71. frequency_penalty: Optional[float] = 0.0
  72. repetition_penalty: Optional[float] = 1.0
  73. logit_bias: Optional[Dict[str, float]] = None
  74. user: Optional[str] = None
  75. best_of: Optional[int] = None
  76. top_k: Optional[int] = -1
  77. top_a: Optional[float] = 0.0
  78. min_p: Optional[float] = 0.0
  79. mirostat_mode: Optional[int] = 0
  80. mirostat_tau: Optional[float] = 0.0
  81. mirostat_eta: Optional[float] = 0.0
  82. dynatemp_min: Optional[float] = 0.0
  83. dynatemp_max: Optional[float] = 0.0
  84. dynatemp_exponent: Optional[float] = 1.0
  85. smoothing_factor: Optional[float] = 0.0
  86. smoothing_curve: Optional[float] = 1.0
  87. ignore_eos: Optional[bool] = False
  88. use_beam_search: Optional[bool] = False
  89. prompt_logprobs: Optional[int] = None
  90. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  91. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  92. skip_special_tokens: Optional[bool] = True
  93. spaces_between_special_tokens: Optional[bool] = True
  94. add_generation_prompt: Optional[bool] = True
  95. echo: Optional[bool] = False
  96. length_penalty: Optional[float] = 1.0
  97. guided_json: Optional[Union[str, dict, BaseModel]] = None
  98. guided_regex: Optional[str] = None
  99. guided_choice: Optional[List[str]] = None
  100. guided_grammar: Optional[str] = None
  101. response_format: Optional[ResponseFormat] = None
  102. guided_decoding_backend: Optional[str] = Field(
  103. default="outlines",
  104. description=(
  105. "If specified, will override the default guided decoding backend "
  106. "of the server for this specific request. If set, must be either "
  107. "'outlines' / 'lm-format-enforcer'"))
  108. guided_whitespace_pattern: Optional[str] = Field(
  109. default=None,
  110. description=(
  111. "If specified, will override the default whitespace pattern "
  112. "for guided json decoding."))
  113. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  114. if self.logprobs and not self.top_logprobs:
  115. raise ValueError("Top logprobs must be set when logprobs is.")
  116. if self.top_k == 0:
  117. self.top_k = -1
  118. logits_processors = []
  119. if self.logit_bias:
  120. biases = {
  121. int(tok): max(-100, min(float(bias), 100))
  122. for tok, bias in self.logit_bias.items()
  123. if 0 < int(tok) < vocab_size
  124. }
  125. logits_processors.append(BiasLogitsProcessor(biases))
  126. return SamplingParams(
  127. n=self.n,
  128. max_tokens=self.max_tokens,
  129. min_tokens=self.min_tokens,
  130. logprobs=self.top_logprobs if self.logprobs else None,
  131. prompt_logprobs=self.top_logprobs if self.echo else None,
  132. temperature=self.temperature,
  133. top_p=self.top_p,
  134. tfs=self.tfs,
  135. eta_cutoff=self.eta_cutoff,
  136. epsilon_cutoff=self.epsilon_cutoff,
  137. typical_p=self.typical_p,
  138. presence_penalty=self.presence_penalty,
  139. frequency_penalty=self.frequency_penalty,
  140. repetition_penalty=self.repetition_penalty,
  141. top_k=self.top_k,
  142. top_a=self.top_a,
  143. min_p=self.min_p,
  144. mirostat_mode=self.mirostat_mode,
  145. mirostat_tau=self.mirostat_tau,
  146. mirostat_eta=self.mirostat_eta,
  147. dynatemp_min=self.dynatemp_min,
  148. dynatemp_max=self.dynatemp_max,
  149. dynatemp_exponent=self.dynatemp_exponent,
  150. smoothing_factor=self.smoothing_factor,
  151. smoothing_curve=self.smoothing_curve,
  152. ignore_eos=self.ignore_eos,
  153. use_beam_search=self.use_beam_search,
  154. stop_token_ids=self.stop_token_ids,
  155. custom_token_bans=self.custom_token_bans,
  156. skip_special_tokens=self.skip_special_tokens,
  157. spaces_between_special_tokens=self.spaces_between_special_tokens,
  158. stop=self.stop,
  159. best_of=self.best_of,
  160. include_stop_str_in_output=self.include_stop_str_in_output,
  161. seed=self.seed,
  162. logits_processors=logits_processors,
  163. )
  164. @model_validator(mode="before")
  165. @classmethod
  166. def check_guided_decoding_count(cls, data):
  167. guide_count = sum([
  168. "guided_json" in data and data["guided_json"] is not None,
  169. "guided_regex" in data and data["guided_regex"] is not None,
  170. "guided_choice" in data and data["guided_choice"] is not None
  171. ])
  172. if guide_count > 1:
  173. raise ValueError(
  174. "You can only use one kind of guided decoding "
  175. "('guided_json', 'guided_regex' or 'guided_choice').")
  176. return data
  177. class CompletionRequest(BaseModel):
  178. model: str
  179. # a string, array of strings, array of tokens, or array of token arrays
  180. prompt: Union[List[int], List[List[int]], str, List[str]]
  181. suffix: Optional[str] = None
  182. max_tokens: Optional[int] = 16
  183. min_tokens: Optional[int] = 0
  184. temperature: Optional[float] = 1.0
  185. top_p: Optional[float] = 1.0
  186. tfs: Optional[float] = 1.0
  187. eta_cutoff: Optional[float] = 0.0
  188. epsilon_cutoff: Optional[float] = 0.0
  189. typical_p: Optional[float] = 1.0
  190. n: Optional[int] = 1
  191. stream: Optional[bool] = False
  192. logprobs: Optional[int] = None
  193. echo: Optional[bool] = False
  194. stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
  195. seed: Optional[int] = Field(None,
  196. ge=torch.iinfo(torch.long).min,
  197. le=torch.iinfo(torch.long).max)
  198. include_stop_str_in_output: Optional[bool] = False
  199. presence_penalty: Optional[float] = 0.0
  200. frequency_penalty: Optional[float] = 0.0
  201. repetition_penalty: Optional[float] = 1.0
  202. best_of: Optional[int] = None
  203. logit_bias: Optional[Dict[str, float]] = None
  204. user: Optional[str] = None
  205. top_k: Optional[int] = -1
  206. top_a: Optional[float] = 0.0
  207. min_p: Optional[float] = 0.0
  208. mirostat_mode: Optional[int] = 0
  209. mirostat_tau: Optional[float] = 0.0
  210. mirostat_eta: Optional[float] = 0.0
  211. dynatemp_min: Optional[float] = Field(0.0,
  212. validation_alias=AliasChoices(
  213. "dynatemp_min", "dynatemp_low"),
  214. description="Aliases: dynatemp_low")
  215. dynatemp_max: Optional[float] = Field(0.0,
  216. validation_alias=AliasChoices(
  217. "dynatemp_max", "dynatemp_high"),
  218. description="Aliases: dynatemp_high")
  219. dynatemp_exponent: Optional[float] = 1.0
  220. smoothing_factor: Optional[float] = 0.0
  221. smoothing_curve: Optional[float] = 1.0
  222. ignore_eos: Optional[bool] = False
  223. use_beam_search: Optional[bool] = False
  224. logprobs: Optional[int] = None
  225. prompt_logprobs: Optional[int] = None
  226. stop_token_ids: Optional[List[int]] = Field(default_factory=list)
  227. custom_token_bans: Optional[List[int]] = Field(default_factory=list)
  228. skip_special_tokens: Optional[bool] = True
  229. spaces_between_special_tokens: Optional[bool] = True
  230. truncate_prompt_tokens: Optional[conint(ge=1)] = None
  231. grammar: Optional[str] = None
  232. length_penalty: Optional[float] = 1.0
  233. guided_json: Optional[Union[str, dict, BaseModel]] = None
  234. guided_regex: Optional[str] = None
  235. guided_choice: Optional[List[str]] = None
  236. guided_grammar: Optional[str] = None
  237. response_format: Optional[ResponseFormat] = None
  238. guided_decoding_backend: Optional[str] = Field(
  239. default="outlines",
  240. description=(
  241. "If specified, will override the default guided decoding backend "
  242. "of the server for this specific request. If set, must be one of "
  243. "'outlines' / 'lm-format-enforcer'"))
  244. guided_whitespace_pattern: Optional[str] = Field(
  245. default=None,
  246. description=(
  247. "If specified, will override the default whitespace pattern "
  248. "for guided json decoding."))
  249. def to_sampling_params(self, vocab_size: int) -> SamplingParams:
  250. echo_without_generation = self.echo and self.max_tokens == 0
  251. if self.top_k == 0:
  252. self.top_k = -1
  253. logits_processors = []
  254. if self.logit_bias:
  255. biases = {
  256. int(tok): max(-100, min(float(bias), 100))
  257. for tok, bias in self.logit_bias.items()
  258. if 0 < int(tok) < vocab_size
  259. }
  260. logits_processors.append(BiasLogitsProcessor(biases))
  261. return SamplingParams(
  262. n=self.n,
  263. max_tokens=self.max_tokens if not echo_without_generation else 1,
  264. min_tokens=self.min_tokens,
  265. temperature=self.temperature,
  266. top_p=self.top_p,
  267. tfs=self.tfs,
  268. eta_cutoff=self.eta_cutoff,
  269. epsilon_cutoff=self.epsilon_cutoff,
  270. typical_p=self.typical_p,
  271. presence_penalty=self.presence_penalty,
  272. frequency_penalty=self.frequency_penalty,
  273. repetition_penalty=self.repetition_penalty,
  274. top_k=self.top_k,
  275. top_a=self.top_a,
  276. min_p=self.min_p,
  277. mirostat_mode=self.mirostat_mode,
  278. mirostat_tau=self.mirostat_tau,
  279. mirostat_eta=self.mirostat_eta,
  280. dynatemp_min=self.dynatemp_min,
  281. dynatemp_max=self.dynatemp_max,
  282. dynatemp_exponent=self.dynatemp_exponent,
  283. smoothing_factor=self.smoothing_factor,
  284. smoothing_curve=self.smoothing_curve,
  285. ignore_eos=self.ignore_eos,
  286. use_beam_search=self.use_beam_search,
  287. logprobs=self.logprobs,
  288. prompt_logprobs=self.prompt_logprobs if self.echo else None,
  289. stop_token_ids=self.stop_token_ids,
  290. custom_token_bans=self.custom_token_bans,
  291. skip_special_tokens=self.skip_special_tokens,
  292. spaces_between_special_tokens=self.spaces_between_special_tokens,
  293. stop=self.stop,
  294. best_of=self.best_of,
  295. include_stop_str_in_output=self.include_stop_str_in_output,
  296. seed=self.seed,
  297. logits_processors=logits_processors,
  298. truncate_prompt_tokens=self.truncate_prompt_tokens,
  299. )
  300. @model_validator(mode="before")
  301. @classmethod
  302. def check_guided_decoding_count(cls, data):
  303. guide_count = sum([
  304. "guided_json" in data and data["guided_json"] is not None,
  305. "guided_regex" in data and data["guided_regex"] is not None,
  306. "guided_choice" in data and data["guided_choice"] is not None
  307. ])
  308. if guide_count > 1:
  309. raise ValueError(
  310. "You can only use one kind of guided decoding "
  311. "('guided_json', 'guided_regex' or 'guided_choice').")
  312. return data
  313. class LogProbs(BaseModel):
  314. text_offset: List[int] = Field(default_factory=list)
  315. token_logprobs: List[Optional[float]] = Field(default_factory=list)
  316. tokens: List[str] = Field(default_factory=list)
  317. top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
  318. class CompletionResponseChoice(BaseModel):
  319. index: int
  320. text: str
  321. logprobs: Optional[LogProbs] = None
  322. finish_reason: Optional[Literal["stop", "length"]] = None
  323. stop_reason: Union[None, int, str] = Field(
  324. default=None,
  325. description=(
  326. "The stop string or token id that caused the completion "
  327. "to stop, None if the completion finished for some other reason "
  328. "including encountering the EOS token"),
  329. )
  330. class CompletionResponse(BaseModel):
  331. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  332. object: str = "text_completion"
  333. created: int = Field(default_factory=lambda: int(time.time()))
  334. model: str
  335. choices: List[CompletionResponseChoice]
  336. usage: UsageInfo
  337. class CompletionResponseStreamChoice(BaseModel):
  338. index: int
  339. text: str
  340. logprobs: Optional[LogProbs] = None
  341. finish_reason: Optional[Literal["stop", "length"]] = None
  342. stop_reason: Union[None, int, str] = Field(
  343. default=None,
  344. description=(
  345. "The stop string or token id that caused the completion "
  346. "to stop, None if the completion finished for some other reason "
  347. "including encountering the EOS token"),
  348. )
  349. class CompletionStreamResponse(BaseModel):
  350. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  351. object: str = "text_completion"
  352. created: int = Field(default_factory=lambda: int(time.time()))
  353. model: str
  354. choices: List[CompletionResponseStreamChoice]
  355. usage: Optional[UsageInfo] = Field(default=None)
  356. class EmbeddingResponseData(BaseModel):
  357. index: int
  358. object: str = "embedding"
  359. embedding: List[float]
  360. class EmbeddingResponse(BaseModel):
  361. id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
  362. object: str = "list"
  363. created: int = Field(default_factory=lambda: int(time.time()))
  364. model: str
  365. data: List[EmbeddingResponseData]
  366. usage: UsageInfo
  367. class ChatMessage(BaseModel):
  368. role: str
  369. content: str
  370. class ChatCompletionResponseChoice(BaseModel):
  371. index: int
  372. message: ChatMessage
  373. logprobs: Optional[LogProbs] = None
  374. finish_reason: Optional[Literal["stop", "length"]] = None
  375. stop_reason: Union[None, int, str] = None
  376. class ChatCompletionResponse(BaseModel):
  377. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  378. object: str = "chat.completion"
  379. created: int = Field(default_factory=lambda: int(time.time()))
  380. model: str
  381. choices: List[ChatCompletionResponseChoice]
  382. usage: UsageInfo
  383. class DeltaMessage(BaseModel):
  384. role: Optional[str] = None
  385. content: Optional[str] = None
  386. class ChatCompletionResponseStreamChoice(BaseModel):
  387. index: int
  388. delta: DeltaMessage
  389. logprobs: Optional[LogProbs] = None
  390. finish_reason: Optional[Literal["stop", "length"]] = None
  391. stop_reason: Union[None, int, str] = None
  392. class ChatCompletionStreamResponse(BaseModel):
  393. id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
  394. object: str = "chat.completion.chunk"
  395. created: int = Field(default_factory=lambda: int(time.time()))
  396. model: str
  397. choices: List[ChatCompletionResponseStreamChoice]
  398. usage: Optional[UsageInfo] = Field(default=None)
  399. logprobs: Optional[LogProbs] = None
  400. class EmbeddingRequest(BaseModel):
  401. # Ordered by official OpenAI API documentation
  402. # https://platform.openai.com/docs/api-reference/embeddings
  403. model: str
  404. input: Union[List[int], List[List[int]], str, List[str]]
  405. encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$')
  406. dimensions: Optional[int] = None
  407. user: Optional[str] = None
  408. # doc: begin-embedding-pooling-params
  409. additional_data: Optional[Any] = None
  410. # doc: end-embedding-pooling-params
  411. def to_pooling_params(self):
  412. return PoolingParams(additional_data=self.additional_data)
  413. class Prompt(BaseModel):
  414. prompt: str
  415. # ========== KoboldAI ========== #
  416. class KoboldSamplingParams(BaseModel):
  417. n: int = Field(1, alias="n")
  418. best_of: Optional[int] = Field(None, alias="best_of")
  419. presence_penalty: float = Field(0.0, alias="presence_penalty")
  420. frequency_penalty: float = Field(0.0, alias="rep_pen")
  421. temperature: float = Field(1.0, alias="temperature")
  422. dynatemp_range: Optional[float] = 0.0
  423. dynatemp_exponent: Optional[float] = 1.0
  424. smoothing_factor: Optional[float] = 0.0
  425. smoothing_curve: Optional[float] = 1.0
  426. top_p: float = Field(1.0, alias="top_p")
  427. top_k: float = Field(-1, alias="top_k")
  428. min_p: float = Field(0.0, alias="min_p")
  429. top_a: float = Field(0.0, alias="top_a")
  430. tfs: float = Field(1.0, alias="tfs")
  431. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  432. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  433. typical_p: float = Field(1.0, alias="typical_p")
  434. use_beam_search: bool = Field(False, alias="use_beam_search")
  435. length_penalty: float = Field(1.0, alias="length_penalty")
  436. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  437. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  438. include_stop_str_in_output: Optional[bool] = False
  439. ignore_eos: bool = Field(False, alias="ignore_eos")
  440. max_tokens: int = Field(16, alias="max_length")
  441. logprobs: Optional[int] = Field(None, alias="logprobs")
  442. custom_token_bans: Optional[List[int]] = Field(None,
  443. alias="custom_token_bans")
  444. @root_validator(pre=False, skip_on_failure=True)
  445. def validate_best_of(cls, values): # pylint: disable=no-self-argument
  446. best_of = values.get("best_of")
  447. n = values.get("n")
  448. if best_of is not None and (best_of <= 0 or best_of > n):
  449. raise ValueError(
  450. "best_of must be a positive integer less than or equal to n")
  451. return values
  452. class KAIGenerationInputSchema(BaseModel):
  453. genkey: Optional[str] = None
  454. prompt: str
  455. n: Optional[int] = 1
  456. max_context_length: int
  457. max_length: int
  458. rep_pen: Optional[float] = 1.0
  459. rep_pen_range: Optional[int] = None
  460. rep_pen_slope: Optional[float] = None
  461. top_k: Optional[int] = 0
  462. top_a: Optional[float] = 0.0
  463. top_p: Optional[float] = 1.0
  464. min_p: Optional[float] = 0.0
  465. tfs: Optional[float] = 1.0
  466. eps_cutoff: Optional[float] = 0.0
  467. eta_cutoff: Optional[float] = 0.0
  468. typical: Optional[float] = 1.0
  469. temperature: Optional[float] = 1.0
  470. dynatemp_range: Optional[float] = 0.0
  471. dynatemp_exponent: Optional[float] = 1.0
  472. smoothing_factor: Optional[float] = 0.0
  473. smoothing_curve: Optional[float] = 1.0
  474. use_memory: Optional[bool] = None
  475. use_story: Optional[bool] = None
  476. use_authors_note: Optional[bool] = None
  477. use_world_info: Optional[bool] = None
  478. use_userscripts: Optional[bool] = None
  479. soft_prompt: Optional[str] = None
  480. disable_output_formatting: Optional[bool] = None
  481. frmtrmblln: Optional[bool] = None
  482. frmtrmspch: Optional[bool] = None
  483. singleline: Optional[bool] = None
  484. use_default_badwordsids: Optional[bool] = None
  485. mirostat: Optional[int] = 0
  486. mirostat_tau: Optional[float] = 0.0
  487. mirostat_eta: Optional[float] = 0.0
  488. disable_input_formatting: Optional[bool] = None
  489. frmtadsnsp: Optional[bool] = None
  490. quiet: Optional[bool] = None
  491. # pylint: disable=unexpected-keyword-arg
  492. sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
  493. sampler_seed: Optional[int] = None
  494. sampler_full_determinism: Optional[bool] = None
  495. stop_sequence: Optional[List[str]] = None
  496. include_stop_str_in_output: Optional[bool] = False
  497. @root_validator(pre=False, skip_on_failure=True)
  498. def check_context(cls, values): # pylint: disable=no-self-argument
  499. assert values.get("max_length") <= values.get(
  500. "max_context_length"
  501. ), "max_length must not be larger than max_context_length"
  502. return values