1
0

convert.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import torch, argparse, copy
  2. from transformers import AutoModelForCausalLM, AutoTokenizer
  3. from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
  4. from marlin import Layer as MarlinLayer
  5. import gc
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument("--model-id", type=str)
  8. parser.add_argument("--save-path", type=str)
  9. parser.add_argument("--do-generation", action="store_true")
  10. def _validate_compatibility(model):
  11. if not hasattr(model.config, "quantization_config"):
  12. raise ValueError("Must be a quantized model to convert to Marlin Format")
  13. quantization_config = model.config.quantization_config
  14. if quantization_config.quant_method != "gptq":
  15. raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}")
  16. if quantization_config.bits != 4:
  17. raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}")
  18. if quantization_config.group_size != 128:
  19. raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}")
  20. if not quantization_config.sym:
  21. raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}")
  22. if quantization_config.desc_act:
  23. raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}")
  24. @torch.no_grad()
  25. def unpack_4bit_to_32bit_signed(qweight, qzeros):
  26. # Unpack 4-bit values and interpret them as signed integers
  27. unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False)
  28. unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False)
  29. for row in range(unpacked_weights.shape[0]):
  30. i = row % 8
  31. unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF
  32. for col in range(unpacked_zeros.shape[1]):
  33. i = col % 8
  34. unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
  35. return unpacked_weights, unpacked_zeros + 1
  36. @torch.no_grad()
  37. def dequantize_weight(layer):
  38. qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
  39. unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
  40. group_size = unpacked_qweight.shape[0] // scales.shape[0]
  41. scales = scales.repeat_interleave(group_size, dim=0)
  42. unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
  43. unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
  44. return unpacked_qweight.T
  45. @torch.no_grad()
  46. def convert_model(model, verbose=True):
  47. for name, module in model.named_modules():
  48. if not isinstance(module, QuantLinear):
  49. continue
  50. if verbose:
  51. print(f"--- Converting Module: {name}")
  52. parent_name = ".".join(name.split(".")[:-1])
  53. layer_name = name[len(parent_name) + 1:]
  54. # Dequantize the weight.
  55. dequantized_weight = dequantize_weight(module).to(torch.float16)
  56. linear_module = torch.nn.Linear(
  57. in_features=dequantized_weight.shape[1],
  58. out_features=dequantized_weight.shape[0],
  59. bias=False,
  60. dtype=torch.float16,
  61. device="cuda")
  62. linear_module.weight.data.copy_(dequantized_weight)
  63. # Create new linear method and copy to model.
  64. new_module = MarlinLayer(
  65. infeatures=linear_module.in_features,
  66. outfeatures=linear_module.out_features,
  67. groupsize=model.config.quantization_config.group_size)
  68. new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t()))
  69. # Save to parent.
  70. parent_module = model.get_submodule(parent_name)
  71. setattr(parent_module, layer_name, new_module)
  72. # Free cuda memory.
  73. del dequantized_weight, module
  74. torch.cuda.empty_cache()
  75. gc.collect()
  76. return model
  77. @torch.no_grad()
  78. def dequantize_model(model, verbose=True):
  79. for name, module in model.named_modules():
  80. if not isinstance(module, QuantLinear):
  81. continue
  82. if verbose:
  83. print(f"--- Dequantizing Module: {name}")
  84. parent_name = ".".join(name.split(".")[:-1])
  85. layer_name = name[len(parent_name) + 1:]
  86. # Dequantize the weight.
  87. dequantized_weight = dequantize_weight(module)
  88. dequantized_weight_cpu = dequantized_weight.to("cpu")
  89. # Create new linear method and copy to model.
  90. new_module = torch.nn.Linear(
  91. in_features=dequantized_weight_cpu.shape[1],
  92. out_features=dequantized_weight_cpu.shape[0],
  93. bias=False,
  94. dtype=torch.float16)
  95. new_module.weight.data.copy_(dequantized_weight_cpu)
  96. new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data))
  97. # Save to parent.
  98. parent_module = model.get_submodule(parent_name)
  99. setattr(parent_module, layer_name, new_module)
  100. # Free cuda memory.
  101. del dequantized_weight, dequantized_weight_cpu, module
  102. torch.cuda.empty_cache()
  103. return model
  104. if __name__ == "__main__":
  105. args = parser.parse_args()
  106. model_id = args.model_id
  107. save_path = args.save_path
  108. do_generation = args.do_generation
  109. print("Loading gptq model...")
  110. model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
  111. tokenizer = AutoTokenizer.from_pretrained(model_id)
  112. # Validate that this model is compatible with Marlin.
  113. print("Validating compatibility...")
  114. _validate_compatibility(model)
  115. # Dequantize the Model.
  116. print("Converting model...")
  117. model = convert_model(model).to("cpu")
  118. # Save after updating quantization config.
  119. print("Saving marlin model...")
  120. model.config.quantization_config = {
  121. "group_size": model.config.quantization_config.group_size,
  122. "quant_method": "marlin"
  123. }
  124. model.save_pretrained(save_path)
  125. tokenizer.save_pretrained(save_path)
  126. if do_generation:
  127. print("Generating sample text...")
  128. model.to("cuda")
  129. prompt = "My favorite song is"
  130. inputs = tokenizer(prompt, return_tensors="pt")
  131. inputs = {k: v.to("cuda") for k, v in inputs.items()}
  132. outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)
  133. print(tokenizer.batch_decode(outputs)[0])