test_generate_multiple_loras.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import weakref
  2. import pytest
  3. # downloading lora to test lora requests
  4. from huggingface_hub import snapshot_download
  5. from aphrodite import LLM
  6. from aphrodite.lora.request import LoRARequest
  7. from ...conftest import cleanup
  8. MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
  9. PROMPTS = [
  10. "Hello, my name is",
  11. "The president of the United States is",
  12. "The capital of France is",
  13. "The future of AI is",
  14. ]
  15. LORA_NAME = "typeof/zephyr-7b-beta-lora"
  16. @pytest.fixture(scope="module")
  17. def llm():
  18. # pytest caches the fixture so we use weakref.proxy to
  19. # enable garbage collection
  20. llm = LLM(model=MODEL_NAME,
  21. tensor_parallel_size=1,
  22. max_model_len=8192,
  23. enable_lora=True,
  24. max_loras=4,
  25. max_lora_rank=64,
  26. max_num_seqs=128,
  27. enforce_eager=True)
  28. with llm.deprecate_legacy_api():
  29. yield weakref.proxy(llm)
  30. del llm
  31. cleanup()
  32. @pytest.fixture(scope="module")
  33. def zephyr_lora_files():
  34. return snapshot_download(repo_id=LORA_NAME)
  35. @pytest.mark.skip_global_cleanup
  36. def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
  37. lora_request = [
  38. LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
  39. for idx in range(len(PROMPTS))
  40. ]
  41. # Multiple SamplingParams should be matched with each prompt
  42. outputs = llm.generate(PROMPTS, lora_request=lora_request)
  43. assert len(PROMPTS) == len(outputs)
  44. # Exception raised, if the size of params does not match the size of prompts
  45. with pytest.raises(ValueError):
  46. outputs = llm.generate(PROMPTS, lora_request=lora_request[:1])
  47. # Single LoRARequest should be applied to every prompt
  48. single_lora_request = lora_request[0]
  49. outputs = llm.generate(PROMPTS, lora_request=single_lora_request)
  50. assert len(PROMPTS) == len(outputs)