1
0

conftest.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import contextlib
  2. import functools
  3. import gc
  4. from typing import Callable, TypeVar
  5. import pytest
  6. import ray
  7. import torch
  8. from typing_extensions import ParamSpec
  9. from aphrodite.distributed import (destroy_distributed_environment,
  10. destroy_model_parallel)
  11. from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
  12. @pytest.fixture(autouse=True)
  13. def cleanup():
  14. destroy_model_parallel()
  15. destroy_distributed_environment()
  16. with contextlib.suppress(AssertionError):
  17. torch.distributed.destroy_process_group()
  18. ray.shutdown()
  19. gc.collect()
  20. torch.cuda.empty_cache()
  21. _P = ParamSpec("_P")
  22. _R = TypeVar("_R")
  23. def retry_until_skip(n: int):
  24. def decorator_retry(func: Callable[_P, _R]) -> Callable[_P, _R]:
  25. @functools.wraps(func)
  26. def wrapper_retry(*args: _P.args, **kwargs: _P.kwargs) -> _R:
  27. for i in range(n):
  28. try:
  29. return func(*args, **kwargs)
  30. except AssertionError:
  31. gc.collect()
  32. torch.cuda.empty_cache()
  33. if i == n - 1:
  34. pytest.skip(f"Skipping test after {n} attempts.")
  35. raise AssertionError("Code should not be reached")
  36. return wrapper_retry
  37. return decorator_retry
  38. @pytest.fixture(autouse=True)
  39. def tensorizer_config():
  40. config = TensorizerConfig(tensorizer_uri="aphrodite")
  41. return config