1
0

test_guided_processors.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # This unit test should be moved to a new
  2. # tests/test_guided_decoding directory.
  3. import pytest
  4. import torch
  5. from transformers import AutoTokenizer
  6. from aphrodite.endpoints.openai.protocol import CompletionRequest
  7. from aphrodite.modeling.guided_decoding import (
  8. get_guided_decoding_logits_processor)
  9. from aphrodite.modeling.guided_decoding.outlines_logits_processors import (
  10. JSONLogitsProcessor, RegexLogitsProcessor)
  11. def test_guided_logits_processors(sample_regex, sample_json_schema):
  12. """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
  13. tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
  14. regex_LP = RegexLogitsProcessor(sample_regex, tokenizer)
  15. json_LP = JSONLogitsProcessor(sample_json_schema,
  16. tokenizer,
  17. whitespace_pattern=None)
  18. token_ids = tokenizer.encode(
  19. f"Give an example IPv4 address with this regex: {sample_regex}")
  20. tensor = torch.rand(32000)
  21. original_tensor = torch.clone(tensor)
  22. regex_LP(token_ids, tensor)
  23. assert tensor.shape == original_tensor.shape
  24. assert not torch.allclose(tensor, original_tensor)
  25. token_ids = tokenizer.encode(
  26. f"Give an employee profile that fits this schema: {sample_json_schema}"
  27. )
  28. tensor = torch.rand(32000)
  29. original_tensor = torch.clone(tensor)
  30. json_LP(token_ids, tensor)
  31. assert tensor.shape == original_tensor.shape
  32. assert not torch.allclose(tensor, original_tensor)
  33. @pytest.mark.asyncio
  34. @pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
  35. async def test_guided_logits_processor_black_box(backend: str, sample_regex,
  36. sample_json_schema):
  37. tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
  38. token_ids = tokenizer.encode(
  39. f"Give an example IPv4 address with this regex: {sample_regex}")
  40. regex_request = CompletionRequest(model='test',
  41. prompt=token_ids,
  42. guided_regex=sample_regex)
  43. regex_lp = await get_guided_decoding_logits_processor(
  44. backend, regex_request, tokenizer)
  45. assert regex_lp is not None
  46. tensor = torch.rand(32000)
  47. original_tensor = torch.clone(tensor)
  48. tensor = regex_lp(token_ids, tensor)
  49. assert tensor.shape == original_tensor.shape
  50. assert not torch.allclose(tensor, original_tensor)
  51. token_ids = tokenizer.encode(
  52. f"Give an employee profile that fits this schema: {sample_json_schema}"
  53. )
  54. json_request = CompletionRequest(model='test',
  55. prompt=token_ids,
  56. guided_json=sample_json_schema)
  57. json_lp = await get_guided_decoding_logits_processor(
  58. backend, json_request, tokenizer)
  59. assert json_lp is not None
  60. tensor = torch.rand(32000)
  61. original_tensor = torch.clone(tensor)
  62. tensor = json_lp(token_ids, tensor)
  63. assert tensor.shape == original_tensor.shape
  64. assert not torch.allclose(tensor, original_tensor)