test_config.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import pytest
  2. from aphrodite.common.config import ModelConfig
  3. MODEL_IDS_EXPECTED = [
  4. ("Qwen/Qwen1.5-7B", 32768),
  5. ("mistralai/Mistral-7B-v0.1", 4096),
  6. ("mistralai/Mistral-7B-Instruct-v0.2", 32768),
  7. ]
  8. @pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
  9. def test_disable_sliding_window(model_id_expected):
  10. model_id, expected = model_id_expected
  11. model_config = ModelConfig(
  12. model_id,
  13. model_id,
  14. tokenizer_mode="auto",
  15. trust_remote_code=False,
  16. seed=0,
  17. dtype="float16",
  18. revision=None,
  19. disable_sliding_window=True,
  20. )
  21. assert model_config.max_model_len == expected
  22. def test_get_sliding_window():
  23. TEST_SLIDING_WINDOW = 4096
  24. # Test that the sliding window is correctly computed.
  25. # For Qwen1.5/Qwen2, get_sliding_window() should be None
  26. # when use_sliding_window is False.
  27. qwen2_model_config = ModelConfig(
  28. "Qwen/Qwen1.5-7B",
  29. "Qwen/Qwen1.5-7B",
  30. tokenizer_mode="auto",
  31. trust_remote_code=False,
  32. seed=0,
  33. dtype="float16",
  34. revision=None,
  35. )
  36. qwen2_model_config.hf_config.use_sliding_window = False
  37. qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
  38. assert qwen2_model_config.get_sliding_window() is None
  39. qwen2_model_config.hf_config.use_sliding_window = True
  40. assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
  41. mistral_model_config = ModelConfig(
  42. "mistralai/Mistral-7B-v0.1",
  43. "mistralai/Mistral-7B-v0.1",
  44. tokenizer_mode="auto",
  45. trust_remote_code=False,
  46. seed=0,
  47. dtype="float16",
  48. revision=None,
  49. )
  50. mistral_model_config.hf_config.sliding_window = None
  51. assert mistral_model_config.get_sliding_window() is None
  52. mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
  53. assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
  54. def test_rope_customization():
  55. TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
  56. TEST_ROPE_THETA = 16_000_000.0
  57. LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
  58. llama_model_config = ModelConfig(
  59. "meta-llama/Meta-Llama-3-8B-Instruct",
  60. "meta-llama/Meta-Llama-3-8B-Instruct",
  61. tokenizer_mode="auto",
  62. trust_remote_code=False,
  63. dtype="float16",
  64. seed=0,
  65. )
  66. assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
  67. assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
  68. assert llama_model_config.max_model_len == 8192
  69. llama_model_config = ModelConfig(
  70. "meta-llama/Meta-Llama-3-8B-Instruct",
  71. "meta-llama/Meta-Llama-3-8B-Instruct",
  72. tokenizer_mode="auto",
  73. trust_remote_code=False,
  74. dtype="float16",
  75. seed=0,
  76. rope_scaling=TEST_ROPE_SCALING,
  77. rope_theta=TEST_ROPE_THETA,
  78. )
  79. assert getattr(llama_model_config.hf_config, "rope_scaling",
  80. None) == TEST_ROPE_SCALING
  81. assert getattr(llama_model_config.hf_config, "rope_theta",
  82. None) == TEST_ROPE_THETA
  83. assert llama_model_config.max_model_len == 16384
  84. longchat_model_config = ModelConfig(
  85. "lmsys/longchat-13b-16k",
  86. "lmsys/longchat-13b-16k",
  87. tokenizer_mode="auto",
  88. trust_remote_code=False,
  89. dtype="float16",
  90. seed=0,
  91. )
  92. # Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
  93. assert all(
  94. longchat_model_config.hf_config.rope_scaling.get(key) == value
  95. for key, value in LONGCHAT_ROPE_SCALING.items())
  96. assert longchat_model_config.max_model_len == 16384
  97. longchat_model_config = ModelConfig(
  98. "lmsys/longchat-13b-16k",
  99. "lmsys/longchat-13b-16k",
  100. tokenizer_mode="auto",
  101. trust_remote_code=False,
  102. dtype="float16",
  103. seed=0,
  104. rope_scaling=TEST_ROPE_SCALING,
  105. )
  106. assert getattr(longchat_model_config.hf_config, "rope_scaling",
  107. None) == TEST_ROPE_SCALING
  108. assert longchat_model_config.max_model_len == 4096