test_chat_template.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import pytest
  2. from aphrodite.endpoints.chat_utils import (apply_chat_template,
  3. load_chat_template)
  4. from aphrodite.endpoints.openai.protocol import ChatCompletionRequest
  5. from aphrodite.transformers_utils.tokenizer import get_tokenizer
  6. from ..utils import APHRODITE_PATH
  7. chatml_jinja_path = APHRODITE_PATH / "examples/chat_templates/chatml.jinja"
  8. assert chatml_jinja_path.exists()
  9. # Define models, templates, and their corresponding expected outputs
  10. MODEL_TEMPLATE_GENERATON_OUTPUT = [
  11. ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user
  12. Hello<|im_end|>
  13. <|im_start|>assistant
  14. Hi there!<|im_end|>
  15. <|im_start|>user
  16. What is the capital of<|im_end|>
  17. <|im_start|>assistant
  18. """),
  19. ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user
  20. Hello<|im_end|>
  21. <|im_start|>assistant
  22. Hi there!<|im_end|>
  23. <|im_start|>user
  24. What is the capital of""")
  25. ]
  26. TEST_MESSAGES = [
  27. {
  28. 'role': 'user',
  29. 'content': 'Hello'
  30. },
  31. {
  32. 'role': 'assistant',
  33. 'content': 'Hi there!'
  34. },
  35. {
  36. 'role': 'user',
  37. 'content': 'What is the capital of'
  38. },
  39. ]
  40. def test_load_chat_template():
  41. # Testing chatml template
  42. template_content = load_chat_template(chat_template=chatml_jinja_path)
  43. # Test assertions
  44. assert template_content is not None
  45. # Hard coded value for template_chatml.jinja
  46. assert template_content == """{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\\n'}}{% endif %}{% endfor %}
  47. {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\\n' }}{% endif %}""" # noqa: E501
  48. def test_no_load_chat_template_filelike():
  49. # Testing chatml template
  50. template = "../../examples/does_not_exist"
  51. with pytest.raises(ValueError, match="looks like a file path"):
  52. load_chat_template(chat_template=template)
  53. def test_no_load_chat_template_literallike():
  54. # Testing chatml template
  55. template = "{{ messages }}"
  56. template_content = load_chat_template(chat_template=template)
  57. assert template_content == template
  58. @pytest.mark.parametrize(
  59. "model,template,add_generation_prompt,expected_output",
  60. MODEL_TEMPLATE_GENERATON_OUTPUT)
  61. def test_get_gen_prompt(model, template, add_generation_prompt,
  62. expected_output):
  63. # Initialize the tokenizer
  64. tokenizer = get_tokenizer(tokenizer_name=model)
  65. template_content = load_chat_template(chat_template=template)
  66. # Create a mock request object using keyword arguments
  67. mock_request = ChatCompletionRequest(
  68. model=model,
  69. messages=TEST_MESSAGES,
  70. add_generation_prompt=add_generation_prompt)
  71. # Call the function and get the result
  72. result = apply_chat_template(
  73. tokenizer,
  74. conversation=mock_request.messages,
  75. chat_template=mock_request.chat_template or template_content,
  76. add_generation_prompt=mock_request.add_generation_prompt,
  77. )
  78. # Test assertion
  79. assert result == expected_output, (
  80. f"The generated prompt does not match the expected output for "
  81. f"model {model} and template {template}")