test_cache_block_hashing.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. """Test hashing of cache blocks.
  2. Run `pytest tests/test_cache_block_hashing.py`.
  3. """
  4. from typing import List, Optional
  5. import pytest
  6. from aphrodite.lora.request import LoRARequest
  7. from aphrodite.common.sequence import Sequence
  8. from aphrodite.transformers_utils.tokenizer_group import TokenizerGroup
  9. # Make two prefixes with different first blocks.
  10. prefix_start = [("You are an expert"), ("You are a")]
  11. prefix_common = (
  12. " school principal, skilled in effectively managing "
  13. "faculty and staff. Draft 10-15 questions for a potential first grade "
  14. "Head Teacher for my K-12, all-girls', independent school that emphasizes "
  15. "community, joyful discovery, and life-long learning. The candidate is "
  16. "coming in for a first-round panel interview for a 8th grade Math "
  17. "teaching role. They have 5 years of previous teaching experience "
  18. "as an assistant teacher at a co-ed, public school with experience "
  19. "in middle school math teaching. Based on this, fulfill "
  20. "the following: ")
  21. prefixes = [start + prefix_common for start in prefix_start]
  22. # Sample prompts.
  23. sample_prompts = [
  24. "Hello, my name is", "The president of the United States is",
  25. "The capital of France is", "The future of AI is"
  26. ]
  27. # Helper function.
  28. def flatten_2d(li):
  29. return [lss for ls in li for lss in ls]
  30. @pytest.mark.parametrize("model", ["facebook/opt-125m"])
  31. @pytest.mark.parametrize("block_size", [16])
  32. @pytest.mark.parametrize("max_num_seqs", [256])
  33. @pytest.mark.parametrize("concurrent_lora_int_ids",
  34. [[None], [1], [None, 1], [None, 1, 2], [1, 2]])
  35. def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
  36. concurrent_lora_int_ids: List[Optional[int]]):
  37. tokenizer = TokenizerGroup(
  38. tokenizer_id="facebook/opt-125m",
  39. enable_lora=False,
  40. max_num_seqs=max_num_seqs,
  41. max_input_length=None,
  42. )
  43. hashes: List[List[List[int]]] = []
  44. for prefix in prefixes:
  45. for lora_int_id in concurrent_lora_int_ids:
  46. lora_request = None
  47. if lora_int_id is not None:
  48. lora_request = LoRARequest(
  49. f"example_lora_{lora_int_id}",
  50. lora_int_id,
  51. f"example/path/to/lora_{lora_int_id}",
  52. )
  53. hashes.append([])
  54. prompts = [prefix + prompt for prompt in sample_prompts]
  55. seq_id = 0
  56. for prompt in prompts:
  57. hashes[-1].append([])
  58. prompt_token_ids = tokenizer.encode(prompt)
  59. seq = Sequence(seq_id,
  60. inputs={
  61. "prompt": prompt,
  62. "prompt_token_ids": prompt_token_ids,
  63. },
  64. block_size=block_size,
  65. eos_token_id=tokenizer.tokenizer.eos_token_id,
  66. lora_request=lora_request)
  67. num_blocks = len(prompt_token_ids) // block_size
  68. for idx in range(num_blocks):
  69. hashes[-1][-1].append(seq.hash_of_block(idx))
  70. seq_id += 1
  71. # Check that hashes made with two prefixes with different first blocks are
  72. # different everywhere.
  73. for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):
  74. assert (hash0 != hash1)
  75. # Check that hashes of different prompts made with the same prefix are the
  76. # same until the hashes that contain the prompt.
  77. for hash_pref in hashes:
  78. same_hashes = [tuple(h[:-1]) for h in hash_pref]
  79. different_hashes = [h[-1] for h in hash_pref]
  80. assert (len(set(same_hashes)) == 1)
  81. assert (len(set(different_hashes)) == len(different_hashes))