1
0

test_lora_lineage.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import json
  2. import openai # use the official client for correctness check
  3. import pytest
  4. import pytest_asyncio
  5. # downloading lora to test lora requests
  6. from huggingface_hub import snapshot_download
  7. from ...utils import RemoteOpenAIServer
  8. # any model with a chat template should work here
  9. MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
  10. # technically this needs Mistral-7B-v0.1 as base, but we're not testing
  11. # generation quality here
  12. LORA_NAME = "typeof/zephyr-7b-beta-lora"
  13. @pytest.fixture(scope="module")
  14. def zephyr_lora_files():
  15. return snapshot_download(repo_id=LORA_NAME)
  16. @pytest.fixture(scope="module")
  17. def server_with_lora_modules_json(zephyr_lora_files):
  18. # Define the json format LoRA module configurations
  19. lora_module_1 = {
  20. "name": "zephyr-lora",
  21. "path": zephyr_lora_files,
  22. "base_model_name": MODEL_NAME,
  23. }
  24. lora_module_2 = {
  25. "name": "zephyr-lora2",
  26. "path": zephyr_lora_files,
  27. "base_model_name": MODEL_NAME,
  28. }
  29. args = [
  30. # use half precision for speed and memory savings in CI environment
  31. "--dtype",
  32. "bfloat16",
  33. "--max-model-len",
  34. "8192",
  35. "--enforce-eager",
  36. # lora config below
  37. "--enable-lora",
  38. "--lora-modules",
  39. json.dumps(lora_module_1),
  40. json.dumps(lora_module_2),
  41. "--max-lora-rank",
  42. "64",
  43. "--max-cpu-loras",
  44. "2",
  45. "--max-num-seqs",
  46. "64",
  47. ]
  48. with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
  49. yield remote_server
  50. @pytest_asyncio.fixture
  51. async def client_for_lora_lineage(server_with_lora_modules_json):
  52. async with server_with_lora_modules_json.get_async_client() as async_client:
  53. yield async_client
  54. @pytest.mark.asyncio
  55. async def test_check_lora_lineage(
  56. client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files
  57. ):
  58. models = await client_for_lora_lineage.models.list()
  59. models = models.data
  60. served_model = models[0]
  61. lora_models = models[1:]
  62. assert served_model.id == MODEL_NAME
  63. assert served_model.root == MODEL_NAME
  64. assert served_model.parent is None
  65. assert all(
  66. lora_model.root == zephyr_lora_files for lora_model in lora_models
  67. )
  68. assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
  69. assert lora_models[0].id == "zephyr-lora"
  70. assert lora_models[1].id == "zephyr-lora2"