123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553 |
- import os
- from typing import Dict, List
- import pytest
- import torch
- from safetensors.torch import load_file
- from torch import nn
- from aphrodite.common.config import LoRAConfig
- from aphrodite.lora.layers import (ColumnParallelLinearWithLoRA,
- MergedColumnParallelLinearWithLoRA,
- RowParallelLinearWithLoRA)
- from aphrodite.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
- from aphrodite.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
- LRUCacheLoRAModelManager)
- from aphrodite.lora.request import LoRARequest
- from aphrodite.lora.worker_manager import (LRUCacheWorkerLoRAManager,
- WorkerLoRAManager)
- from aphrodite.modeling.layers.linear import RowParallelLinear
- EMBEDDING_MODULES = {
- "embed_tokens": "input_embeddings",
- "lm_head": "output_embeddings",
- }
- EMBEDDING_PADDING_MODULES = ["lm_head"]
- def test_from_lora_tensors(sql_lora_files):
- tensors = load_file(
- os.path.join(sql_lora_files, "adapter_model.safetensors"))
- new_embeddings = load_file(
- os.path.join(sql_lora_files, "new_embeddings.safetensors"))
- lora_model = LoRAModel.from_lora_tensors(
- 1,
- 8,
- 16,
- tensors,
- "cuda",
- embeddings=new_embeddings,
- embedding_modules=EMBEDDING_MODULES,
- embedding_padding_modules=EMBEDDING_PADDING_MODULES)
- for module_name, lora in lora_model.loras.items():
- assert lora.module_name == module_name
- assert lora.rank == 8
- assert lora.lora_alpha == 16
- assert lora.lora_a is not None
- assert lora.lora_b is not None
- assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
- ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
- assert lora.lora_a.shape[1] == 8
- embeddings_module = next(
- (k for k in EMBEDDING_MODULES if k in module_name), None)
- if embeddings_module:
- assert torch.equal(
- lora.embeddings_tensor,
- new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
- device=lora.embeddings_tensor.device))
- else:
- assert lora.embeddings_tensor is None
- def create_lora(lora_id: int, model: nn.Module,
- sub_modules: List[str]) -> LoRAModel:
- loras: Dict[str, LoRALayerWeights] = {}
- for name in sub_modules:
- w = model.get_submodule(name).weight
- loras[name] = LoRALayerWeights(
- name,
- 8,
- 16,
- torch.rand([w.shape[1], 8], device="cuda"),
- torch.rand([8, w.shape[0]], device="cuda"),
- )
- return LoRAModel(lora_id, 8, loras)
- def create_packed_lora(
- lora_id: int,
- model: nn.Module,
- module_name,
- replaced_module_names,
- empty_replaced_module_name=None,
- ) -> LoRAModel:
- w = model.get_submodule(module_name).weight
- loras: Dict[str, LoRALayerWeights] = {}
- for replaced_module_name in replaced_module_names:
- if replaced_module_name == empty_replaced_module_name:
- continue
- loras[replaced_module_name] = LoRALayerWeights(
- replaced_module_name,
- 8,
- 16,
- torch.rand([w.shape[1], 8], device="cuda"),
- torch.rand([8, w.shape[0] // len(replaced_module_names)],
- device="cuda"),
- )
- return LoRAModel(lora_id, 8, loras)
- def test_replace_submodules(dist_init, dummy_model):
- model = dummy_model
- model.supported_lora_modules = ["dense1", "layer1.dense2"]
- model.packed_modules_mapping = {}
- manager = LoRAModelManager(
- model, 1, 1, 1,
- LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
- model = manager.model
- assert isinstance(model.get_submodule("dense1"),
- ColumnParallelLinearWithLoRA)
- assert isinstance(model.get_submodule("layer1.dense1"),
- ColumnParallelLinearWithLoRA)
- assert isinstance(model.get_submodule("dense2"), RowParallelLinear)
- assert isinstance(model.get_submodule("layer1.dense2"),
- RowParallelLinearWithLoRA)
- def test_lora_model_manager(dist_init, dummy_model):
- model = dummy_model
- model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
- model.packed_modules_mapping = {}
- model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
- model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
- model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
- manager = LoRAModelManager(
- model, 2, 2, 2,
- LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
- assert all(x is None for x in manager.lora_index_to_id)
- assert manager.add_adapter(model_lora1)
- assert manager.activate_adapter(1)
- assert manager.lora_index_to_id[0] == 1
- assert not manager.add_adapter(model_lora1)
- assert not manager.activate_adapter(1)
- assert manager.add_adapter(model_lora2)
- assert manager.activate_adapter(2)
- assert manager.lora_index_to_id[0] == 1
- assert manager.lora_index_to_id[1] == 2
- assert not manager.add_adapter(model_lora2)
- assert not manager.activate_adapter(2)
- assert manager.add_adapter(model_lora3)
- assert manager.lora_index_to_id[0] == 1
- assert manager.lora_index_to_id[1] == 2
- with pytest.raises(ValueError):
- assert manager.activate_adapter(3)
- assert manager.lora_index_to_id[0] == 1
- assert manager.lora_index_to_id[1] == 2
- assert manager.remove_adapter(model_lora2.id)
- assert manager.lora_index_to_id[1] is None
- assert not manager.remove_adapter(model_lora2.id)
- assert manager.remove_adapter(model_lora1.id)
- assert not manager.remove_adapter(model_lora1.id)
- assert manager.add_adapter(model_lora1)
- assert manager.lora_index_to_id[0] is None
- assert manager.lora_index_to_id[1] is None
- assert manager.add_adapter(model_lora2)
- assert manager.activate_adapter(3)
- assert manager.lora_index_to_id[0] == 3
- assert manager.lora_index_to_id[1] is None
- assert manager.activate_adapter(2)
- assert manager.lora_index_to_id[0] == 3
- assert manager.lora_index_to_id[1] == 2
- def test_lora_lru_cache_model_manager(dist_init, dummy_model):
- model = dummy_model
- model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
- model.packed_modules_mapping = {}
- model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
- model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
- model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
- manager = LRUCacheLoRAModelManager(
- model, 2, 2, 2,
- LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
- assert all(x is None for x in manager.lora_index_to_id)
- assert manager.add_adapter(model_lora1)
- assert manager.activate_adapter(1)
- assert manager.lora_index_to_id[0] == 1
- assert not manager.add_adapter(model_lora1)
- assert not manager.activate_adapter(1)
- assert manager.add_adapter(model_lora2)
- assert manager.activate_adapter(2)
- assert manager.lora_index_to_id[0] == 1
- assert manager.lora_index_to_id[1] == 2
- assert not manager.add_adapter(model_lora2)
- assert not manager.activate_adapter(2)
- assert manager.add_adapter(model_lora3)
- assert manager.lora_index_to_id[0] == 1
- assert manager.lora_index_to_id[1] == 2
- assert manager.activate_adapter(3)
- assert manager.lora_index_to_id[0] == 3
- assert manager.lora_index_to_id[1] == 2
- assert manager.remove_adapter(model_lora2.id)
- assert manager.lora_index_to_id[1] is None
- assert not manager.remove_adapter(model_lora2.id)
- assert manager.remove_adapter(model_lora1.id)
- assert not manager.remove_adapter(model_lora1.id)
- assert manager.add_adapter(model_lora1)
- assert manager.activate_adapter(1)
- assert manager.lora_index_to_id[0] == 3
- assert manager.lora_index_to_id[1] == 1
- assert manager.add_adapter(model_lora2)
- assert manager.deactivate_adapter(3)
- assert manager.lora_index_to_id[0] is None
- assert manager.lora_index_to_id[1] == 1
- assert manager.activate_adapter(2)
- assert manager.lora_index_to_id[0] == 2
- assert manager.lora_index_to_id[1] == 1
- assert manager.activate_adapter(3)
- assert manager.lora_index_to_id[0] == 2
- assert manager.lora_index_to_id[1] == 3
- assert manager.pin_adapter(2)
- assert manager.lora_index_to_id[0] == 2
- assert manager.lora_index_to_id[1] == 3
- assert manager.activate_adapter(1)
- assert manager.lora_index_to_id[0] == 2
- assert manager.lora_index_to_id[1] == 1
- assert manager.deactivate_adapter(2)
- assert manager.lora_index_to_id[0] is None
- assert manager.lora_index_to_id[1] == 1
- assert manager.activate_adapter(3)
- assert manager.lora_index_to_id[0] == 3
- assert manager.lora_index_to_id[1] == 1
- assert manager.pin_adapter(3)
- assert manager.pin_adapter(1)
- with pytest.raises(RuntimeError):
- assert manager.pin_adapter(2)
- assert manager.lora_index_to_id[0] == 3
- assert manager.lora_index_to_id[1] == 1
- with pytest.raises(RuntimeError):
- assert manager.activate_adapter(2)
- assert manager.deactivate_adapter(3)
- assert manager.pin_adapter(2)
- assert manager.lora_index_to_id[0] == 2
- assert manager.lora_index_to_id[1] == 1
- assert manager.remove_adapter(3)
- with pytest.raises(ValueError):
- assert manager.pin_adapter(3)
- def test_lru_lora_model_manager(dist_init, dummy_model):
- # This tests just the LRU cache functionality, everything else is
- # tested in test_lora_model_manager
- model = dummy_model
- model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
- model.packed_modules_mapping = {}
- model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
- model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
- model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
- model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
- manager = LRUCacheLoRAModelManager(
- model, 2, 2, 2,
- LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
- assert all(x is None for x in manager.lora_index_to_id)
- # Add up to capacity
- assert manager.add_adapter(model_lora1)
- assert manager.add_adapter(model_lora2)
- assert manager.activate_adapter(1)
- assert manager.activate_adapter(2)
- assert set(manager.list_adapters()) == {1, 2}
- assert manager.lora_index_to_id[0] == 1
- assert manager.lora_index_to_id[1] == 2
- # Add over capacity
- assert manager.add_adapter(model_lora3)
- assert manager.add_adapter(model_lora4)
- assert manager.activate_adapter(3)
- assert manager.activate_adapter(4)
- assert set(manager.list_adapters()) == {3, 4}
- assert manager.lora_index_to_id[0] == 3
- assert manager.lora_index_to_id[1] == 4
- # Add 3 again to move it to the top and then add 2
- # should return false since it's in already
- assert not manager.add_adapter(model_lora3)
- assert not manager.activate_adapter(3)
- assert manager.add_adapter(model_lora2)
- assert manager.activate_adapter(2)
- assert set(manager.list_adapters()) == {3, 2}
- assert manager.lora_index_to_id[0] == 3
- assert manager.lora_index_to_id[1] == 2
- # Remove manually
- assert manager.remove_adapter(3)
- assert not manager.remove_adapter(3)
- assert set(manager.list_adapters()) == {2}
- assert manager.lora_index_to_id[0] is None
- assert manager.lora_index_to_id[1] == 2
- assert manager.add_adapter(model_lora3)
- assert manager.activate_adapter(3)
- assert manager.add_adapter(model_lora4)
- assert manager.activate_adapter(4)
- assert set(manager.list_adapters()) == {3, 4}
- assert manager.lora_index_to_id[0] == 3
- assert manager.lora_index_to_id[1] == 4
- assert manager.remove_oldest_adapter()
- assert set(manager.list_adapters()) == {4}
- assert manager.lora_index_to_id[0] is None
- assert manager.lora_index_to_id[1] == 4
- assert manager.remove_oldest_adapter()
- assert set(manager.list_adapters()) == set()
- assert all(x is None for x in manager.lora_index_to_id)
- assert not manager.remove_oldest_adapter()
- assert set(manager.list_adapters()) == set()
- assert all(x is None for x in manager.lora_index_to_id)
- # pinning
- assert manager.add_adapter(model_lora3)
- assert manager.activate_adapter(3)
- assert manager.add_adapter(model_lora4)
- assert manager.activate_adapter(4)
- assert set(manager.list_adapters()) == {3, 4}
- with pytest.raises(ValueError):
- assert manager.pin_adapter(1)
- assert manager.pin_adapter(3)
- # Remove manually
- assert manager.remove_adapter(3)
- assert not manager.remove_adapter(3)
- assert set(manager.list_adapters()) == {4}
- assert manager.lora_index_to_id[0] is None
- assert manager.lora_index_to_id[1] == 4
- assert manager.add_adapter(model_lora1)
- assert manager.pin_adapter(1)
- assert manager.add_adapter(model_lora2)
- assert manager.activate_adapter(2)
- assert set(manager.list_adapters()) == {1, 2}
- assert manager.lora_index_to_id[0] == 1
- assert manager.lora_index_to_id[1] == 2
- assert manager.remove_oldest_adapter()
- assert set(manager.list_adapters()) == {1}
- assert manager.lora_index_to_id[0] == 1
- assert manager.lora_index_to_id[1] is None
- with pytest.raises(RuntimeError):
- assert manager.remove_oldest_adapter()
- assert set(manager.list_adapters()) == {1}
- def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
- sql_lora_files):
- lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
- worker_adapter_manager = LRUCacheWorkerLoRAManager(
- 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
- lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
- EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
- worker_adapter_manager.create_lora_manager(
- llama_2_7b_model_extra_embeddings)
- mapping = LoRAMapping([], [])
- worker_adapter_manager.set_active_adapters([
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("2", 2, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {1, 2}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
- worker_adapter_manager.set_active_adapters([
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("3", 3, sql_lora_files),
- LoRARequest("4", 4, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
- worker_adapter_manager.set_active_adapters([
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("2", 2, sql_lora_files),
- LoRARequest("5", 5, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
- worker_adapter_manager.set_active_adapters([
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("1", 1, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
- worker_adapter_manager.set_active_adapters([
- LoRARequest("6", 6, sql_lora_files),
- LoRARequest("7", 7, sql_lora_files),
- LoRARequest("8", 8, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6
- # Over capacity
- with pytest.raises(RuntimeError):
- worker_adapter_manager.set_active_adapters([
- LoRARequest("10", 10, sql_lora_files),
- LoRARequest("11", 11, sql_lora_files),
- LoRARequest("12", 12, sql_lora_files),
- LoRARequest("13", 13, sql_lora_files),
- LoRARequest("14", 14, sql_lora_files)
- ], mapping)
- def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
- sql_lora_files):
- # Should remove every LoRA not specified in the request.
- lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
- worker_adapter_manager = WorkerLoRAManager(
- 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
- lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
- EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
- worker_adapter_manager.create_lora_manager(
- llama_2_7b_model_extra_embeddings)
- mapping = LoRAMapping([], [])
- worker_adapter_manager.set_active_adapters([
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("2", 2, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {1, 2}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
- worker_adapter_manager.set_active_adapters([
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("3", 3, sql_lora_files),
- LoRARequest("4", 4, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {1, 3, 4}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
- worker_adapter_manager.set_active_adapters([
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("2", 2, sql_lora_files),
- LoRARequest("5", 5, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {1, 2, 5}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
- worker_adapter_manager.set_active_adapters([
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("1", 1, sql_lora_files),
- LoRARequest("1", 1, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {1}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
- worker_adapter_manager.set_active_adapters([
- LoRARequest("6", 6, sql_lora_files),
- LoRARequest("7", 7, sql_lora_files),
- LoRARequest("8", 8, sql_lora_files)
- ], mapping)
- assert worker_adapter_manager.list_adapters() == {6, 7, 8}
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
- assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7
- # Over capacity
- with pytest.raises(RuntimeError):
- worker_adapter_manager.set_active_adapters([
- LoRARequest("10", 10, sql_lora_files),
- LoRARequest("11", 11, sql_lora_files),
- LoRARequest("12", 12, sql_lora_files),
- LoRARequest("13", 13, sql_lora_files),
- LoRARequest("14", 14, sql_lora_files)
- ], mapping)
- def test_packed_loras(dist_init, dummy_model_gate_up):
- model = dummy_model_gate_up
- model.supported_lora_modules = ["gate_up_proj"]
- model.packed_modules_mapping = {
- "gate_up_proj": [
- "gate_proj",
- "up_proj",
- ],
- }
- model_lora = create_packed_lora(
- 1,
- model,
- module_name="gate_up_proj",
- replaced_module_names=["gate_proj", "up_proj"])
- model_lora1 = create_packed_lora(
- 2,
- model,
- module_name="gate_up_proj",
- replaced_module_names=["gate_proj", "up_proj"],
- empty_replaced_module_name="gate_proj",
- )
- manager = LoRAModelManager(
- model, 2, 2, 2,
- LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
- model = manager.model
- assert isinstance(model.get_submodule("gate_up_proj"),
- MergedColumnParallelLinearWithLoRA)
- assert manager.add_adapter(model_lora)
- assert manager.add_adapter(model_lora1)
- packed_lora = model_lora.get_lora("gate_up_proj")
- assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
- torch.testing.assert_close(packed_lora.lora_a[0],
- model_lora.get_lora("gate_proj").lora_a)
- torch.testing.assert_close(packed_lora.lora_b[0],
- model_lora.get_lora("gate_proj").lora_b)
- torch.testing.assert_close(packed_lora.lora_a[1],
- model_lora.get_lora("up_proj").lora_a)
- torch.testing.assert_close(packed_lora.lora_b[1],
- model_lora.get_lora("up_proj").lora_b)
- packed_lora1 = model_lora1.get_lora("gate_up_proj")
- assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
- assert packed_lora1.lora_a[0] is None
- assert packed_lora1.lora_b[0] is None
- torch.testing.assert_close(packed_lora1.lora_a[1],
- model_lora1.get_lora("up_proj").lora_a)
- torch.testing.assert_close(packed_lora1.lora_b[1],
- model_lora1.get_lora("up_proj").lora_b)
|