mistral.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import os
  2. import re
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
  6. from huggingface_hub import HfApi, hf_hub_download
  7. # yapf: disable
  8. from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
  9. from mistral_common.tokens.tokenizers.mistral import (
  10. MistralTokenizer as PublicMistralTokenizer)
  11. # yapf: enable
  12. from mistral_common.tokens.tokenizers.sentencepiece import (
  13. SentencePieceTokenizer)
  14. from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
  15. Tekkenizer)
  16. from aphrodite.common.logger import log_once
  17. if TYPE_CHECKING:
  18. from aphrodite.endpoints.chat_utils import ChatCompletionMessageParam
  19. @dataclass
  20. class Encoding:
  21. input_ids: List[int]
  22. def find_tokenizer_file(files: List[str]):
  23. file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
  24. matched_files = [file for file in files if file_pattern.match(file)]
  25. if len(matched_files) > 1:
  26. raise OSError(f"Found {len(matched_files)} files matching the "
  27. "pattern: {matched_files}. Make sure only one Mistral "
  28. "tokenizer is present in {tokenizer_name}.")
  29. elif len(matched_files) == 0:
  30. raise OSError(f"Found {len(matched_files)} files matching the "
  31. "pattern: {matched_files}. Make sure that a Mistral "
  32. "tokenizer is present in {tokenizer_name}.")
  33. return matched_files[0]
  34. class MistralTokenizer:
  35. def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
  36. self.mistral = tokenizer
  37. self.instruct = tokenizer.instruct_tokenizer
  38. tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
  39. if isinstance(tokenizer_, Tekkenizer):
  40. # Make sure special tokens will not raise
  41. tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE
  42. self._vocab = {
  43. token: idx
  44. for idx, token in enumerate(tokenizer_.vocab())
  45. }
  46. elif isinstance(tokenizer_, SentencePieceTokenizer):
  47. self._vocab = {
  48. token: idx
  49. for idx, token in enumerate(tokenizer_.vocab())
  50. }
  51. else:
  52. raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
  53. self.tokenizer = tokenizer_
  54. @classmethod
  55. def from_pretrained(cls,
  56. path_or_repo_id: str,
  57. *,
  58. revision: Optional[str] = None) -> "MistralTokenizer":
  59. if not Path(path_or_repo_id).exists():
  60. assert len(path_or_repo_id.split("/")) == 2, (
  61. "You have either provided a non-existent path: "
  62. "{path_or_repo_id} or an invalid HF Hub repo id.")
  63. tokenizer_file = cls._download_mistral_tokenizer_from_hf(
  64. path_or_repo_id, revision)
  65. elif Path(path_or_repo_id).is_dir():
  66. tokenizer_file_name = find_tokenizer_file(
  67. os.listdir(path_or_repo_id))
  68. tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
  69. else:
  70. assert Path(
  71. path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
  72. mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
  73. return cls(mistral_tokenizer)
  74. @staticmethod
  75. def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
  76. revision: Optional[str]) -> str:
  77. api = HfApi()
  78. repo_info = api.model_info(tokenizer_name)
  79. files = [s.rfilename for s in repo_info.siblings]
  80. filename = find_tokenizer_file(files)
  81. tokenizer_file = hf_hub_download(tokenizer_name,
  82. filename=filename,
  83. revision=revision)
  84. return tokenizer_file
  85. # the following attributes are set to fit VLLM's design
  86. @property
  87. def all_special_tokens_extended(self) -> List[str]:
  88. return []
  89. @property
  90. def all_special_tokens(self) -> List[str]:
  91. return []
  92. @property
  93. def all_special_ids(self) -> List[int]:
  94. return []
  95. @property
  96. def bos_token_id(self) -> int:
  97. return self.tokenizer.bos_id
  98. @property
  99. def eos_token_id(self) -> int:
  100. return self.tokenizer.eos_id
  101. @property
  102. def is_fast(self) -> bool:
  103. return True
  104. @property
  105. def vocab_size(self) -> int:
  106. return len(self._vocab)
  107. def __len__(self) -> int:
  108. return self.vocab_size
  109. def __call__(
  110. self,
  111. prompt: str,
  112. add_special_tokens: bool = False,
  113. truncation: bool = False,
  114. max_length: Optional[int] = None,
  115. ):
  116. # Mistral Tokenizers should not add special tokens
  117. input_ids = self.encode(prompt)
  118. if truncation:
  119. input_ids = input_ids[:max_length]
  120. return Encoding(input_ids=input_ids)
  121. def get_vocab(self) -> Dict[str, int]:
  122. return self._vocab
  123. def get_added_vocab(self) -> Dict[str, int]:
  124. # Mistral tokenizers have no added vocabulary
  125. return {}
  126. def encode(self, prompt: str) -> List[int]:
  127. # `encode` should only be used for prompt completion
  128. # it should never be used for chat_completion.
  129. # For chat completion use `apply_chat_template`
  130. return self.tokenizer.encode(prompt, bos=True, eos=False)
  131. def apply_chat_template(self,
  132. messages: List["ChatCompletionMessageParam"],
  133. tools: Optional[Dict[str, Any]] = None,
  134. **kwargs) -> List[int]:
  135. request = ChatCompletionRequest(messages=messages,
  136. tools=tools) # type: ignore[type-var]
  137. encoded = self.mistral.encode_chat_completion(request)
  138. # encode-decode to get clean prompt
  139. return encoded.tokens
  140. def convert_tokens_to_string(self, tokens: List[str]) -> str:
  141. if isinstance(self.tokenizer, Tekkenizer):
  142. tokens = [
  143. t for t in tokens
  144. if t not in self.tokenizer._all_special_tokens
  145. ]
  146. if any(isinstance(t, bytes) for t in tokens):
  147. # we need to encode and decode all tokens again
  148. shift = self.tokenizer.num_special_tokens
  149. byte_tokens = [
  150. t.encode("utf-8") if not isinstance(t, bytes) else t
  151. for t in tokens
  152. ]
  153. ids = [
  154. self.tokenizer._tekken_token2id_nospecial[t] + shift
  155. for t in byte_tokens
  156. ]
  157. decoded = self.tokenizer.decode(ids)
  158. else:
  159. decoded = "".join(tokens)
  160. else:
  161. decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
  162. return decoded
  163. def decode(self, ids: Union[List[int], int]) -> str:
  164. if isinstance(ids, int):
  165. ids = [ids]
  166. return self.tokenizer.decode(ids)
  167. def convert_ids_to_tokens(
  168. self,
  169. ids: List[int],
  170. skip_special_tokens: bool = True,
  171. ) -> List[str]:
  172. # TODO(Patrick) - potentially allow special tokens to not be skipped
  173. if not skip_special_tokens:
  174. log_once(
  175. level="ERROR",
  176. message="skip_special_tokens=False is not supported for "
  177. "Mistral tokenizers.")
  178. assert isinstance(self.tokenizer,
  179. (Tekkenizer, SentencePieceTokenizer)), type(
  180. self.tokenizer)
  181. tokens = [self.tokenizer.id_to_piece(id) for id in ids]
  182. if any(t.strip() == "�" for t in tokens):
  183. # if any stripped decoded token is undefined
  184. # because it's invalid unicode then pass bytes
  185. tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids]
  186. return tokens