bagel_dpo34_model.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from io import BytesIO
  2. from PIL import Image
  3. import base64
  4. from openai import OpenAI
  5. import json
  6. class BagelDPOModel:
  7. def __init__(self, name):
  8. config = json.load(open("config.json"))
  9. api_key = config['llms']['bagel']['api_key'].strip()
  10. api_base = config['llms']['bagel']['api_base'].strip()
  11. self.model_name = config['llms']['bagel']['model_name'].strip()
  12. self.client = OpenAI(api_key=api_key, base_url=api_base)
  13. self.name = name
  14. config = json.load(open("config.json"))
  15. self.hparams = config['hparams']
  16. self.hparams.update(config['llms']['bagel'].get('hparams') or {})
  17. def make_request(self, conversation, add_image=None, max_tokens=None):
  18. conversation = [{"role": "user" if i%2 == 0 else "assistant", "content": content} for i,content in enumerate(conversation)]
  19. if add_image:
  20. buffered = BytesIO()
  21. add_image.convert("RGB").save(buffered, format="JPEG")
  22. img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
  23. img_str = f"data:image/jpeg;base64,{img_str}"
  24. conversation[0]['content'] = [{"type": "text", "text": conversation[0]['content']},
  25. {
  26. "type": "image_url",
  27. "image_url": {
  28. "url": img_str
  29. }
  30. }
  31. ]
  32. kwargs = {
  33. "messages": conversation,
  34. "max_tokens": max_tokens,
  35. }
  36. kwargs.update(self.hparams)
  37. for k,v in list(kwargs.items()):
  38. if v is None:
  39. del kwargs[k]
  40. out = self.client.chat.completions.create(
  41. model=self.model_name,
  42. **kwargs
  43. )
  44. return out.choices[0].message.content
  45. if __name__ == "__main__":
  46. import sys
  47. #q = sys.stdin.read().strip()
  48. q = "hello there"
  49. print(q+":", BagelDPOModel("mixtral").make_request([q]))