moonshot_model.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. from io import BytesIO
  2. from PIL import Image
  3. import base64
  4. from openai import OpenAI
  5. import json
  6. class MoonshotAIModel:
  7. def __init__(self, name):
  8. config = json.load(open("config.json"))
  9. api_key = config['llms']['moonshot']['api_key'].strip()
  10. self.client = OpenAI(api_key=api_key, base_url='https://api.moonshot.cn/v1')
  11. self.name = name
  12. self.hparams = config['hparams']
  13. self.hparams.update(config['llms']['moonshot'].get('hparams') or {})
  14. def make_request(self, conversation, add_image=None, max_tokens=None):
  15. conversation = [{"role": "user" if i%2 == 0 else "assistant", "content": content} for i,content in enumerate(conversation)]
  16. if add_image:
  17. buffered = BytesIO()
  18. add_image.convert("RGB").save(buffered, format="JPEG")
  19. img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
  20. img_str = f"data:image/jpeg;base64,{img_str}"
  21. conversation[0]['content'] = [{"type": "text", "text": conversation[0]['content']},
  22. {
  23. "type": "image_url",
  24. "image_url": {
  25. "url": img_str
  26. }
  27. }
  28. ]
  29. kwargs = {
  30. "messages": conversation,
  31. "max_tokens": max_tokens,
  32. }
  33. kwargs.update(self.hparams)
  34. for k,v in list(kwargs.items()):
  35. if v is None:
  36. del kwargs[k]
  37. out = self.client.chat.completions.create(
  38. model=self.name,
  39. **kwargs
  40. )
  41. return out.choices[0].message.content
  42. if __name__ == "__main__":
  43. import sys
  44. #q = sys.stdin.read().strip()
  45. q = "hello there"
  46. print(q+":", MoonshotAIModel("moonshot-v1-8k").make_request([q]))