gguf_to_torch.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import json
  2. import os
  3. import torch
  4. import gguf
  5. from sentencepiece import sentencepiece_model_pb2
  6. def convert_to_state_dict(checkpoint, save_dir):
  7. if not os.path.exists(save_dir):
  8. os.makedirs(save_dir)
  9. state_dict = {}
  10. result = gguf.GGUFReader(checkpoint)
  11. architecture = result.fields['general.architecture']
  12. architecture = str(bytes(architecture.parts[architecture.data[0]]), encoding = 'utf-8')
  13. if architecture != "llama":
  14. print(f"Unsupported architecture {architecture}")
  15. return
  16. # write vocab
  17. vocab = sentencepiece_model_pb2.ModelProto()
  18. vocab_size = len(result.fields['tokenizer.ggml.token_type'].data)
  19. vocab.trainer_spec.model_type = 2 # BPE
  20. vocab.trainer_spec.vocab_size = vocab_size
  21. vocab.trainer_spec.byte_fallback = True
  22. vocab.normalizer_spec.remove_extra_whitespaces = False
  23. tokens = result.fields['tokenizer.ggml.tokens']
  24. scores = result.fields['tokenizer.ggml.scores']
  25. types = result.fields['tokenizer.ggml.token_type']
  26. for i in range(vocab_size):
  27. new_token = vocab.SentencePiece()
  28. new_token.piece = str(bytes(tokens.parts[tokens.data[i]]), encoding = 'utf-8')
  29. new_token.score = scores.parts[scores.data[i]]
  30. # llama.cpp tokentype is the same with sentencepiece token type
  31. new_token.type = int(types.parts[types.data[i]])
  32. vocab.pieces.append(new_token)
  33. with open(os.path.join(save_dir, "tokenizer.model"), 'wb') as f:
  34. f.write(vocab.SerializeToString())
  35. tokenizer_config = {
  36. "tokenizer_class": "LlamaTokenizer",
  37. "legacy": False,
  38. "clean_up_tokenization_spaces": False,
  39. }
  40. if 'tokenizer.ggml.bos_token_id' in result.fields:
  41. tokenizer_config["bos_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.bos_token_id'].parts[-1])].piece
  42. if 'tokenizer.ggml.eos_token_id' in result.fields:
  43. tokenizer_config["eos_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.eos_token_id'].parts[-1])].piece
  44. if 'tokenizer.ggml.padding_token_id' in result.fields:
  45. tokenizer_config["pad_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.padding_token_id'].parts[-1])].piece
  46. if 'tokenizer.ggml.unknown_token_id' in result.fields:
  47. tokenizer_config["unk_token"] = vocab.pieces[int(result.fields['tokenizer.ggml.unknown_token_id'].parts[-1])].piece
  48. if 'tokenizer.ggml.add_bos_token' in result.fields:
  49. tokenizer_config["add_bos_token"] = bool(result.fields['tokenizer.ggml.add_bos_token'].parts[-1])
  50. if 'tokenizer.ggml.add_eos_token' in result.fields:
  51. tokenizer_config["add_eos_token"] = bool(result.fields['tokenizer.ggml.add_eos_token'].parts[-1])
  52. if 'tokenizer.chat_template' in result.fields:
  53. tokenizer_config["chat_template"] = str(bytes(result.fields['tokenizer.chat_template'].parts[-1]), encoding="utf-8")
  54. json.dump(tokenizer_config, open(os.path.join(save_dir, "tokenizer_config.json"), 'w'), indent=2)
  55. # write config
  56. context_length = int(result.fields['llama.context_length'].parts[-1])
  57. n_layer = int(result.fields['llama.block_count'].parts[-1])
  58. n_head = int(result.fields['llama.attention.head_count'].parts[-1])
  59. n_local_heads = int(result.fields['llama.attention.head_count_kv'].parts[-1])
  60. intermediate_size = int(result.fields['llama.feed_forward_length'].parts[-1])
  61. norm_eps = float(result.fields['llama.attention.layer_norm_rms_epsilon'].parts[-1])
  62. dim = int(result.fields['llama.embedding_length'].parts[-1])
  63. kv_dim = dim // n_head * n_local_heads
  64. arch = "MixtralForCausalLM"
  65. if 'llama.expert_count' in result.fields:
  66. arch = "MixtralForCausalLM"
  67. name = "mixtral"
  68. else:
  69. arch = "LlamaForCausalLM"
  70. name = "llama"
  71. model_config= {
  72. "architectures": [arch],
  73. "bos_token_id": 1,
  74. "eos_token_id": 2,
  75. "hidden_act": "silu",
  76. "hidden_size": dim,
  77. "intermediate_size": intermediate_size,
  78. "max_position_embeddings": context_length,
  79. "model_type": name,
  80. "num_attention_heads": n_head,
  81. "num_hidden_layers": n_layer,
  82. "num_key_value_heads": n_local_heads,
  83. "rms_norm_eps": norm_eps,
  84. "torch_dtype": "float16",
  85. "vocab_size": vocab_size
  86. }
  87. if 'llama.rope.freq_base' in result.fields:
  88. model_config['rope_theta'] = float(result.fields['llama.rope.freq_base'].parts[-1])
  89. if 'llama.expert_count' in result.fields:
  90. model_config['num_local_experts'] = int(result.fields['llama.expert_count'].parts[-1])
  91. model_config['num_experts_per_tok'] = int(result.fields['llama.expert_used_count'].parts[-1])
  92. json.dump(model_config, open(os.path.join(save_dir, "config.json"), 'w'), indent=2)
  93. # write tensor
  94. tensor_mapping = {
  95. "token_embd": ("model.embed_tokens", vocab_size),
  96. "output": ("lm_head", vocab_size),
  97. "output_norm": ("model.norm", -1),
  98. "blk.{bid}.attn_norm": ("model.layers.{bid}.input_layernorm", -1),
  99. "blk.{bid}.attn_q": ("model.layers.{bid}.self_attn.q_proj", dim),
  100. "blk.{bid}.attn_k": ("model.layers.{bid}.self_attn.k_proj", kv_dim),
  101. "blk.{bid}.attn_v": ("model.layers.{bid}.self_attn.v_proj", kv_dim),
  102. "blk.{bid}.attn_output": ("model.layers.{bid}.self_attn.o_proj", dim),
  103. "blk.{bid}.attn_rot_embd": ("model.layers.{bid}.self_attn.rotary_emb.inv_freq", -1),
  104. "blk.{bid}.ffn_norm": ("model.layers.{bid}.post_attention_layernorm", -1),
  105. "blk.{bid}.ffn_up": ("model.layers.{bid}.mlp.up_proj", intermediate_size),
  106. "blk.{bid}.ffn_down": ("model.layers.{bid}.mlp.down_proj", dim),
  107. "blk.{bid}.ffn_gate": ("model.layers.{bid}.mlp.gate_proj", intermediate_size),
  108. "blk.{bid}.ffn_up.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w3", intermediate_size),
  109. "blk.{bid}.ffn_down.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w2", dim),
  110. "blk.{bid}.ffn_gate.{xid}": ("model.layers.{bid}.block_sparse_moe.experts.{xid}.w1", intermediate_size),
  111. "blk.{bid}.ffn_gate_inp": ("model.layers.{bid}.block_sparse_moe.gate", model_config.get('num_local_experts', 1)),
  112. }
  113. mapping = {}
  114. max_block_num = 200
  115. max_expert_num = 8
  116. for k, v in tensor_mapping.items():
  117. for i in range(max_block_num):
  118. for j in range(max_expert_num):
  119. fk = k.format(bid=i, xid=j)
  120. fv = v[0].format(bid=i, xid=j)
  121. if k not in mapping:
  122. mapping[fk] = (fv, v[1])
  123. for ts in result.tensors:
  124. weight_type = torch.tensor(int(ts.tensor_type), dtype=torch.int)
  125. layer, suffix = ts.name.rsplit(".", 1)
  126. new_key, output_dim = mapping[layer]
  127. new_key += f".{suffix}"
  128. data = torch.tensor(ts.data)
  129. if output_dim != -1:
  130. data = data.view(output_dim, -1)
  131. if weight_type > 1:
  132. state_dict[new_key.replace("weight", "weight_type")] = weight_type
  133. state_dict[new_key] = data
  134. torch.save(state_dict, os.path.join(save_dir, "pytorch_model.bin"))
  135. if __name__ == '__main__':
  136. import argparse
  137. parser = argparse.ArgumentParser(description='Convert GGUF checkpoints to torch')
  138. parser.add_argument('--input', type=str, help='The path to GGUF file')
  139. parser.add_argument('--output', type=str, help='The path to output directory')
  140. args = parser.parse_args()
  141. convert_to_state_dict(args.input, args.output)