test_outlines.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # This unit test should be moved to a new
  2. # tests/test_guided_decoding directory.
  3. from transformers import AutoTokenizer
  4. import torch
  5. from aphrodite.modeling.outlines_logits_processors import (
  6. RegexLogitsProcessor, JSONLogitsProcessor)
  7. TEST_SCHEMA = {
  8. "type": "object",
  9. "properties": {
  10. "name": {
  11. "type": "string"
  12. },
  13. "age": {
  14. "type": "integer"
  15. },
  16. "skills": {
  17. "type": "array",
  18. "items": {
  19. "type": "string",
  20. "maxLength": 10
  21. },
  22. "minItems": 3
  23. },
  24. "work history": {
  25. "type": "array",
  26. "items": {
  27. "type": "object",
  28. "properties": {
  29. "company": {
  30. "type": "string"
  31. },
  32. "duration": {
  33. "type": "string"
  34. },
  35. "position": {
  36. "type": "string"
  37. }
  38. },
  39. "required": ["company", "position"]
  40. }
  41. }
  42. },
  43. "required": ["name", "age", "skills", "work history"]
  44. }
  45. TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \
  46. r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)"
  47. def test_guided_logits_processors():
  48. """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
  49. tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
  50. regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer)
  51. json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer)
  52. regex_LP.init_state()
  53. token_ids = tokenizer.encode(
  54. f"Give an example IPv4 address with this regex: {TEST_REGEX}")
  55. tensor = torch.rand(32000)
  56. original_tensor = torch.clone(tensor)
  57. regex_LP(token_ids, tensor)
  58. assert tensor.shape == original_tensor.shape
  59. assert not torch.allclose(tensor, original_tensor)
  60. json_LP.init_state()
  61. token_ids = tokenizer.encode(
  62. f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
  63. tensor = torch.rand(32000)
  64. original_tensor = torch.clone(tensor)
  65. json_LP(token_ids, tensor)
  66. assert tensor.shape == original_tensor.shape
  67. assert not torch.allclose(tensor, original_tensor)