test_sequence.py 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import pytest
  2. from aphrodite.common.sequence import (CompletionSequenceGroupOutput,
  3. SequenceData, SequenceOutput)
  4. from aphrodite.modeling.layers.sampler import SamplerOutput
  5. from .core.utils import create_dummy_prompt
  6. @pytest.fixture
  7. def sample_outputs():
  8. return [
  9. CompletionSequenceGroupOutput(samples=[
  10. SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
  11. ],
  12. prompt_logprobs=None) for i in range(5)
  13. ]
  14. @pytest.fixture
  15. def sampler_output(sample_outputs):
  16. return SamplerOutput(outputs=sample_outputs)
  17. def test_sampler_output_initialization(sampler_output, sample_outputs):
  18. assert len(sampler_output) == len(sample_outputs)
  19. assert sampler_output.sampled_token_probs is None
  20. assert sampler_output.sampled_token_ids is None
  21. assert sampler_output.spec_decode_worker_metrics is None
  22. def test_sampler_output_getitem(sampler_output, sample_outputs):
  23. assert sampler_output[2] == sample_outputs[2]
  24. def test_sampler_output_setitem(sampler_output):
  25. new_output = CompletionSequenceGroupOutput(samples=[
  26. SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
  27. ],
  28. prompt_logprobs=None)
  29. sampler_output[2] = new_output
  30. assert sampler_output[2] == new_output
  31. def test_sampler_output_len(sampler_output, sample_outputs):
  32. assert len(sampler_output) == len(sample_outputs)
  33. def test_sampler_output_eq(sample_outputs):
  34. sampler_output1 = SamplerOutput(outputs=sample_outputs)
  35. sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
  36. sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
  37. assert sampler_output1 == sampler_output2
  38. assert sampler_output1 != sampler_output3
  39. def test_sequence_data_prefill():
  40. seq_data = SequenceData.from_seqs([1, 2, 3, 4])
  41. assert seq_data.get_num_uncomputed_tokens() == 4
  42. assert seq_data.get_num_computed_tokens() == 0
  43. # advance by 2
  44. seq_data.update_num_computed_tokens(2)
  45. assert seq_data.get_num_uncomputed_tokens() == 2
  46. assert seq_data.get_num_computed_tokens() == 2
  47. # advance by 1
  48. seq_data.update_num_computed_tokens(1)
  49. assert seq_data.get_num_uncomputed_tokens() == 1
  50. assert seq_data.get_num_computed_tokens() == 3
  51. # append tokens and reset, simulating recompute
  52. seq_data.append_token_id(1, logprob=0.0)
  53. seq_data.reset_state_for_recompute()
  54. assert seq_data.get_num_uncomputed_tokens() == 5
  55. assert seq_data.get_num_computed_tokens() == 0
  56. def test_sequence_group_stage():
  57. _, seq_group = create_dummy_prompt("1", 12)
  58. assert seq_group.is_prefill() is True
  59. seq_group.update_num_computed_tokens(6)
  60. assert seq_group.is_prefill() is True
  61. seq_group.update_num_computed_tokens(5)
  62. assert seq_group.is_prefill() is True
  63. seq_group.update_num_computed_tokens(1)
  64. assert seq_group.is_prefill() is False
  65. seqs = seq_group.get_seqs()
  66. assert len(seqs) == 1
  67. seqs[0].data.append_token_id(1, logprob=0.0)
  68. for seq in seq_group.get_seqs():
  69. seq.reset_state_for_recompute()
  70. assert seq_group.is_prefill() is True
  71. seq_group.update_num_computed_tokens(5)
  72. assert seq_group.is_prefill() is True
  73. seq_group.update_num_computed_tokens(7)
  74. assert seq_group.is_prefill() is True
  75. seq_group.update_num_computed_tokens(1)
  76. assert seq_group.is_prefill() is False