explore_sql_db.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from evaluator import *
  2. DESCRIPTION = "Test if the model can interact with an SQLite database and provide the correct command to add a new person with specific criteria."
  3. TAGS = ['sql', 'agent']
  4. question = """
  5. I've just run `sqlite3 people.db` and you are now in the interpreter. I want to add another person named Victor. They should be as old as the youngest person. And have the most popular profession. What command should I run?
  6. On each round you should just give me just one command to type, I'll run it exactly as-is, and give you the exact output. I will not make any changes. So do not write any placholder text. I'll get started:
  7. INPUT: `.version`
  8. OUTPUT: ```SQLite 3.39.5 2022-10-14 20:58:05```
  9. INPUT: `.database`
  10. OUTPUT: ```main: people.db r/w```
  11. Now your turn. What do you want to run? Do not explain yourself, just give me exactly one command, ending with ; if necessary. Run small queries one at a time; think step by step. Start your response with INPUT:
  12. """
  13. def setup():
  14. import sqlite3
  15. # Create a connection object using the connect function
  16. conn = sqlite3.connect('people.db')
  17. # Create a cursor object using the cursor method
  18. cursor = conn.cursor()
  19. # Create table named people with columns for name, favorite friend, age, and profession
  20. cursor.execute('''CREATE TABLE people
  21. (a_name TEXT, favorite_friend TEXT, person_age INT, the_profession TEXT)''')
  22. # List of people data to insert into the people table
  23. people_data = [
  24. ('Alice', 'Bob', 29, 'Engineer'),
  25. ('Bob', 'Alice', 32, 'Doctor'),
  26. ('Charlie', 'Alice', 28, 'Artist'),
  27. ('David', 'Eve', 35, 'Doctor'),
  28. ('Eve', 'Frank', 26, 'Teacher'),
  29. ('Frank', 'Alice', 31, 'Scientist'),
  30. ('Grace', 'Heidi', 30, 'Nurse'),
  31. ('Heidi', 'Ivy', 25, 'Lawyer'),
  32. ('Ivy', 'Charlie', 34, 'Chef'),
  33. ('Judy', 'Grace', 27, 'Accountant')
  34. ]
  35. # Insert each person into the people table
  36. cursor.executemany('INSERT INTO people VALUES (?,?,?,?)', people_data)
  37. # Commit the changes
  38. conn.commit()
  39. # Close the connection
  40. conn.close()
  41. def test_ok():
  42. import sqlite3
  43. # Create a connection object using the connect function
  44. conn = sqlite3.connect('people.db')
  45. # Create a cursor object using the cursor method
  46. cursor = conn.cursor()
  47. # Get Victor's age and profession
  48. cursor.execute('SELECT person_age, the_profession FROM people WHERE a_name = "Victor"')
  49. out = cursor.fetchone()
  50. if out is None:
  51. return False
  52. age, profession = out
  53. return age == 25 and profession == 'Doctor'
  54. def do_extract(x):
  55. if 'INPUT:' in x:
  56. x = x.split("INPUT:")[1]
  57. try:
  58. to_send = x.split("`")[1]
  59. except:
  60. to_send = x.split("\n")[0]
  61. return to_send
  62. def do_prepare(x):
  63. print("Preparing to pass back", x)
  64. x = x.replace("\nsqlite>","")
  65. return f"I get an OUTPUT: ```{x}```.\n\nWhat is the exact command I should run next? Start your response with INPUT:"
  66. TestSqlExplore = Setup(setup) >> StartDockerJob("sqlite3 people.db", eos_string="sqlite>") >> question >> UntilDone(PyEvaluator(test_ok), (LLMConversation() >> PyFunc(do_extract) >> SendStdoutReceiveStdin() >> PyFunc(do_prepare)), max_iters=10) >> PyEvaluator(test_ok)
  67. if __name__ == "__main__":
  68. print(run_test(TestSqlExplore))