convert.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. import argparse
  2. import copy
  3. import gc
  4. import torch
  5. from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
  6. from marlin import Layer as MarlinLayer
  7. from transformers import AutoModelForCausalLM, AutoTokenizer
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument("--model-id", type=str)
  10. parser.add_argument("--save-path", type=str)
  11. parser.add_argument("--do-generation", action="store_true")
  12. def _validate_compatibility(model):
  13. if not hasattr(model.config, "quantization_config"):
  14. raise ValueError(
  15. "Must be a quantized model to convert to Marlin Format")
  16. quantization_config = model.config.quantization_config
  17. if quantization_config.quant_method != "gptq":
  18. raise ValueError(
  19. "Only GPTQ models can be converted to Marlin format. You passed a "
  20. f"model with quant_method={quantization_config.quant_method}")
  21. if quantization_config.bits != 4:
  22. raise ValueError(
  23. "Only 4 bit quantized models can be converted to Marlin format. "
  24. f"You passed a model with bits={quantization_config.bits}")
  25. if quantization_config.group_size != 128:
  26. raise ValueError(
  27. "Only group size 128 models can be converted to Marlin format. You "
  28. f"passed a model with group_size={quantization_config.group_size}")
  29. if not quantization_config.sym:
  30. raise ValueError(
  31. "Only models with symmetric quantization can be converted to "
  32. "Marlin Format. You passed a model with sym="
  33. f"{quantization_config.sym}")
  34. if quantization_config.desc_act:
  35. raise ValueError(
  36. "Models with act order quantization cannot be converted to "
  37. "Marlin Format. You passed a model with desc_act="
  38. f"{quantization_config.desc_act}")
  39. @torch.no_grad()
  40. def unpack_4bit_to_32bit_signed(qweight, qzeros):
  41. # Unpack 4-bit values and interpret them as signed integers
  42. unpacked_weights = torch.zeros((qweight.shape[0] * 8, qweight.shape[1]),
  43. dtype=torch.int8,
  44. device=qweight.device,
  45. requires_grad=False)
  46. unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1] * 8),
  47. dtype=torch.int8,
  48. device=qzeros.device,
  49. requires_grad=False)
  50. for row in range(unpacked_weights.shape[0]):
  51. i = row % 8
  52. unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF
  53. for col in range(unpacked_zeros.shape[1]):
  54. i = col % 8
  55. unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
  56. return unpacked_weights, unpacked_zeros + 1
  57. @torch.no_grad()
  58. def dequantize_weight(layer):
  59. qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
  60. unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(
  61. qweight, qzeros)
  62. group_size = unpacked_qweight.shape[0] // scales.shape[0]
  63. scales = scales.repeat_interleave(group_size, dim=0)
  64. unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
  65. unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
  66. return unpacked_qweight.T
  67. @torch.no_grad()
  68. def convert_model(model, verbose=True):
  69. for name, module in model.named_modules():
  70. if not isinstance(module, QuantLinear):
  71. continue
  72. if verbose:
  73. print(f"--- Converting Module: {name}")
  74. parent_name = ".".join(name.split(".")[:-1])
  75. layer_name = name[len(parent_name) + 1:]
  76. # Dequantize the weight.
  77. dequantized_weight = dequantize_weight(module).to(torch.float16)
  78. linear_module = torch.nn.Linear(
  79. in_features=dequantized_weight.shape[1],
  80. out_features=dequantized_weight.shape[0],
  81. bias=False,
  82. dtype=torch.float16,
  83. device="cuda")
  84. linear_module.weight.data.copy_(dequantized_weight)
  85. # Create new linear method and copy to model.
  86. new_module = MarlinLayer(
  87. infeatures=linear_module.in_features,
  88. outfeatures=linear_module.out_features,
  89. groupsize=model.config.quantization_config.group_size)
  90. new_module.pack(linear_module,
  91. scales=copy.deepcopy(module.scales.data.t()))
  92. # Save to parent.
  93. parent_module = model.get_submodule(parent_name)
  94. setattr(parent_module, layer_name, new_module)
  95. # Free cuda memory.
  96. del dequantized_weight, module
  97. torch.cuda.empty_cache()
  98. gc.collect()
  99. return model
  100. @torch.no_grad()
  101. def dequantize_model(model, verbose=True):
  102. for name, module in model.named_modules():
  103. if not isinstance(module, QuantLinear):
  104. continue
  105. if verbose:
  106. print(f"--- Dequantizing Module: {name}")
  107. parent_name = ".".join(name.split(".")[:-1])
  108. layer_name = name[len(parent_name) + 1:]
  109. # Dequantize the weight.
  110. dequantized_weight = dequantize_weight(module)
  111. dequantized_weight_cpu = dequantized_weight.to("cpu")
  112. # Create new linear method and copy to model.
  113. new_module = torch.nn.Linear(
  114. in_features=dequantized_weight_cpu.shape[1],
  115. out_features=dequantized_weight_cpu.shape[0],
  116. bias=False,
  117. dtype=torch.float16)
  118. new_module.weight.data.copy_(dequantized_weight_cpu)
  119. new_module.scales = torch.nn.Parameter(
  120. copy.deepcopy(module.scales.data))
  121. # Save to parent.
  122. parent_module = model.get_submodule(parent_name)
  123. setattr(parent_module, layer_name, new_module)
  124. # Free cuda memory.
  125. del dequantized_weight, dequantized_weight_cpu, module
  126. torch.cuda.empty_cache()
  127. return model
  128. if __name__ == "__main__":
  129. args = parser.parse_args()
  130. model_id = args.model_id
  131. save_path = args.save_path
  132. do_generation = args.do_generation
  133. print("Loading gptq model...")
  134. model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
  135. tokenizer = AutoTokenizer.from_pretrained(model_id)
  136. # Validate that this model is compatible with Marlin.
  137. print("Validating compatibility...")
  138. _validate_compatibility(model)
  139. # Dequantize the Model.
  140. print("Converting model...")
  141. model = convert_model(model).to("cpu")
  142. # Save after updating quantization config.
  143. print("Saving marlin model...")
  144. model.config.quantization_config = {
  145. "group_size": model.config.quantization_config.group_size,
  146. "quant_method": "marlin"
  147. }
  148. model.save_pretrained(save_path)
  149. tokenizer.save_pretrained(save_path)
  150. if do_generation:
  151. print("Generating sample text...")
  152. model.to("cuda")
  153. prompt = "My favorite song is"
  154. inputs = tokenizer(prompt, return_tensors="pt")
  155. inputs = {k: v.to("cuda") for k, v in inputs.items()}
  156. outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)
  157. print(tokenizer.batch_decode(outputs)[0])