test_ngram_worker.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import torch
  2. from aphrodite.common.sequence import ExecuteModelRequest
  3. from aphrodite.spec_decode.ngram_worker import NGramWorker
  4. from aphrodite.spec_decode.top1_proposer import Top1Proposer
  5. from .utils import create_seq_group_metadata_from_prompts, create_worker
  6. def test_ngram_algo_correctness_for_single_no_match():
  7. """Verify our ngram algo find the right candidate in the prompt
  8. For the scenario cannot find any candidate in one single batch
  9. """
  10. block_size = 32
  11. num_gpu_blocks = 2048 // block_size
  12. seed = 100
  13. model_name = 'JackFram/llama-68m'
  14. vocab_size = 32_000
  15. device = 'cuda:0'
  16. ngram_worker = create_worker(
  17. NGramWorker,
  18. model_name,
  19. block_size,
  20. num_gpu_blocks,
  21. seed,
  22. )
  23. proposer = Top1Proposer(
  24. worker=ngram_worker,
  25. device=device,
  26. vocab_size=vocab_size,
  27. max_proposal_len=20,
  28. )
  29. # set ngram window [1, 3], which is window=1/2/3
  30. ngram_worker.set_ngram_window_size(1, 3)
  31. prompts = [
  32. # shall find no candidate
  33. [1, 2, 3, 4, 5, 6, 7],
  34. ]
  35. proposal_len = 5
  36. final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
  37. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  38. prompts,
  39. num_gpu_blocks,
  40. block_size,
  41. final_prompt_lens=final_prompt_lens)
  42. proposals = proposer.get_spec_proposals(
  43. execute_model_req=ExecuteModelRequest(
  44. seq_group_metadata_list=seq_group_metadata_list,
  45. num_lookahead_slots=proposal_len),
  46. seq_ids_with_bonus_token_in_last_step=None)
  47. assert torch.is_tensor(proposals.proposal_token_ids)
  48. assert torch.is_tensor(proposals.proposal_probs)
  49. assert proposals.proposal_token_ids.shape == torch.Size([1, proposal_len])
  50. assert proposals.proposal_probs.shape[:-1] == torch.Size([1, proposal_len])
  51. assert proposals.proposal_lens.shape == torch.Size([1])
  52. assert proposals.proposal_lens.tolist() == [0]
  53. def test_ngram_algo_correctness_for_batches_not_match_all():
  54. """Verify our ngram algo find the right candidate in the prompt
  55. For the scenario find some candidate not full in batchs
  56. """
  57. block_size = 32
  58. num_gpu_blocks = 2048 // block_size
  59. seed = 100
  60. model_name = 'JackFram/llama-68m'
  61. vocab_size = 32_000
  62. device = 'cuda:0'
  63. ngram_worker = create_worker(
  64. NGramWorker,
  65. model_name,
  66. block_size,
  67. num_gpu_blocks,
  68. seed,
  69. )
  70. proposer = Top1Proposer(
  71. worker=ngram_worker,
  72. device=device,
  73. vocab_size=vocab_size,
  74. max_proposal_len=20,
  75. )
  76. # set ngram window [1, 3], which is window=1/2/3
  77. ngram_worker.set_ngram_window_size(1, 3)
  78. prompts = [
  79. # shall find no candidate
  80. [1, 2, 3, 4, 5, 6, 7],
  81. # shall find candidate 12,13,14,15,16
  82. [11, 12, 13, 14, 15, 16, 11],
  83. # shall find candidate 23,24,25,26,21
  84. [21, 21, 22, 23, 24, 25, 26, 21, 22],
  85. # shall find candidate 34,35,36,37,38
  86. [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
  87. # shall find no candidate as exceed max_proposal_len
  88. [
  89. 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 31, 32, 33, 34, 35, 36, 37,
  90. 38, 31, 32, 33
  91. ],
  92. ]
  93. proposal_len = 5
  94. final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
  95. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  96. prompts,
  97. num_gpu_blocks,
  98. block_size,
  99. final_prompt_lens=final_prompt_lens)
  100. proposals = proposer.get_spec_proposals(
  101. execute_model_req=ExecuteModelRequest(
  102. seq_group_metadata_list=seq_group_metadata_list,
  103. num_lookahead_slots=proposal_len),
  104. seq_ids_with_bonus_token_in_last_step=None)
  105. assert torch.is_tensor(proposals.proposal_token_ids)
  106. assert torch.is_tensor(proposals.proposal_probs)
  107. assert proposals.proposal_token_ids.shape == torch.Size([5, proposal_len])
  108. assert proposals.proposal_probs.shape[:-1] == torch.Size([5, proposal_len])
  109. assert proposals.proposal_lens.shape == torch.Size([5])
  110. # the first sequence has no match so proposal_len should be overwritten to 0
  111. assert proposals.proposal_lens.tolist(
  112. ) == [0] + [proposal_len for _ in range(3)] + [0]
  113. for i in range(proposal_len):
  114. assert proposals.proposal_token_ids[0][i] == -1
  115. assert proposals.proposal_token_ids[1][i] == prompts[1][i + 1]
  116. assert proposals.proposal_token_ids[2][i] == prompts[2][i + 3]
  117. assert proposals.proposal_token_ids[3][i] == prompts[3][i + 5]
  118. assert proposals.proposal_token_ids[4][i] == -1
  119. def test_ngram_algo_correctness_for_batches_match_all():
  120. """Verify our ngram algo find the right candidate in the prompt
  121. For the scenario find candidate in all batchs
  122. """
  123. block_size = 32
  124. num_gpu_blocks = 2048 // block_size
  125. seed = 100
  126. model_name = 'JackFram/llama-68m'
  127. vocab_size = 32_000
  128. device = 'cuda:0'
  129. ngram_worker = create_worker(
  130. NGramWorker,
  131. model_name,
  132. block_size,
  133. num_gpu_blocks,
  134. seed,
  135. )
  136. proposer = Top1Proposer(
  137. worker=ngram_worker,
  138. device=device,
  139. vocab_size=vocab_size,
  140. max_proposal_len=20,
  141. )
  142. # set ngram window [0, 3], which is window=1/2/3
  143. ngram_worker.set_ngram_window_size(1, 3)
  144. prompts = [
  145. # shall find candidate 12,13,14,15,16
  146. [11, 12, 13, 14, 15, 16, 11],
  147. # shall find candidate 23,24,25,26,21
  148. [21, 21, 22, 23, 24, 25, 26, 21, 22],
  149. # shall find candidate 34,35,36,37,38
  150. [31, 32, 31, 32, 33, 34, 35, 36, 37, 38, 31, 32, 33],
  151. ]
  152. proposal_len = 5
  153. final_prompt_lens = [len(prompt) + proposal_len for prompt in prompts]
  154. seq_group_metadata_list = create_seq_group_metadata_from_prompts(
  155. prompts,
  156. num_gpu_blocks,
  157. block_size,
  158. final_prompt_lens=final_prompt_lens)
  159. proposals = proposer.get_spec_proposals(
  160. execute_model_req=ExecuteModelRequest(
  161. seq_group_metadata_list=seq_group_metadata_list,
  162. num_lookahead_slots=proposal_len),
  163. seq_ids_with_bonus_token_in_last_step=None)
  164. assert torch.is_tensor(proposals.proposal_token_ids)
  165. assert torch.is_tensor(proposals.proposal_probs)
  166. assert proposals.proposal_token_ids.shape == torch.Size([3, proposal_len])
  167. assert proposals.proposal_probs.shape[:-1] == torch.Size([3, proposal_len])
  168. assert proposals.proposal_lens.shape == torch.Size([3])
  169. assert proposals.proposal_lens.tolist() == [proposal_len for _ in range(3)]
  170. for i in range(proposal_len):
  171. assert proposals.proposal_token_ids[0][i] == prompts[0][i + 1]
  172. assert proposals.proposal_token_ids[1][i] == prompts[1][i + 3]
  173. assert proposals.proposal_token_ids[2][i] == prompts[2][i + 5]