mistral_model.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import json
  2. import time
  3. import requests
  4. class MistralModel:
  5. def __init__(self, name):
  6. self.name = name
  7. config = json.load(open("config.json"))
  8. self.hparams = config['hparams']
  9. self.hparams.update(config['llms']['mistral'].get('hparams') or {})
  10. self.api_key = config['llms']['mistral']['api_key'].strip()
  11. self.headers = {
  12. 'Authorization': f'Bearer {self.api_key}', # Adjust if the API expects a different kind of authentication
  13. 'Content-Type': 'application/json',
  14. 'Accept': 'application/json'
  15. }
  16. self.endpoint = "https://api.mistral.ai/v1/chat/completions"
  17. def make_request(self, conversation, add_image=None, max_tokens=None):
  18. # Prepare the conversation messages in the required format
  19. formatted_conversation = [
  20. {"role": "user" if i % 2 == 0 else "assistant", "content": content}
  21. for i, content in enumerate(conversation)
  22. ]
  23. # Construct the data payload
  24. data = {
  25. "model": self.name,
  26. "messages": formatted_conversation,
  27. "max_tokens": max_tokens or 2048,
  28. }
  29. data.update(self.hparams)
  30. # Make the POST request to the API endpoint
  31. response = requests.post(self.endpoint, headers=self.headers, data=json.dumps(data))
  32. time.sleep(1)
  33. if response.status_code == 200:
  34. # Parse and return the response content
  35. return response.json()['choices'][0]['message']['content']
  36. else:
  37. # Handle errors or unsuccessful status codes as needed
  38. return f"API request failed with status code {response.status_code}"
  39. if __name__ == "__main__":
  40. import sys
  41. q = sys.stdin.read().strip()
  42. print(q+":", MistralModel("mistral-small").make_request([q]))