test_registry.py 527 B

123456789101112131415
  1. import pytest
  2. import transformers
  3. from aphrodite.modeling.models import _MODELS, ModelRegistry
  4. @pytest.mark.parametrize("model_cls", _MODELS)
  5. def test_registry_imports(model_cls):
  6. if (model_cls in ("LlavaOnevisionForConditionalGeneration",
  7. "Qwen2VLForConditionalGeneration")
  8. and transformers.__version__ < "4.45"):
  9. pytest.skip("Waiting for next transformers release")
  10. # Ensure all model classes can be imported successfully
  11. ModelRegistry.resolve_model_cls([model_cls])