test_lora_checkpoints.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from typing import List
  2. import pytest
  3. from aphrodite.lora.models import LoRAModel
  4. from aphrodite.modeling.models.baichuan import BaiChuanBaseForCausalLM
  5. lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]
  6. @pytest.mark.parametrize("lora_name", lora_lst)
  7. def test_load_checkpoints(
  8. lora_name,
  9. baichuan_lora_files,
  10. baichuan_zero_lora_files,
  11. chatglm3_lora_files,
  12. ):
  13. supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
  14. packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
  15. embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
  16. embed_padding_modules = BaiChuanBaseForCausalLM.embedding_padding_modules
  17. expected_lora_modules: List[str] = []
  18. for module in supported_lora_modules:
  19. if module in packed_modules_mapping:
  20. expected_lora_modules.extend(packed_modules_mapping[module])
  21. else:
  22. expected_lora_modules.append(module)
  23. if lora_name == "baichuan7B":
  24. # For the baichuan7B model, load it's LoRA,
  25. # and the test should pass.
  26. LoRAModel.from_local_checkpoint(
  27. baichuan_lora_files,
  28. expected_lora_modules,
  29. lora_model_id=1,
  30. device="cpu",
  31. embedding_modules=embedding_modules,
  32. embedding_padding_modules=embed_padding_modules)
  33. elif lora_name == "baichuan7B-zero":
  34. #Test that the target_modules contain prefix
  35. # such as "model.layers.0.self_atten.W_pack", and
  36. # the test should pass.
  37. LoRAModel.from_local_checkpoint(
  38. baichuan_zero_lora_files,
  39. expected_lora_modules,
  40. lora_model_id=1,
  41. device="cpu",
  42. embedding_modules=embedding_modules,
  43. embedding_padding_modules=embed_padding_modules)
  44. else:
  45. # For the baichuan7B model, load chatglm3-6b's LoRA,
  46. # and the test should raise the following error.
  47. expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
  48. with pytest.raises(ValueError, match=expected_error):
  49. LoRAModel.from_local_checkpoint(
  50. chatglm3_lora_files,
  51. expected_lora_modules,
  52. lora_model_id=1,
  53. device="cpu",
  54. embedding_modules=embedding_modules,
  55. embedding_padding_modules=embed_padding_modules)