protocol.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. from typing import List, Optional, Union
  2. from pydantic import BaseModel, Field, root_validator, conint, confloat, conlist, NonNegativeFloat, NonNegativeInt, PositiveInt
  3. class SamplingParams(BaseModel):
  4. n: int = Field(1, alias="n")
  5. best_of: Optional[int] = Field(None, alias="best_of")
  6. presence_penalty: float = Field(0.0, alias="presence_penalty")
  7. frequency_penalty: float = Field(0.0, alias="rep_pen")
  8. temperature: float = Field(1.0, alias="temperature")
  9. top_p: float = Field(1.0, alias="top_p")
  10. top_k: float = Field(-1, alias="top_k")
  11. tfs: float = Field(1.0, alias="tfs")
  12. eta_cutoff: float = Field(0.0, alias="eta_cutoff")
  13. epsilon_cutoff: float = Field(0.0, alias="epsilon_cutoff")
  14. typical_p: float = Field(1.0, alias="typical_p")
  15. use_beam_search: bool = Field(False, alias="use_beam_search")
  16. length_penalty: float = Field(1.0, alias="length_penalty")
  17. early_stopping: Union[bool, str] = Field(False, alias="early_stopping")
  18. stop: Union[None, str, List[str]] = Field(None, alias="stop_sequence")
  19. ignore_eos: bool = Field(False, alias="ignore_eos")
  20. max_tokens: int = Field(16, alias="max_length")
  21. logprobs: Optional[int] = Field(None, alias="logprobs")
  22. @root_validator
  23. def validate_best_of(cls, values):
  24. best_of = values.get("best_of")
  25. n = values.get("n")
  26. if best_of is not None and (best_of <= 0 or best_of > n):
  27. raise ValueError("best_of must be a positive integer less than or equal to n")
  28. return values
  29. class KAIGenerationInputSchema(BaseModel):
  30. prompt: str
  31. n: Optional[conint(ge=1, le=5)] = 1
  32. max_context_length: PositiveInt
  33. max_length: PositiveInt
  34. rep_pen: Optional[confloat(ge=1)] = 1.0
  35. rep_pen_range: Optional[NonNegativeInt]
  36. rep_pen_slope: Optional[NonNegativeFloat]
  37. top_k: Optional[NonNegativeInt] = 0.0
  38. top_a: Optional[NonNegativeFloat] = 0.0
  39. top_p: Optional[confloat(ge=0, le=1)] = 1.0
  40. tfs: Optional[confloat(ge=0, le=1)] = 1.0
  41. eps_cutoff: Optional[confloat(ge=0,le=1000)] = 0.0
  42. eta_cutoff: Optional[NonNegativeFloat] = 0.0
  43. typical: Optional[confloat(ge=0, le=1)] = 1.0
  44. temperature: Optional[NonNegativeFloat] = 1.0
  45. use_memory: Optional[bool]
  46. use_story: Optional[bool]
  47. use_authors_note: Optional[bool]
  48. use_world_info: Optional[bool]
  49. use_userscripts: Optional[bool]
  50. soft_prompt: Optional[str]
  51. disable_output_formatting: Optional[bool]
  52. frmtrmblln: Optional[bool]
  53. frmtrmspch: Optional[bool]
  54. singleline: Optional[bool]
  55. use_default_badwordsids: Optional[bool]
  56. disable_input_formatting: Optional[bool]
  57. frmtadsnsp: Optional[bool]
  58. quiet: Optional[bool]
  59. sampler_order: Optional[conlist(int, min_items=6)]
  60. sampler_seed: Optional[conint(ge=0, le=2**64 - 1)]
  61. sampler_full_determinism: Optional[bool]
  62. stop_sequence: Optional[List[str]]
  63. @root_validator
  64. def check_context(cls, values):
  65. assert values.get("max_length") <= values.get("max_context_length"), f"max_length must not be larger than max_context_length"
  66. return values