gguf_to_torch.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import json
  2. import os
  3. import torch
  4. import gguf
  5. from sentencepiece import sentencepiece_model_pb2
  6. from safetensors.torch import save_file as safe_save_file
  7. from transformers.modeling_utils import shard_checkpoint
  8. from transformers.utils import (WEIGHTS_NAME, WEIGHTS_INDEX_NAME,
  9. SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME)
  10. def convert_to_state_dict(checkpoint, save_dir, max_shard_size,
  11. safe_serialization):
  12. if not os.path.exists(save_dir):
  13. os.makedirs(save_dir)
  14. state_dict = {}
  15. result = gguf.GGUFReader(checkpoint)
  16. architecture = result.fields['general.architecture']
  17. architecture = str(bytes(architecture.parts[architecture.data[0]]),
  18. encoding='utf-8')
  19. if architecture != "llama":
  20. print(f"Unsupported architecture {architecture}")
  21. return
  22. # write vocab
  23. vocab = sentencepiece_model_pb2.ModelProto()
  24. vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
  25. vocab.trainer_spec.model_type = 2 # BPE
  26. vocab.trainer_spec.vocab_size = vocab_size
  27. vocab.trainer_spec.byte_fallback = True
  28. vocab.normalizer_spec.remove_extra_whitespaces = False
  29. tokens = result.fields['tokenizer.ggml.tokens']
  30. scores = result.fields['tokenizer.ggml.scores']
  31. types = result.fields['tokenizer.ggml.token_type']
  32. for i in range(vocab_size):
  33. new_token = vocab.SentencePiece()
  34. new_token.piece = str(bytes(tokens.parts[tokens.data[i]]),
  35. encoding='utf-8')
  36. new_token.score = scores.parts[scores.data[i]]
  37. # llama.cpp tokentype is the same with sentencepiece token type
  38. new_token.type = int(types.parts[types.data[i]])
  39. vocab.pieces.append(new_token)
  40. with open(os.path.join(save_dir, "tokenizer.model"), 'wb') as f:
  41. f.write(vocab.SerializeToString())
  42. tokenizer_config = {
  43. "tokenizer_class": "LlamaTokenizer",
  44. "legacy": False,
  45. "clean_up_tokenization_spaces": False,
  46. }
  47. if 'tokenizer.ggml.bos_token_id' in result.fields:
  48. tokenizer_config["bos_token"] = vocab.pieces[int(
  49. result.fields['tokenizer.ggml.bos_token_id'].parts[-1])].piece
  50. if 'tokenizer.ggml.eos_token_id' in result.fields:
  51. tokenizer_config["eos_token"] = vocab.pieces[int(
  52. result.fields['tokenizer.ggml.eos_token_id'].parts[-1])].piece
  53. if 'tokenizer.ggml.padding_token_id' in result.fields:
  54. tokenizer_config["pad_token"] = vocab.pieces[int(
  55. result.fields['tokenizer.ggml.padding_token_id'].parts[-1])].piece
  56. if 'tokenizer.ggml.unknown_token_id' in result.fields:
  57. tokenizer_config["unk_token"] = vocab.pieces[int(
  58. result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])].piece
  59. if 'tokenizer.ggml.add_bos_token' in result.fields:
  60. tokenizer_config["add_bos_token"] = bool(
  61. result.fields['tokenizer.ggml.add_bos_token'].parts[-1])
  62. if 'tokenizer.ggml.add_eos_token' in result.fields:
  63. tokenizer_config["add_eos_token"] = bool(
  64. result.fields['tokenizer.ggml.add_eos_token'].parts[-1])
  65. if 'tokenizer.chat_template' in result.fields:
  66. tokenizer_config["chat_template"] = str(bytes(
  67. result.fields['tokenizer.chat_template'].parts[-1]),
  68. encoding="utf-8")
  69. with open(os.path.join(save_dir, "tokenizer_config.json"), 'w') as f:
  70. json.dump(tokenizer_config, f, indent=2)
  71. # write config
  72. context_length = int(result.fields['llama.context_length'].parts[-1])
  73. n_layer = int(result.fields['llama.block_count'].parts[-1])
  74. n_head = int(result.fields['llama.attention.head_count'].parts[-1])
  75. n_local_heads = int(
  76. result.fields['llama.attention.head_count_kv'].parts[-1])
  77. intermediate_size = int(
  78. result.fields['llama.feed_forward_length'].parts[-1])
  79. norm_eps = float(
  80. result.fields['llama.attention.layer_norm_rms_epsilon'].parts[-1])
  81. dim = int(result.fields['llama.embedding_length'].parts[-1])
  82. kv_dim = dim // n_head * n_local_heads
  83. arch = "MixtralForCausalLM"
  84. if 'llama.expert_count' in result.fields:
  85. arch = "MixtralForCausalLM"
  86. name = "mixtral"
  87. else:
  88. arch = "LlamaForCausalLM"
  89. name = "llama"
  90. model_config = {
  91. "architectures": [arch],
  92. "bos_token_id": 1,
  93. "eos_token_id": 2,
  94. "hidden_act": "silu",
  95. "hidden_size": dim,
  96. "intermediate_size": intermediate_size,
  97. "max_position_embeddings": context_length,
  98. "model_type": name,
  99. "num_attention_heads": n_head,
  100. "num_hidden_layers": n_layer,
  101. "num_key_value_heads": n_local_heads,
  102. "rms_norm_eps": norm_eps,
  103. "torch_dtype": "float16",
  104. "vocab_size": vocab_size
  105. }
  106. if 'llama.rope.freq_base' in result.fields:
  107. model_config['rope_theta'] = float(
  108. result.fields['llama.rope.freq_base'].parts[-1])
  109. if 'llama.expert_count' in result.fields:
  110. model_config['num_local_experts'] = int(
  111. result.fields['llama.expert_count'].parts[-1])
  112. model_config['num_experts_per_tok'] = int(
  113. result.fields['llama.expert_used_count'].parts[-1])
  114. with open(os.path.join(save_dir, "config.json"), 'w') as f:
  115. json.dump(model_config, f, indent=2)
  116. # write tensor
  117. tensor_mapping = {
  118. "token_embd": ("model.embed_tokens", vocab_size),
  119. "output": ("lm_head", vocab_size),
  120. "output_norm": ("model.norm", -1),
  121. "blk.{bid}.attn_norm": ("model.layers.{bid}.input_layernorm", -1),
  122. "blk.{bid}.attn_q": ("model.layers.{bid}.self_attn.q_proj", dim),
  123. "blk.{bid}.attn_k": ("model.layers.{bid}.self_attn.k_proj", kv_dim),
  124. "blk.{bid}.attn_v": ("model.layers.{bid}.self_attn.v_proj", kv_dim),
  125. "blk.{bid}.attn_output": ("model.layers.{bid}.self_attn.o_proj", dim),
  126. "blk.{bid}.attn_rot_embd":
  127. ("model.layers.{bid}.self_attn.rotary_emb.inv_freq", -1),
  128. "blk.{bid}.ffn_norm": ("model.layers.{bid}.post_attention_layernorm",
  129. -1),
  130. "blk.{bid}.ffn_up": ("model.layers.{bid}.mlp.up_proj",
  131. intermediate_size),
  132. "blk.{bid}.ffn_down": ("model.layers.{bid}.mlp.down_proj", dim),
  133. "blk.{bid}.ffn_gate": ("model.layers.{bid}.mlp.gate_proj",
  134. intermediate_size),
  135. "blk.{bid}.ffn_up.{xid}":
  136. ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w3",
  137. intermediate_size),
  138. "blk.{bid}.ffn_down.{xid}":
  139. ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", dim),
  140. "blk.{bid}.ffn_gate.{xid}":
  141. ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w1",
  142. intermediate_size),
  143. "blk.{bid}.ffn_gate_inp": ("model.layers.{bid}.block_sparse_moe.gate",
  144. model_config.get('num_local_experts', 1)),
  145. }
  146. mapping = {}
  147. max_block_num = 200
  148. max_expert_num = 8
  149. for k, v in tensor_mapping.items():
  150. for i in range(max_block_num):
  151. for j in range(max_expert_num):
  152. fk = k.format(bid=i, xid=j)
  153. fv = v[0].format(bid=i, xid=j)
  154. if k not in mapping:
  155. mapping[fk] = (fv, v[1])
  156. for ts in result.tensors:
  157. weight_type = torch.tensor(int(ts.tensor_type), dtype=torch.int)
  158. layer, suffix = ts.name.rsplit(".", 1)
  159. new_key, output_dim = mapping[layer]
  160. new_key += f".{suffix}"
  161. data = torch.tensor(ts.data)
  162. if output_dim != -1:
  163. data = data.view(output_dim, -1)
  164. if weight_type > 1:
  165. state_dict[new_key.replace("weight", "weight_type")] = weight_type
  166. state_dict[new_key] = data
  167. if max_shard_size == "0":
  168. if safe_serialization:
  169. safe_save_file(state_dict,
  170. os.path.join(save_dir, SAFE_WEIGHTS_NAME),
  171. metadata={"format": "pt"})
  172. else:
  173. torch.save(state_dict, os.path.join(save_dir, WEIGHTS_NAME))
  174. else:
  175. shards, index = shard_checkpoint(
  176. state_dict, max_shard_size,
  177. SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME)
  178. for shard_file, shard in shards.items():
  179. if safe_serialization:
  180. safe_save_file(shard,
  181. os.path.join(save_dir, shard_file),
  182. metadata={"format": "pt"})
  183. else:
  184. torch.save(shard, os.path.join(save_dir, shard_file))
  185. if index is not None:
  186. if safe_serialization:
  187. save_index_file = SAFE_WEIGHTS_INDEX_NAME
  188. else:
  189. save_index_file = WEIGHTS_INDEX_NAME
  190. save_index_file = os.path.join(save_dir, save_index_file)
  191. # Save the index as well
  192. with open(save_index_file, "w", encoding="utf-8") as f:
  193. content = json.dumps(index, indent=2, sort_keys=True) + "\n"
  194. f.write(content)
  195. if __name__ == '__main__':
  196. import argparse
  197. parser = argparse.ArgumentParser(
  198. description='Convert GGUF checkpoints to torch')
  199. parser.add_argument('--input', type=str, help='The path to GGUF file')
  200. parser.add_argument('--output',
  201. type=str,
  202. help='The path to output directory')
  203. parser.add_argument(
  204. '--max-shard-size',
  205. default="0",
  206. type=str,
  207. help='Shard the model in specified shard size, e.g. 10GB. 0 to disable'
  208. )
  209. parser.add_argument('--safetensors',
  210. action='store_true',
  211. help='Save in .safetensors format')
  212. args = parser.parse_args()
  213. convert_to_state_dict(args.input, args.output, args.max_shard_size,
  214. args.safetensors)