1
0

test_worker.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import os
  2. import random
  3. import tempfile
  4. from unittest.mock import patch
  5. from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
  6. LoRAConfig, ModelConfig, ParallelConfig,
  7. SchedulerConfig)
  8. from aphrodite.lora.models import LoRAMapping
  9. from aphrodite.lora.request import LoRARequest
  10. from aphrodite.task_handler.worker import Worker
  11. @patch.dict(os.environ, {"RANK": "0"})
  12. def test_worker_apply_lora(sql_lora_files):
  13. worker = Worker(
  14. model_config=ModelConfig(
  15. "meta-llama/Llama-2-7b-hf",
  16. "meta-llama/Llama-2-7b-hf",
  17. tokenizer_mode="auto",
  18. trust_remote_code=False,
  19. seed=0,
  20. dtype="float16",
  21. revision=None,
  22. ),
  23. load_config=LoadConfig(
  24. download_dir=None,
  25. load_format="dummy",
  26. ),
  27. parallel_config=ParallelConfig(1, 1, False),
  28. scheduler_config=SchedulerConfig(32, 32, 32),
  29. device_config=DeviceConfig("cuda"),
  30. cache_config=CacheConfig(block_size=16,
  31. gpu_memory_utilization=1.,
  32. swap_space=0,
  33. cache_dtype="auto"),
  34. local_rank=0,
  35. rank=0,
  36. lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
  37. max_loras=32),
  38. distributed_init_method=f"file://{tempfile.mkstemp()[1]}",
  39. )
  40. worker.init_device()
  41. worker.load_model()
  42. worker.model_runner.set_active_loras([], LoRAMapping([], []))
  43. assert worker.list_loras() == set()
  44. n_loras = 32
  45. lora_requests = [
  46. LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras)
  47. ]
  48. worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], []))
  49. assert worker.list_loras() == {
  50. lora_request.lora_int_id
  51. for lora_request in lora_requests
  52. }
  53. for i in range(32):
  54. random.seed(i)
  55. iter_lora_requests = random.choices(lora_requests,
  56. k=random.randint(1, n_loras))
  57. random.shuffle(iter_lora_requests)
  58. iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)]
  59. worker.model_runner.set_active_loras(iter_lora_requests,
  60. LoRAMapping([], []))
  61. assert worker.list_loras().issuperset(
  62. {lora_request.lora_int_id
  63. for lora_request in iter_lora_requests})