baichuan.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # yapf: disable
  2. # Adapted from
  3. # https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/8f6e343d545c503b91429582231d1d354dac2740/tokenization_baichuan.py
  4. # Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
  5. import os
  6. from shutil import copyfile
  7. from typing import Any, Dict, List, Optional, Tuple
  8. import sentencepiece as spm
  9. from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
  10. from transformers.utils import logging
  11. logger = logging.get_logger(__name__)
  12. VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
  13. PRETRAINED_VOCAB_FILES_MAP = {
  14. "vocab_file": {},
  15. "tokenizer_file": {},
  16. }
  17. PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
  18. class BaichuanTokenizer(PreTrainedTokenizer):
  19. """
  20. Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding.
  21. Args:
  22. vocab_file (`str`):
  23. Path to the vocabulary file.
  24. """
  25. vocab_files_names = VOCAB_FILES_NAMES
  26. pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
  27. max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
  28. model_input_names = ["input_ids", "attention_mask"]
  29. def __init__(
  30. self,
  31. vocab_file,
  32. unk_token="<unk>",
  33. bos_token="<s>",
  34. eos_token="</s>",
  35. pad_token=None,
  36. sp_model_kwargs: Optional[Dict[str, Any]] = None,
  37. add_bos_token=True,
  38. add_eos_token=False,
  39. clean_up_tokenization_spaces=False,
  40. **kwargs,
  41. ):
  42. self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
  43. bos_token = (
  44. AddedToken(bos_token, lstrip=False, rstrip=False)
  45. if isinstance(bos_token, str)
  46. else bos_token
  47. )
  48. eos_token = (
  49. AddedToken(eos_token, lstrip=False, rstrip=False)
  50. if isinstance(eos_token, str)
  51. else eos_token
  52. )
  53. unk_token = (
  54. AddedToken(unk_token, lstrip=False, rstrip=False)
  55. if isinstance(unk_token, str)
  56. else unk_token
  57. )
  58. pad_token = (
  59. AddedToken(pad_token, lstrip=False, rstrip=False)
  60. if isinstance(pad_token, str)
  61. else pad_token
  62. )
  63. self.vocab_file = vocab_file
  64. self.add_bos_token = add_bos_token
  65. self.add_eos_token = add_eos_token
  66. self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
  67. self.sp_model.Load(vocab_file)
  68. super().__init__(
  69. bos_token=bos_token,
  70. eos_token=eos_token,
  71. unk_token=unk_token,
  72. pad_token=pad_token,
  73. add_bos_token=add_bos_token,
  74. add_eos_token=add_eos_token,
  75. sp_model_kwargs=self.sp_model_kwargs,
  76. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  77. **kwargs,
  78. )
  79. def __getstate__(self):
  80. state = self.__dict__.copy()
  81. state["sp_model"] = None
  82. return state
  83. def __setstate__(self, d):
  84. self.__dict__ = d
  85. self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
  86. self.sp_model.Load(self.vocab_file)
  87. @property
  88. def vocab_size(self):
  89. """Returns vocab size"""
  90. return self.sp_model.get_piece_size()
  91. def get_vocab(self):
  92. """Returns vocab as a dict"""
  93. vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
  94. vocab.update(self.added_tokens_encoder)
  95. return vocab
  96. def _tokenize(self, text):
  97. """Returns a tokenized string."""
  98. return self.sp_model.encode(text, out_type=str)
  99. def _convert_token_to_id(self, token):
  100. """Converts a token (str) in an id using the vocab."""
  101. return self.sp_model.piece_to_id(token)
  102. def _convert_id_to_token(self, index):
  103. """Converts an index (integer) in a token (str) using the vocab."""
  104. token = self.sp_model.IdToPiece(index)
  105. return token
  106. def convert_tokens_to_string(self, tokens):
  107. """Converts a sequence of tokens (string) in a single string."""
  108. current_sub_tokens = []
  109. out_string = ""
  110. prev_is_special = False
  111. for i, token in enumerate(tokens):
  112. # make sure that special tokens are not decoded using sentencepiece model
  113. if token in self.all_special_tokens:
  114. if not prev_is_special and i != 0:
  115. out_string += " "
  116. out_string += self.sp_model.decode(current_sub_tokens) + token
  117. prev_is_special = True
  118. current_sub_tokens = []
  119. else:
  120. current_sub_tokens.append(token)
  121. prev_is_special = False
  122. out_string += self.sp_model.decode(current_sub_tokens)
  123. return out_string
  124. def save_vocabulary(
  125. self, save_directory, filename_prefix: Optional[str] = None
  126. ) -> Tuple[str]:
  127. """
  128. Save the vocabulary and special tokens file to a directory.
  129. Args:
  130. save_directory (`str`):
  131. The directory in which to save the vocabulary.
  132. Returns:
  133. `Tuple(str)`: Paths to the files saved.
  134. """
  135. if not os.path.isdir(save_directory):
  136. logger.error(f"Vocabulary path ({save_directory}) should be a directory")
  137. return
  138. out_vocab_file = os.path.join(
  139. save_directory,
  140. (filename_prefix + "-" if filename_prefix else "")
  141. + VOCAB_FILES_NAMES["vocab_file"],
  142. )
  143. if os.path.abspath(self.vocab_file) != os.path.abspath(
  144. out_vocab_file
  145. ) and os.path.isfile(self.vocab_file):
  146. copyfile(self.vocab_file, out_vocab_file)
  147. elif not os.path.isfile(self.vocab_file):
  148. with open(out_vocab_file, "wb") as fi:
  149. content_spiece_model = self.sp_model.serialized_model_proto()
  150. fi.write(content_spiece_model)
  151. return (out_vocab_file,)
  152. def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
  153. bos_token_id = [self.bos_token_id] if self.add_bos_token else []
  154. eos_token_id = [self.eos_token_id] if self.add_eos_token else []
  155. output = bos_token_id + token_ids_0 + eos_token_id
  156. if token_ids_1 is not None:
  157. output = output + bos_token_id + token_ids_1 + eos_token_id
  158. return output
  159. def get_special_tokens_mask(
  160. self,
  161. token_ids_0: List[int],
  162. token_ids_1: Optional[List[int]] = None,
  163. already_has_special_tokens: bool = False,
  164. ) -> List[int]:
  165. """
  166. Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
  167. special tokens using the tokenizer `prepare_for_model` method.
  168. Args:
  169. token_ids_0 (`List[int]`):
  170. List of IDs.
  171. token_ids_1 (`List[int]`, *optional*):
  172. Optional second list of IDs for sequence pairs.
  173. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  174. Whether or not the token list is already formatted with special tokens for the model.
  175. Returns:
  176. `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  177. """
  178. if already_has_special_tokens:
  179. return super().get_special_tokens_mask(
  180. token_ids_0=token_ids_0,
  181. token_ids_1=token_ids_1,
  182. already_has_special_tokens=True,
  183. )
  184. bos_token_id = [1] if self.add_bos_token else []
  185. eos_token_id = [1] if self.add_eos_token else []
  186. if token_ids_1 is None:
  187. return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
  188. return (
  189. bos_token_id
  190. + ([0] * len(token_ids_0))
  191. + eos_token_id
  192. + bos_token_id
  193. + ([0] * len(token_ids_1))
  194. + eos_token_id
  195. )
  196. def create_token_type_ids_from_sequences(
  197. self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
  198. ) -> List[int]:
  199. """
  200. Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
  201. sequence pair mask has the following format:
  202. ```
  203. 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
  204. | first sequence | second sequence |
  205. ```
  206. if token_ids_1 is None, only returns the first portion of the mask (0s).
  207. Args:
  208. token_ids_0 (`List[int]`):
  209. List of ids.
  210. token_ids_1 (`List[int]`, *optional*):
  211. Optional second list of IDs for sequence pairs.
  212. Returns:
  213. `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
  214. """
  215. bos_token_id = [self.bos_token_id] if self.add_bos_token else []
  216. eos_token_id = [self.eos_token_id] if self.add_eos_token else []
  217. output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
  218. if token_ids_1 is not None:
  219. output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
  220. return output