1
0

test_utils.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from collections import OrderedDict
  2. from unittest.mock import patch
  3. import pytest
  4. from huggingface_hub.utils import HfHubHTTPError
  5. from torch import nn
  6. from aphrodite.common.utils import LRUCache
  7. from aphrodite.lora.utils import (get_adapter_absolute_path,
  8. parse_fine_tuned_lora_name,
  9. replace_submodule)
  10. def test_parse_fine_tuned_lora_name_valid():
  11. fixture = {
  12. ("base_model.model.lm_head.lora_A.weight", "lm_head", True),
  13. ("base_model.model.lm_head.lora_B.weight", "lm_head", False),
  14. (
  15. "base_model.model.model.embed_tokens.lora_embedding_A",
  16. "model.embed_tokens",
  17. True,
  18. ),
  19. (
  20. "base_model.model.model.embed_tokens.lora_embedding_B",
  21. "model.embed_tokens",
  22. False,
  23. ),
  24. (
  25. "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
  26. "model.layers.9.mlp.down_proj",
  27. True,
  28. ),
  29. (
  30. "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
  31. "model.layers.9.mlp.down_proj",
  32. False,
  33. ),
  34. }
  35. for name, module_name, is_lora_a in fixture:
  36. assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
  37. def test_parse_fine_tuned_lora_name_invalid():
  38. fixture = {
  39. "weight",
  40. "base_model.weight",
  41. "base_model.model.weight",
  42. }
  43. for name in fixture:
  44. with pytest.raises(ValueError, match="unsupported LoRA weight"):
  45. parse_fine_tuned_lora_name(name)
  46. def test_replace_submodule():
  47. model = nn.Sequential(
  48. OrderedDict([
  49. ("dense1", nn.Linear(764, 100)),
  50. ("act1", nn.ReLU()),
  51. ("dense2", nn.Linear(100, 50)),
  52. (
  53. "seq1",
  54. nn.Sequential(
  55. OrderedDict([
  56. ("dense1", nn.Linear(100, 10)),
  57. ("dense2", nn.Linear(10, 50)),
  58. ])),
  59. ),
  60. ("act2", nn.ReLU()),
  61. ("output", nn.Linear(50, 10)),
  62. ("outact", nn.Sigmoid()),
  63. ]))
  64. sigmoid = nn.Sigmoid()
  65. replace_submodule(model, "act1", sigmoid)
  66. assert dict(model.named_modules())["act1"] == sigmoid
  67. dense2 = nn.Linear(1, 5)
  68. replace_submodule(model, "seq1.dense2", dense2)
  69. assert dict(model.named_modules())["seq1.dense2"] == dense2
  70. class TestLRUCache(LRUCache):
  71. def _on_remove(self, key, value):
  72. if not hasattr(self, "_remove_counter"):
  73. self._remove_counter = 0
  74. self._remove_counter += 1
  75. def test_lru_cache():
  76. cache = TestLRUCache(3)
  77. cache.put(1, 1)
  78. assert len(cache) == 1
  79. cache.put(1, 1)
  80. assert len(cache) == 1
  81. cache.put(2, 2)
  82. assert len(cache) == 2
  83. cache.put(3, 3)
  84. assert len(cache) == 3
  85. assert set(cache.cache) == {1, 2, 3}
  86. cache.put(4, 4)
  87. assert len(cache) == 3
  88. assert set(cache.cache) == {2, 3, 4}
  89. assert cache._remove_counter == 1
  90. assert cache.get(2) == 2
  91. cache.put(5, 5)
  92. assert set(cache.cache) == {2, 4, 5}
  93. assert cache._remove_counter == 2
  94. assert cache.pop(5) == 5
  95. assert len(cache) == 2
  96. assert set(cache.cache) == {2, 4}
  97. assert cache._remove_counter == 3
  98. cache.pop(10)
  99. assert len(cache) == 2
  100. assert set(cache.cache) == {2, 4}
  101. assert cache._remove_counter == 3
  102. cache.get(10)
  103. assert len(cache) == 2
  104. assert set(cache.cache) == {2, 4}
  105. assert cache._remove_counter == 3
  106. cache.put(6, 6)
  107. assert len(cache) == 3
  108. assert set(cache.cache) == {2, 4, 6}
  109. assert 2 in cache
  110. assert 4 in cache
  111. assert 6 in cache
  112. cache.remove_oldest()
  113. assert len(cache) == 2
  114. assert set(cache.cache) == {2, 6}
  115. assert cache._remove_counter == 4
  116. cache.clear()
  117. assert len(cache) == 0
  118. assert cache._remove_counter == 6
  119. cache._remove_counter = 0
  120. cache[1] = 1
  121. assert len(cache) == 1
  122. cache[1] = 1
  123. assert len(cache) == 1
  124. cache[2] = 2
  125. assert len(cache) == 2
  126. cache[3] = 3
  127. assert len(cache) == 3
  128. assert set(cache.cache) == {1, 2, 3}
  129. cache[4] = 4
  130. assert len(cache) == 3
  131. assert set(cache.cache) == {2, 3, 4}
  132. assert cache._remove_counter == 1
  133. assert cache[2] == 2
  134. cache[5] = 5
  135. assert set(cache.cache) == {2, 4, 5}
  136. assert cache._remove_counter == 2
  137. del cache[5]
  138. assert len(cache) == 2
  139. assert set(cache.cache) == {2, 4}
  140. assert cache._remove_counter == 3
  141. cache.pop(10)
  142. assert len(cache) == 2
  143. assert set(cache.cache) == {2, 4}
  144. assert cache._remove_counter == 3
  145. cache[6] = 6
  146. assert len(cache) == 3
  147. assert set(cache.cache) == {2, 4, 6}
  148. assert 2 in cache
  149. assert 4 in cache
  150. assert 6 in cache
  151. # Unit tests for get_adapter_absolute_path
  152. @patch('os.path.isabs')
  153. def test_get_adapter_absolute_path_absolute(mock_isabs):
  154. path = '/absolute/path/to/lora'
  155. mock_isabs.return_value = True
  156. assert get_adapter_absolute_path(path) == path
  157. @patch('os.path.expanduser')
  158. def test_get_adapter_absolute_path_expanduser(mock_expanduser):
  159. # Path with ~ that needs to be expanded
  160. path = '~/relative/path/to/lora'
  161. absolute_path = '/home/user/relative/path/to/lora'
  162. mock_expanduser.return_value = absolute_path
  163. assert get_adapter_absolute_path(path) == absolute_path
  164. @patch('os.path.exists')
  165. @patch('os.path.abspath')
  166. def test_get_adapter_absolute_path_local_existing(mock_abspath, mock_exist):
  167. # Relative path that exists locally
  168. path = 'relative/path/to/lora'
  169. absolute_path = '/absolute/path/to/lora'
  170. mock_exist.return_value = True
  171. mock_abspath.return_value = absolute_path
  172. assert get_adapter_absolute_path(path) == absolute_path
  173. @patch('huggingface_hub.snapshot_download')
  174. @patch('os.path.exists')
  175. def test_get_adapter_absolute_path_huggingface(mock_exist,
  176. mock_snapshot_download):
  177. # Hugging Face model identifier
  178. path = 'org/repo'
  179. absolute_path = '/mock/snapshot/path'
  180. mock_exist.return_value = False
  181. mock_snapshot_download.return_value = absolute_path
  182. assert get_adapter_absolute_path(path) == absolute_path
  183. @patch('huggingface_hub.snapshot_download')
  184. @patch('os.path.exists')
  185. def test_get_adapter_absolute_path_huggingface_error(mock_exist,
  186. mock_snapshot_download):
  187. # Hugging Face model identifier with download error
  188. path = 'org/repo'
  189. mock_exist.return_value = False
  190. mock_snapshot_download.side_effect = HfHubHTTPError(
  191. "failed to query model info")
  192. assert get_adapter_absolute_path(path) == path