configure.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import subprocess
  2. import sys
  3. from blessed import Terminal
  4. import signal
  5. import os
  6. from fuzzywuzzy import process
  7. def configure(args):
  8. term = Terminal()
  9. class Field:
  10. def __init__(self, name, prompt, field_type, options=None,
  11. validator=None, optional=True, default=None,
  12. advanced=False):
  13. self.name = name
  14. self.prompt = prompt
  15. self.field_type = field_type
  16. self.options = options
  17. self.validator = validator
  18. self.optional = optional
  19. self.value = default
  20. self.advanced = advanced
  21. def validate_int(value):
  22. return lambda x: x.isdigit() or x == ""
  23. def validate_float(value):
  24. return lambda x: x.replace('.', '').isdigit() and 0 <= float(x) <= value or x == ""
  25. fields = [
  26. Field("model", "Model Name", "input", optional=False),
  27. Field("tensor_parallel_size", "Tensor Parallel Size", "input",
  28. validator=validate_int, optional=True),
  29. Field("gpu_memory_utilization", "GPU Memory Utilization (%)", "input",
  30. validator=validate_float, optional=True),
  31. Field("enable_chunked_prefill", "Enable Chunked Prefill", "checkbox", optional=True),
  32. Field("revision", "Model Revision", "input", optional=True, advanced=True),
  33. Field("max_logprobs", "Max Logprobs", "input", validator=validate_int, optional=True, advanced=True),
  34. Field("trust_remote_code", "Trust Remote Code", "checkbox", optional=True, advanced=True),
  35. Field("dtype", "Data Type", "multichoice", options=["auto", "float16", "bfloat16", "float32"], optional=True, advanced=True),
  36. Field("enforce_eager", "Enforce Eager", "boolean", optional=True, default=True, advanced=True),
  37. ]
  38. def draw_ui(current_field, show_advanced):
  39. print(term.clear)
  40. print(term.bold + term.cyan + "Aphrodite Configuration" + term.normal)
  41. print(term.yellow + "Use ↑/↓ arrows to navigate, Space to toggle/cycle, Tab for completion" + term.normal)
  42. print(term.yellow + "Press F2 to toggle advanced options, Enter to finish, Esc or Ctrl+C to cancel" + term.normal)
  43. print()
  44. print(term.bold + "Required Fields:" + term.normal)
  45. for i, field in enumerate(fields):
  46. if not field.optional:
  47. draw_field(i, field, current_field)
  48. print()
  49. print(term.bold + "Optional Fields:" + term.normal)
  50. for i, field in enumerate(fields):
  51. if field.optional and not field.advanced:
  52. draw_field(i, field, current_field)
  53. if show_advanced:
  54. print()
  55. print(term.bold + "Advanced Options:" + term.normal)
  56. for i, field in enumerate(fields):
  57. if field.advanced:
  58. draw_field(i, field, current_field)
  59. print("\n" + term.yellow + "Press Enter when finished to launch the engine" + term.normal)
  60. def draw_field(i, field, current_field):
  61. if i == current_field:
  62. print(term.bold + term.green, end="")
  63. if field.field_type == "checkbox":
  64. checkbox = "[x]" if field.value else "[ ]"
  65. print(f"{'>' if i == current_field else ' '} {checkbox} {field.prompt}" + term.normal)
  66. elif field.field_type == "boolean":
  67. value = "True" if field.value is True else "False" if field.value is False else "Default"
  68. print(f"{'>' if i == current_field else ' '} {field.prompt}: {value}" + term.normal)
  69. elif field.field_type == "multichoice":
  70. print(f"{'>' if i == current_field else ' '} {field.prompt}: {field.value or 'Not set'}" + term.normal)
  71. else:
  72. print(f"{'>' if i == current_field else ' '} {field.prompt}: {field.value or ''}", end="")
  73. if i == current_field:
  74. print(term.blue + "█" + term.normal)
  75. else:
  76. print(term.normal)
  77. def expand_user_path(path):
  78. return os.path.expanduser(path)
  79. def path_complete(path):
  80. path = expand_user_path(path)
  81. directory = os.path.dirname(path) or '.'
  82. filename = os.path.basename(path)
  83. try:
  84. files = os.listdir(directory)
  85. except OSError:
  86. return []
  87. if not filename.startswith('.'):
  88. files = [f for f in files if not f.startswith('.')]
  89. matches = [f for f in files if f.lower().startswith(filename.lower())]
  90. if matches:
  91. return [os.path.join(directory, sorted(matches)[0])]
  92. fuzzy_matches = process.extractOne(filename, files)
  93. if fuzzy_matches and fuzzy_matches[1] > 70:
  94. return [os.path.join(directory, fuzzy_matches[0])]
  95. return []
  96. current_field = 0
  97. show_advanced_options = False
  98. def signal_handler(sig, frame):
  99. print(term.normal + term.clear)
  100. print("Configuration cancelled.")
  101. sys.exit(0)
  102. signal.signal(signal.SIGINT, signal_handler)
  103. try:
  104. with term.cbreak(), term.hidden_cursor():
  105. while True:
  106. draw_ui(current_field, show_advanced_options)
  107. key = term.inkey()
  108. if key.name == 'KEY_ESCAPE':
  109. raise KeyboardInterrupt
  110. elif key.name == 'KEY_TAB':
  111. if current_field == 0: # Model Name field
  112. completions = path_complete(fields[0].value)
  113. if completions:
  114. fields[0].value = completions[0]
  115. if os.path.isdir(completions[0]):
  116. fields[0].value += os.path.sep
  117. elif key.name == 'KEY_ENTER':
  118. if not fields[0].value: # Ensure required field is filled
  119. continue
  120. break # Finished entering data
  121. elif key.name == 'KEY_UP':
  122. current_field = (current_field - 1) % len(fields)
  123. while fields[current_field].advanced and not show_advanced_options:
  124. current_field = (current_field - 1) % len(fields)
  125. elif key.name == 'KEY_DOWN':
  126. current_field = (current_field + 1) % len(fields)
  127. while fields[current_field].advanced and not show_advanced_options:
  128. current_field = (current_field + 1) % len(fields)
  129. elif key == ' ':
  130. field = fields[current_field]
  131. if field.field_type == "checkbox":
  132. field.value = not field.value
  133. elif field.field_type == "boolean":
  134. if field.value is True:
  135. field.value = False
  136. elif field.value is False:
  137. field.value = None
  138. else:
  139. field.value = True
  140. elif field.field_type == "multichoice":
  141. if field.value is None:
  142. field.value = field.options[0]
  143. else:
  144. index = (field.options.index(field.value) + 1) % len(field.options)
  145. field.value = field.options[index]
  146. elif key.name == 'KEY_F2':
  147. show_advanced_options = not show_advanced_options
  148. if not show_advanced_options:
  149. current_field = next((i for i, f in enumerate(fields) if not f.advanced), 0)
  150. elif key.name == 'KEY_BACKSPACE':
  151. if not fields[current_field].field_type in ["checkbox", "boolean", "multichoice"]:
  152. fields[current_field].value = fields[current_field].value[:-1] if fields[current_field].value else ""
  153. elif not key.is_sequence and fields[current_field].field_type == "input":
  154. field = fields[current_field]
  155. new_value = (field.value or "") + key
  156. if not field.validator or field.validator(new_value):
  157. field.value = new_value
  158. except KeyboardInterrupt:
  159. print(term.normal + term.clear)
  160. print("Configuration cancelled.")
  161. return
  162. # Construct and execute command
  163. command = f"aphrodite run {fields[0].value}"
  164. for field in fields[1:]:
  165. if field.value:
  166. if field.field_type == "checkbox" and field.value:
  167. command += f" --{field.name.replace('_', '-')}"
  168. elif field.field_type == "boolean" and field.value is not None:
  169. command += f" --{field.name.replace('_', '-')} {str(field.value).lower()}"
  170. elif field.field_type in ["input", "multichoice"]:
  171. command += f" --{field.name.replace('_', '-')} {field.value}"
  172. elif field.name == "gmu":
  173. command += f" -gmu {float(field.value) / 100:.2f}"
  174. print(term.clear)
  175. print(term.green + "Executing command:" + term.normal)
  176. print(command)
  177. print()
  178. try:
  179. subprocess.run(command, shell=True, check=True)
  180. except subprocess.CalledProcessError as e:
  181. print(term.red + f"Error executing command: {e}" + term.normal,
  182. file=sys.stderr)
  183. sys.exit(1)