1
0

mistral.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. import re
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
  6. from huggingface_hub import HfApi, hf_hub_download
  7. # yapf: disable
  8. from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
  9. from mistral_common.tokens.tokenizers.mistral import (
  10. MistralTokenizer as PublicMistralTokenizer)
  11. # yapf: enable
  12. from mistral_common.tokens.tokenizers.sentencepiece import (
  13. SentencePieceTokenizer)
  14. from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
  15. Tekkenizer)
  16. if TYPE_CHECKING:
  17. from aphrodite.endpoints.chat_utils import ConversationMessage
  18. @dataclass
  19. class Encoding:
  20. input_ids: List[int]
  21. def find_tokenizer_file(files: List[str]):
  22. file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
  23. matched_files = [file for file in files if file_pattern.match(file)]
  24. if len(matched_files) > 1:
  25. raise OSError(f"Found {len(matched_files)} files matching the "
  26. "pattern: {matched_files}. Make sure only one Mistral "
  27. "tokenizer is present in {tokenizer_name}.")
  28. elif len(matched_files) == 0:
  29. raise OSError(f"Found {len(matched_files)} files matching the "
  30. "pattern: {matched_files}. Make sure that a Mistral "
  31. "tokenizer is present in {tokenizer_name}.")
  32. return matched_files[0]
  33. class MistralTokenizer:
  34. def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
  35. self.mistral = tokenizer
  36. self.instruct = tokenizer.instruct_tokenizer
  37. self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
  38. self.vocab_size = len(self.tokenizer.vocab())
  39. assert isinstance(self.tokenizer,
  40. (Tekkenizer, SentencePieceTokenizer)), type(
  41. self.tokenizer)
  42. self._is_tekken = isinstance(self.tokenizer, Tekkenizer)
  43. if self._is_tekken:
  44. # Make sure special tokens will not raise
  45. self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
  46. # the following attributes are set to fit VLLM's design
  47. self.is_fast = True
  48. self.chat_template = True
  49. self.all_special_ids: List[Any] = []
  50. self.all_special_tokens: List[Any] = []
  51. self.all_special_tokens_extended: List[Any] = []
  52. @classmethod
  53. def from_pretrained(cls,
  54. path_or_repo_id: str,
  55. *,
  56. revision: Optional[str] = None) -> "MistralTokenizer":
  57. if not Path(path_or_repo_id).exists():
  58. assert len(path_or_repo_id.split("/")) == 2, (
  59. "You have either provided a non-existent path: "
  60. "{path_or_repo_id} or an invalid HF Hub repo id.")
  61. tokenizer_file = cls._download_mistral_tokenizer_from_hf(
  62. path_or_repo_id, revision)
  63. elif Path(path_or_repo_id).is_dir():
  64. tokenizer_file_name = find_tokenizer_file(
  65. os.listdir(path_or_repo_id))
  66. tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
  67. else:
  68. assert Path(
  69. path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
  70. mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
  71. return cls(mistral_tokenizer)
  72. @staticmethod
  73. def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
  74. revision: Optional[str]) -> str:
  75. api = HfApi()
  76. repo_info = api.model_info(tokenizer_name)
  77. files = [s.rfilename for s in repo_info.siblings]
  78. filename = find_tokenizer_file(files)
  79. tokenizer_file = hf_hub_download(tokenizer_name,
  80. filename=filename,
  81. revision=revision)
  82. return tokenizer_file
  83. def __call__(
  84. self,
  85. prompt: str,
  86. add_special_tokens: bool = False,
  87. truncation: bool = False,
  88. max_length: Optional[int] = None,
  89. ):
  90. # Mistral Tokenizers should not add special tokens
  91. input_ids = self.encode(prompt)
  92. if truncation:
  93. input_ids = input_ids[:max_length]
  94. return Encoding(input_ids=input_ids)
  95. def get_added_vocab(self) -> List[str]:
  96. # Mistral tokenizers have no added vocabulary
  97. return []
  98. def encode(self, prompt: str) -> List[int]:
  99. # `encode ` should only be used for prompt completion
  100. # it should never be used for chat_completion.
  101. # For chat completion use `apply_chat_template`
  102. return self.tokenizer.encode(prompt, bos=True, eos=False)
  103. def apply_chat_template(self,
  104. conversation: List["ConversationMessage"],
  105. tools: Optional[Dict[str, Any]] = None,
  106. **kwargs) -> List[int]:
  107. assert tools is None, "`tools` are not yet supported."
  108. request = ChatCompletionRequest(
  109. messages=conversation) # type: ignore[type-var]
  110. encoded = self.mistral.encode_chat_completion(request)
  111. # encode-decode to get clean prompt
  112. return encoded.tokens
  113. def convert_tokens_to_string(self, tokens: List[str]) -> str:
  114. if self._is_tekken:
  115. return "".join(tokens)
  116. else:
  117. return self.tokenizer.decode(tokens) # type: ignore[arg-type]
  118. def decode(self, ids: Union[List[int], int]) -> str:
  119. if isinstance(ids, int):
  120. ids = [ids]
  121. return self.tokenizer.decode(ids)
  122. @property
  123. def eos_token_id(self):
  124. return self.tokenizer.eos_id
  125. def convert_ids_to_tokens(
  126. self,
  127. ids: List[int],
  128. skip_special_tokens: Optional[bool] = True) -> List[str]:
  129. # TODO(Patrick) - potentially allow special tokens to not be skipped
  130. assert (
  131. skip_special_tokens
  132. ), "Skipping special tokens is not supported for Mistral tokenizers."
  133. assert isinstance(self.tokenizer,
  134. (Tekkenizer, SentencePieceTokenizer)), type(
  135. self.tokenizer)
  136. tokens = [self.tokenizer.id_to_piece(id) for id in ids]
  137. return tokens
  138. def __len__(self):
  139. return self.vocab_size