test_model_runner.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import random
  2. import torch
  3. from aphrodite.common.sequence import (SamplingParams, SequenceData,
  4. SequenceGroupMetadata)
  5. from aphrodite.task_handler.model_runner import ModelRunner
  6. def test_prepare_prompt():
  7. model_runner = ModelRunner(None, None, None)
  8. model_runner.set_block_size(16)
  9. batch_size = random.randint(1, 256)
  10. prompt_lens = []
  11. seq_group_metadata_list = []
  12. for i in range(batch_size):
  13. # make sure all tokens fit into one block
  14. prompt_len = i % (model_runner.block_size - 1) + 1
  15. prompt_lens.append(prompt_len)
  16. seq_data = list(range(prompt_len))
  17. seq_group_metadata_list.append(
  18. SequenceGroupMetadata(
  19. request_id=f"test_{i}",
  20. is_prompt=True,
  21. seq_data={0: SequenceData(seq_data)},
  22. sampling_params=SamplingParams(temperature=0),
  23. block_tables={0: [1]},
  24. persistent_data={},
  25. ))
  26. expected_selected_token_indices = []
  27. selected_token_start_idx = 0
  28. max_seq_len = max(prompt_lens)
  29. for prompt_len in prompt_lens:
  30. expected_selected_token_indices.append(selected_token_start_idx +
  31. prompt_len - 1)
  32. selected_token_start_idx += max_seq_len
  33. input_tokens, input_positions, return_prompt_lens = (
  34. model_runner._prepare_prompt(seq_group_metadata_list))
  35. assert return_prompt_lens == prompt_lens
  36. sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
  37. prompt_lens)
  38. assert input_tokens.shape == (batch_size, max_seq_len)
  39. assert input_positions.shape == (batch_size, max_seq_len)
  40. torch.testing.assert_close(input_tokens, input_positions)
  41. actual = sampling_metadata.selected_token_indices
  42. expected = torch.tensor(expected_selected_token_indices,
  43. device=actual.device,
  44. dtype=actual.dtype)
  45. torch.testing.assert_close(actual, expected)