test_tokenizer_group.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import pytest
  2. from transformers import AutoTokenizer, PreTrainedTokenizerBase
  3. from aphrodite.lora.request import LoRARequest
  4. from aphrodite.transformers_utils.tokenizer import get_lora_tokenizer
  5. from aphrodite.transformers_utils.tokenizer_group import get_tokenizer_group
  6. from ..conftest import get_tokenizer_pool_config
  7. @pytest.mark.asyncio
  8. @pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
  9. async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
  10. reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
  11. tokenizer_group = get_tokenizer_group(
  12. get_tokenizer_pool_config(tokenizer_group_type),
  13. tokenizer_id="gpt2",
  14. enable_lora=True,
  15. max_num_seqs=1,
  16. max_input_length=None,
  17. )
  18. lora_request = LoRARequest("1", 1, sql_lora_files)
  19. assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
  20. request_id="request_id", prompt="prompt", lora_request=lora_request)
  21. assert reference_tokenizer.encode(
  22. "prompt") == await tokenizer_group.encode_async(
  23. request_id="request_id",
  24. prompt="prompt",
  25. lora_request=lora_request)
  26. assert isinstance(tokenizer_group.get_lora_tokenizer(None),
  27. PreTrainedTokenizerBase)
  28. assert tokenizer_group.get_lora_tokenizer(
  29. None) == await tokenizer_group.get_lora_tokenizer_async(None)
  30. assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request),
  31. PreTrainedTokenizerBase)
  32. assert tokenizer_group.get_lora_tokenizer(
  33. lora_request) != tokenizer_group.get_lora_tokenizer(None)
  34. assert tokenizer_group.get_lora_tokenizer(
  35. lora_request) == await tokenizer_group.get_lora_tokenizer_async(
  36. lora_request)
  37. def test_get_lora_tokenizer(sql_lora_files, tmpdir):
  38. lora_request = None
  39. tokenizer = get_lora_tokenizer(lora_request)
  40. assert not tokenizer
  41. lora_request = LoRARequest("1", 1, sql_lora_files)
  42. tokenizer = get_lora_tokenizer(lora_request)
  43. assert tokenizer.get_added_vocab()
  44. lora_request = LoRARequest("1", 1, str(tmpdir))
  45. tokenizer = get_lora_tokenizer(lora_request)
  46. assert not tokenizer