test_scheduler_encoder_decoder.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from typing import List
  2. import pytest # noqa
  3. from aphrodite.common.config import CacheConfig, SchedulerConfig
  4. from aphrodite.common.sequence import SequenceGroup
  5. from aphrodite.processing.scheduler import Scheduler
  6. from .utils import (append_new_token, create_dummy_prompt_encoder_decoder,
  7. get_sequence_groups, schedule_and_update_computed_tokens)
  8. def test_scheduler_schedule_simple_encoder_decoder():
  9. '''
  10. Test basic scheduler functionality in the context
  11. of an encoder/decoder model. Focus on testing
  12. enc/dec-specific functionality sense tests already
  13. exist for decoder-only functionality
  14. Test behavior:
  15. * Construct Scheduler
  16. * Construct dummy encoder/decoder sequence groups
  17. * Add dummy seq groups to scheduler backlog
  18. * Schedule the next seq group & validate:
  19. * Cross-attn block tables
  20. * Updated states of seq groups
  21. * Number of batched tokens
  22. * Number of blocks to copy/swap-in/swap-out
  23. * Number of scheduled seq groups
  24. * Repeat for both prefill- and decode-phase
  25. * Abort scheduled seq groups
  26. * Assert that aborted seq groups no longer appear in
  27. cross-attention block table
  28. '''
  29. block_size = 4
  30. num_seq_group = 4
  31. max_model_len = 16
  32. scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
  33. cache_config = CacheConfig(block_size, 1.0, 1, "auto")
  34. cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
  35. cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
  36. scheduler = Scheduler(scheduler_config, cache_config, None)
  37. running: List[SequenceGroup] = []
  38. # Add seq groups to scheduler.
  39. req_id_list = []
  40. for i in range(num_seq_group):
  41. req_id = str(i)
  42. req_id_list.append(req_id)
  43. _, _, seq_group = create_dummy_prompt_encoder_decoder(
  44. req_id, block_size, block_size, block_size)
  45. scheduler.add_seq_group(seq_group)
  46. running.append(seq_group)
  47. # Schedule seq groups prefill.
  48. num_tokens = block_size * num_seq_group
  49. seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
  50. # - Verify that sequence group cross-attention block tables are
  51. # registered with the block manager
  52. assert all([(req_id in scheduler.block_manager.cross_block_tables)
  53. for req_id in req_id_list])
  54. # - Validate sequence-group status
  55. assert set(get_sequence_groups(out)) == set(running)
  56. # - Validate number of batched tokens
  57. assert out.num_batched_tokens == num_tokens
  58. # - Validate there are no remaining blocks to swap
  59. assert (not out.blocks_to_copy and not out.blocks_to_swap_in
  60. and not out.blocks_to_swap_out)
  61. # - Validate all seq groups were scheduled
  62. assert len(seq_group_meta_list) == num_seq_group
  63. append_new_token(out, 1)
  64. # Schedule seq groups decode.
  65. seq_group_meta_list, out = schedule_and_update_computed_tokens(scheduler)
  66. # - Verify that sequence group metadata includes encoder attention
  67. # and cross-attention metadata
  68. assert all([
  69. not ((seq_group_meta.encoder_seq_data is None) or
  70. (seq_group_meta.cross_block_table is None))
  71. for seq_group_meta in seq_group_meta_list
  72. ])
  73. # - Validate sequence-group status
  74. assert set(get_sequence_groups(out)) == set(running)
  75. # - Validate there is one batched token per seq group
  76. assert out.num_batched_tokens == num_seq_group
  77. # - Validate there are no remaining blocks to swap
  78. assert (not out.blocks_to_copy and not out.blocks_to_swap_in
  79. and not out.blocks_to_swap_out)
  80. # - Validate that all seq groups were scheduled
  81. assert len(seq_group_meta_list) == num_seq_group
  82. append_new_token(out, 1)
  83. # Abort sequences
  84. for req_id in req_id_list:
  85. scheduler.abort_seq_group(req_id)
  86. # - Verify that sequence group cross-attention block tables are
  87. # NO LONGER registered with the block manager
  88. assert req_id not in scheduler.block_manager.cross_block_tables