1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import contextlib
- import functools
- import gc
- from typing import Callable, TypeVar
- import pytest
- import ray
- import torch
- from typing_extensions import ParamSpec
- from aphrodite.distributed import (destroy_distributed_environment,
- destroy_model_parallel)
- from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
- @pytest.fixture(autouse=True)
- def cleanup():
- destroy_model_parallel()
- destroy_distributed_environment()
- with contextlib.suppress(AssertionError):
- torch.distributed.destroy_process_group()
- ray.shutdown()
- gc.collect()
- torch.cuda.empty_cache()
- _P = ParamSpec("_P")
- _R = TypeVar("_R")
- def retry_until_skip(n: int):
- def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
- @functools.wraps(func)
- def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
- for i in range(n):
- try:
- return func(*args, **kwargs)
- except AssertionError:
- gc.collect()
- torch.cuda.empty_cache()
- if i == n - 1:
- pytest.skip(f"Skipping test after {n} attempts.")
- raise AssertionError("Code should not be reached")
- return wrapper_retry
- return decorator_retry
- @pytest.fixture(autouse=True)
- def tensorizer_config():
- config = TensorizerConfig(tensorizer_uri="aphrodite")
- return config
|