test_utils.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import asyncio
  2. import os
  3. import socket
  4. from functools import partial
  5. from typing import AsyncIterator, Tuple
  6. import pytest
  7. from aphrodite.common.utils import (FlexibleArgumentParser, deprecate_kwargs,
  8. get_open_port, merge_async_iterators)
  9. from .utils import error_on_warning
  10. @pytest.mark.asyncio
  11. async def test_merge_async_iterators():
  12. async def mock_async_iterator(idx: int):
  13. try:
  14. while True:
  15. yield f"item from iterator {idx}"
  16. await asyncio.sleep(0.1)
  17. except asyncio.CancelledError:
  18. print(f"iterator {idx} cancelled")
  19. iterators = [mock_async_iterator(i) for i in range(3)]
  20. merged_iterator = merge_async_iterators(*iterators,
  21. is_cancelled=partial(asyncio.sleep,
  22. 0,
  23. result=False))
  24. async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
  25. async for idx, output in generator:
  26. print(f"idx: {idx}, output: {output}")
  27. task = asyncio.create_task(stream_output(merged_iterator))
  28. await asyncio.sleep(0.5)
  29. task.cancel()
  30. with pytest.raises(asyncio.CancelledError):
  31. await task
  32. for iterator in iterators:
  33. try:
  34. # Can use anext() in python >= 3.10
  35. await asyncio.wait_for(iterator.__anext__(), 1)
  36. except StopAsyncIteration:
  37. # All iterators should be cancelled and print this message.
  38. print("Iterator was cancelled normally")
  39. except (Exception, asyncio.CancelledError) as e:
  40. raise AssertionError() from e
  41. def test_deprecate_kwargs_always():
  42. @deprecate_kwargs("old_arg", is_deprecated=True)
  43. def dummy(*, old_arg: object = None, new_arg: object = None):
  44. pass
  45. with pytest.warns(DeprecationWarning, match="'old_arg'"):
  46. dummy(old_arg=1)
  47. with error_on_warning():
  48. dummy(new_arg=1)
  49. def test_deprecate_kwargs_never():
  50. @deprecate_kwargs("old_arg", is_deprecated=False)
  51. def dummy(*, old_arg: object = None, new_arg: object = None):
  52. pass
  53. with error_on_warning():
  54. dummy(old_arg=1)
  55. with error_on_warning():
  56. dummy(new_arg=1)
  57. def test_deprecate_kwargs_dynamic():
  58. is_deprecated = True
  59. @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
  60. def dummy(*, old_arg: object = None, new_arg: object = None):
  61. pass
  62. with pytest.warns(DeprecationWarning, match="'old_arg'"):
  63. dummy(old_arg=1)
  64. with error_on_warning():
  65. dummy(new_arg=1)
  66. is_deprecated = False
  67. with error_on_warning():
  68. dummy(old_arg=1)
  69. with error_on_warning():
  70. dummy(new_arg=1)
  71. def test_deprecate_kwargs_additional_message():
  72. @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
  73. def dummy(*, old_arg: object = None, new_arg: object = None):
  74. pass
  75. with pytest.warns(DeprecationWarning, match="abcd"):
  76. dummy(old_arg=1)
  77. def test_get_open_port():
  78. os.environ["APHRODITE_PORT"] = "5678"
  79. # make sure we can get multiple ports, even if the env var is set
  80. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
  81. s1.bind(("localhost", get_open_port()))
  82. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
  83. s2.bind(("localhost", get_open_port()))
  84. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
  85. s3.bind(("localhost", get_open_port()))
  86. os.environ.pop("APHRODITE_PORT")
  87. # Tests for FlexibleArgumentParser
  88. @pytest.fixture
  89. def parser():
  90. parser = FlexibleArgumentParser()
  91. parser.add_argument('--image-input-type',
  92. choices=['pixel_values', 'image_features'])
  93. parser.add_argument('--model-name')
  94. parser.add_argument('--batch-size', type=int)
  95. parser.add_argument('--enable-feature', action='store_true')
  96. return parser
  97. def test_underscore_to_dash(parser):
  98. args = parser.parse_args(['--image_input_type', 'pixel_values'])
  99. assert args.image_input_type == 'pixel_values'
  100. def test_mixed_usage(parser):
  101. args = parser.parse_args([
  102. '--image_input_type', 'image_features', '--model-name',
  103. 'facebook/opt-125m'
  104. ])
  105. assert args.image_input_type == 'image_features'
  106. assert args.model_name == 'facebook/opt-125m'
  107. def test_with_equals_sign(parser):
  108. args = parser.parse_args(
  109. ['--image_input_type=pixel_values', '--model-name=facebook/opt-125m'])
  110. assert args.image_input_type == 'pixel_values'
  111. assert args.model_name == 'facebook/opt-125m'
  112. def test_with_int_value(parser):
  113. args = parser.parse_args(['--batch_size', '32'])
  114. assert args.batch_size == 32
  115. args = parser.parse_args(['--batch-size', '32'])
  116. assert args.batch_size == 32
  117. def test_with_bool_flag(parser):
  118. args = parser.parse_args(['--enable_feature'])
  119. assert args.enable_feature is True
  120. args = parser.parse_args(['--enable-feature'])
  121. assert args.enable_feature is True
  122. def test_invalid_choice(parser):
  123. with pytest.raises(SystemExit):
  124. parser.parse_args(['--image_input_type', 'invalid_choice'])
  125. def test_missing_required_argument(parser):
  126. parser.add_argument('--required-arg', required=True)
  127. with pytest.raises(SystemExit):
  128. parser.parse_args([])