1
0

test_jamba.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import pytest
  2. from aphrodite.worker.model_runner import _get_graph_batch_size
  3. from ...utils import check_outputs_equal
  4. MODELS = ["ai21labs/Jamba-tiny-random"]
  5. # Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
  6. # TODO: Fix this with trained model
  7. @pytest.mark.skip()
  8. @pytest.mark.parametrize("model", MODELS)
  9. @pytest.mark.parametrize("dtype", ["bfloat16"])
  10. @pytest.mark.parametrize("max_tokens", [10])
  11. def test_models(
  12. hf_runner,
  13. aphrodite_runner,
  14. example_prompts,
  15. model: str,
  16. dtype: str,
  17. max_tokens: int,
  18. ) -> None:
  19. with hf_runner(model, dtype=dtype) as hf_model:
  20. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  21. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  22. aphrodite_outputs = aphrodite_model.generate_greedy(
  23. example_prompts, max_tokens)
  24. for i in range(len(example_prompts)):
  25. hf_output_ids, hf_output_str = hf_outputs[i]
  26. aphrodite_output_ids, aphrodite_output_str = aphrodite_outputs[i]
  27. assert hf_output_str == aphrodite_output_str, (
  28. f"Test{i}:\nHF: {hf_output_str!r}\nAPHRODITE: "
  29. f"{aphrodite_output_str!r}")
  30. assert hf_output_ids == aphrodite_output_ids, (
  31. f"Test{i}:\nHF: {hf_output_ids}\nAPHRODITE: "
  32. f"{aphrodite_output_ids}")
  33. @pytest.mark.parametrize("model", MODELS)
  34. @pytest.mark.parametrize("dtype", ["half"])
  35. @pytest.mark.parametrize("max_tokens", [5])
  36. def test_batching(
  37. aphrodite_runner,
  38. example_prompts,
  39. model: str,
  40. dtype: str,
  41. max_tokens: int,
  42. ) -> None:
  43. # To pass the small model tests, we need full precision.
  44. for_loop_outputs = []
  45. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  46. for prompt in example_prompts:
  47. for_loop_outputs.append(
  48. aphrodite_model.generate_greedy([prompt], max_tokens)[0])
  49. batched_outputs = aphrodite_model.generate_greedy(example_prompts,
  50. max_tokens)
  51. check_outputs_equal(
  52. outputs_0_lst=for_loop_outputs,
  53. outputs_1_lst=batched_outputs,
  54. name_0="for_loop_aphrodite",
  55. name_1="batched_aphrodite",
  56. )
  57. @pytest.mark.parametrize("model", MODELS)
  58. @pytest.mark.parametrize("dtype", ["bfloat16"])
  59. @pytest.mark.parametrize("max_tokens", [20])
  60. def test_mamba_cache_cg_padding(
  61. aphrodite_runner,
  62. example_prompts,
  63. model: str,
  64. dtype: str,
  65. max_tokens: int,
  66. ) -> None:
  67. # This test is for verifying that mamba cache is padded to CG captured
  68. # batch size. If it's not, a torch RuntimeError will be raised because
  69. # tensor dimensions aren't compatible
  70. while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
  71. example_prompts.append(example_prompts[0])
  72. try:
  73. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  74. aphrodite_model.generate_greedy(example_prompts, max_tokens)
  75. except RuntimeError:
  76. pytest.fail(
  77. "Couldn't run batch size which is not equal to a Cuda Graph "
  78. "captured batch size. "
  79. "Could be related to mamba cache not padded correctly")
  80. @pytest.mark.parametrize("model", MODELS)
  81. @pytest.mark.parametrize("dtype", ["float"])
  82. @pytest.mark.parametrize("max_tokens", [20])
  83. def test_models_preemption_recompute(
  84. hf_runner,
  85. aphrodite_runner,
  86. example_prompts,
  87. model: str,
  88. dtype: str,
  89. max_tokens: int,
  90. ) -> None:
  91. # Tests that outputs are identical with and w/o preemtions (recompute)
  92. assert dtype == "float"
  93. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  94. aphrodite_model.model.llm_engine.scheduler[
  95. 0].ENABLE_ARTIFICIAL_PREEMPT = True
  96. preempt_aphrodite_outputs = aphrodite_model.generate_greedy(
  97. example_prompts, max_tokens)
  98. aphrodite_model.model.llm_engine.scheduler[
  99. 0].ENABLE_ARTIFICIAL_PREEMPT = False
  100. aphrodite_outputs = aphrodite_model.generate_greedy(
  101. example_prompts, max_tokens)
  102. check_outputs_equal(
  103. outputs_0_lst=preempt_aphrodite_outputs,
  104. outputs_1_lst=aphrodite_outputs,
  105. name_0="aphrodite_preepmtions",
  106. name_1="aphrodite",
  107. )
  108. @pytest.mark.parametrize("model", MODELS)
  109. @pytest.mark.parametrize("dtype", ["float"])
  110. def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
  111. aphrodite_runner,
  112. model: str,
  113. dtype: str,
  114. example_prompts,
  115. ) -> None:
  116. # This test is for verifying that the Jamba inner state management doesn't
  117. # collapse in case where the number of incoming requests and
  118. # finished_requests_ids is larger than the maximum mamba block capacity.
  119. # This could generally happen due to the fact that Jamba does support
  120. # statelessness mechanism where it can cleanup new incoming requests in
  121. # a single step.
  122. try:
  123. with aphrodite_runner(
  124. model,
  125. dtype=dtype,
  126. max_num_seqs=10,
  127. ) as aphrodite_model:
  128. aphrodite_model.generate_greedy([example_prompts[0]] * 100, 10)
  129. except ValueError:
  130. pytest.fail("Jamba inner state wasn't cleaned up properly between"
  131. "steps finished requests registered unnecessarily ")
  132. @pytest.mark.parametrize("model", MODELS)
  133. @pytest.mark.parametrize("dtype", ["float"])
  134. def test_state_cleanup(
  135. aphrodite_runner,
  136. model: str,
  137. dtype: str,
  138. example_prompts,
  139. ) -> None:
  140. # This test is for verifying that the Jamba state is cleaned up between
  141. # steps, If its not cleaned, an error would be expected.
  142. try:
  143. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  144. for _ in range(10):
  145. aphrodite_model.generate_greedy([example_prompts[0]] * 100, 1)
  146. except ValueError:
  147. pytest.fail("Jamba inner state wasn't cleaned up between states, "
  148. "could be related to finished_requests_ids")
  149. @pytest.mark.parametrize("model", MODELS)
  150. @pytest.mark.parametrize("dtype", ["float"])
  151. def test_model_print(
  152. aphrodite_runner,
  153. model: str,
  154. dtype: str,
  155. ) -> None:
  156. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  157. # This test is for verifying whether the model's extra_repr
  158. # can be printed correctly.
  159. print(aphrodite_model.model.llm_engine.model_executor.driver_worker.
  160. model_runner.model)