cohere_model.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. from io import BytesIO
  2. from PIL import Image
  3. import base64
  4. import cohere
  5. import json
  6. class CohereModel:
  7. def __init__(self, name):
  8. config = json.load(open("config.json"))
  9. api_key = config['llms']['cohere']['api_key'].strip()
  10. self.client = cohere.Client(api_key)
  11. self.name = name
  12. self.hparams = config['hparams']
  13. self.hparams.update(config['llms']['cohere'].get('hparams') or {})
  14. def make_request(self, conversation, add_image=None, max_tokens=None):
  15. prior_messages = [{"role": "USER" if i%2 == 0 else "CHATBOT", "message": content} for i,content in enumerate(conversation[:-1])]
  16. kwargs = {
  17. "chat_history": prior_messages,
  18. "message": conversation[-1],
  19. "max_tokens": max_tokens,
  20. "model": self.name
  21. }
  22. kwargs.update(self.hparams)
  23. for k,v in list(kwargs.items()):
  24. if v is None:
  25. del kwargs[k]
  26. out = self.client.chat(
  27. prompt_truncation='AUTO',
  28. **kwargs
  29. )
  30. return out.text
  31. if __name__ == "__main__":
  32. import sys
  33. #q = sys.stdin.read().strip()
  34. q = "what specific date?"
  35. print(q+":", CohereModel("command").make_request(["Who discovered relativity?", "Einstein.", q]))