test_sequence.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from array import array
  2. import pytest
  3. from aphrodite.common.sequence import (APHRODITE_TOKEN_ID_ARRAY_TYPE,
  4. CompletionSequenceGroupOutput,
  5. SequenceData, SequenceOutput)
  6. from aphrodite.modeling.layers.sampler import SamplerOutput
  7. from .core.utils import create_dummy_prompt
  8. @pytest.fixture
  9. def sample_outputs():
  10. return [
  11. CompletionSequenceGroupOutput(samples=[
  12. SequenceOutput(parent_seq_id=0, output_token=i, logprobs={})
  13. ],
  14. prompt_logprobs=None) for i in range(5)
  15. ]
  16. @pytest.fixture
  17. def sampler_output(sample_outputs):
  18. return SamplerOutput(outputs=sample_outputs)
  19. def test_sampler_output_initialization(sampler_output, sample_outputs):
  20. assert len(sampler_output) == len(sample_outputs)
  21. assert sampler_output.sampled_token_probs is None
  22. assert sampler_output.sampled_token_ids is None
  23. assert sampler_output.spec_decode_worker_metrics is None
  24. def test_sampler_output_getitem(sampler_output, sample_outputs):
  25. assert sampler_output[2] == sample_outputs[2]
  26. def test_sampler_output_setitem(sampler_output):
  27. new_output = CompletionSequenceGroupOutput(samples=[
  28. SequenceOutput(parent_seq_id=0, output_token=99, logprobs={})
  29. ],
  30. prompt_logprobs=None)
  31. sampler_output[2] = new_output
  32. assert sampler_output[2] == new_output
  33. def test_sampler_output_len(sampler_output, sample_outputs):
  34. assert len(sampler_output) == len(sample_outputs)
  35. def test_sampler_output_eq(sample_outputs):
  36. sampler_output1 = SamplerOutput(outputs=sample_outputs)
  37. sampler_output2 = SamplerOutput(outputs=sample_outputs.copy())
  38. sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1])
  39. assert sampler_output1 == sampler_output2
  40. assert sampler_output1 != sampler_output3
  41. def test_sequence_data_prefill():
  42. seq_data = SequenceData(array(APHRODITE_TOKEN_ID_ARRAY_TYPE, [1, 2, 3, 4]))
  43. assert seq_data.get_num_uncomputed_tokens() == 4
  44. assert seq_data.get_num_computed_tokens() == 0
  45. # advance by 2
  46. seq_data.update_num_computed_tokens(2)
  47. assert seq_data.get_num_uncomputed_tokens() == 2
  48. assert seq_data.get_num_computed_tokens() == 2
  49. # advance by 1
  50. seq_data.update_num_computed_tokens(1)
  51. assert seq_data.get_num_uncomputed_tokens() == 1
  52. assert seq_data.get_num_computed_tokens() == 3
  53. # append tokens and reset, simulating recompute
  54. seq_data.append_token_id(1, logprob=0.0)
  55. seq_data.reset_state_for_recompute()
  56. assert seq_data.get_num_uncomputed_tokens() == 5
  57. assert seq_data.get_num_computed_tokens() == 0
  58. def test_sequence_group_stage():
  59. _, seq_group = create_dummy_prompt("1", 12)
  60. assert seq_group.is_prefill() is True
  61. seq_group.update_num_computed_tokens(6)
  62. assert seq_group.is_prefill() is True
  63. seq_group.update_num_computed_tokens(5)
  64. assert seq_group.is_prefill() is True
  65. seq_group.update_num_computed_tokens(1)
  66. assert seq_group.is_prefill() is False
  67. seqs = seq_group.get_seqs()
  68. assert len(seqs) == 1
  69. seqs[0].data.append_token_id(1, logprob=0.0)
  70. for seq in seq_group.get_seqs():
  71. seq.reset_state_for_recompute()
  72. assert seq_group.is_prefill() is True
  73. seq_group.update_num_computed_tokens(5)
  74. assert seq_group.is_prefill() is True
  75. seq_group.update_num_computed_tokens(7)
  76. assert seq_group.is_prefill() is True
  77. seq_group.update_num_computed_tokens(1)
  78. assert seq_group.is_prefill() is False