test_cli_args.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import json
  2. import unittest
  3. from aphrodite.common.utils import FlexibleArgumentParser
  4. from aphrodite.endpoints.openai.args import make_arg_parser
  5. from aphrodite.endpoints.openai.serving_engine import LoRAModulePath
  6. LORA_MODULE = {
  7. "name": "module2",
  8. "path": "/path/to/module2",
  9. "base_model_name": "llama",
  10. }
  11. class TestLoraParserAction(unittest.TestCase):
  12. def setUp(self):
  13. # Setting up argparse parser for tests
  14. parser = FlexibleArgumentParser(
  15. description="Aphrodite's remote OpenAI server."
  16. )
  17. self.parser = make_arg_parser(parser)
  18. def test_valid_key_value_format(self):
  19. # Test old format: name=path
  20. args = self.parser.parse_args(
  21. [
  22. "--lora-modules",
  23. "module1=/path/to/module1",
  24. ]
  25. )
  26. expected = [LoRAModulePath(name="module1", path="/path/to/module1")]
  27. self.assertEqual(args.lora_modules, expected)
  28. def test_valid_json_format(self):
  29. # Test valid JSON format input
  30. args = self.parser.parse_args(
  31. [
  32. "--lora-modules",
  33. json.dumps(LORA_MODULE),
  34. ]
  35. )
  36. expected = [
  37. LoRAModulePath(
  38. name="module2", path="/path/to/module2", base_model_name="llama"
  39. )
  40. ]
  41. self.assertEqual(args.lora_modules, expected)
  42. def test_invalid_json_format(self):
  43. # Test invalid JSON format input, missing closing brace
  44. with self.assertRaises(SystemExit):
  45. self.parser.parse_args(
  46. [
  47. "--lora-modules",
  48. '{"name": "module3", "path": "/path/to/module3"',
  49. ]
  50. )
  51. def test_invalid_type_error(self):
  52. # Test type error when values are not JSON or key=value
  53. with self.assertRaises(SystemExit):
  54. self.parser.parse_args(
  55. [
  56. "--lora-modules",
  57. "invalid_format", # This is not JSON or key=value format
  58. ]
  59. )
  60. def test_invalid_json_field(self):
  61. # Test valid JSON format but missing required fields
  62. with self.assertRaises(SystemExit):
  63. self.parser.parse_args(
  64. [
  65. "--lora-modules",
  66. '{"name": "module4"}', # Missing required 'path' field
  67. ]
  68. )
  69. def test_empty_values(self):
  70. # Test when no LoRA modules are provided
  71. args = self.parser.parse_args(["--lora-modules", ""])
  72. self.assertEqual(args.lora_modules, [])
  73. def test_multiple_valid_inputs(self):
  74. # Test multiple valid inputs (both old and JSON format)
  75. args = self.parser.parse_args(
  76. [
  77. "--lora-modules",
  78. "module1=/path/to/module1",
  79. json.dumps(LORA_MODULE),
  80. ]
  81. )
  82. expected = [
  83. LoRAModulePath(name="module1", path="/path/to/module1"),
  84. LoRAModulePath(
  85. name="module2", path="/path/to/module2", base_model_name="llama"
  86. ),
  87. ]
  88. self.assertEqual(args.lora_modules, expected)
  89. if __name__ == "__main__":
  90. unittest.main()