mistral.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import os
  2. import re
  3. from dataclasses import dataclass
  4. from pathlib import Path
  5. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
  6. from huggingface_hub import HfApi, hf_hub_download
  7. # yapf: disable
  8. from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest
  9. from mistral_common.tokens.tokenizers.mistral import (
  10. MistralTokenizer as PublicMistralTokenizer)
  11. # yapf: enable
  12. from mistral_common.tokens.tokenizers.sentencepiece import (
  13. SentencePieceTokenizer)
  14. from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy,
  15. Tekkenizer)
  16. from aphrodite.common.logger import log_once
  17. if TYPE_CHECKING:
  18. from aphrodite.endpoints.chat_utils import ConversationMessage
  19. @dataclass
  20. class Encoding:
  21. input_ids: List[int]
  22. def find_tokenizer_file(files: List[str]):
  23. file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$")
  24. matched_files = [file for file in files if file_pattern.match(file)]
  25. if len(matched_files) > 1:
  26. raise OSError(f"Found {len(matched_files)} files matching the "
  27. "pattern: {matched_files}. Make sure only one Mistral "
  28. "tokenizer is present in {tokenizer_name}.")
  29. elif len(matched_files) == 0:
  30. raise OSError(f"Found {len(matched_files)} files matching the "
  31. "pattern: {matched_files}. Make sure that a Mistral "
  32. "tokenizer is present in {tokenizer_name}.")
  33. return matched_files[0]
  34. class MistralTokenizer:
  35. def __init__(self, tokenizer: PublicMistralTokenizer) -> None:
  36. self.mistral = tokenizer
  37. self.instruct = tokenizer.instruct_tokenizer
  38. self.tokenizer = tokenizer.instruct_tokenizer.tokenizer
  39. self.vocab_size = len(self.tokenizer.vocab())
  40. assert isinstance(self.tokenizer,
  41. (Tekkenizer, SentencePieceTokenizer)), type(
  42. self.tokenizer)
  43. self._is_tekken = isinstance(self.tokenizer, Tekkenizer)
  44. if self._is_tekken:
  45. # Make sure special tokens will not raise
  46. self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE
  47. # the following attributes are set to fit VLLM's design
  48. self.is_fast = True
  49. self.chat_template = True
  50. self.all_special_ids: List[Any] = []
  51. self.all_special_tokens: List[Any] = []
  52. self.all_special_tokens_extended: List[Any] = []
  53. @classmethod
  54. def from_pretrained(cls,
  55. path_or_repo_id: str,
  56. *,
  57. revision: Optional[str] = None) -> "MistralTokenizer":
  58. if not Path(path_or_repo_id).exists():
  59. assert len(path_or_repo_id.split("/")) == 2, (
  60. "You have either provided a non-existent path: "
  61. "{path_or_repo_id} or an invalid HF Hub repo id.")
  62. tokenizer_file = cls._download_mistral_tokenizer_from_hf(
  63. path_or_repo_id, revision)
  64. elif Path(path_or_repo_id).is_dir():
  65. tokenizer_file_name = find_tokenizer_file(
  66. os.listdir(path_or_repo_id))
  67. tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
  68. else:
  69. assert Path(
  70. path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
  71. mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
  72. return cls(mistral_tokenizer)
  73. @staticmethod
  74. def _download_mistral_tokenizer_from_hf(tokenizer_name: str,
  75. revision: Optional[str]) -> str:
  76. api = HfApi()
  77. repo_info = api.model_info(tokenizer_name)
  78. files = [s.rfilename for s in repo_info.siblings]
  79. filename = find_tokenizer_file(files)
  80. tokenizer_file = hf_hub_download(tokenizer_name,
  81. filename=filename,
  82. revision=revision)
  83. return tokenizer_file
  84. def __call__(
  85. self,
  86. prompt: str,
  87. add_special_tokens: bool = False,
  88. truncation: bool = False,
  89. max_length: Optional[int] = None,
  90. ):
  91. # Mistral Tokenizers should not add special tokens
  92. input_ids = self.encode(prompt)
  93. if truncation:
  94. input_ids = input_ids[:max_length]
  95. return Encoding(input_ids=input_ids)
  96. def get_added_vocab(self) -> List[str]:
  97. # Mistral tokenizers have no added vocabulary
  98. return []
  99. def encode(self, prompt: str) -> List[int]:
  100. # `encode ` should only be used for prompt completion
  101. # it should never be used for chat_completion.
  102. # For chat completion use `apply_chat_template`
  103. return self.tokenizer.encode(prompt, bos=True, eos=False)
  104. def apply_chat_template(self,
  105. conversation: List["ConversationMessage"],
  106. tools: Optional[Dict[str, Any]] = None,
  107. **kwargs) -> List[int]:
  108. assert tools is None, "`tools` are not yet supported."
  109. request = ChatCompletionRequest(
  110. messages=conversation) # type: ignore[type-var]
  111. encoded = self.mistral.encode_chat_completion(request)
  112. # encode-decode to get clean prompt
  113. return encoded.tokens
  114. def convert_tokens_to_string(self, tokens: List[str]) -> str:
  115. if self._is_tekken:
  116. return "".join(tokens)
  117. else:
  118. return self.tokenizer.decode(tokens) # type: ignore[arg-type]
  119. def decode(self, ids: Union[List[int], int]) -> str:
  120. if isinstance(ids, int):
  121. ids = [ids]
  122. return self.tokenizer.decode(ids)
  123. @property
  124. def eos_token_id(self):
  125. return self.tokenizer.eos_id
  126. def convert_ids_to_tokens(
  127. self,
  128. ids: List[int],
  129. skip_special_tokens: Optional[bool] = True) -> List[str]:
  130. # TODO(Patrick) - potentially allow special tokens to not be skipped
  131. if not skip_special_tokens:
  132. log_once(
  133. level="ERROR",
  134. message="skip_special_tokens=False is not supported for "
  135. "Mistral tokenizers.")
  136. assert isinstance(self.tokenizer,
  137. (Tekkenizer, SentencePieceTokenizer)), type(
  138. self.tokenizer)
  139. tokens = [self.tokenizer.id_to_piece(id) for id in ids]
  140. return tokens
  141. def __len__(self):
  142. return self.vocab_size