import contextlib import gc import tempfile from collections import OrderedDict from typing import Dict, List, TypedDict from unittest.mock import MagicMock, patch import pytest import ray import torch import torch.nn as nn from huggingface_hub import snapshot_download import aphrodite from aphrodite.common.config import LoRAConfig from aphrodite.distributed import (destroy_distributed_environment, destroy_model_parallel, init_distributed_environment, initialize_model_parallel) from aphrodite.modeling.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) from aphrodite.modeling.layers.logits_processor import LogitsProcessor from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead from aphrodite.modeling.model_loader import get_model class ContextIDInfo(TypedDict): lora_id: int context_length: str class ContextInfo(TypedDict): lora: str context_length: str LONG_LORA_INFOS: List[ContextIDInfo] = [{ "lora_id": 1, "context_length": "16k", }, { "lora_id": 2, "context_length": "16k", }, { "lora_id": 3, "context_length": "32k", }] def cleanup(): destroy_model_parallel() destroy_distributed_environment() with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() gc.collect() torch.cuda.empty_cache() ray.shutdown() @pytest.fixture() def should_do_global_cleanup_after_test(request) -> bool: """Allow subdirectories to skip global cleanup by overriding this fixture. This can provide a ~10x speedup for non-GPU unit tests since they don't need to initialize torch. """ if request.node.get_closest_marker("skip_global_cleanup"): return False return True @pytest.fixture(autouse=True) def cleanup_fixture(should_do_global_cleanup_after_test: bool): yield if should_do_global_cleanup_after_test: cleanup() @pytest.fixture def dist_init(): temp_file = tempfile.mkstemp()[1] init_distributed_environment( world_size=1, rank=0, distributed_init_method=f"file://{temp_file}", local_rank=0, backend="nccl", ) initialize_model_parallel(1, 1) yield cleanup() @pytest.fixture def dist_init_torch_only(): if torch.distributed.is_initialized(): return temp_file = tempfile.mkstemp()[1] torch.distributed.init_process_group( backend="nccl", world_size=1, rank=0, init_method=f"file://{temp_file}", ) @pytest.fixture def dummy_model() -> nn.Module: model = nn.Sequential( OrderedDict([ ("dense1", ColumnParallelLinear(764, 100)), ("dense2", RowParallelLinear(100, 50)), ( "layer1", nn.Sequential( OrderedDict([ ("dense1", ColumnParallelLinear(100, 10)), ("dense2", RowParallelLinear(10, 50)), ])), ), ("act2", nn.ReLU()), ("output", ColumnParallelLinear(50, 10)), ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), ("logits_processor", LogitsProcessor(512)), ("sampler", Sampler()) ])) model.config = MagicMock() return model @pytest.fixture def dummy_model_gate_up() -> nn.Module: model = nn.Sequential( OrderedDict([ ("dense1", ColumnParallelLinear(764, 100)), ("dense2", RowParallelLinear(100, 50)), ( "layer1", nn.Sequential( OrderedDict([ ("dense1", ColumnParallelLinear(100, 10)), ("dense2", RowParallelLinear(10, 50)), ])), ), ("act2", nn.ReLU()), ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), ("outact", nn.Sigmoid()), # Special handling for lm_head & sampler ("lm_head", ParallelLMHead(512, 10)), ("logits_processor", LogitsProcessor(512)), ("sampler", Sampler()) ])) model.config = MagicMock() return model @pytest.fixture(scope="session") def sql_lora_huggingface_id(): # huggingface repo id is used to test lora runtime downloading. return "yard1/llama-2-7b-sql-lora-test" @pytest.fixture(scope="session") def sql_lora_files(sql_lora_huggingface_id): return snapshot_download(repo_id=sql_lora_huggingface_id) @pytest.fixture(scope="session") def mixtral_lora_files(): # Note: this module has incorrect adapter_config.json to test # https://github.com/aphrodite-project/aphrodite/pull/5909/files. return snapshot_download(repo_id="SangBinCho/mixtral-lora") @pytest.fixture(scope="session") def gemma_lora_files(): return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") @pytest.fixture(scope="session") def chatglm3_lora_files(): return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider") @pytest.fixture(scope="session") def baichuan_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") @pytest.fixture(scope="session") def baichuan_zero_lora_files(): # all the lora_B weights are initialized to zero. return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") @pytest.fixture(scope="session") def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") @pytest.fixture(scope="session") def phi2_lora_files(): return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora") @pytest.fixture(scope="session") def long_context_lora_files_16k_1(): return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1") @pytest.fixture(scope="session") def long_context_lora_files_16k_2(): return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2") @pytest.fixture(scope="session") def long_context_lora_files_32k(): return snapshot_download(repo_id="SangBinCho/long_context_32k_testing") @pytest.fixture(scope="session") def long_context_infos(long_context_lora_files_16k_1, long_context_lora_files_16k_2, long_context_lora_files_32k): cleanup() infos: Dict[int, ContextInfo] = {} for lora_checkpoint_info in LONG_LORA_INFOS: lora_id = lora_checkpoint_info["lora_id"] if lora_id == 1: lora = long_context_lora_files_16k_1 elif lora_id == 2: lora = long_context_lora_files_16k_2 elif lora_id == 3: lora = long_context_lora_files_32k else: raise AssertionError("Unknown lora id") infos[lora_id] = { "context_length": lora_checkpoint_info["context_length"], "lora": lora, } return infos @pytest.fixture def llama_2_7b_engine_extra_embeddings(): cleanup() get_model_old = get_model def get_model_patched(*, model_config, device_config, **kwargs): kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8) return get_model_old(model_config=model_config, device_config=device_config, **kwargs) with patch("aphrodite.task_handler.model_runner.get_model", get_model_patched): engine = aphrodite.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) yield engine.llm_engine del engine cleanup() @pytest.fixture def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. model_runner.model)