test_lora_manager.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. import os
  2. from typing import Dict, List
  3. import pytest
  4. import torch
  5. from safetensors.torch import load_file
  6. from torch import nn
  7. from aphrodite.common.config import LoRAConfig
  8. from aphrodite.lora.layers import (ColumnParallelLinearWithLoRA,
  9. MergedColumnParallelLinearWithLoRA,
  10. RowParallelLinearWithLoRA)
  11. from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
  12. from aphrodite.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
  13. LRUCacheLoRAModelManager)
  14. from aphrodite.lora.request import LoRARequest
  15. from aphrodite.lora.worker_manager import (LRUCacheWorkerLoRAManager,
  16. WorkerLoRAManager)
  17. from aphrodite.modeling.layers.linear import RowParallelLinear
  18. EMBEDDING_MODULES = {
  19. "embed_tokens": "input_embeddings",
  20. "lm_head": "output_embeddings",
  21. }
  22. EMBEDDING_PADDING_MODULES = ["lm_head"]
  23. def test_from_lora_tensors(sql_lora_files):
  24. tensors = load_file(
  25. os.path.join(sql_lora_files, "adapter_model.safetensors"))
  26. new_embeddings = load_file(
  27. os.path.join(sql_lora_files, "new_embeddings.safetensors"))
  28. lora_model = LoRAModel.from_lora_tensors(
  29. 1,
  30. 8,
  31. 16,
  32. tensors,
  33. "cuda",
  34. embeddings=new_embeddings,
  35. embedding_modules=EMBEDDING_MODULES,
  36. embedding_padding_modules=EMBEDDING_PADDING_MODULES)
  37. for module_name, lora in lora_model.loras.items():
  38. assert lora.module_name == module_name
  39. assert lora.rank == 8
  40. assert lora.lora_alpha == 16
  41. assert lora.lora_a is not None
  42. assert lora.lora_b is not None
  43. assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
  44. ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
  45. assert lora.lora_a.shape[1] == 8
  46. embeddings_module = next(
  47. (k for k in EMBEDDING_MODULES if k in module_name), None)
  48. if embeddings_module:
  49. assert torch.equal(
  50. lora.embeddings_tensor,
  51. new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
  52. device=lora.embeddings_tensor.device))
  53. else:
  54. assert lora.embeddings_tensor is None
  55. def create_lora(lora_id: int, model: nn.Module,
  56. sub_modules: List[str]) -> LoRAModel:
  57. loras: Dict[str, LoRALayerWeights] = {}
  58. for name in sub_modules:
  59. w = model.get_submodule(name).weight
  60. loras[name] = LoRALayerWeights(
  61. name,
  62. 8,
  63. 16,
  64. torch.rand([w.shape[1], 8], device="cuda"),
  65. torch.rand([8, w.shape[0]], device="cuda"),
  66. )
  67. return LoRAModel(lora_id, 8, loras)
  68. def create_packed_lora(
  69. lora_id: int,
  70. model: nn.Module,
  71. module_name,
  72. replaced_module_names,
  73. empty_replaced_module_name=None,
  74. ) -> LoRAModel:
  75. w = model.get_submodule(module_name).weight
  76. loras: Dict[str, LoRALayerWeights] = {}
  77. for replaced_module_name in replaced_module_names:
  78. if replaced_module_name == empty_replaced_module_name:
  79. continue
  80. loras[replaced_module_name] = LoRALayerWeights(
  81. replaced_module_name,
  82. 8,
  83. 16,
  84. torch.rand([w.shape[1], 8], device="cuda"),
  85. torch.rand([8, w.shape[0] // len(replaced_module_names)],
  86. device="cuda"),
  87. )
  88. return LoRAModel(lora_id, 8, loras)
  89. def test_replace_submodules(dist_init, dummy_model):
  90. model = dummy_model
  91. model.supported_lora_modules = ["dense1", "layer1.dense2"]
  92. model.packed_modules_mapping = {}
  93. manager = LoRAModelManager(
  94. model, 1, 1, 1,
  95. LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
  96. model = manager.model
  97. assert isinstance(model.get_submodule("dense1"),
  98. ColumnParallelLinearWithLoRA)
  99. assert isinstance(model.get_submodule("layer1.dense1"),
  100. ColumnParallelLinearWithLoRA)
  101. assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
  102. assert isinstance(model.get_submodule("layer1.dense2"),
  103. RowParallelLinearWithLoRA)
  104. def test_lora_model_manager(dist_init, dummy_model):
  105. model = dummy_model
  106. model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
  107. model.packed_modules_mapping = {}
  108. model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
  109. model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
  110. model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
  111. manager = LoRAModelManager(
  112. model, 2, 2, 2,
  113. LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
  114. assert all(x is None for x in manager.lora_index_to_id)
  115. assert manager.add_adapter(model_lora1)
  116. assert manager.activate_adapter(1)
  117. assert manager.lora_index_to_id[0] == 1
  118. assert not manager.add_adapter(model_lora1)
  119. assert not manager.activate_adapter(1)
  120. assert manager.add_adapter(model_lora2)
  121. assert manager.activate_adapter(2)
  122. assert manager.lora_index_to_id[0] == 1
  123. assert manager.lora_index_to_id[1] == 2
  124. assert not manager.add_adapter(model_lora2)
  125. assert not manager.activate_adapter(2)
  126. assert manager.add_adapter(model_lora3)
  127. assert manager.lora_index_to_id[0] == 1
  128. assert manager.lora_index_to_id[1] == 2
  129. with pytest.raises(ValueError):
  130. assert manager.activate_adapter(3)
  131. assert manager.lora_index_to_id[0] == 1
  132. assert manager.lora_index_to_id[1] == 2
  133. assert manager.remove_adapter(model_lora2.id)
  134. assert manager.lora_index_to_id[1] is None
  135. assert not manager.remove_adapter(model_lora2.id)
  136. assert manager.remove_adapter(model_lora1.id)
  137. assert not manager.remove_adapter(model_lora1.id)
  138. assert manager.add_adapter(model_lora1)
  139. assert manager.lora_index_to_id[0] is None
  140. assert manager.lora_index_to_id[1] is None
  141. assert manager.add_adapter(model_lora2)
  142. assert manager.activate_adapter(3)
  143. assert manager.lora_index_to_id[0] == 3
  144. assert manager.lora_index_to_id[1] is None
  145. assert manager.activate_adapter(2)
  146. assert manager.lora_index_to_id[0] == 3
  147. assert manager.lora_index_to_id[1] == 2
  148. def test_lora_lru_cache_model_manager(dist_init, dummy_model):
  149. model = dummy_model
  150. model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
  151. model.packed_modules_mapping = {}
  152. model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
  153. model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
  154. model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
  155. manager = LRUCacheLoRAModelManager(
  156. model, 2, 2, 2,
  157. LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
  158. assert all(x is None for x in manager.lora_index_to_id)
  159. assert manager.add_adapter(model_lora1)
  160. assert manager.activate_adapter(1)
  161. assert manager.lora_index_to_id[0] == 1
  162. assert not manager.add_adapter(model_lora1)
  163. assert not manager.activate_adapter(1)
  164. assert manager.add_adapter(model_lora2)
  165. assert manager.activate_adapter(2)
  166. assert manager.lora_index_to_id[0] == 1
  167. assert manager.lora_index_to_id[1] == 2
  168. assert not manager.add_adapter(model_lora2)
  169. assert not manager.activate_adapter(2)
  170. assert manager.add_adapter(model_lora3)
  171. assert manager.lora_index_to_id[0] == 1
  172. assert manager.lora_index_to_id[1] == 2
  173. assert manager.activate_adapter(3)
  174. assert manager.lora_index_to_id[0] == 3
  175. assert manager.lora_index_to_id[1] == 2
  176. assert manager.remove_adapter(model_lora2.id)
  177. assert manager.lora_index_to_id[1] is None
  178. assert not manager.remove_adapter(model_lora2.id)
  179. assert manager.remove_adapter(model_lora1.id)
  180. assert not manager.remove_adapter(model_lora1.id)
  181. assert manager.add_adapter(model_lora1)
  182. assert manager.activate_adapter(1)
  183. assert manager.lora_index_to_id[0] == 3
  184. assert manager.lora_index_to_id[1] == 1
  185. assert manager.add_adapter(model_lora2)
  186. assert manager.deactivate_adapter(3)
  187. assert manager.lora_index_to_id[0] is None
  188. assert manager.lora_index_to_id[1] == 1
  189. assert manager.activate_adapter(2)
  190. assert manager.lora_index_to_id[0] == 2
  191. assert manager.lora_index_to_id[1] == 1
  192. assert manager.activate_adapter(3)
  193. assert manager.lora_index_to_id[0] == 2
  194. assert manager.lora_index_to_id[1] == 3
  195. assert manager.pin_adapter(2)
  196. assert manager.lora_index_to_id[0] == 2
  197. assert manager.lora_index_to_id[1] == 3
  198. assert manager.activate_adapter(1)
  199. assert manager.lora_index_to_id[0] == 2
  200. assert manager.lora_index_to_id[1] == 1
  201. assert manager.deactivate_adapter(2)
  202. assert manager.lora_index_to_id[0] is None
  203. assert manager.lora_index_to_id[1] == 1
  204. assert manager.activate_adapter(3)
  205. assert manager.lora_index_to_id[0] == 3
  206. assert manager.lora_index_to_id[1] == 1
  207. assert manager.pin_adapter(3)
  208. assert manager.pin_adapter(1)
  209. with pytest.raises(RuntimeError):
  210. assert manager.pin_adapter(2)
  211. assert manager.lora_index_to_id[0] == 3
  212. assert manager.lora_index_to_id[1] == 1
  213. with pytest.raises(RuntimeError):
  214. assert manager.activate_adapter(2)
  215. assert manager.deactivate_adapter(3)
  216. assert manager.pin_adapter(2)
  217. assert manager.lora_index_to_id[0] == 2
  218. assert manager.lora_index_to_id[1] == 1
  219. assert manager.remove_adapter(3)
  220. with pytest.raises(ValueError):
  221. assert manager.pin_adapter(3)
  222. def test_lru_lora_model_manager(dist_init, dummy_model):
  223. # This tests just the LRU cache functionality, everything else is
  224. # tested in test_lora_model_manager
  225. model = dummy_model
  226. model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
  227. model.packed_modules_mapping = {}
  228. model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
  229. model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
  230. model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
  231. model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
  232. manager = LRUCacheLoRAModelManager(
  233. model, 2, 2, 2,
  234. LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
  235. assert all(x is None for x in manager.lora_index_to_id)
  236. # Add up to capacity
  237. assert manager.add_adapter(model_lora1)
  238. assert manager.add_adapter(model_lora2)
  239. assert manager.activate_adapter(1)
  240. assert manager.activate_adapter(2)
  241. assert set(manager.list_adapters()) == {1, 2}
  242. assert manager.lora_index_to_id[0] == 1
  243. assert manager.lora_index_to_id[1] == 2
  244. # Add over capacity
  245. assert manager.add_adapter(model_lora3)
  246. assert manager.add_adapter(model_lora4)
  247. assert manager.activate_adapter(3)
  248. assert manager.activate_adapter(4)
  249. assert set(manager.list_adapters()) == {3, 4}
  250. assert manager.lora_index_to_id[0] == 3
  251. assert manager.lora_index_to_id[1] == 4
  252. # Add 3 again to move it to the top and then add 2
  253. # should return false since it's in already
  254. assert not manager.add_adapter(model_lora3)
  255. assert not manager.activate_adapter(3)
  256. assert manager.add_adapter(model_lora2)
  257. assert manager.activate_adapter(2)
  258. assert set(manager.list_adapters()) == {3, 2}
  259. assert manager.lora_index_to_id[0] == 3
  260. assert manager.lora_index_to_id[1] == 2
  261. # Remove manually
  262. assert manager.remove_adapter(3)
  263. assert not manager.remove_adapter(3)
  264. assert set(manager.list_adapters()) == {2}
  265. assert manager.lora_index_to_id[0] is None
  266. assert manager.lora_index_to_id[1] == 2
  267. assert manager.add_adapter(model_lora3)
  268. assert manager.activate_adapter(3)
  269. assert manager.add_adapter(model_lora4)
  270. assert manager.activate_adapter(4)
  271. assert set(manager.list_adapters()) == {3, 4}
  272. assert manager.lora_index_to_id[0] == 3
  273. assert manager.lora_index_to_id[1] == 4
  274. assert manager.remove_oldest_adapter()
  275. assert set(manager.list_adapters()) == {4}
  276. assert manager.lora_index_to_id[0] is None
  277. assert manager.lora_index_to_id[1] == 4
  278. assert manager.remove_oldest_adapter()
  279. assert set(manager.list_adapters()) == set()
  280. assert all(x is None for x in manager.lora_index_to_id)
  281. assert not manager.remove_oldest_adapter()
  282. assert set(manager.list_adapters()) == set()
  283. assert all(x is None for x in manager.lora_index_to_id)
  284. # pinning
  285. assert manager.add_adapter(model_lora3)
  286. assert manager.activate_adapter(3)
  287. assert manager.add_adapter(model_lora4)
  288. assert manager.activate_adapter(4)
  289. assert set(manager.list_adapters()) == {3, 4}
  290. with pytest.raises(ValueError):
  291. assert manager.pin_adapter(1)
  292. assert manager.pin_adapter(3)
  293. # Remove manually
  294. assert manager.remove_adapter(3)
  295. assert not manager.remove_adapter(3)
  296. assert set(manager.list_adapters()) == {4}
  297. assert manager.lora_index_to_id[0] is None
  298. assert manager.lora_index_to_id[1] == 4
  299. assert manager.add_adapter(model_lora1)
  300. assert manager.pin_adapter(1)
  301. assert manager.add_adapter(model_lora2)
  302. assert manager.activate_adapter(2)
  303. assert set(manager.list_adapters()) == {1, 2}
  304. assert manager.lora_index_to_id[0] == 1
  305. assert manager.lora_index_to_id[1] == 2
  306. assert manager.remove_oldest_adapter()
  307. assert set(manager.list_adapters()) == {1}
  308. assert manager.lora_index_to_id[0] == 1
  309. assert manager.lora_index_to_id[1] is None
  310. with pytest.raises(RuntimeError):
  311. assert manager.remove_oldest_adapter()
  312. assert set(manager.list_adapters()) == {1}
  313. def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
  314. sql_lora_files):
  315. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
  316. worker_adapter_manager = LRUCacheWorkerLoRAManager(
  317. 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
  318. lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
  319. EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
  320. worker_adapter_manager.create_lora_manager(
  321. llama_2_7b_model_extra_embeddings)
  322. mapping = LoRAMapping([], [])
  323. worker_adapter_manager.set_active_adapters([
  324. LoRARequest("1", 1, sql_lora_files),
  325. LoRARequest("2", 2, sql_lora_files)
  326. ], mapping)
  327. assert worker_adapter_manager.list_adapters() == {1, 2}
  328. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
  329. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
  330. worker_adapter_manager.set_active_adapters([
  331. LoRARequest("1", 1, sql_lora_files),
  332. LoRARequest("3", 3, sql_lora_files),
  333. LoRARequest("4", 4, sql_lora_files)
  334. ], mapping)
  335. assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
  336. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
  337. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
  338. assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
  339. assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
  340. worker_adapter_manager.set_active_adapters([
  341. LoRARequest("1", 1, sql_lora_files),
  342. LoRARequest("2", 2, sql_lora_files),
  343. LoRARequest("5", 5, sql_lora_files)
  344. ], mapping)
  345. assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
  346. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
  347. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
  348. assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
  349. assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
  350. worker_adapter_manager.set_active_adapters([
  351. LoRARequest("1", 1, sql_lora_files),
  352. LoRARequest("1", 1, sql_lora_files),
  353. LoRARequest("1", 1, sql_lora_files)
  354. ], mapping)
  355. assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
  356. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
  357. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
  358. assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
  359. assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
  360. worker_adapter_manager.set_active_adapters([
  361. LoRARequest("6", 6, sql_lora_files),
  362. LoRARequest("7", 7, sql_lora_files),
  363. LoRARequest("8", 8, sql_lora_files)
  364. ], mapping)
  365. assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
  366. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
  367. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
  368. assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
  369. assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
  370. # Over capacity
  371. with pytest.raises(RuntimeError):
  372. worker_adapter_manager.set_active_adapters([
  373. LoRARequest("10", 10, sql_lora_files),
  374. LoRARequest("11", 11, sql_lora_files),
  375. LoRARequest("12", 12, sql_lora_files),
  376. LoRARequest("13", 13, sql_lora_files),
  377. LoRARequest("14", 14, sql_lora_files)
  378. ], mapping)
  379. def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
  380. sql_lora_files):
  381. # Should remove every LoRA not specified in the request.
  382. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
  383. worker_adapter_manager = WorkerLoRAManager(
  384. 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
  385. lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
  386. EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
  387. worker_adapter_manager.create_lora_manager(
  388. llama_2_7b_model_extra_embeddings)
  389. mapping = LoRAMapping([], [])
  390. worker_adapter_manager.set_active_adapters([
  391. LoRARequest("1", 1, sql_lora_files),
  392. LoRARequest("2", 2, sql_lora_files)
  393. ], mapping)
  394. assert worker_adapter_manager.list_adapters() == {1, 2}
  395. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
  396. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
  397. worker_adapter_manager.set_active_adapters([
  398. LoRARequest("1", 1, sql_lora_files),
  399. LoRARequest("3", 3, sql_lora_files),
  400. LoRARequest("4", 4, sql_lora_files)
  401. ], mapping)
  402. assert worker_adapter_manager.list_adapters() == {1, 3, 4}
  403. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
  404. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
  405. assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
  406. worker_adapter_manager.set_active_adapters([
  407. LoRARequest("1", 1, sql_lora_files),
  408. LoRARequest("2", 2, sql_lora_files),
  409. LoRARequest("5", 5, sql_lora_files)
  410. ], mapping)
  411. assert worker_adapter_manager.list_adapters() == {1, 2, 5}
  412. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
  413. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
  414. assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
  415. worker_adapter_manager.set_active_adapters([
  416. LoRARequest("1", 1, sql_lora_files),
  417. LoRARequest("1", 1, sql_lora_files),
  418. LoRARequest("1", 1, sql_lora_files)
  419. ], mapping)
  420. assert worker_adapter_manager.list_adapters() == {1}
  421. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
  422. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
  423. assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
  424. worker_adapter_manager.set_active_adapters([
  425. LoRARequest("6", 6, sql_lora_files),
  426. LoRARequest("7", 7, sql_lora_files),
  427. LoRARequest("8", 8, sql_lora_files)
  428. ], mapping)
  429. assert worker_adapter_manager.list_adapters() == {6, 7, 8}
  430. assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
  431. assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
  432. assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
  433. # Over capacity
  434. with pytest.raises(RuntimeError):
  435. worker_adapter_manager.set_active_adapters([
  436. LoRARequest("10", 10, sql_lora_files),
  437. LoRARequest("11", 11, sql_lora_files),
  438. LoRARequest("12", 12, sql_lora_files),
  439. LoRARequest("13", 13, sql_lora_files),
  440. LoRARequest("14", 14, sql_lora_files)
  441. ], mapping)
  442. def test_packed_loras(dist_init, dummy_model_gate_up):
  443. model = dummy_model_gate_up
  444. model.supported_lora_modules = ["gate_up_proj"]
  445. model.packed_modules_mapping = {
  446. "gate_up_proj": [
  447. "gate_proj",
  448. "up_proj",
  449. ],
  450. }
  451. model_lora = create_packed_lora(
  452. 1,
  453. model,
  454. module_name="gate_up_proj",
  455. replaced_module_names=["gate_proj", "up_proj"])
  456. model_lora1 = create_packed_lora(
  457. 2,
  458. model,
  459. module_name="gate_up_proj",
  460. replaced_module_names=["gate_proj", "up_proj"],
  461. empty_replaced_module_name="gate_proj",
  462. )
  463. manager = LoRAModelManager(
  464. model, 2, 2, 2,
  465. LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
  466. model = manager.model
  467. assert isinstance(model.get_submodule("gate_up_proj"),
  468. MergedColumnParallelLinearWithLoRA)
  469. assert manager.add_adapter(model_lora)
  470. assert manager.add_adapter(model_lora1)
  471. packed_lora = model_lora.get_lora("gate_up_proj")
  472. assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
  473. torch.testing.assert_close(packed_lora.lora_a[0],
  474. model_lora.get_lora("gate_proj").lora_a)
  475. torch.testing.assert_close(packed_lora.lora_b[0],
  476. model_lora.get_lora("gate_proj").lora_b)
  477. torch.testing.assert_close(packed_lora.lora_a[1],
  478. model_lora.get_lora("up_proj").lora_a)
  479. torch.testing.assert_close(packed_lora.lora_b[1],
  480. model_lora.get_lora("up_proj").lora_b)
  481. packed_lora1 = model_lora1.get_lora("gate_up_proj")
  482. assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
  483. assert packed_lora1.lora_a[0] is None
  484. assert packed_lora1.lora_b[0] is None
  485. torch.testing.assert_close(packed_lora1.lora_a[1],
  486. model_lora1.get_lora("up_proj").lora_a)
  487. torch.testing.assert_close(packed_lora1.lora_b[1],
  488. model_lora1.get_lora("up_proj").lora_b)