test_batch_expansion.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from typing import List
  2. import pytest
  3. import torch
  4. from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
  5. from .utils import create_seq_group_metadata_from_prompts, mock_worker
  6. @pytest.mark.parametrize('num_target_seq_ids', [100])
  7. @pytest.mark.skip_global_cleanup
  8. def test_create_target_seq_id_iterator(num_target_seq_ids: int):
  9. """Verify all new sequence ids are greater than all input
  10. seq ids.
  11. """
  12. scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
  13. all_seq_ids = [
  14. [1, 3, 5, 7],
  15. list(range(100)) + [0],
  16. [100],
  17. ]
  18. for seq_ids in all_seq_ids:
  19. max_seq_id = max(seq_ids)
  20. iterator = scorer._create_target_seq_id_iterator(seq_ids) # pylint: disable=protected-access
  21. for _ in range(num_target_seq_ids):
  22. assert next(iterator) > max_seq_id
  23. @pytest.mark.parametrize('k', [1, 2, 6])
  24. @pytest.mark.skip_global_cleanup
  25. def test_get_token_ids_to_score(k: int):
  26. """Verify correct tokens are selected for scoring.
  27. """
  28. proposal_token_ids = torch.tensor(
  29. list(range(k)),
  30. dtype=torch.int64,
  31. device='cuda',
  32. )
  33. expected_output: List[List[int]] = [
  34. [],
  35. ]
  36. for i in range(proposal_token_ids.shape[0]):
  37. expected_output.append(proposal_token_ids[:i + 1].tolist())
  38. scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
  39. actual_output = scorer._get_token_ids_to_score(proposal_token_ids.tolist()) # pylint: disable=protected-access
  40. actual_output = [
  41. x.tolist() if isinstance(x, torch.Tensor) else x for x in actual_output
  42. ]
  43. assert actual_output == expected_output
  44. @pytest.mark.parametrize('k', [1, 2, 6])
  45. @pytest.mark.skip_global_cleanup
  46. def test_create_single_target_seq_group_metadata(k: int):
  47. """Verify correct creation of a batch-expanded seq group metadata.
  48. """
  49. prompt_tokens = [1, 2, 3]
  50. prev_output_tokens = [4, 5, 6]
  51. token_ids = list(range(k))
  52. num_tokens_processed = len(prompt_tokens) + len(prev_output_tokens) - 1
  53. final_seq_len = len(prompt_tokens) + len(prev_output_tokens) + len(
  54. token_ids)
  55. block_size = 32
  56. input_seq_group_metadata = create_seq_group_metadata_from_prompts(
  57. [prompt_tokens], 2048 // block_size, block_size, [final_seq_len],
  58. [prev_output_tokens], [num_tokens_processed])[0]
  59. input_seq_id = list(input_seq_group_metadata.seq_data.keys())[0]
  60. target_seq_id = 100
  61. scorer = BatchExpansionTop1Scorer(mock_worker(), 'cuda:0', 32_000)
  62. output = scorer._create_single_target_seq_group_metadata( # pylint: disable=protected-access
  63. input_seq_group_metadata,
  64. input_seq_id,
  65. target_seq_id,
  66. token_ids,
  67. input_seq_group_metadata.sampling_params,
  68. )
  69. assert output.request_id == input_seq_group_metadata.request_id
  70. assert len(output.seq_data) == 1
  71. assert output.seq_data[target_seq_id].get_prompt_token_ids() == tuple(
  72. prompt_tokens)
  73. assert output.seq_data[target_seq_id].get_output_token_ids() == tuple(
  74. prev_output_tokens + token_ids)
  75. assert len(output.block_tables) == 1
  76. assert output.block_tables[
  77. target_seq_id] == input_seq_group_metadata.block_tables[input_seq_id]