test_models.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import openai # use the official client for correctness check
  2. import pytest
  3. # downloading lora to test lora requests
  4. from huggingface_hub import snapshot_download
  5. from ...utils import RemoteOpenAIServer
  6. # any model with a chat template should work here
  7. MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
  8. # technically this needs Mistral-7B-v0.1 as base, but we're not testing
  9. # generation quality here
  10. LORA_NAME = "typeof/zephyr-7b-beta-lora"
  11. @pytest.fixture(scope="module")
  12. def zephyr_lora_files():
  13. return snapshot_download(repo_id=LORA_NAME)
  14. @pytest.fixture(scope="module")
  15. def server(zephyr_lora_files):
  16. args = [
  17. # use half precision for speed and memory savings in CI environment
  18. "--dtype",
  19. "bfloat16",
  20. "--max-model-len",
  21. "8192",
  22. "--enforce-eager",
  23. # lora config below
  24. "--enable-lora",
  25. "--lora-modules",
  26. f"zephyr-lora={zephyr_lora_files}",
  27. f"zephyr-lora2={zephyr_lora_files}",
  28. "--max-lora-rank",
  29. "64",
  30. "--max-cpu-loras",
  31. "2",
  32. "--max-num-seqs",
  33. "128",
  34. ]
  35. with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
  36. yield remote_server
  37. @pytest.fixture(scope="module")
  38. def client(server):
  39. return server.get_async_client()
  40. @pytest.mark.asyncio
  41. async def test_check_models(client: openai.AsyncOpenAI, zephyr_lora_files):
  42. models = await client.models.list()
  43. models = models.data
  44. served_model = models[0]
  45. lora_models = models[1:]
  46. assert served_model.id == MODEL_NAME
  47. assert served_model.root == MODEL_NAME
  48. assert all(lora_model.root == zephyr_lora_files
  49. for lora_model in lora_models)
  50. assert lora_models[0].id == "zephyr-lora"
  51. assert lora_models[1].id == "zephyr-lora2"