vertexai_model.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import vertexai
  2. from vertexai.language_models import ChatModel, InputOutputTextPair
  3. from vertexai.preview.generative_models import GenerativeModel
  4. import json
  5. import requests
  6. class VertexAIModel:
  7. def __init__(self, name):
  8. self.name = name
  9. config = json.load(open("config.json"))
  10. self.hparams = config['hparams']
  11. self.hparams.update(config['llms']['vertexai'].get('hparams') or {})
  12. project_id = config['llms']['vertexai']['project_id'].strip()
  13. vertexai.init(project=project_id, location="us-central1")
  14. if 'gemini' in name:
  15. self.chat_model = GenerativeModel(name)
  16. else:
  17. self.chat_model = ChatModel.from_pretrained(name)
  18. def make_request(self, conversation, add_image=None, max_tokens=2048, stream=False):
  19. if 'gemini' in self.name:
  20. conversation = [" " if c == "" else c for c in conversation]
  21. conf = {
  22. "max_output_tokens": 2048,
  23. }
  24. conf.update(self.hparams)
  25. response = self.chat_model.generate_content(conversation, generation_config=conf)
  26. else:
  27. conversation_pairs = conversation[:-1]
  28. conversation_pairs = [(a, b) for a, b in zip(conversation_pairs[::2], conversation_pairs[1::2])]
  29. chat = self.chat_model.start_chat(
  30. examples=[
  31. InputOutputTextPair(
  32. input_text=a,
  33. output_text=b,
  34. ) for a,b in conversation_pairs]
  35. )
  36. conf = {
  37. "max_output_tokens": 2048,
  38. }
  39. conf.update(self.hparams)
  40. response = chat.send_message(
  41. conversation[-1],
  42. **conf
  43. )
  44. try:
  45. return response.text
  46. except:
  47. return ''
  48. if __name__ == "__main__":
  49. import sys
  50. #q = sys.stdin.read().strip()
  51. q = "why?"
  52. print(VertexAIModel("gemini-1.5-pro-preview-0409").make_request(["hi, how are you doing", "i'm a bit sad", q]))