test_cached_tokenizer.py 901 B

12345678910111213141516171819202122
  1. from copy import deepcopy
  2. from transformers import AutoTokenizer
  3. from aphrodite.transformers_utils.tokenizer import get_cached_tokenizer
  4. def test_cached_tokenizer():
  5. reference_tokenizer = AutoTokenizer.from_pretrained("gpt2")
  6. reference_tokenizer.add_special_tokens({"cls_token": "<CLS>"})
  7. reference_tokenizer.add_special_tokens(
  8. {"additional_special_tokens": ["<SEP>"]})
  9. cached_tokenizer = get_cached_tokenizer(deepcopy(reference_tokenizer))
  10. assert reference_tokenizer.encode("prompt") == cached_tokenizer.encode(
  11. "prompt")
  12. assert set(reference_tokenizer.all_special_ids) == set(
  13. cached_tokenizer.all_special_ids)
  14. assert set(reference_tokenizer.all_special_tokens) == set(
  15. cached_tokenizer.all_special_tokens)
  16. assert set(reference_tokenizer.all_special_tokens_extended) == set(
  17. cached_tokenizer.all_special_tokens_extended)