llm.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. ## Copyright (C) 2024, Nicholas Carlini <nicholas@carlini.com>.
  2. ##
  3. ## This program is free software: you can redistribute it and/or modify
  4. ## it under the terms of the GNU General Public License as published by
  5. ## the Free Software Foundation, either version 3 of the License, or
  6. ## (at your option) any later version.
  7. ##
  8. ## This program is distributed in the hope that it will be useful,
  9. ## but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. ## GNU General Public License for more details.
  12. ##
  13. ## You should have received a copy of the GNU General Public License
  14. ## along with this program. If not, see <http://www.gnu.org/licenses/>.
  15. from io import BytesIO
  16. import os
  17. import base64
  18. import requests
  19. import json
  20. import pickle
  21. import time
  22. from llms.openai_model import OpenAIModel
  23. from llms.anthropic_model import AnthropicModel
  24. from llms.mistral_model import MistralModel
  25. from llms.vertexai_model import VertexAIModel
  26. from llms.cohere_model import CohereModel
  27. from llms.moonshot_model import MoonshotAIModel
  28. from llms.bagel_dpo34_model import BagelDPOModel
  29. from llms.custom_model import CustomModel
  30. from llms.groq_model import GroqModel
  31. class LLM:
  32. def __init__(self, name="gpt-3.5-turbo", use_cache=True, override_hparams={}):
  33. self.name = name
  34. if 'gpt' in name or name.startswith('o1'):
  35. self.model = OpenAIModel(name)
  36. elif 'bagel' in name:
  37. self.model = BagelDPOModel(name)
  38. # elif 'llama' in name:
  39. # self.model = LLAMAModel(name)
  40. elif 'mistral' in name:
  41. self.model = MistralModel(name)
  42. elif 'bison' in name or 'gemini' in name:
  43. self.model = VertexAIModel(name)
  44. #elif 'gemini' in name:
  45. # self.model = GeminiModel(name)
  46. elif 'claude' in name:
  47. self.model = AnthropicModel(name)
  48. elif 'moonshot' in name:
  49. self.model = MoonshotAIModel(name)
  50. elif 'command' in name:
  51. self.model = CohereModel(name)
  52. elif 'llama3' in name or 'mixtral' in name or 'gemma' in name:
  53. self.model = GroqModel(name)
  54. else:
  55. self.model = CustomModel(name)
  56. print("Evaluating cutom model:%s" % name)
  57. self.model.hparams.update(override_hparams)
  58. self.use_cache = use_cache
  59. if use_cache:
  60. try:
  61. if not os.path.exists("tmp"):
  62. os.mkdir("tmp")
  63. self.cache = pickle.load(open(f"tmp/cache-{name.split('/')[-1]}.p","rb"))
  64. except:
  65. self.cache = {}
  66. else:
  67. self.cache = {}
  68. def __call__(self, conversation, add_image=None, max_tokens=None, skip_cache=False, json=False):
  69. if type(conversation) == str:
  70. conversation = [conversation]
  71. cache_key = tuple(conversation) if add_image is None else tuple(conversation + [add_image.tobytes()])
  72. if cache_key in self.cache and not skip_cache and self.use_cache:
  73. print(self.name, "GETCACHE", repr(self.cache[cache_key]))
  74. if len(self.cache[cache_key]) > 0:
  75. return self.cache[cache_key]
  76. else:
  77. print("Empty cache hit")
  78. print(self.name, "CACHE MISS", repr(conversation))
  79. import traceback
  80. from concurrent.futures import ThreadPoolExecutor, TimeoutError
  81. response = "Model API request failed"
  82. for _ in range(3):
  83. try:
  84. extra = {}
  85. if json:
  86. extra['json'] = json
  87. def request_with_timeout():
  88. return self.model.make_request(conversation, add_image=add_image, max_tokens=max_tokens, **extra)
  89. with ThreadPoolExecutor() as executor:
  90. future = executor.submit(request_with_timeout)
  91. try:
  92. response = future.result(timeout=60*10) # 10 minutes
  93. break # If successful, break out of the retry loop
  94. except TimeoutError:
  95. print("Request timed out after 60 seconds")
  96. response = "Model API request failed due to timeout"
  97. # Continue to the next retry
  98. except Exception as e:
  99. import traceback
  100. traceback.print_exc()
  101. print("RUN FAILED", e)
  102. traceback.print_exc()
  103. time.sleep(10)
  104. if self.use_cache and response != "Model API request failed":
  105. self.cache[cache_key] = response
  106. pickle.dump(self.cache, open(f"tmp/cache-{self.name.split('/')[-1]}.p","wb"))
  107. return response
  108. #llm = LLM("command")
  109. #llm = LLM("gpt-3.5-turbo")
  110. #llm = LLM("gpt-4-1106-preview")
  111. #llm = LLM("claude-instant-1.2")
  112. #llm = LLM("gpt-4-turbo-2024-04-09")
  113. #llm = LLM("gemini-1.5-pro-preview-0409")
  114. llm = LLM("o1-mini")
  115. #llm = LLM("claude-3-opus-20240229")
  116. #llm = LLM("claude-3-5-sonnet-20240620")
  117. #llm = LLM("mistral-tiny")
  118. #llm = LLM("gemini-pro", override_hparams={'temperature': 0.3}, use_cache=False)
  119. #llm = LLM("bagel")
  120. #llm = LLM("nebula")
  121. #llm = LLM("noushermes")
  122. #eval_llm = LLM("gpt-4-1106-preview")
  123. eval_llm = LLM("gpt-4o", override_hparams={'temperature': 0.1})
  124. #eval_llm = LLM("gpt-3.5-turbo", override_hparams={'temperature': 0.1})
  125. vision_eval_llm = LLM("gpt-4o", override_hparams={'temperature': 0.1})