123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413 |
- import functools
- import os
- import signal
- import subprocess
- import sys
- import time
- import warnings
- from contextlib import contextmanager
- from pathlib import Path
- from typing import Any, Callable, Dict, List, Optional
- import openai
- import requests
- from transformers import AutoTokenizer
- from typing_extensions import ParamSpec
- from aphrodite.common.utils import (FlexibleArgumentParser, get_open_port,
- is_hip)
- from aphrodite.distributed import (ensure_model_parallel_initialized,
- init_distributed_environment)
- from aphrodite.endpoints.openai.args import make_arg_parser
- from aphrodite.platforms import current_platform
- if current_platform.is_rocm():
- from amdsmi import (amdsmi_get_gpu_vram_usage,
- amdsmi_get_processor_handles, amdsmi_init,
- amdsmi_shut_down)
- @contextmanager
- def _nvml():
- try:
- amdsmi_init()
- yield
- finally:
- amdsmi_shut_down()
- elif current_platform.is_cuda():
- from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
- nvmlInit, nvmlShutdown)
- @contextmanager
- def _nvml():
- try:
- nvmlInit()
- yield
- finally:
- nvmlShutdown()
- else:
- @contextmanager
- def _nvml():
- yield
- APHRODITE_PATH = Path(__file__).parent.parent
- """Path to root of the Aphrodite repository."""
- class RemoteOpenAIServer:
- DUMMY_API_KEY = "token-abc123" # Aphrodite's OpenAI server needn't API key
- MAX_START_WAIT_S = 240 # wait for server to start for 240 seconds
- def __init__(
- self,
- model: str,
- cli_args: List[str],
- *,
- env_dict: Optional[Dict[str, str]] = None,
- auto_port: bool = True,
- ) -> None:
- if auto_port:
- if "-p" in cli_args or "--port" in cli_args:
- raise ValueError("You have manually specified the port"
- "when `auto_port=True`.")
- cli_args = cli_args + ["--port", str(get_open_port())]
- parser = FlexibleArgumentParser(
- description="Aphrodite's remote OpenAI server.")
- parser = make_arg_parser(parser)
- args = parser.parse_args(cli_args)
- self.host = str(args.host or 'localhost')
- self.port = int(args.port)
- env = os.environ.copy()
- # the current process might initialize cuda,
- # to be safe, we should use spawn method
- env['APHRODITE_WORKER_MULTIPROC_METHOD'] = 'spawn'
- if env_dict is not None:
- env.update(env_dict)
- self.proc = subprocess.Popen(["aphrodite", "run"] + [model] + cli_args,
- env=env,
- stdout=sys.stdout,
- stderr=sys.stderr)
- self._wait_for_server(url=self.url_for("health"),
- timeout=self.MAX_START_WAIT_S)
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_value, traceback):
- self.proc.terminate()
- try:
- self.proc.wait(3)
- except subprocess.TimeoutExpired:
- # force kill if needed
- self.proc.kill()
- def _wait_for_server(self, *, url: str, timeout: float):
- # run health check
- start = time.time()
- while True:
- try:
- if requests.get(url).status_code == 200:
- break
- except Exception as err:
- result = self.proc.poll()
- if result is not None and result != 0:
- raise RuntimeError("Server exited unexpectedly.") from err
- time.sleep(0.5)
- if time.time() - start > timeout:
- raise RuntimeError(
- "Server failed to start in time.") from err
- @property
- def url_root(self) -> str:
- return f"http://{self.host}:{self.port}"
- def url_for(self, *parts: str) -> str:
- return self.url_root + "/" + "/".join(parts)
- def get_client(self):
- return openai.OpenAI(
- base_url=self.url_for("v1"),
- api_key=self.DUMMY_API_KEY,
- )
- def get_async_client(self):
- return openai.AsyncOpenAI(
- base_url=self.url_for("v1"),
- api_key=self.DUMMY_API_KEY,
- )
- def compare_two_settings(model: str,
- arg1: List[str],
- arg2: List[str],
- env1: Optional[Dict[str, str]] = None,
- env2: Optional[Dict[str, str]] = None):
- """
- Launch API server with two different sets of arguments/environments
- and compare the results of the API calls.
- Args:
- model: The model to test.
- arg1: The first set of arguments to pass to the API server.
- arg2: The second set of arguments to pass to the API server.
- env1: The first set of environment variables to pass to the API server.
- env2: The second set of environment variables to pass to the API server.
- """
- tokenizer = AutoTokenizer.from_pretrained(model)
- prompt = "Hello, my name is"
- token_ids = tokenizer(prompt)["input_ids"]
- results = []
- for args, env in ((arg1, env1), (arg2, env2)):
- with RemoteOpenAIServer(model, args, env_dict=env) as server:
- client = server.get_client()
- # test models list
- models = client.models.list()
- models = models.data
- served_model = models[0]
- results.append({
- "test": "models_list",
- "id": served_model.id,
- "root": served_model.root,
- })
- # test with text prompt
- completion = client.completions.create(model=model,
- prompt=prompt,
- max_tokens=5,
- temperature=0.0)
- results.append({
- "test": "single_completion",
- "text": completion.choices[0].text,
- "finish_reason": completion.choices[0].finish_reason,
- "usage": completion.usage,
- })
- # test using token IDs
- completion = client.completions.create(
- model=model,
- prompt=token_ids,
- max_tokens=5,
- temperature=0.0,
- )
- results.append({
- "test": "token_ids",
- "text": completion.choices[0].text,
- "finish_reason": completion.choices[0].finish_reason,
- "usage": completion.usage,
- })
- # test seeded random sampling
- completion = client.completions.create(model=model,
- prompt=prompt,
- max_tokens=5,
- seed=33,
- temperature=1.0)
- results.append({
- "test": "seeded_sampling",
- "text": completion.choices[0].text,
- "finish_reason": completion.choices[0].finish_reason,
- "usage": completion.usage,
- })
- # test seeded random sampling with multiple prompts
- completion = client.completions.create(model=model,
- prompt=[prompt, prompt],
- max_tokens=5,
- seed=33,
- temperature=1.0)
- results.append({
- "test":
- "seeded_sampling",
- "text": [choice.text for choice in completion.choices],
- "finish_reason":
- [choice.finish_reason for choice in completion.choices],
- "usage":
- completion.usage,
- })
- # test simple list
- batch = client.completions.create(
- model=model,
- prompt=[prompt, prompt],
- max_tokens=5,
- temperature=0.0,
- )
- results.append({
- "test": "simple_list",
- "text0": batch.choices[0].text,
- "text1": batch.choices[1].text,
- })
- # test streaming
- batch = client.completions.create(
- model=model,
- prompt=[prompt, prompt],
- max_tokens=5,
- temperature=0.0,
- stream=True,
- )
- texts = [""] * 2
- for chunk in batch:
- assert len(chunk.choices) == 1
- choice = chunk.choices[0]
- texts[choice.index] += choice.text
- results.append({
- "test": "streaming",
- "texts": texts,
- })
- n = len(results) // 2
- arg1_results = results[:n]
- arg2_results = results[n:]
- for arg1_result, arg2_result in zip(arg1_results, arg2_results):
- assert arg1_result == arg2_result, (
- f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
- f"{arg1_result=} != {arg2_result=}")
- def init_test_distributed_environment(
- tp_size: int,
- pp_size: int,
- rank: int,
- distributed_init_port: str,
- local_rank: int = -1,
- ) -> None:
- distributed_init_method = f"tcp://localhost:{distributed_init_port}"
- init_distributed_environment(
- world_size=pp_size * tp_size,
- rank=rank,
- distributed_init_method=distributed_init_method,
- local_rank=local_rank)
- ensure_model_parallel_initialized(tp_size, pp_size)
- def multi_process_parallel(
- tp_size: int,
- pp_size: int,
- test_target: Any,
- ) -> None:
- import ray
- # Using ray helps debugging the error when it failed
- # as compared to multiprocessing.
- # NOTE: We need to set working_dir for distributed tests,
- # otherwise we may get import errors on ray workers
- ray.init(runtime_env={"working_dir": APHRODITE_PATH})
- distributed_init_port = get_open_port()
- refs = []
- for rank in range(tp_size * pp_size):
- refs.append(
- test_target.remote(tp_size, pp_size, rank, distributed_init_port))
- ray.get(refs)
- ray.shutdown()
- @contextmanager
- def error_on_warning():
- """
- Within the scope of this context manager, tests will fail if any warning
- is emitted.
- """
- with warnings.catch_warnings():
- warnings.simplefilter("error")
- yield
- @_nvml()
- def wait_for_gpu_memory_to_clear(devices: List[int],
- threshold_bytes: int,
- timeout_s: float = 120) -> None:
- # Use nvml instead of pytorch to reduce measurement error from torch cuda
- # context.
- start_time = time.time()
- while True:
- output: Dict[int, str] = {}
- output_raw: Dict[int, float] = {}
- for device in devices:
- if is_hip():
- dev_handle = amdsmi_get_processor_handles()[device]
- mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
- gb_used = mem_info["vram_used"] / 2**10
- else:
- dev_handle = nvmlDeviceGetHandleByIndex(device)
- mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
- gb_used = mem_info.used / 2**30
- output_raw[device] = gb_used
- output[device] = f'{gb_used:.02f}'
- print('gpu memory used (GB): ', end='')
- for k, v in output.items():
- print(f'{k}={v}; ', end='')
- print('')
- dur_s = time.time() - start_time
- if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
- print(f'Done waiting for free GPU memory on devices {devices=} '
- f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
- break
- if dur_s >= timeout_s:
- raise ValueError(f'Memory of devices {devices=} not free after '
- f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
- time.sleep(5)
- _P = ParamSpec("_P")
- def fork_new_process_for_each_test(
- f: Callable[_P, None]) -> Callable[_P, None]:
- """Decorator to fork a new process for each test function.
- """
- @functools.wraps(f)
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
- # Make the process the leader of its own process group
- # to avoid sending SIGTERM to the parent process
- os.setpgrp()
- from _pytest.outcomes import Skipped
- pid = os.fork()
- print(f"Fork a new process to run a test {pid}")
- if pid == 0:
- try:
- f(*args, **kwargs)
- except Skipped as e:
- # convert Skipped to exit code 0
- print(str(e))
- os._exit(0)
- except Exception:
- import traceback
- traceback.print_exc()
- os._exit(1)
- else:
- os._exit(0)
- else:
- pgid = os.getpgid(pid)
- _pid, _exitcode = os.waitpid(pid, 0)
- # ignore SIGTERM signal itself
- old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
- # kill all child processes
- os.killpg(pgid, signal.SIGTERM)
- # restore the signal handler
- signal.signal(signal.SIGTERM, old_signal_handler)
- assert _exitcode == 0, (f"function {f} failed when called with"
- f" args {args} and kwargs {kwargs}")
- return wrapper
|