test_lora_huggingface.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. from typing import List
  2. import pytest
  3. from aphrodite.lora.models import LoRAModel
  4. from aphrodite.lora.utils import get_adapter_absolute_path
  5. from aphrodite.modeling.models.llama import LlamaForCausalLM
  6. # Provide absolute path and huggingface lora ids
  7. lora_fixture_name = ["sql_lora_files", "sql_lora_huggingface_id"]
  8. @pytest.mark.parametrize("lora_fixture_name", lora_fixture_name)
  9. def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
  10. lora_name = request.getfixturevalue(lora_fixture_name)
  11. supported_lora_modules = LlamaForCausalLM.supported_lora_modules
  12. packed_modules_mapping = LlamaForCausalLM.packed_modules_mapping
  13. embedding_modules = LlamaForCausalLM.embedding_modules
  14. embed_padding_modules = LlamaForCausalLM.embedding_padding_modules
  15. expected_lora_modules: List[str] = []
  16. for module in supported_lora_modules:
  17. if module in packed_modules_mapping:
  18. expected_lora_modules.extend(packed_modules_mapping[module])
  19. else:
  20. expected_lora_modules.append(module)
  21. lora_path = get_adapter_absolute_path(lora_name)
  22. # lora loading should work for either absolute path and hugggingface id.
  23. lora_model = LoRAModel.from_local_checkpoint(
  24. lora_path,
  25. expected_lora_modules,
  26. lora_model_id=1,
  27. device="cpu",
  28. embedding_modules=embedding_modules,
  29. embedding_padding_modules=embed_padding_modules)
  30. # Assertions to ensure the model is loaded correctly
  31. assert lora_model is not None, "LoRAModel is not loaded correctly"