test_jamba.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. import pytest
  2. from aphrodite.worker.model_runner import _get_graph_batch_size
  3. from tests.models.utils import check_outputs_equal
  4. MODELS = ["ai21labs/Jamba-tiny-random"]
  5. # Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl
  6. # TODO: Fix this with trained model
  7. @pytest.mark.skip()
  8. @pytest.mark.parametrize("model", MODELS)
  9. @pytest.mark.parametrize("dtype", ["bfloat16"])
  10. @pytest.mark.parametrize("max_tokens", [10])
  11. def test_models(
  12. hf_runner,
  13. aphrodite_runner,
  14. example_prompts,
  15. model: str,
  16. dtype: str,
  17. max_tokens: int,
  18. ) -> None:
  19. with hf_runner(model, dtype=dtype) as hf_model:
  20. hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
  21. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  22. aphrodite_outputs = aphrodite_model.generate_greedy(
  23. example_prompts, max_tokens)
  24. for i in range(len(example_prompts)):
  25. hf_output_ids, hf_output_str = hf_outputs[i]
  26. aphrodite_output_ids, aphrodite_output_str = aphrodite_outputs[i]
  27. assert hf_output_str == aphrodite_output_str, (
  28. f"Test{i}:\nHF: {hf_output_str!r}\nAphrodite: "
  29. f"{aphrodite_output_str!r}")
  30. assert hf_output_ids == aphrodite_output_ids, (
  31. f"Test{i}:\nHF: {hf_output_ids}\nAphrodite: {aphrodite_output_ids}")
  32. @pytest.mark.parametrize("model", MODELS)
  33. @pytest.mark.parametrize("dtype", ["half"])
  34. @pytest.mark.parametrize("max_tokens", [5])
  35. def test_batching(
  36. aphrodite_runner,
  37. example_prompts,
  38. model: str,
  39. dtype: str,
  40. max_tokens: int,
  41. ) -> None:
  42. # To pass the small model tests, we need full precision.
  43. for_loop_outputs = []
  44. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  45. for prompt in example_prompts:
  46. for_loop_outputs.append(
  47. aphrodite_model.generate_greedy([prompt], max_tokens)[0])
  48. batched_outputs = aphrodite_model.generate_greedy(example_prompts,
  49. max_tokens)
  50. check_outputs_equal(
  51. outputs_0_lst=for_loop_outputs,
  52. outputs_1_lst=batched_outputs,
  53. name_0="for_loop_aphrodite",
  54. name_1="batched_aphrodite",
  55. )
  56. @pytest.mark.parametrize("model", MODELS)
  57. @pytest.mark.parametrize("dtype", ["bfloat16"])
  58. @pytest.mark.parametrize("max_tokens", [20])
  59. def test_mamba_cache_cg_padding(
  60. aphrodite_runner,
  61. example_prompts,
  62. model: str,
  63. dtype: str,
  64. max_tokens: int,
  65. ) -> None:
  66. # This test is for verifying that mamba cache is padded to CG captured
  67. # batch size. If it's not, a torch RuntimeError will be raised because
  68. # tensor dimensions aren't compatible
  69. while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
  70. example_prompts.append(example_prompts[0])
  71. try:
  72. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  73. aphrodite_model.generate_greedy(example_prompts, max_tokens)
  74. except RuntimeError:
  75. pytest.fail(
  76. "Couldn't run batch size which is not equal to a Cuda Graph "
  77. "captured batch size. "
  78. "Could be related to mamba cache not padded correctly")
  79. @pytest.mark.parametrize("model", MODELS)
  80. @pytest.mark.parametrize("dtype", ["float"])
  81. @pytest.mark.parametrize("max_tokens", [20])
  82. def test_models_preemption_recompute(
  83. hf_runner,
  84. aphrodite_runner,
  85. example_prompts,
  86. model: str,
  87. dtype: str,
  88. max_tokens: int,
  89. ) -> None:
  90. # Tests that outputs are identical with and w/o preemtions (recompute)
  91. assert dtype == "float"
  92. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  93. aphrodite_model.model.llm_engine.scheduler[
  94. 0].ENABLE_ARTIFICIAL_PREEMPT = True
  95. preempt_aphrodite_outputs = aphrodite_model.generate_greedy(
  96. example_prompts, max_tokens)
  97. aphrodite_model.model.llm_engine.scheduler[
  98. 0].ENABLE_ARTIFICIAL_PREEMPT = False
  99. aphrodite_outputs = aphrodite_model.generate_greedy(
  100. example_prompts, max_tokens)
  101. check_outputs_equal(
  102. outputs_0_lst=preempt_aphrodite_outputs,
  103. outputs_1_lst=aphrodite_outputs,
  104. name_0="aphrodite_preepmtions",
  105. name_1="aphrodite",
  106. )
  107. @pytest.mark.parametrize("model", MODELS)
  108. @pytest.mark.parametrize("dtype", ["float"])
  109. def test_fail_upon_inc_requests_and_finished_requests_lt_available_blocks(
  110. aphrodite_runner,
  111. model: str,
  112. dtype: str,
  113. example_prompts,
  114. ) -> None:
  115. # This test is for verifying that the Jamba inner state management doesn't
  116. # collapse in case where the number of incoming requests and
  117. # finished_requests_ids is larger than the maximum mamba block capacity.
  118. # This could generally happen due to the fact that Jamba does support
  119. # statelessness mechanism where it can cleanup new incoming requests in
  120. # a single step.
  121. try:
  122. with aphrodite_runner(model, dtype=dtype,
  123. max_num_seqs=10) as aphrodite_model:
  124. aphrodite_model.generate_greedy([example_prompts[0]] * 100, 10)
  125. except ValueError:
  126. pytest.fail("Jamba inner state wasn't cleaned up properly between"
  127. "steps finished requests registered unnecessarily ")
  128. @pytest.mark.parametrize("model", MODELS)
  129. @pytest.mark.parametrize("dtype", ["float"])
  130. def test_state_cleanup(
  131. aphrodite_runner,
  132. model: str,
  133. dtype: str,
  134. example_prompts,
  135. ) -> None:
  136. # This test is for verifying that the Jamba state is cleaned up between
  137. # steps, If its not cleaned, an error would be expected.
  138. try:
  139. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  140. for _ in range(10):
  141. aphrodite_model.generate_greedy([example_prompts[0]] * 100, 1)
  142. except ValueError:
  143. pytest.fail("Jamba inner state wasn't cleaned up between states, "
  144. "could be related to finished_requests_ids")
  145. @pytest.mark.parametrize("model", MODELS)
  146. @pytest.mark.parametrize("dtype", ["float"])
  147. def test_model_print(
  148. aphrodite_runner,
  149. model: str,
  150. dtype: str,
  151. ) -> None:
  152. with aphrodite_runner(model, dtype=dtype) as aphrodite_model:
  153. # This test is for verifying whether the model's extra_repr
  154. # can be printed correctly.
  155. print(aphrodite_model.model.llm_engine.model_executor.driver_worker.
  156. model_runner.model)