12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455 |
- import pytest
- from transformers import AutoTokenizer, PreTrainedTokenizerBase
- from aphrodite.lora.request import LoRARequest
- from aphrodite.transformers_utils.tokenizer import get_lora_tokenizer
- from aphrodite.transformers_utils.tokenizer_group import get_tokenizer_group
- from ..conftest import get_tokenizer_pool_config
- @pytest.mark.asyncio
- @pytest.mark.parametrize("tokenizer_group_type", [None, "ray"])
- async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
- reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files)
- tokenizer_group = get_tokenizer_group(
- get_tokenizer_pool_config(tokenizer_group_type),
- tokenizer_id="gpt2",
- enable_lora=True,
- max_num_seqs=1,
- max_input_length=None,
- )
- lora_request = LoRARequest("1", 1, sql_lora_files)
- assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
- request_id="request_id", prompt="prompt", lora_request=lora_request)
- assert reference_tokenizer.encode(
- "prompt") == await tokenizer_group.encode_async(
- request_id="request_id",
- prompt="prompt",
- lora_request=lora_request)
- assert isinstance(tokenizer_group.get_lora_tokenizer(None),
- PreTrainedTokenizerBase)
- assert tokenizer_group.get_lora_tokenizer(
- None) == await tokenizer_group.get_lora_tokenizer_async(None)
- assert isinstance(tokenizer_group.get_lora_tokenizer(lora_request),
- PreTrainedTokenizerBase)
- assert tokenizer_group.get_lora_tokenizer(
- lora_request) != tokenizer_group.get_lora_tokenizer(None)
- assert tokenizer_group.get_lora_tokenizer(
- lora_request) == await tokenizer_group.get_lora_tokenizer_async(
- lora_request)
- def test_get_lora_tokenizer(sql_lora_files, tmpdir):
- lora_request = None
- tokenizer = get_lora_tokenizer(lora_request)
- assert not tokenizer
- lora_request = LoRARequest("1", 1, sql_lora_files)
- tokenizer = get_lora_tokenizer(lora_request)
- assert tokenizer.get_added_vocab()
- lora_request = LoRARequest("1", 1, str(tmpdir))
- tokenizer = get_lora_tokenizer(lora_request)
- assert not tokenizer
|