evaluator.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. ## Copyright (C) 2024, Nicholas Carlini <nicholas@carlini.com>.
  2. ##
  3. ## This program is free software: you can redistribute it and/or modify
  4. ## it under the terms of the GNU General Public License as published by
  5. ## the Free Software Foundation, either version 3 of the License, or
  6. ## (at your option) any later version.
  7. ##
  8. ## This program is distributed in the hope that it will be useful,
  9. ## but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. ## GNU General Public License for more details.
  12. ##
  13. ## You should have received a copy of the GNU General Public License
  14. ## along with this program. If not, see <http://www.gnu.org/licenses/>.
  15. import subprocess
  16. import pickle
  17. import random
  18. import json
  19. import os
  20. import time
  21. import io
  22. import docker
  23. import inspect
  24. import re
  25. import numpy as np
  26. from PIL import Image
  27. import docker_controller
  28. from docker_controller import invoke_docker, DockerJob
  29. ## Constants that define which model we're supposed to be using:
  30. LLM = "llm" # The LLM under evaluation
  31. EVAL_LLM = "eval_llm" # A good LLM that can act as a judge
  32. VISION_EVAL_LLM = "vision_eval_llm" # And a good judge for vision tasks
  33. PYTHON_ENV = "python3" # The version of python to use
  34. class Env:
  35. """
  36. An environment that holds the local variables for each test case.
  37. """
  38. # The docker object we're running the test in
  39. docker = None
  40. # (Optionally, if in unsafe mode, the fake docker object)
  41. fake_docker_id = None
  42. # The docker container we're running the tests in
  43. container = None
  44. # A DockerJob object, if the test case requires it.
  45. # These objects allow the test to interact with stdin/out
  46. # of a process running in the docker container and must be
  47. # persistant across multiple classes in the test case.
  48. docker_job = None
  49. class Reason:
  50. """
  51. A class to keep track of the solution path of a test.
  52. """
  53. def __init__(self, node, children):
  54. self.node = node
  55. self.children = children
  56. def __repr__(self):
  57. return repr((self.node, self.children))
  58. class Node:
  59. """
  60. A node forms the operations in the computation graph for evaluating a test case;
  61. the most important object in this file. A test case might look like
  62. Node1 >> Node2 >> (Node3 & Node4)
  63. Each of these operators that connects nodes return a new node. So this graph
  64. would be equivalent to writing:
  65. ThenNode(ThenNode(Node1, Node2), AndNode(Node3, Node4))
  66. Once the computation graph has been constructed, evaluation is performed by
  67. calling __call__ on the root node, that then passes off the evalaution process
  68. as defined by each of the node types.
  69. """
  70. def __init__(self, runner):
  71. """
  72. Many sub-classes take a single argument, the runner, which is a function
  73. that should be executed for performing this node's computation.
  74. """
  75. self.runner = runner
  76. def setup(self, env, conv, llm, eval_llm, vision_eval_llm):
  77. """
  78. Once the graph has been constructed, before running __call__ to evaluate
  79. the test case, we run setup() on each of the nodes to pass all the
  80. necessary context.
  81. """
  82. self.env = env
  83. self.conv = conv
  84. self.llm = llm
  85. self.eval_llm = eval_llm
  86. self.vision_eval_llm = vision_eval_llm
  87. def __call__(self, orig_output=""):
  88. """
  89. Evaluate the test case, starting at this node. This is the main entry
  90. point for the evaluation process.
  91. Returns two arguments:
  92. 1. The output of the current node that should be passed to the next node.
  93. 2. A Reason object that explains how the output was generated for debugging.
  94. """
  95. raise NotImplementedError()
  96. def __rshift__(self, other_node):
  97. """
  98. Add the >> operator, which creates a ThenNode.
  99. Wrap any strings in a StringNode first, to allow for code like
  100. SetupNode >> "command to run" >> LLMRunNode
  101. """
  102. if isinstance(other_node, str):
  103. other_node = StringNode(other_node)
  104. return ThenNode(self, other_node)
  105. def __rrshift__(self, other_node):
  106. """
  107. If a string is the first node, we need to special case the
  108. rrshift operator, since we can't override the string class.
  109. Allows the (very common) pattern of
  110. "command to run" >> LLMRunNode
  111. """
  112. if isinstance(other_node, str):
  113. other_node = StringNode(other_node)
  114. return ThenNode(other_node, self)
  115. def __and__(self, other_node):
  116. return AndNode(self, other_node)
  117. def __or__(self, other_node):
  118. return OrNode(self, other_node)
  119. def __invert__(self):
  120. return NotNode(self)
  121. class StringNode(Node):
  122. def __init__(self, string):
  123. """
  124. A boring node, just returns the string.
  125. """
  126. self.string = string
  127. def __call__(self, orig_output=""):
  128. """
  129. Just pass whatever the provided constant string is to the next node.
  130. """
  131. yield self.string, Reason(type(self), self.string)
  132. class ThenNode(Node):
  133. """
  134. Perform two operations in sequence. The output of node1 is passed to node2.
  135. """
  136. def __init__(self, node1, node2):
  137. self.node1 = node1
  138. self.node2 = node2
  139. def setup(self, env, conv, llm, eval_llm, vision_eval_llm):
  140. super().setup(env, conv, llm, eval_llm, vision_eval_llm)
  141. self.node1.setup(env, conv, llm, eval_llm, vision_eval_llm)
  142. self.node2.setup(env=env, conv=conv, llm=llm, eval_llm=eval_llm, vision_eval_llm=vision_eval_llm)
  143. def __call__(self, orig_output=None):
  144. for output1, response1 in self.node1(orig_output):
  145. for output2, response2 in self.node2(output1):
  146. yield output2, Reason(type(self), (response1, response2))
  147. class AndNode(ThenNode):
  148. """
  149. An evaluation node that returns true if both outputs are true.
  150. """
  151. def __init__(self, node1, node2):
  152. self.node1 = node1
  153. self.node2 = node2
  154. def __call__(self, orig_output):
  155. for output1, txt1 in self.node1(orig_output):
  156. for output2, txt2 in self.node2(orig_output):
  157. yield output1 and output2, Reason(type(self), (txt1, txt2, output1 and output2))
  158. class OrNode(ThenNode):
  159. """
  160. An evaluation node that returns true if either outputs are true.
  161. """
  162. def __init__(self, node1, node2):
  163. self.node1 = node1
  164. self.node2 = node2
  165. def __call__(self, orig_output):
  166. for output1, txt1 in self.node1(orig_output):
  167. for output2, txt2 in self.node2(orig_output):
  168. yield output1 or output2, Reason(type(self), (txt1, txt2, output1 or output2))
  169. class NotNode(Node):
  170. """
  171. An evaluation node that negates the prior answer.
  172. """
  173. def __init__(self, node1):
  174. self.node1 = node1
  175. def setup(self, env, conv, llm, eval_llm, vision_eval_llm):
  176. super().setup(env, conv, llm, eval_llm, vision_eval_llm)
  177. self.node1.setup(env, conv, llm, eval_llm, vision_eval_llm)
  178. def __call__(self, orig_output):
  179. for output1, txt1 in self.node1(orig_output):
  180. yield not output1, Reason(type(self), [txt1, not output1])
  181. class PyFunc(Node):
  182. """
  183. A node that just runs a python function on the prior result.
  184. If the code crashes then just return an error.
  185. """
  186. def __call__(self, x):
  187. try:
  188. out = self.runner(x)
  189. if type(out) == tuple:
  190. ok, log = out
  191. return [(ok, Reason(type(self), (log, ok)))]
  192. else:
  193. return [(out, Reason(type(self), ("", out)))]
  194. except:
  195. return [("", Reason(type(self), ["Error", False]))]
  196. class Echo(Node):
  197. """
  198. A no-op node that helps debug test cases by printing whatever's being
  199. passed along the pipe. Kind of like the Unix tee command.
  200. """
  201. def __init__(self):
  202. pass
  203. def __call__(self, x):
  204. print('ECHOING:', x)
  205. yield x, Reason(type(self), None)
  206. class Setup(Node):
  207. """
  208. A node that starts up a new Docker environment with a specific setup file.
  209. Even though the argument is a method, this function needs to be able to
  210. extract the string representation of that function so it can be executed
  211. in the context of the docker environment.
  212. """
  213. def __call__(self, x):
  214. docker_controller.setup_docker(self.env)
  215. code = inspect.getsource(self.runner)
  216. to_invoke = self.runner.__name__
  217. code = code + f"\n\n{to_invoke}()"
  218. out = invoke_docker(self.env, {"setup.py": code.encode()}, [PYTHON_ENV, "setup.py"])
  219. return [(out, Reason(type(self), None))]
  220. class PyEvaluator(Node):
  221. """
  222. A node that runs a python program within the docker environment to judge whether
  223. or not the test case is solved.
  224. Even though the argument is a method, this function needs to be able to
  225. extract the string representation of that function so it can be executed
  226. in the context of the docker environment.
  227. """
  228. def __call__(self, x):
  229. code = inspect.getsource(self.runner)
  230. to_invoke = self.runner.__name__
  231. code = code + f"\n\nprint('final: ' + str({to_invoke}()))"
  232. out = invoke_docker(self.env, {"check.py": code.encode()}, [PYTHON_ENV, "check.py"])
  233. return [("final: True" in out, Reason(type(self), [out, "final: True" in out]))]
  234. class SubstringEvaluator(Node):
  235. """
  236. An evaluation node that checks if a substring is in the output.
  237. """
  238. def __init__(self, substr, lower=False):
  239. self.substr = substr
  240. self.lower = lower
  241. def __call__(self, output):
  242. if self.lower:
  243. cond = self.substr.lower() in output.lower()
  244. else:
  245. cond = self.substr in output
  246. if cond:
  247. yield True, Reason(type(self), [self.substr, True])
  248. else:
  249. yield False, Reason(type(self), [self.substr, False])
  250. class RegexEvaluator(Node):
  251. """
  252. An evaluation node that checks if a regex pattern matches the output.
  253. """
  254. def __init__(self, pattern, ignore_case=False):
  255. self.pattern = pattern
  256. self.ignore_case = ignore_case
  257. def __call__(self, output):
  258. import re
  259. flags = re.IGNORECASE if self.ignore_case else 0
  260. match = re.search(self.pattern, output, flags)
  261. if match:
  262. yield True, Reason(type(self), [self.pattern, True])
  263. else:
  264. yield False, Reason(type(self), [self.pattern, False])
  265. class ContainsIntEvaluator(Node):
  266. """
  267. An evaluation node that checks if a given integer is in the output.
  268. """
  269. def __init__(self, num):
  270. self.num = num
  271. def __call__(self, output):
  272. all_integers = re.findall(r'-?[\d,]*\d+\.?\d*', output)
  273. all_integers = [x.replace(",", "") for x in all_integers]
  274. if str(self.num) in all_integers:
  275. yield True, Reason(type(self), [self.num, True])
  276. else:
  277. yield False, Reason(type(self), [self.num, False])
  278. class EqualEvaluator(Node):
  279. """
  280. An evaluation node that checks if the output is equal to a given string.
  281. """
  282. def __init__(self, goal):
  283. self.goal = goal
  284. def __call__(self, output):
  285. if self.goal == output:
  286. yield True, Reason(type(self), [self.goal, True])
  287. else:
  288. yield False, Reason(type(self), [self.goal, False])
  289. class UntilDone(Node):
  290. """
  291. A node that will loop a specific body node until the condition returns true and it's finished.
  292. This node is useful when you want a model to, e.g., iterative interact
  293. with a sqlite database until it's completed some task.
  294. """
  295. def __init__(self, cond, body, max_iters=100):
  296. self.cond = cond
  297. self.body = body
  298. self.max_iters = max_iters
  299. def setup(self, env, conv, llm, eval_llm, vision_eval_llm):
  300. super().setup(env, conv, llm, eval_llm, vision_eval_llm)
  301. self.cond.setup(env, conv, llm, eval_llm, vision_eval_llm)
  302. self.body.setup(env, conv, llm, eval_llm, vision_eval_llm)
  303. def __call__(self, orig_output=None):
  304. log = []
  305. for i in range(self.max_iters):
  306. for output, txt in self.cond(orig_output):
  307. if output:
  308. yield orig_output, Reason(type(self), log)
  309. return
  310. orig_output, partial = next(self.body(orig_output))
  311. log.append(partial)
  312. yield orig_output, Reason(type(self), log)
  313. class ExtractJSON(Node):
  314. """
  315. A node that extracts a JSON object from the response.
  316. Usually you can just extract the json blob out of the response,
  317. but if the response contains multiple possible JSON blobs,
  318. then this node queries the model again asking it for just the JSON.
  319. """
  320. def __init__(self):
  321. pass
  322. def try_extract(self, output):
  323. output = output.replace("```json", "```")
  324. if "```" in output:
  325. yield output.split("```")[1]
  326. out1 = "\n".join(output.split("```")[1::2])
  327. yield out1
  328. else:
  329. yield output
  330. def __call__(self, orig_output):
  331. if orig_output.count("```") == 2:
  332. for maybe in self.try_extract(orig_output):
  333. yield maybe, Reason(type(self), [maybe])
  334. else:
  335. output = self.llm("Take the below answer to my question asking for a JSON output and just return the JSON object directly, with no other description, so I can copy it into an editor directly:\n" + orig_output)
  336. for maybe in self.try_extract(output):
  337. yield maybe, Reason(type(self), [maybe])
  338. class ExtractCode(Node):
  339. """
  340. A node that extracts code from the response
  341. Usually you can just extract the code out of the response,
  342. but if the response contains multiple possible code objects,
  343. then this node queries the model again asking it for just the code.
  344. """
  345. def __init__(self, keep_main=False, postfix="", manual=None, lang=None):
  346. self.keep_main = keep_main
  347. self.postfix = postfix
  348. self.manual = manual
  349. self.lang = lang
  350. def try_extract(self, output):
  351. output = re.sub('```[a-z]*', '```', output)
  352. if "```" in output:
  353. ans = output.split("```")[1] + "\n" + self.postfix
  354. else:
  355. ans = output + "\n" + self.postfix
  356. yield ans
  357. def __call__(self, orig_output):
  358. if orig_output.count("```") == 2:
  359. for maybe in self.try_extract(orig_output):
  360. yield maybe, Reason(type(self), maybe)
  361. return
  362. language = ""
  363. if self.lang is not None:
  364. language = f"(in {self.lang})"
  365. if self.manual is not None:
  366. output = self.llm(self.manual.replace("<A>", orig_output))
  367. elif self.keep_main:
  368. assert self.postfix == ""
  369. output = self.llm(f"Take the below answer to my programming question {language} and return just the complete code in a single file so I can copy and paste it into an editor and directly run it. Include any header and main necessary so I can run it by copying this one file. DO NOT MODIFY THE CODE OR WRITE NEW CODE. Here is the code: \n" + orig_output)
  370. else:
  371. output = self.llm(f"Take the below answer to my programming question {language} and return just the complete code in a single file so I can copy and paste it into an editor and directly run it. Remove any test cases or example code after the function definition. Remove any main function. I will write those myself. Do include header imports. DO NOT MODIFY THE CODE OR WRITE NEW CODE. Here is the code: \n" + orig_output + ("\nI will be running this code with the following helper functions:\n" + self.postfix if self.postfix else ""))
  372. for maybe in self.try_extract(output):
  373. yield maybe, Reason(type(self), maybe)
  374. class MakeFile(Node):
  375. """
  376. A node that makes a new file within the docker environment.
  377. """
  378. def __init__(self, name):
  379. self.name = name
  380. def __call__(self, code):
  381. out = invoke_docker(self.env, {self.name: code.encode()}, ["echo"])
  382. yield out, Reason(type(self), (code, out))
  383. class MakeFilesFromJSON(Node):
  384. """
  385. A node that makes a new file within the docker environment.
  386. """
  387. def __init__(self):
  388. pass
  389. def __call__(self, json_str):
  390. try:
  391. json_obj = json.loads(json_str)
  392. except:
  393. json_obj = {}
  394. for k in json_obj.keys():
  395. if not isinstance(json_obj[k], bytes):
  396. json_obj[k] = json_obj[k].encode()
  397. out = invoke_docker(self.env, json_obj, ["echo"])
  398. yield out, Reason(type(self), (json_str, out))
  399. class PythonRun(Node):
  400. """
  401. A node that runs the output from the prior command as a python function.
  402. Optionally append a set of test cases to the code that's been provided.
  403. """
  404. def __init__(self, test_case="", out_bytes=False):
  405. self.test_case = test_case
  406. self.out_bytes = out_bytes
  407. def __call__(self, code):
  408. code = code + "\n\n" + self.test_case
  409. out = invoke_docker(self.env, {"main.py": code.encode()}, [PYTHON_ENV, "main.py"], out_bytes=self.out_bytes)
  410. yield out, Reason(type(self), (code, out))
  411. class SQLRun(Node):
  412. """
  413. A node that runs the output from the prior command as a sqlite function.
  414. """
  415. def __init__(self):
  416. pass
  417. def __call__(self, code):
  418. out = invoke_docker(self.env, {"run.sql": code.encode()}, ["sqlite3", "-init", "run.sql", "database.db", ".exit"])
  419. yield out, Reason(type(self), (code, out))
  420. class BashRun(Node):
  421. """
  422. A node that runs the output from the prior command as a bash script.
  423. """
  424. def __init__(self, test_case="", args=[]):
  425. self.test_case = test_case
  426. self.args = args
  427. def __call__(self, code):
  428. code = code + "\n\n" + self.test_case
  429. out = invoke_docker(self.env, {"main.sh": code.encode()}, ["bash", "main.sh", *self.args])
  430. yield out, Reason(type(self), (code, out))
  431. class TerminalRun(Node):
  432. """
  433. A node that directly runs a command line argument in the terminal.
  434. """
  435. def __init__(self):
  436. return
  437. def __call__(self, code):
  438. if code:
  439. out = invoke_docker(self.env, {"main.sh": code.encode()}, ["bash", "main.sh"])
  440. else:
  441. out = ""
  442. yield out, Reason(type(self), (code, out))
  443. class RustRun(Node):
  444. """
  445. A node that compiles and runs the output Rust code from the prior command.
  446. Optionally append a set of test cases to the code that's been provided.
  447. """
  448. def __init__(self, test_case=""):
  449. self.test_case = test_case
  450. def __call__(self, code):
  451. if 'fn main' in code and 'fn main' in self.test_case:
  452. code = code.replace('fn main', 'fn __delete_this__main')
  453. code = code + "\n\n" + self.test_case
  454. out = invoke_docker(self.env, {"main.rs": code.encode(),
  455. "main.sh": "rustc -o a.out main.rs\n./a.out".encode()},
  456. ["bash", "main.sh"])
  457. yield out, Reason(type(self), (code, out))
  458. class CRun(Node):
  459. """
  460. A node that runs the output from the prior command as a c function.
  461. Optionally append a set of test cases to the code that's been provided.
  462. """
  463. def __init__(self, test_case="", out_bytes=False, gccflags="", argv=""):
  464. self.test_case = test_case
  465. self.out_bytes = out_bytes
  466. self.gccflags = gccflags
  467. self.argv = argv
  468. def __call__(self, code):
  469. if 'int main' in code and 'int main' in self.test_case:
  470. code = code.replace('int main', 'int __delete_this__main')
  471. code = code + "\n\n" + self.test_case
  472. out = invoke_docker(self.env, {"main.c": code.encode(),
  473. "main.sh": f"gcc -o a.out main.c -lm {self.gccflags}\n./a.out {self.argv}".encode()},
  474. ["bash", "main.sh"], out_bytes=self.out_bytes)
  475. yield out, Reason(type(self), (code, out))
  476. class CppRun(Node):
  477. """
  478. A node that runs the output from the prior command as a c++ function.
  479. Optionally append a set of test cases to the code that's been provided.
  480. """
  481. def __init__(self, test_case="", out_bytes=False):
  482. self.test_case = test_case
  483. self.out_bytes = out_bytes
  484. def __call__(self, code):
  485. if 'int main' in code and 'int main' in self.test_case:
  486. code = code.replace('int main', 'int __delete_this__main')
  487. code = code + "\n\n" + self.test_case
  488. out = invoke_docker(self.env, {"main.cpp": code.encode(),
  489. "main.sh": "g++ -o a.out main.cpp -lm\n./a.out".encode()},
  490. ["bash", "main.sh"], out_bytes=self.out_bytes)
  491. yield out, Reason(type(self), (code, out))
  492. class StartDockerJob(Node):
  493. """
  494. Start a new process within the docker container that's termainl interactive.
  495. This lets us test models that expect to be able to interface with other pieces
  496. of software by connecting the llm to stdin and stdout, sending data to the
  497. program and then reading the output back.
  498. """
  499. def __init__(self, command, eos_string):
  500. self.command = command
  501. self.eos_string = eos_string
  502. def __call__(self, text):
  503. self.env.docker_job = DockerJob(self.env.container.id if 'id' in dir(self.env.container) else self.env.container, self.eos_string)
  504. out = self.env.docker_job(self.command)
  505. yield out, Reason(type(self), (text, out))
  506. class SendStdoutReceiveStdin(Node):
  507. """
  508. This node takes a given piece of text and sends it to the stdin of whatever
  509. the current running DockerJob is. It then waits for the running process to handle
  510. this input, and returns the output that the DockerJob returned from stdout.
  511. """
  512. def __init__(self):
  513. pass
  514. def __call__(self, text):
  515. out = self.env.docker_job(text)
  516. yield out, Reason(type(self), (out,))
  517. class LLMRun(Node):
  518. """
  519. A node to invoke a language model on any given text.
  520. This is the core function that allows us to evaluate the capabilities of any model.
  521. """
  522. def __init__(self, check_prompt="<A>", llm=LLM, json=False):
  523. self.check_prompt = check_prompt
  524. self.which_llm = llm
  525. self.json = json
  526. def __call__(self, output):
  527. llm = getattr(self, self.which_llm)
  528. to_send = self.check_prompt.replace("<A>", output)
  529. out = llm(to_send, json=self.json)
  530. yield out, Reason(type(self), (to_send, out))
  531. class LLMConversation(Node):
  532. """
  533. A node to invoke a language model on any given text, but keeps state.
  534. This node allows us to send messages that refer to prior messages, whereas
  535. LLMRun is just a stateless operation.
  536. """
  537. def __init__(self, check_prompt="<A>"):
  538. self.check_prompt = check_prompt
  539. def __call__(self, output):
  540. to_send = self.check_prompt.replace("<A>", output)
  541. out = self.conv(to_send)
  542. yield out, Reason(type(self), (to_send, out))
  543. class SeleniumDraw(Node):
  544. """
  545. A node that creates a new HTML page, renders it in chrome, and then
  546. captures the output with Selenium.
  547. """
  548. def __init__(self):
  549. pass
  550. def __call__(self, code):
  551. try:
  552. #if 1:
  553. from selenium import webdriver
  554. from selenium.webdriver.chrome.options import Options
  555. chrome_options = Options()
  556. #chrome_options.add_argument("--headless")
  557. chrome_options.add_argument("--no-sandbox")
  558. chrome_options.add_argument("--disable-dev-shm-usage")
  559. r = random.randint(0, 1000000)
  560. open("/tmp/a%r.html"%r, "w").write(code)
  561. url = 'file:///tmp/a%d.html'%r
  562. browser = webdriver.Chrome(options=chrome_options)
  563. browser.get(url)
  564. time.sleep(2)
  565. screenshot_path = '/tmp/a%d.png'%r
  566. browser.save_screenshot(screenshot_path)
  567. browser.quit()
  568. time.sleep(1)
  569. img = Image.open(screenshot_path).convert('RGB')
  570. # get png data
  571. img_data = io.BytesIO()
  572. img.save(img_data, format="PNG")
  573. img_data.seek(0)
  574. img_data = img_data.read()
  575. yield img_data, Reason(type(self), img_data)
  576. #try:
  577. pass
  578. except:
  579. yield b"", Reason(type(self), b"")
  580. class JSONSubsetEvaluator(Node):
  581. def __init__(self, goal):
  582. self.goal = goal
  583. def check(self, goal, output):
  584. if isinstance(goal, dict) and isinstance(output, dict):
  585. # Iterate over all key-value pairs in the goal dictionary
  586. for key, value in goal.items():
  587. # Check if the key is present in the output
  588. if key not in output:
  589. return False
  590. # If the value is a dict or list, recursively check
  591. if isinstance(value, (dict, list)):
  592. if not self.check(value, output[key]):
  593. return False
  594. # Otherwise, check if the value matches
  595. elif output[key] != value:
  596. return False
  597. elif isinstance(goal, list) and isinstance(output, list):
  598. # Check each element in the goal list
  599. for item in goal:
  600. if item not in output:
  601. return False, Reason(self, ["Item not present", item])
  602. else:
  603. # Not a dict or list, so check if the values are equal
  604. if goal == output:
  605. return True
  606. else:
  607. return False
  608. # todo better error message
  609. return True
  610. def __call__(self, output):
  611. try:
  612. output = json.loads(output)
  613. except:
  614. yield False, Reason(type(self), [self.goal, False])
  615. return
  616. ok = self.check(self.goal, output)
  617. yield ok, Reason(type(self), [self.goal, ok])
  618. class LLMVisionRun(Node):
  619. """
  620. A node to evalaute an image output from a prior operation. Invokes the
  621. vision evaluation model.
  622. """
  623. def __init__(self, check_prompt="<A>", llm=VISION_EVAL_LLM):
  624. self.check_prompt = check_prompt
  625. self.which_llm = llm
  626. def __call__(self, output):
  627. llm = getattr(self, self.which_llm)
  628. try:
  629. if isinstance(output, bytes):
  630. img = Image.open(io.BytesIO(output))
  631. else:
  632. img = output
  633. out = llm(self.check_prompt, add_image=img, max_tokens=512)
  634. except Exception as e:
  635. out = str(e)
  636. yield out, Reason(type(self), (self.check_prompt, out))
  637. class Conversation:
  638. """
  639. An object that keeps track of the conversation history between the
  640. model and the test case prior questions/steps.
  641. """
  642. def __init__(self, llm,preample = ''):
  643. self.llm = llm
  644. self.history = []
  645. self.preample = preample
  646. def __call__(self, msg):
  647. if len(self.history)==0:
  648. msg = self.preample + msg
  649. self.history.append(msg)
  650. output = self.llm(self.history)
  651. self.history.append(output)
  652. return output
  653. def __repr__(self):
  654. return "Conversation(" + repr(self.history) + ")"
  655. def run_test(test):
  656. """
  657. A helper function to run just one specific test case.
  658. Used to debug tests by running each file directly.
  659. """
  660. from llm import llm, eval_llm, vision_eval_llm
  661. env = Env()
  662. test.setup(env, Conversation(llm), llm, eval_llm, vision_eval_llm)
  663. ok = False
  664. for success, output in test():
  665. if success:
  666. ok = True
  667. break
  668. import create_results_html
  669. fmt = create_results_html.format_markdown(output)
  670. while '\n\n' in fmt:
  671. fmt = fmt.replace('\n\n', '\n')
  672. fmt = fmt.replace("\n#", "\n\n#")
  673. print(fmt)
  674. if env.container:
  675. docker_controller.async_kill_container(env.docker, env.container)
  676. return ok
  677. def make_python_test(q_and_a, header=""):
  678. qs = [header]
  679. for q, a in q_and_a:
  680. qs.append(f"""
  681. answer = {q}
  682. expected = {a}
  683. assert answer == expected, f'Wrong answer; got {{answer}} instead of {{expected}}'""")
  684. qs.append("print('All tests passed')")
  685. return "\n".join(qs), "All tests passed"
  686. def make_c_test(q_and_a, header="", extra_methods=""):
  687. qs = []
  688. qs.append("#include<stdio.h>\n#include<stdlib.h>\n" + extra_methods + "\nint main() {")
  689. qs.append(header)
  690. for q, a in q_and_a:
  691. qs.append(f"""
  692. int answer = {q};
  693. int expected = {a};
  694. if (answer != expected) {{
  695. printf("Wrong answer; got %d instead of %d.\\n", answer, expected);
  696. exit(1);
  697. }}""")
  698. qs.append('printf("All tests passed\\n");')
  699. qs.append("}");
  700. return "\n".join(qs), "All tests passed"