12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- import multiprocessing
- import sys
- import time
- import torch
- from openai import OpenAI, OpenAIError
- from aphrodite import ModelRegistry
- from aphrodite.common.utils import get_open_port
- from aphrodite.modeling.models.opt import OPTForCausalLM
- from aphrodite.modeling.sampling_metadata import SamplingMetadata
- class MyOPTForCausalLM(OPTForCausalLM):
- def compute_logits(self, hidden_states: torch.Tensor,
- sampling_metadata: SamplingMetadata) -> torch.Tensor:
- # this dummy model always predicts the first token
- logits = super().compute_logits(hidden_states, sampling_metadata)
- logits.zero_()
- logits[:, 0] += 1.0
- return logits
- def server_function(port):
- # register our dummy model
- ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
- sys.argv = ["placeholder.py"] + \
- ("--model facebook/opt-125m --dtype"
- f" float32 --api-keys token-abc123 --port {port}").split()
- import runpy
- runpy.run_module('aphrodite.endpoints.openai.api_server',
- run_name='__main__')
- def test_oot_registration_for_api_server():
- port = get_open_port()
- server = multiprocessing.Process(target=server_function, args=(port, ))
- server.start()
- client = OpenAI(
- base_url=f"http://localhost:{port}/v1",
- api_key="token-abc123",
- )
- while True:
- try:
- completion = client.chat.completions.create(
- model="facebook/opt-125m",
- messages=[{
- "role": "system",
- "content": "You are a helpful assistant."
- }, {
- "role": "user",
- "content": "Hello!"
- }],
- temperature=0,
- )
- break
- except OpenAIError as e:
- if "Connection error" in str(e):
- time.sleep(3)
- else:
- raise e
- server.kill()
- generated_text = completion.choices[0].message.content
- # make sure only the first token is generated
- rest = generated_text.replace("<s>", "")
- assert rest == ""
|