test_server_oot_registration.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import multiprocessing
  2. import sys
  3. import time
  4. import torch
  5. from openai import OpenAI, OpenAIError
  6. from aphrodite import ModelRegistry
  7. from aphrodite.common.utils import get_open_port
  8. from aphrodite.modeling.models.opt import OPTForCausalLM
  9. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  10. class MyOPTForCausalLM(OPTForCausalLM):
  11. def compute_logits(self, hidden_states: torch.Tensor,
  12. sampling_metadata: SamplingMetadata) -> torch.Tensor:
  13. # this dummy model always predicts the first token
  14. logits = super().compute_logits(hidden_states, sampling_metadata)
  15. logits.zero_()
  16. logits[:, 0] += 1.0
  17. return logits
  18. def server_function(port):
  19. # register our dummy model
  20. ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
  21. sys.argv = ["placeholder.py"] + \
  22. ("--model facebook/opt-125m --dtype"
  23. f" float32 --api-keys token-abc123 --port {port}").split()
  24. import runpy
  25. runpy.run_module('aphrodite.endpoints.openai.api_server',
  26. run_name='__main__')
  27. def test_oot_registration_for_api_server():
  28. port = get_open_port()
  29. server = multiprocessing.Process(target=server_function, args=(port, ))
  30. server.start()
  31. client = OpenAI(
  32. base_url=f"http://localhost:{port}/v1",
  33. api_key="token-abc123",
  34. )
  35. while True:
  36. try:
  37. completion = client.chat.completions.create(
  38. model="facebook/opt-125m",
  39. messages=[{
  40. "role": "system",
  41. "content": "You are a helpful assistant."
  42. }, {
  43. "role": "user",
  44. "content": "Hello!"
  45. }],
  46. temperature=0,
  47. )
  48. break
  49. except OpenAIError as e:
  50. if "Connection error" in str(e):
  51. time.sleep(3)
  52. else:
  53. raise e
  54. server.kill()
  55. generated_text = completion.choices[0].message.content
  56. # make sure only the first token is generated
  57. rest = generated_text.replace("<s>", "")
  58. assert rest == ""