baichuan.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # yapf: disable
  2. # Adapted from
  3. # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/8f6e343d545c503b91429582231d1d354dac2740/tokenization_baichuan.py
  4. # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
  5. import os
  6. from shutil import copyfile
  7. from typing import Any, Dict, List, Optional, Tuple
  8. import sentencepiece as spm
  9. from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
  10. from transformers.utils import logging
  11. logger = logging.get_logger(__name__)
  12. VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
  13. PRETRAINED_VOCAB_FILES_MAP = {
  14. "vocab_file": {},
  15. "tokenizer_file": {},
  16. }
  17. PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
  18. class BaichuanTokenizer(PreTrainedTokenizer):
  19. """
  20. Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding.
  21. Args:
  22. vocab_file (`str`):
  23. Path to the vocabulary file.
  24. """
  25. vocab_files_names = VOCAB_FILES_NAMES
  26. pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
  27. max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
  28. model_input_names = ["input_ids", "attention_mask"]
  29. def __init__(
  30. self,
  31. vocab_file,
  32. unk_token="<unk>",
  33. bos_token="<s>",
  34. eos_token="</s>",
  35. pad_token=None,
  36. sp_model_kwargs: Optional[Dict[str, Any]] = None,
  37. add_bos_token=True,
  38. add_eos_token=False,
  39. clean_up_tokenization_spaces=False,
  40. **kwargs,
  41. ):
  42. if sp_model_kwargs is None:
  43. self.sp_model_kwargs = {}
  44. else:
  45. self.sp_model_kwargs = sp_model_kwargs
  46. bos_token = (
  47. AddedToken(bos_token, lstrip=False, rstrip=False)
  48. if isinstance(bos_token, str)
  49. else bos_token
  50. )
  51. eos_token = (
  52. AddedToken(eos_token, lstrip=False, rstrip=False)
  53. if isinstance(eos_token, str)
  54. else eos_token
  55. )
  56. unk_token = (
  57. AddedToken(unk_token, lstrip=False, rstrip=False)
  58. if isinstance(unk_token, str)
  59. else unk_token
  60. )
  61. pad_token = (
  62. AddedToken(pad_token, lstrip=False, rstrip=False)
  63. if isinstance(pad_token, str)
  64. else pad_token
  65. )
  66. self.vocab_file = vocab_file
  67. self.add_bos_token = add_bos_token
  68. self.add_eos_token = add_eos_token
  69. self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
  70. self.sp_model.Load(vocab_file)
  71. super().__init__(
  72. bos_token=bos_token,
  73. eos_token=eos_token,
  74. unk_token=unk_token,
  75. pad_token=pad_token,
  76. add_bos_token=add_bos_token,
  77. add_eos_token=add_eos_token,
  78. sp_model_kwargs=self.sp_model_kwargs,
  79. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  80. **kwargs,
  81. )
  82. def __getstate__(self):
  83. state = self.__dict__.copy()
  84. state["sp_model"] = None
  85. return state
  86. def __setstate__(self, d):
  87. self.__dict__ = d
  88. self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
  89. self.sp_model.Load(self.vocab_file)
  90. @property
  91. def vocab_size(self):
  92. """Returns vocab size"""
  93. return self.sp_model.get_piece_size()
  94. def get_vocab(self):
  95. """Returns vocab as a dict"""
  96. vocab = {self.convert_ids_to_tokens(i): i for i in range(
  97. self.vocab_size)}
  98. vocab.update(self.added_tokens_encoder)
  99. return vocab
  100. def _tokenize(self, text):
  101. """Returns a tokenized string."""
  102. return self.sp_model.encode(text, out_type=str)
  103. def _convert_token_to_id(self, token):
  104. """Converts a token (str) in an id using the vocab."""
  105. return self.sp_model.piece_to_id(token)
  106. def _convert_id_to_token(self, index):
  107. """Converts an index (integer) in a token (str) using the vocab."""
  108. token = self.sp_model.IdToPiece(index)
  109. return token
  110. def convert_tokens_to_string(self, tokens):
  111. """Converts a sequence of tokens (string) in a single string."""
  112. current_sub_tokens = []
  113. out_string = ""
  114. prev_is_special = False
  115. for i, token in enumerate(tokens):
  116. # make sure that special tokens are not decoded using
  117. # sentencepiece model
  118. if token in self.all_special_tokens:
  119. if not prev_is_special and i != 0:
  120. out_string += " "
  121. out_string += self.sp_model.decode(current_sub_tokens) + token
  122. prev_is_special = True
  123. current_sub_tokens = []
  124. else:
  125. current_sub_tokens.append(token)
  126. prev_is_special = False
  127. out_string += self.sp_model.decode(current_sub_tokens)
  128. return out_string
  129. def save_vocabulary(
  130. self, save_directory, filename_prefix: Optional[str] = None
  131. ) -> Tuple[str]:
  132. """
  133. Save the vocabulary and special tokens file to a directory.
  134. Args:
  135. save_directory (`str`):
  136. The directory in which to save the vocabulary.
  137. Returns:
  138. `Tuple(str)`: Paths to the files saved.
  139. """
  140. if not os.path.isdir(save_directory):
  141. logger.error(f"Vocabulary path ({save_directory}) should be"
  142. " a directory")
  143. return
  144. out_vocab_file = os.path.join(
  145. save_directory,
  146. (filename_prefix + "-" if filename_prefix else "")
  147. + VOCAB_FILES_NAMES["vocab_file"],
  148. )
  149. if os.path.abspath(self.vocab_file) != os.path.abspath(
  150. out_vocab_file
  151. ) and os.path.isfile(self.vocab_file):
  152. copyfile(self.vocab_file, out_vocab_file)
  153. elif not os.path.isfile(self.vocab_file):
  154. with open(out_vocab_file, "wb") as fi:
  155. content_spiece_model = self.sp_model.serialized_model_proto()
  156. fi.write(content_spiece_model)
  157. return (out_vocab_file,)
  158. def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
  159. bos_token_id = [self.bos_token_id] if self.add_bos_token else []
  160. eos_token_id = [self.eos_token_id] if self.add_eos_token else []
  161. output = bos_token_id + token_ids_0 + eos_token_id
  162. if token_ids_1 is not None:
  163. output = output + bos_token_id + token_ids_1 + eos_token_id
  164. return output
  165. def get_special_tokens_mask(
  166. self,
  167. token_ids_0: List[int],
  168. token_ids_1: Optional[List[int]] = None,
  169. already_has_special_tokens: bool = False,
  170. ) -> List[int]:
  171. """
  172. Retrieve sequence ids from a token list that has no special tokens
  173. added. This method is called when adding special tokens using the
  174. tokenizer `prepare_for_model` method.
  175. Args:
  176. token_ids_0 (`List[int]`):
  177. List of IDs.
  178. token_ids_1 (`List[int]`, *optional*):
  179. Optional second list of IDs for sequence pairs.
  180. already_has_special_tokens(`bool`, *optional*, defaults to `False`):
  181. Whether or not the token list is already formatted with special
  182. tokens for the model.
  183. Returns:
  184. `List[int]`: A list of integers in the range [0, 1]: 1 for a
  185. special token, 0 for a sequence token.
  186. """
  187. if already_has_special_tokens:
  188. return super().get_special_tokens_mask(
  189. token_ids_0=token_ids_0,
  190. token_ids_1=token_ids_1,
  191. already_has_special_tokens=True,
  192. )
  193. bos_token_id = [1] if self.add_bos_token else []
  194. eos_token_id = [1] if self.add_eos_token else []
  195. if token_ids_1 is None:
  196. return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
  197. return (
  198. bos_token_id
  199. + ([0] * len(token_ids_0))
  200. + eos_token_id
  201. + bos_token_id
  202. + ([0] * len(token_ids_1))
  203. + eos_token_id
  204. )
  205. def create_token_type_ids_from_sequences(
  206. self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
  207. ) -> List[int]:
  208. """
  209. Creates a mask from the two sequences passed to be used in a
  210. sequence-pair classification task. An ALBERT
  211. sequence pair mask has the following format:
  212. ```
  213. 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
  214. | first sequence | second sequence |
  215. ```
  216. if token_ids_1 is None, only returns the first portion of the mask (0s).
  217. Args:
  218. token_ids_0 (`List[int]`):
  219. List of ids.
  220. token_ids_1 (`List[int]`, *optional*):
  221. Optional second list of IDs for sequence pairs.
  222. Returns:
  223. `List[int]`: List of [token type IDs](../glossary#token-type-ids)
  224. according to the given sequence(s).
  225. """
  226. bos_token_id = [self.bos_token_id] if self.add_bos_token else []
  227. eos_token_id = [self.eos_token_id] if self.add_eos_token else []
  228. output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
  229. if token_ids_1 is not None:
  230. output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
  231. return output