fp6.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. import torch.nn as nn
  4. from loguru import logger
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.distributed import get_tensor_model_parallel_rank
  7. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  8. from aphrodite.modeling.utils import set_weight_attrs
  9. from aphrodite.quantization.base_config import QuantizationConfig
  10. from aphrodite.quantization.utils.fp6_utils import (_SPLIT_K_MAP,
  11. from_scaled_tc_fpx,
  12. to_scaled_tc_fpx)
  13. class QuantLLMFPConfig(QuantizationConfig):
  14. """Config for QuantLLM FP quantizer. It supports fp2, fp3, fp4,
  15. fp5, fp6, fp7.
  16. Reference: https://arxiv.org/abs/2401.14112
  17. Args:
  18. weight_bits: the target quantization bits, should be one of
  19. 2, 3, 4, 5, 6, 7.
  20. """
  21. def __init__(
  22. self,
  23. weight_bits: int = 6,
  24. exp_bits: int = 2,
  25. ) -> None:
  26. self.weight_bits = weight_bits
  27. self.exponent_bits = exp_bits
  28. self.mantissa_bits = weight_bits - self.exponent_bits - 1
  29. self.valid_types = [torch.float16]
  30. if self.weight_bits not in [2, 3, 4, 5, 6, 7]:
  31. raise ValueError(
  32. "Currently, only 4-bit, 5-bit, 6-bit, and 7-bit "
  33. "quantization are "
  34. f"supported for QuantLLM FP quantizaiton, but got "
  35. f"{self.weight_bits} bits.")
  36. if get_tensor_model_parallel_rank() == 0:
  37. logger.info(f"Loading model in FP{self.weight_bits}_E"
  38. f"{self.exponent_bits}M{self.mantissa_bits} format.")
  39. def __repr__(self) -> str:
  40. return (f"QuantLLMFPConfig(weight_bits={self.weight_bits}), "
  41. f"exponent_bits={self.exponent_bits}")
  42. @classmethod
  43. def get_name(cls) -> str:
  44. return "QuantLLMFP"
  45. @classmethod
  46. def from_config(cls, config: Dict[str, Any]) -> "QuantLLMFPConfig":
  47. weight_bits = cls.get_from_keys(config, ["bits"])
  48. exp_bits = cls.get_from_keys(config, ["exp_bits"])
  49. return cls(weight_bits=weight_bits, exp_bits=exp_bits)
  50. def get_linear_method(self) -> "QuantLLMFPLinearMethod":
  51. return QuantLLMFPLinearMethod(self)
  52. def get_scaled_act_names(self) -> List[str]:
  53. return []
  54. @classmethod
  55. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  56. return [torch.half]
  57. @classmethod
  58. # Need to figure it out
  59. def get_min_capability(cls) -> int:
  60. return 80
  61. @staticmethod
  62. def get_config_filenames() -> List[str]:
  63. return [
  64. "quant_config.json",
  65. "quantize_config.json",
  66. ]
  67. def get_quant_method(
  68. self,
  69. layer: torch.nn.Module,
  70. prefix: str) -> Optional["QuantLLMFPLinearMethod"]:
  71. if isinstance(layer, LinearBase):
  72. return QuantLLMFPLinearMethod(self)
  73. return None
  74. class QuantLLMFPLinearMethod(LinearMethodBase):
  75. """Linear method for QuantLLMFP quantizer.
  76. Args:
  77. quant_config: the QuantLLMFP quantization config.
  78. """
  79. def __init__(self, quant_config: QuantLLMFPConfig):
  80. self.quant_config = quant_config
  81. self.weight = None
  82. def create_weights(self,
  83. layer: torch.nn.Module,
  84. input_size_per_partition: int,
  85. output_partition_sizes: List[int],
  86. input_size: int,
  87. output_size: int,
  88. params_dtype: torch.dtype,
  89. weight_loader=None,
  90. **extra_weight_attrs):
  91. del output_size
  92. del input_size
  93. output_size_per_partition = sum(output_partition_sizes)
  94. weight = QuantLLMFPParameter(
  95. torch.Size((output_size_per_partition, input_size_per_partition)),
  96. params_dtype=params_dtype,
  97. quant_config=self.quant_config,
  98. )
  99. set_weight_attrs(weight, {
  100. "input_dim": 1,
  101. "output_dim": 0,
  102. })
  103. layer.register_parameter("weight", weight)
  104. def quant_weight_loader(param, loaded_weight, *args, **kwargs):
  105. # Calls the original weight loader (if any), quantizes the result,
  106. # and then loads the quantized parameter.
  107. if weight_loader is not None:
  108. orig_param_data = param.data
  109. param.data = param.quant_llmdequantize()
  110. weight_loader(param, loaded_weight, *args, **kwargs)
  111. param.data, loaded_weight = orig_param_data, param.data
  112. param.quant_llmquantize_(loaded_weight.cuda())
  113. extra_weight_attrs["weight_loader"] = quant_weight_loader
  114. set_weight_attrs(weight, extra_weight_attrs)
  115. def apply(self,
  116. layer,
  117. x: torch.Tensor,
  118. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  119. weight = layer.weight
  120. weights = weight.data
  121. scales = weight.scales
  122. out_dim, in_dim = weights.shape
  123. bsize = x.shape[0]
  124. splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(
  125. out_dim, 1) if bsize <= 768 else 1
  126. if bias is None:
  127. return ops.fp_eXmY_linear_forward_cuda(
  128. self.quant_config.exponent_bits,
  129. self.quant_config.mantissa_bits,
  130. x, weights, scales, splitK)
  131. else:
  132. return ops.fp_eXmY_linear_forward_cuda(
  133. self.quant_config.exponent_bits,
  134. self.quant_config.mantissa_bits,
  135. x, weights, scales, splitK) + bias
  136. class QuantLLMFPParameter(nn.Parameter):
  137. """
  138. QuantLLMFP quantized parameter class that implements fp5/fp6/fp7
  139. quantization. Weights are stored in quantized form on
  140. GPUs, and can be directly applied to float16 activations.
  141. """
  142. def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
  143. quant_config: QuantLLMFPConfig):
  144. data = torch.empty(torch.Size((orig_shape[0],
  145. orig_shape[1] * quant_config.weight_bits // 8)),
  146. dtype=torch.uint8)
  147. self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
  148. self.scales = torch.empty(orig_shape[0],
  149. dtype=torch.float16)
  150. self.quant_config = quant_config
  151. self.orig_shape = orig_shape
  152. return self
  153. def quant_llmquantize_(self, tensor: torch.Tensor):
  154. assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
  155. data, scales = to_scaled_tc_fpx(
  156. tensor.data, self.quant_config.exponent_bits,
  157. self.quant_config.mantissa_bits)
  158. self.data.copy_(data)
  159. self.scales.copy_(scales)
  160. def quant_llmdequantize(self, output_dtype=None):
  161. output_dtype = output_dtype or torch.get_default_dtype()
  162. return from_scaled_tc_fpx(self.data, self.quant_config.exponent_bits,
  163. self.quant_config.mantissa_bits, self.scales
  164. ).to(output_dtype)