conftest.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import contextlib
  2. import gc
  3. import tempfile
  4. from collections import OrderedDict
  5. from typing import Dict, List, TypedDict
  6. from unittest.mock import MagicMock, patch
  7. import pytest
  8. import ray
  9. import torch
  10. import torch.nn as nn
  11. from huggingface_hub import snapshot_download
  12. import aphrodite
  13. from aphrodite.common.config import LoRAConfig
  14. from aphrodite.distributed import (destroy_distributed_environment,
  15. destroy_model_parallel,
  16. init_distributed_environment,
  17. initialize_model_parallel)
  18. from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
  19. MergedColumnParallelLinear,
  20. RowParallelLinear)
  21. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  22. from aphrodite.modeling.layers.sampler import Sampler
  23. from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
  24. from aphrodite.modeling.model_loader import get_model
  25. class ContextIDInfo(TypedDict):
  26. lora_id: int
  27. context_length: str
  28. class ContextInfo(TypedDict):
  29. lora: str
  30. context_length: str
  31. LONG_LORA_INFOS: List[ContextIDInfo] = [{
  32. "lora_id": 1,
  33. "context_length": "16k",
  34. }, {
  35. "lora_id": 2,
  36. "context_length": "16k",
  37. }, {
  38. "lora_id": 3,
  39. "context_length": "32k",
  40. }]
  41. def cleanup():
  42. destroy_model_parallel()
  43. destroy_distributed_environment()
  44. with contextlib.suppress(AssertionError):
  45. torch.distributed.destroy_process_group()
  46. gc.collect()
  47. torch.cuda.empty_cache()
  48. ray.shutdown()
  49. @pytest.fixture()
  50. def should_do_global_cleanup_after_test(request) -> bool:
  51. """Allow subdirectories to skip global cleanup by overriding this fixture.
  52. This can provide a ~10x speedup for non-GPU unit tests since they don't need
  53. to initialize torch.
  54. """
  55. if request.node.get_closest_marker("skip_global_cleanup"):
  56. return False
  57. return True
  58. @pytest.fixture(autouse=True)
  59. def cleanup_fixture(should_do_global_cleanup_after_test: bool):
  60. yield
  61. if should_do_global_cleanup_after_test:
  62. cleanup()
  63. @pytest.fixture
  64. def dist_init():
  65. temp_file = tempfile.mkstemp()[1]
  66. init_distributed_environment(
  67. world_size=1,
  68. rank=0,
  69. distributed_init_method=f"file://{temp_file}",
  70. local_rank=0,
  71. backend="nccl",
  72. )
  73. initialize_model_parallel(1, 1)
  74. yield
  75. cleanup()
  76. @pytest.fixture
  77. def dist_init_torch_only():
  78. if torch.distributed.is_initialized():
  79. return
  80. temp_file = tempfile.mkstemp()[1]
  81. torch.distributed.init_process_group(
  82. backend="nccl",
  83. world_size=1,
  84. rank=0,
  85. init_method=f"file://{temp_file}",
  86. )
  87. @pytest.fixture
  88. def dummy_model() -> nn.Module:
  89. model = nn.Sequential(
  90. OrderedDict([
  91. ("dense1", ColumnParallelLinear(764, 100)),
  92. ("dense2", RowParallelLinear(100, 50)),
  93. (
  94. "layer1",
  95. nn.Sequential(
  96. OrderedDict([
  97. ("dense1", ColumnParallelLinear(100, 10)),
  98. ("dense2", RowParallelLinear(10, 50)),
  99. ])),
  100. ),
  101. ("act2", nn.ReLU()),
  102. ("output", ColumnParallelLinear(50, 10)),
  103. ("outact", nn.Sigmoid()),
  104. # Special handling for lm_head & sampler
  105. ("lm_head", ParallelLMHead(512, 10)),
  106. ("logits_processor", LogitsProcessor(512)),
  107. ("sampler", Sampler())
  108. ]))
  109. model.config = MagicMock()
  110. return model
  111. @pytest.fixture
  112. def dummy_model_gate_up() -> nn.Module:
  113. model = nn.Sequential(
  114. OrderedDict([
  115. ("dense1", ColumnParallelLinear(764, 100)),
  116. ("dense2", RowParallelLinear(100, 50)),
  117. (
  118. "layer1",
  119. nn.Sequential(
  120. OrderedDict([
  121. ("dense1", ColumnParallelLinear(100, 10)),
  122. ("dense2", RowParallelLinear(10, 50)),
  123. ])),
  124. ),
  125. ("act2", nn.ReLU()),
  126. ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])),
  127. ("outact", nn.Sigmoid()),
  128. # Special handling for lm_head & sampler
  129. ("lm_head", ParallelLMHead(512, 10)),
  130. ("logits_processor", LogitsProcessor(512)),
  131. ("sampler", Sampler())
  132. ]))
  133. model.config = MagicMock()
  134. return model
  135. @pytest.fixture(scope="session")
  136. def sql_lora_huggingface_id():
  137. # huggingface repo id is used to test lora runtime downloading.
  138. return "yard1/llama-2-7b-sql-lora-test"
  139. @pytest.fixture(scope="session")
  140. def sql_lora_files(sql_lora_huggingface_id):
  141. return snapshot_download(repo_id=sql_lora_huggingface_id)
  142. @pytest.fixture(scope="session")
  143. def mixtral_lora_files():
  144. # Note: this module has incorrect adapter_config.json to test
  145. # https://github.com/aphrodite-project/aphrodite/pull/5909/files.
  146. return snapshot_download(repo_id="SangBinCho/mixtral-lora")
  147. @pytest.fixture(scope="session")
  148. def gemma_lora_files():
  149. return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
  150. @pytest.fixture(scope="session")
  151. def chatglm3_lora_files():
  152. return snapshot_download(repo_id="jeeejeee/chatglm3-text2sql-spider")
  153. @pytest.fixture(scope="session")
  154. def baichuan_lora_files():
  155. return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")
  156. @pytest.fixture(scope="session")
  157. def baichuan_zero_lora_files():
  158. # all the lora_B weights are initialized to zero.
  159. return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
  160. @pytest.fixture(scope="session")
  161. def tinyllama_lora_files():
  162. return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
  163. @pytest.fixture(scope="session")
  164. def phi2_lora_files():
  165. return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")
  166. @pytest.fixture(scope="session")
  167. def long_context_lora_files_16k_1():
  168. return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_1")
  169. @pytest.fixture(scope="session")
  170. def long_context_lora_files_16k_2():
  171. return snapshot_download(repo_id="SangBinCho/long_context_16k_testing_2")
  172. @pytest.fixture(scope="session")
  173. def long_context_lora_files_32k():
  174. return snapshot_download(repo_id="SangBinCho/long_context_32k_testing")
  175. @pytest.fixture(scope="session")
  176. def long_context_infos(long_context_lora_files_16k_1,
  177. long_context_lora_files_16k_2,
  178. long_context_lora_files_32k):
  179. cleanup()
  180. infos: Dict[int, ContextInfo] = {}
  181. for lora_checkpoint_info in LONG_LORA_INFOS:
  182. lora_id = lora_checkpoint_info["lora_id"]
  183. if lora_id == 1:
  184. lora = long_context_lora_files_16k_1
  185. elif lora_id == 2:
  186. lora = long_context_lora_files_16k_2
  187. elif lora_id == 3:
  188. lora = long_context_lora_files_32k
  189. else:
  190. raise AssertionError("Unknown lora id")
  191. infos[lora_id] = {
  192. "context_length": lora_checkpoint_info["context_length"],
  193. "lora": lora,
  194. }
  195. return infos
  196. @pytest.fixture
  197. def llama_2_7b_engine_extra_embeddings():
  198. cleanup()
  199. get_model_old = get_model
  200. def get_model_patched(*, model_config, device_config, **kwargs):
  201. kwargs["lora_config"] = LoRAConfig(max_loras=4, max_lora_rank=8)
  202. return get_model_old(model_config=model_config,
  203. device_config=device_config,
  204. **kwargs)
  205. with patch("aphrodite.worker.model_runner.get_model",
  206. get_model_patched):
  207. engine = aphrodite.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
  208. yield engine.llm_engine
  209. del engine
  210. cleanup()
  211. @pytest.fixture
  212. def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings):
  213. yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
  214. model_runner.model)