1
0

gguf_to_torch.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import json
  2. import os
  3. import torch
  4. import gguf
  5. from sentencepiece import sentencepiece_model_pb2
  6. from safetensors.torch import save_file as safe_save_file
  7. from transformers.modeling_utils import shard_checkpoint
  8. from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
  9. def convert_to_state_dict(checkpoint, save_dir, max_shard_size, safe_serialization):
  10. if not os.path.exists(save_dir):
  11. os.makedirs(save_dir)
  12. state_dict = {}
  13. result = gguf.GGUFReader(checkpoint)
  14. architecture = result.fields['general.architecture']
  15. architecture = str(bytes(architecture.parts[architecture.data[0]]), encoding = 'utf-8')
  16. if architecture != "llama":
  17. print(f"Unsupported architecture {architecture}")
  18. return
  19. # write vocab
  20. vocab = sentencepiece_model_pb2.ModelProto()
  21. vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
  22. vocab.trainer_spec.model_type = 2 # BPE
  23. vocab.trainer_spec.vocab_size = vocab_size
  24. vocab.trainer_spec.byte_fallback = True
  25. vocab.normalizer_spec.remove_extra_whitespaces = False
  26. tokens = result.fields['tokenizer.ggml.tokens']
  27. scores = result.fields['tokenizer.ggml.scores']
  28. types = result.fields['tokenizer.ggml.token_type']
  29. for i in range(vocab_size):
  30. new_token = vocab.SentencePiece()
  31. new_token.piece = str(bytes(tokens.parts[tokens.data[i]]), encoding = 'utf-8')
  32. new_token.score = scores.parts[scores.data[i]]
  33. # llama.cpp tokentype is the same with sentencepiece token type
  34. new_token.type = int(types.parts[types.data[i]])
  35. vocab.pieces.append(new_token)
  36. with open(os.path.join(save_dir, "tokenizer.model"), 'wb') as f:
  37. f.write(vocab.SerializeToString())
  38. tokenizer_config = {
  39. "tokenizer_class": "LlamaTokenizer",
  40. "legacy": False,
  41. "clean_up_tokenization_spaces": False,
  42. }
  43. if 'tokenizer.ggml.bos_token_id' in result.fields:
  44. tokenizer_config["bos_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.bos_token_id'].parts[-1])].piece
  45. if 'tokenizer.ggml.eos_token_id' in result.fields:
  46. tokenizer_config["eos_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.eos_token_id'].parts[-1])].piece
  47. if 'tokenizer.ggml.padding_token_id' in result.fields:
  48. tokenizer_config["pad_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.padding_token_id'].parts[-1])].piece
  49. if 'tokenizer.ggml.unknown_token_id' in result.fields:
  50. tokenizer_config["unk_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])].piece
  51. if 'tokenizer.ggml.add_bos_token' in result.fields:
  52. tokenizer_config["add_bos_token"] = bool(result.fields['tokenizer.ggml.add_bos_token'].parts[-1])
  53. if 'tokenizer.ggml.add_eos_token' in result.fields:
  54. tokenizer_config["add_eos_token"] = bool(result.fields['tokenizer.ggml.add_eos_token'].parts[-1])
  55. if 'tokenizer.chat_template' in result.fields:
  56. tokenizer_config["chat_template"] = str(bytes(result.fields['tokenizer.chat_template'].parts[-1]), encoding="utf-8")
  57. json.dump(tokenizer_config, open(os.path.join(save_dir, "tokenizer_config.json"), 'w'), indent=2)
  58. # write config
  59. context_length = int(result.fields['llama.context_length'].parts[-1])
  60. n_layer = int(result.fields['llama.block_count'].parts[-1])
  61. n_head = int(result.fields['llama.attention.head_count'].parts[-1])
  62. n_local_heads = int(result.fields['llama.attention.head_count_kv'].parts[-1])
  63. intermediate_size = int(result.fields['llama.feed_forward_length'].parts[-1])
  64. norm_eps = float(result.fields['llama.attention.layer_norm_rms_epsilon'].parts[-1])
  65. dim = int(result.fields['llama.embedding_length'].parts[-1])
  66. kv_dim = dim // n_head * n_local_heads
  67. arch = "MixtralForCausalLM"
  68. if 'llama.expert_count' in result.fields:
  69. arch = "MixtralForCausalLM"
  70. name = "mixtral"
  71. else:
  72. arch = "LlamaForCausalLM"
  73. name = "llama"
  74. model_config= {
  75. "architectures": [arch],
  76. "bos_token_id": 1,
  77. "eos_token_id": 2,
  78. "hidden_act": "silu",
  79. "hidden_size": dim,
  80. "intermediate_size": intermediate_size,
  81. "max_position_embeddings": context_length,
  82. "model_type": name,
  83. "num_attention_heads": n_head,
  84. "num_hidden_layers": n_layer,
  85. "num_key_value_heads": n_local_heads,
  86. "rms_norm_eps": norm_eps,
  87. "torch_dtype": "float16",
  88. "vocab_size": vocab_size
  89. }
  90. if 'llama.rope.freq_base' in result.fields:
  91. model_config['rope_theta'] = float(result.fields['llama.rope.freq_base'].parts[-1])
  92. if 'llama.expert_count' in result.fields:
  93. model_config['num_local_experts'] = int(result.fields['llama.expert_count'].parts[-1])
  94. model_config['num_experts_per_tok'] = int(result.fields['llama.expert_used_count'].parts[-1])
  95. json.dump(model_config, open(os.path.join(save_dir, "config.json"), 'w'), indent=2)
  96. # write tensor
  97. tensor_mapping = {
  98. "token_embd": ("model.embed_tokens", vocab_size),
  99. "output": ("lm_head", vocab_size),
  100. "output_norm": ("model.norm", -1),
  101. "blk.{bid}.attn_norm": ("model.layers.{bid}.input_layernorm", -1),
  102. "blk.{bid}.attn_q": ("model.layers.{bid}.self_attn.q_proj", dim),
  103. "blk.{bid}.attn_k": ("model.layers.{bid}.self_attn.k_proj", kv_dim),
  104. "blk.{bid}.attn_v": ("model.layers.{bid}.self_attn.v_proj", kv_dim),
  105. "blk.{bid}.attn_output": ("model.layers.{bid}.self_attn.o_proj", dim),
  106. "blk.{bid}.attn_rot_embd": ("model.layers.{bid}.self_attn.rotary_emb.inv_freq", -1),
  107. "blk.{bid}.ffn_norm": ("model.layers.{bid}.post_attention_layernorm", -1),
  108. "blk.{bid}.ffn_up": ("model.layers.{bid}.mlp.up_proj", intermediate_size),
  109. "blk.{bid}.ffn_down": ("model.layers.{bid}.mlp.down_proj", dim),
  110. "blk.{bid}.ffn_gate": ("model.layers.{bid}.mlp.gate_proj", intermediate_size),
  111. "blk.{bid}.ffn_up.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", intermediate_size),
  112. "blk.{bid}.ffn_down.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", dim),
  113. "blk.{bid}.ffn_gate.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", intermediate_size),
  114. "blk.{bid}.ffn_gate_inp": ("model.layers.{bid}.block_sparse_moe.gate", model_config.get('num_local_experts', 1)),
  115. }
  116. mapping = {}
  117. max_block_num = 200
  118. max_expert_num = 8
  119. for k, v in tensor_mapping.items():
  120. for i in range(max_block_num):
  121. for j in range(max_expert_num):
  122. fk = k.format(bid=i, xid=j)
  123. fv = v[0].format(bid=i, xid=j)
  124. if k not in mapping:
  125. mapping[fk] = (fv, v[1])
  126. for ts in result.tensors:
  127. weight_type = torch.tensor(int(ts.tensor_type), dtype=torch.int)
  128. layer, suffix = ts.name.rsplit(".", 1)
  129. new_key, output_dim = mapping[layer]
  130. new_key += f".{suffix}"
  131. data = torch.tensor(ts.data)
  132. if output_dim != -1:
  133. data = data.view(output_dim, -1)
  134. if weight_type > 1:
  135. state_dict[new_key.replace("weight", "weight_type")] = weight_type
  136. state_dict[new_key] = data
  137. if max_shard_size == "0":
  138. if safe_serialization:
  139. safe_save_file(state_dict, os.path.join(save_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
  140. else:
  141. torch.save(state_dict, os.path.join(save_dir, WEIGHTS_NAME))
  142. else:
  143. shards, index = shard_checkpoint(state_dict, max_shard_size, SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME)
  144. for shard_file, shard in shards.items():
  145. if safe_serialization:
  146. safe_save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
  147. else:
  148. torch.save(shard, os.path.join(save_dir, shard_file))
  149. if index is not None:
  150. save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
  151. save_index_file = os.path.join(save_dir, save_index_file)
  152. # Save the index as well
  153. with open(save_index_file, "w", encoding="utf-8") as f:
  154. content = json.dumps(index, indent=2, sort_keys=True) + "\n"
  155. f.write(content)
  156. if __name__ == '__main__':
  157. import argparse
  158. parser = argparse.ArgumentParser(description='Convert GGUF checkpoints to torch')
  159. parser.add_argument('--input', type=str, help='The path to GGUF file')
  160. parser.add_argument('--output', type=str, help='The path to output directory')
  161. parser.add_argument('--max-shard-size', default="0", type=str, help='Shard the model in specified shard size, e.g. 10GB. 0 to disable')
  162. parser.add_argument('--safetensors', action='store_true', help='Save in .safetensors format')
  163. args = parser.parse_args()
  164. convert_to_state_dict(args.input, args.output, args.max_shard_size, args.safetensors)