123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- import os
- import re
- from dataclasses import dataclass
- from pathlib import Path
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
- from huggingface_hub import HfApi, hf_hub_download
- # yapf: disable
- from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
- from mistral_common.tokens.tokenizers.mistral import (
- MistralTokenizer as PublicMistralTokenizer)
- # yapf: enable
- from mistral_common.tokens.tokenizers.sentencepiece import (
- SentencePieceTokenizer)
- from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
- Tekkenizer)
- from aphrodite.common.logger import log_once
- if TYPE_CHECKING:
- from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
- @dataclass
- class Encoding:
- input_ids: List[int]
- def find_tokenizer_file(files: List[str]):
- file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
- matched_files = [file for file in files if file_pattern.match(file)]
- if len(matched_files) > 1:
- raise OSError(f"Found {len(matched_files)} files matching the "
- "pattern: {matched_files}. Make sure only one Mistral "
- "tokenizer is present in {tokenizer_name}.")
- elif len(matched_files) == 0:
- raise OSError(f"Found {len(matched_files)} files matching the "
- "pattern: {matched_files}. Make sure that a Mistral "
- "tokenizer is present in {tokenizer_name}.")
- return matched_files[0]
- class MistralTokenizer:
- def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
- self.mistral = tokenizer
- self.instruct = tokenizer.instruct_tokenizer
- tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
- if isinstance(tokenizer_, Tekkenizer):
- # Make sure special tokens will not raise
- tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
- self._vocab = {
- token: idx
- for idx, token in enumerate(tokenizer_.vocab())
- }
- elif isinstance(tokenizer_, SentencePieceTokenizer):
- self._vocab = {
- token: idx
- for idx, token in enumerate(tokenizer_.vocab())
- }
- else:
- raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
- self.tokenizer = tokenizer_
- @classmethod
- def from_pretrained(cls,
- path_or_repo_id: str,
- *,
- revision: Optional[str] = None) -> "MistralTokenizer":
- if not Path(path_or_repo_id).exists():
- assert len(path_or_repo_id.split("/")) == 2, (
- "You have either provided a non-existent path: "
- "{path_or_repo_id} or an invalid HF Hub repo id.")
- tokenizer_file = cls._download_mistral_tokenizer_from_hf(
- path_or_repo_id, revision)
- elif Path(path_or_repo_id).is_dir():
- tokenizer_file_name = find_tokenizer_file(
- os.listdir(path_or_repo_id))
- tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
- else:
- assert Path(
- path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
- mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
- return cls(mistral_tokenizer)
- @staticmethod
- def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
- revision: Optional[str]) -> str:
- api = HfApi()
- repo_info = api.model_info(tokenizer_name)
- files = [s.rfilename for s in repo_info.siblings]
- filename = find_tokenizer_file(files)
- tokenizer_file = hf_hub_download(tokenizer_name,
- filename=filename,
- revision=revision)
- return tokenizer_file
- # the following attributes are set to fit VLLM's design
- @property
- def all_special_tokens_extended(self) -> List[str]:
- return []
- @property
- def all_special_tokens(self) -> List[str]:
- return []
- @property
- def all_special_ids(self) -> List[int]:
- return []
- @property
- def bos_token_id(self) -> int:
- return self.tokenizer.bos_id
- @property
- def eos_token_id(self) -> int:
- return self.tokenizer.eos_id
- @property
- def is_fast(self) -> bool:
- return True
- @property
- def vocab_size(self) -> int:
- return len(self._vocab)
- def __len__(self) -> int:
- return self.vocab_size
- def __call__(
- self,
- prompt: str,
- add_special_tokens: bool = False,
- truncation: bool = False,
- max_length: Optional[int] = None,
- ):
- # Mistral Tokenizers should not add special tokens
- input_ids = self.encode(prompt)
- if truncation:
- input_ids = input_ids[:max_length]
- return Encoding(input_ids=input_ids)
- def get_vocab(self) -> Dict[str, int]:
- return self._vocab
- def get_added_vocab(self) -> Dict[str, int]:
- # Mistral tokenizers have no added vocabulary
- return {}
- def encode(self, prompt: str) -> List[int]:
- # `encode` should only be used for prompt completion
- # it should never be used for chat_completion.
- # For chat completion use `apply_chat_template`
- return self.tokenizer.encode(prompt, bos=True, eos=False)
- def apply_chat_template(self,
- messages: List["ChatCompletionMessageParam"],
- tools: Optional[Dict[str, Any]] = None,
- **kwargs) -> List[int]:
- request = ChatCompletionRequest(messages=messages,
- tools=tools) # type: ignore[type-var]
- encoded = self.mistral.encode_chat_completion(request)
- # encode-decode to get clean prompt
- return encoded.tokens
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
- if isinstance(self.tokenizer, Tekkenizer):
- tokens = [
- t for t in tokens
- if t not in self.tokenizer._all_special_tokens
- ]
- if any(isinstance(t, bytes) for t in tokens):
- # we need to encode and decode all tokens again
- shift = self.tokenizer.num_special_tokens
- byte_tokens = [
- t.encode("utf-8") if not isinstance(t, bytes) else t
- for t in tokens
- ]
- ids = [
- self.tokenizer._tekken_token2id_nospecial[t] + shift
- for t in byte_tokens
- ]
- decoded = self.tokenizer.decode(ids)
- else:
- decoded = "".join(tokens)
- else:
- decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
- return decoded
- def decode(self, ids: Union[List[int], int]) -> str:
- if isinstance(ids, int):
- ids = [ids]
- return self.tokenizer.decode(ids)
- def convert_ids_to_tokens(
- self,
- ids: List[int],
- skip_special_tokens: bool = True,
- ) -> List[str]:
- # TODO(Patrick) - potentially allow special tokens to not be skipped
- if not skip_special_tokens:
- log_once(
- level="ERROR",
- message="skip_special_tokens=False is not supported for "
- "Mistral tokenizers.")
- assert isinstance(self.tokenizer,
- (Tekkenizer, SentencePieceTokenizer)), type(
- self.tokenizer)
- tokens = [self.tokenizer.id_to_piece(id) for id in ids]
- if any(t.strip() == "�" for t in tokens):
- # if any stripped decoded token is undefined
- # because it's invalid unicode then pass bytes
- tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
- return tokens
|