test_logits_processor.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import random
  2. from typing import Tuple
  3. from unittest.mock import patch
  4. import pytest
  5. import torch
  6. from aphrodite.common.sequence import (SamplingParams, SequenceData,
  7. SequenceGroupMetadata)
  8. from aphrodite.common.utils import is_pin_memory_available
  9. from aphrodite.modeling.layers.logits_processor import LogitsProcessor
  10. from aphrodite.modeling.sampling_metadata import SamplingMetadata
  11. from aphrodite.modeling.utils import set_random_seed
  12. class MockLogitsProcessor(LogitsProcessor):
  13. def __init__(self, vocab_size: int, scale: float,
  14. fake_logits: torch.Tensor):
  15. super().__init__(vocab_size=vocab_size, scale=scale)
  16. self.fake_logits = fake_logits.clone()
  17. def forward(self, *args, **kwargs):
  18. with patch(
  19. "aphrodite.modeling.layers.logits_processor._prune_hidden_states",
  20. lambda x, y: x
  21. ), patch(
  22. "aphrodite.modeling.layers.logits_processor.LogitsProcessor._get_logits",
  23. lambda *args, **kwargs: self.fake_logits):
  24. return super().forward(*args, **kwargs)
  25. def _prepare_test(
  26. batch_size: int
  27. ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
  28. vocab_size = 32000
  29. input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
  30. fake_logits = torch.full((batch_size, vocab_size),
  31. 1e-2,
  32. dtype=input_tensor.dtype)
  33. logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
  34. return input_tensor, fake_logits, logits_processor
  35. RANDOM_SEEDS = list(range(128))
  36. CUDA_DEVICES = [
  37. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  38. ]
  39. @pytest.mark.parametrize("seed", RANDOM_SEEDS)
  40. @pytest.mark.parametrize("device", CUDA_DEVICES)
  41. def test_logits_processors(seed: int, device: str):
  42. set_random_seed(seed)
  43. torch.set_default_device(device)
  44. batch_size = random.randint(1, 256)
  45. input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
  46. # This sample logits processor gives infinite score to the i-th token,
  47. # where i is the length of the input sequence.
  48. # We therefore expect the output token sequence to be [0, 1, 2, ...]
  49. def pick_ith(token_ids, logits):
  50. logits[len(token_ids)] = float("inf")
  51. return logits
  52. seq_group_metadata_list = []
  53. seq_lens = []
  54. for i in range(batch_size):
  55. seq_group_metadata_list.append(
  56. SequenceGroupMetadata(
  57. request_id=f"test_{i}",
  58. is_prompt=True,
  59. seq_data={0: SequenceData([1, 2, 3])},
  60. sampling_params=SamplingParams(temperature=0,
  61. logits_processors=[pick_ith]),
  62. block_tables={0: [1]},
  63. ))
  64. seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
  65. sampling_metadata = SamplingMetadata.prepare(
  66. seq_group_metadata_list,
  67. seq_lens,
  68. query_lens=seq_lens,
  69. device=device,
  70. pin_memory=is_pin_memory_available())
  71. logits_processor_output = logits_processor(
  72. lm_head=None,
  73. hidden_states=input_tensor,
  74. sampling_metadata=sampling_metadata)
  75. assert torch.isinf(logits_processor_output[:, 0]).all()
  76. fake_logits *= logits_processor.scale
  77. torch.testing.assert_close(logits_processor_output[:, 1],
  78. fake_logits[:, 1],
  79. rtol=1e-4,
  80. atol=0.0)