protocol.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from typing import List, Optional, Union
  2. from pydantic import BaseModel, Field, root_validator
  3. class SamplingParams(BaseModel):
  4. n: int = Field(1, alias="n")
  5. best_of: Optional[int] = Field(None, alias="best_of")
  6. presence_penalty: float = Field(0.0, alias="presence_penalty")
  7. frequency_penalty: float = Field(0.0, alias="rep_pen")
  8. temperature: float = Field(1.0, alias="temperature")
  9. dynatemp_range: Optional[float] = 0.0
  10. dynatemp_exponent: Optional[float] = 1.0
  11. smoothing_factor: Optional[float] = 0.0
  12. smoothing_curve: Optional[float] = 1.0
  13. top_p: float = Field(1.0, alias="top_p")
  14. top_k: float = Field(-1, alias="top_k")
  15. min_p: float = Field(0.0, alias="min_p")
  16. top_a: float = Field(0.0, alias="top_a")
  17. tfs: float = Field(1.0, alias="tfs")
  18. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  19. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  20. typical_p: float = Field(1.0, alias="typical_p")
  21. use_beam_search: bool = Field(False, alias="use_beam_search")
  22. length_penalty: float = Field(1.0, alias="length_penalty")
  23. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  24. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  25. include_stop_str_in_output: Optional[bool] = False
  26. ignore_eos: bool = Field(False, alias="ignore_eos")
  27. max_tokens: int = Field(16, alias="max_length")
  28. logprobs: Optional[int] = Field(None, alias="logprobs")
  29. custom_token_bans: Optional[List[int]] = Field(None,
  30. alias="custom_token_bans")
  31. @root_validator(pre=False, skip_on_failure=True)
  32. def validate_best_of(cls, values): # pylint: disable=no-self-argument
  33. best_of = values.get("best_of")
  34. n = values.get("n")
  35. if best_of is not None and (best_of <= 0 or best_of > n):
  36. raise ValueError(
  37. "best_of must be a positive integer less than or equal to n")
  38. return values
  39. class KAIGenerationInputSchema(BaseModel):
  40. genkey: Optional[str] = None
  41. prompt: str
  42. n: Optional[int] = 1
  43. max_context_length: int
  44. max_length: int
  45. rep_pen: Optional[float] = 1.0
  46. rep_pen_range: Optional[int] = None
  47. rep_pen_slope: Optional[float] = None
  48. top_k: Optional[int] = 0
  49. top_a: Optional[float] = 0.0
  50. top_p: Optional[float] = 1.0
  51. min_p: Optional[float] = 0.0
  52. tfs: Optional[float] = 1.0
  53. eps_cutoff: Optional[float] = 0.0
  54. eta_cutoff: Optional[float] = 0.0
  55. typical: Optional[float] = 1.0
  56. temperature: Optional[float] = 1.0
  57. dynatemp_range: Optional[float] = 0.0
  58. dynatemp_exponent: Optional[float] = 1.0
  59. smoothing_factor: Optional[float] = 0.0
  60. smoothing_curve: Optional[float] = 1.0
  61. use_memory: Optional[bool] = None
  62. use_story: Optional[bool] = None
  63. use_authors_note: Optional[bool] = None
  64. use_world_info: Optional[bool] = None
  65. use_userscripts: Optional[bool] = None
  66. soft_prompt: Optional[str] = None
  67. disable_output_formatting: Optional[bool] = None
  68. frmtrmblln: Optional[bool] = None
  69. frmtrmspch: Optional[bool] = None
  70. singleline: Optional[bool] = None
  71. use_default_badwordsids: Optional[bool] = None
  72. mirostat: Optional[int] = 0
  73. mirostat_tau: Optional[float] = 0.0
  74. mirostat_eta: Optional[float] = 0.0
  75. disable_input_formatting: Optional[bool] = None
  76. frmtadsnsp: Optional[bool] = None
  77. quiet: Optional[bool] = None
  78. # pylint: disable=unexpected-keyword-arg
  79. sampler_order: Optional[Union[List, str]] = Field(default_factory=list)
  80. sampler_seed: Optional[int] = None
  81. sampler_full_determinism: Optional[bool] = None
  82. stop_sequence: Optional[List[str]] = None
  83. include_stop_str_in_output: Optional[bool] = False
  84. @root_validator(pre=False, skip_on_failure=True)
  85. def check_context(cls, values): # pylint: disable=no-self-argument
  86. assert values.get("max_length") <= values.get(
  87. "max_context_length"
  88. ), "max_length must not be larger than max_context_length"
  89. return values