test_serialization.py 1.3 KB

123456789101112131415161718192021222324252627282930313233
  1. import msgspec
  2. from aphrodite.common.sequence import ExecuteModelRequest
  3. from aphrodite.executor.msgspec_utils import decode_hook, encode_hook
  4. from ..spec_decode.utils import create_batch
  5. def test_msgspec_serialization():
  6. num_lookahead_slots = 4
  7. seq_group_metadata_list, _, _ = create_batch(16, num_lookahead_slots)
  8. execute_model_req = ExecuteModelRequest(
  9. seq_group_metadata_list=seq_group_metadata_list,
  10. num_lookahead_slots=num_lookahead_slots,
  11. running_queue_size=4)
  12. encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
  13. decoder = msgspec.msgpack.Decoder(ExecuteModelRequest,
  14. dec_hook=decode_hook)
  15. req = decoder.decode(encoder.encode(execute_model_req))
  16. expected = execute_model_req.seq_group_metadata_list
  17. actual = req.seq_group_metadata_list
  18. assert (len(expected) == len(actual))
  19. expected = expected[0]
  20. actual = actual[0]
  21. assert expected.block_tables == actual.block_tables
  22. assert expected.is_prompt == actual.is_prompt
  23. assert expected.request_id == actual.request_id
  24. assert (expected.seq_data[0].prompt_token_ids ==
  25. actual.seq_data[0].prompt_token_ids)
  26. assert (expected.seq_data[0].output_token_ids ==
  27. actual.seq_data[0].output_token_ids)