main.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. ## Copyright (C) 2024, Nicholas Carlini <nicholas@carlini.com>.
  2. ##
  3. ## This program is free software: you can redistribute it and/or modify
  4. ## it under the terms of the GNU General Public License as published by
  5. ## the Free Software Foundation, either version 3 of the License, or
  6. ## (at your option) any later version.
  7. ##
  8. ## This program is distributed in the hope that it will be useful,
  9. ## but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. ## GNU General Public License for more details.
  12. ##
  13. ## You should have received a copy of the GNU General Public License
  14. ## along with this program. If not, see <http://www.gnu.org/licenses/>.
  15. import sys
  16. import re
  17. import importlib
  18. import tests
  19. import os
  20. import llm
  21. import json
  22. import argparse
  23. import pickle
  24. import subprocess
  25. import create_results_html
  26. from evaluator import Env, Conversation, run_test
  27. import multiprocessing as mp
  28. def run_one_test(test, test_llm, eval_llm, vision_eval_llm):
  29. """
  30. Runs just one test case and returns either true or false and the output.
  31. """
  32. import docker_controller
  33. env = Env()
  34. test.setup(env, Conversation(test_llm), test_llm, eval_llm, vision_eval_llm)
  35. for success, output in test():
  36. if success:
  37. if env.container:
  38. docker_controller.async_kill_container(env.docker, env.container)
  39. return True, output
  40. else:
  41. pass
  42. if env.container:
  43. docker_controller.async_kill_container(env.docker, env.container)
  44. return False, output
  45. def run_all_tests(test_llm, use_cache=True, which_tests=None):
  46. """
  47. Run every test case in the benchmark, returning a dictionary of the results
  48. of the format { "test_name": (success, output) }
  49. """
  50. test_llm = llm.LLM(test_llm, use_cache=use_cache)
  51. sr = {}
  52. for f in os.listdir("tests"):
  53. if not f.endswith(".py"): continue
  54. if which_tests is not None and f[:-3] not in which_tests:
  55. continue
  56. try:
  57. spec = importlib.util.spec_from_file_location(f[:-3], "tests/" + f)
  58. module = importlib.util.module_from_spec(spec)
  59. spec.loader.exec_module(module)
  60. except:
  61. print("SKIPPING TEST", f)
  62. continue
  63. test_case = [x for x in dir(module) if x.startswith("Test") and x != "TestCase"]
  64. if len(test_case) == 0:
  65. pass
  66. else:
  67. print(f)
  68. for t in test_case:
  69. print("Run Job", t)
  70. tmp = sys.stdout
  71. sys.stdout = open(os.devnull, 'w')
  72. test = getattr(module, t)
  73. ok, reason = run_one_test(test, test_llm, llm.eval_llm, llm.vision_eval_llm)
  74. sys.stdout = tmp
  75. if ok:
  76. print("Test Passes:", t)
  77. else:
  78. print("Test Fails:", t, 'from', f)
  79. sr[f+"."+t] = (ok, reason)
  80. return sr
  81. def get_tags():
  82. """
  83. Each test has a description and a set of tags. This returns dictionaries
  84. of the format { "test_name": "description" } and { "test_name": ["tag1", "tag2"] }
  85. """
  86. descriptions = {}
  87. tags = {}
  88. for f in os.listdir("tests"):
  89. if not f.endswith(".py"): continue
  90. try:
  91. spec = importlib.util.spec_from_file_location(f[:-3], "tests/" + f)
  92. module = importlib.util.module_from_spec(spec)
  93. spec.loader.exec_module(module)
  94. except:
  95. continue
  96. if 'TAGS' in dir(module):
  97. test_case = [x for x in dir(module) if x.startswith("Test") and x != "TestCase"]
  98. for t in test_case:
  99. tags[f+"."+t] = module.TAGS
  100. descriptions[f+"."+t] = module.DESCRIPTION
  101. return tags, descriptions
  102. def get_ordered_logs(logdir):
  103. hashes = []
  104. for githash in os.listdir(logdir):
  105. if '-run' in githash:
  106. print("There was a breaking change in how results are stored. Please move the runs into args.logdir/[git commit hash]/[the results].")
  107. exit(1)
  108. hashes.append(githash)
  109. command = ['git', 'log', '--pretty=format:%H']
  110. result = subprocess.run(command, capture_output=True, text=True)
  111. commit_hashes = result.stdout.strip().split('\n')
  112. commit_hashes = [x for x in commit_hashes if x in hashes]
  113. return commit_hashes
  114. def load_saved_runs(output_dir, model):
  115. """
  116. Load saved runs from the output directory for a specific model.
  117. """
  118. saved_runs = {}
  119. for file in sorted(os.listdir(output_dir)):
  120. if file.startswith(model+"-run"):
  121. one_run = None
  122. if '.json' in file:
  123. with open(os.path.join(output_dir, file), 'r') as f:
  124. one_run = json.loads(f.readlines()[-1])
  125. elif '.p' in file:
  126. one_run = pickle.load(open(os.path.join(output_dir, file), 'rb'))
  127. try:
  128. for k,(v1,v2) in one_run.items():
  129. if k not in saved_runs:
  130. saved_runs[k] = ([], [])
  131. saved_runs[k][0].append(v1)
  132. saved_runs[k][1].append(v2)
  133. except json.JSONDecodeError:
  134. print(f"Warning: Invalid JSON in file {file}")
  135. return saved_runs
  136. def main():
  137. parser = argparse.ArgumentParser(description="Run tests on language models.")
  138. parser.add_argument('--model', help='Specify a specific model to run.', type=str, action="append")
  139. parser.add_argument('--all-models', help='Run all models.', action='store_true')
  140. parser.add_argument('--test', help='Specify a specific test to run.', type=str, action="append")
  141. parser.add_argument('--times', help='Number of times to run the model(s).', type=int, default=1)
  142. parser.add_argument('--runid', help='Offset of the run ID for saving.', type=int, default=0)
  143. parser.add_argument('--logdir', help='Output path for the results.', type=str, default='results')
  144. parser.add_argument('--generate-report', help='Generate an HTML report.', action='store_true')
  145. parser.add_argument('--load-saved', help='Load saved evaluations.', action='store_true')
  146. parser.add_argument('--run-tests', help='Run a batch of tests.', action='store_true')
  147. parser.add_argument('--only-changed', help='Only run tests that have changed since the given commit (INCLUSIVE).')
  148. args = parser.parse_args()
  149. assert args.run_tests ^ args.load_saved, "Exactly one of --run-tests or --load-saved must be specified."
  150. if args.all_models and args.model:
  151. parser.error("The arguments --all-models and --model cannot be used together.")
  152. # Create the results directory if it doesn't exist
  153. if not os.path.exists(args.logdir):
  154. os.makedirs(args.logdir)
  155. models_to_run = []
  156. if args.model:
  157. models_to_run = args.model
  158. elif args.all_models:
  159. models_to_run = ["gpt-4o", "gpt-4-0125-preview", "claude-3-opus-20240229", "claude-3-sonnet-20240229", "gpt-3.5-turbo-0125", "gemini-pro", "mistral-large-latest", "mistral-medium"]
  160. data = {}
  161. for model in models_to_run:
  162. if args.load_saved:
  163. data[model] = {}
  164. commit_hashes = get_ordered_logs(args.logdir)
  165. print("Loading data from commits")
  166. for githash in commit_hashes[::-1]:
  167. print(githash)
  168. kvs = load_saved_runs(os.path.join(args.logdir, githash), model)
  169. for k,v in kvs.items():
  170. data[model][k] = v
  171. elif args.run_tests:
  172. tests_subset = None # run all of them
  173. if args.test:
  174. tests_subset = args.test # run the ones the user said
  175. elif args.only_changed:
  176. latest_commit_finished = args.only_changed
  177. command = ['git', 'diff', '--name-only', latest_commit_finished+"^", 'HEAD']
  178. result = subprocess.run(command, capture_output=True, text=True)
  179. changed_files = result.stdout.strip().split('\n')
  180. changed_files = [x.split("tests/")[1].split(".py")[0] for x in changed_files if x.startswith("tests/")]
  181. print("Running the following tests:\n -",
  182. "\n - ".join(changed_files))
  183. tests_subset = set(changed_files)
  184. command = ['git', 'rev-parse', 'HEAD']
  185. result = subprocess.run(command, capture_output=True, text=True)
  186. current_commit_hash = result.stdout.strip()
  187. data[model] = {}
  188. for i in range(args.times):
  189. print(f"Running {model}, iteration {i+args.runid}")
  190. result = run_all_tests(model, use_cache=False,
  191. which_tests=tests_subset)
  192. for k,(v1,v2) in result.items():
  193. if k not in data[model]:
  194. data[model][k] = ([], [])
  195. data[model][k][0].append(v1)
  196. data[model][k][1].append(v2)
  197. if not os.path.exists(os.path.join(args.logdir, current_commit_hash)):
  198. os.mkdir(os.path.join(args.logdir, current_commit_hash))
  199. with open(f"{args.logdir}/{current_commit_hash}/{model}-run{i+args.runid}.p", 'wb') as f:
  200. pickle.dump(result, f)
  201. else:
  202. raise "Unreachable"
  203. if args.generate_report:
  204. tags, descriptions = get_tags() # Assuming these functions are defined in your codebase
  205. create_results_html.generate_report(data, tags, descriptions)
  206. if __name__ == "__main__":
  207. main()