deepspeedfp.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. class DeepSpeedFPConfig(QuantizationConfig):
  9. """Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
  10. Args:
  11. weight_bits: the target quantization bits, 6 or 8.
  12. group_size: group size for quantizaiton, default to 128.
  13. """
  14. def __init__(
  15. self,
  16. weight_bits: int = 8,
  17. group_size: int = 512,
  18. ) -> None:
  19. self.weight_bits = weight_bits
  20. self.group_size = group_size
  21. self.valid_types = [torch.bfloat16, torch.float16]
  22. if self.weight_bits not in (4, 6, 8, 12):
  23. raise ValueError(
  24. "Currently, only 4-bit, 6-bit, 8-bit, and 12-bit weight"
  25. " quantization are "
  26. f"supported for DeepSpeed FP quantizaiton, but got "
  27. f"{self.weight_bits} bits.")
  28. def __repr__(self) -> str:
  29. return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), "
  30. f"group_size={self.group_size}")
  31. @classmethod
  32. def get_name(cls) -> str:
  33. return "DeepSpeedFP"
  34. @classmethod
  35. def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
  36. weight_bits = cls.get_from_keys(config, ["bits"])
  37. group_size = cls.get_from_keys(config, ["group_size"])
  38. return cls(weight_bits=weight_bits, group_size=group_size)
  39. def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
  40. return DeepSpeedFPLinearMethod(self)
  41. def get_scaled_act_names(self) -> List[str]:
  42. return []
  43. @classmethod
  44. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  45. return [torch.half, torch.bfloat16]
  46. @classmethod
  47. # Need to figure it out
  48. def get_min_capability(cls) -> int:
  49. return 60
  50. @staticmethod
  51. def get_config_filenames() -> List[str]:
  52. return [
  53. "quant_config.json",
  54. "quantize_config.json",
  55. ]
  56. def get_quant_method(self, layer: torch.nn.Module,
  57. prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
  58. if isinstance(layer, LinearBase):
  59. return DeepSpeedFPLinearMethod(self)
  60. return None
  61. class DeepSpeedFPLinearMethod(LinearMethodBase):
  62. """Linear method for DeepSpeedFP quantizer.
  63. Args:
  64. quant_config: the DeepSpeedFP quantization config.
  65. """
  66. def __init__(self, quant_config: DeepSpeedFPConfig):
  67. self.quant_config = quant_config
  68. self.weight = None
  69. def create_weights(self,
  70. layer: torch.nn.Module,
  71. input_size_per_partition: int,
  72. output_partition_sizes: List[int],
  73. input_size: int,
  74. output_size: int,
  75. params_dtype: torch.dtype,
  76. weight_loader=None,
  77. **extra_weight_attrs):
  78. del output_size
  79. del input_size
  80. output_size_per_partition = sum(output_partition_sizes)
  81. weight = DeepSpeedFPParameter(
  82. torch.Size((output_size_per_partition, input_size_per_partition)),
  83. params_dtype=params_dtype,
  84. quant_config=self.quant_config,
  85. )
  86. set_weight_attrs(weight, {
  87. "input_dim": 1,
  88. "output_dim": 0,
  89. })
  90. layer.register_parameter("weight", weight)
  91. def quant_weight_loader(param, loaded_weight, *args, **kwargs):
  92. # Calls the original weight loader (if any), quantizes the result,
  93. # and then loads the quantized parameter.
  94. if weight_loader is not None:
  95. orig_param_data = param.data
  96. param.data = param.ds_dequantize()
  97. weight_loader(param, loaded_weight, *args, **kwargs)
  98. param.data, loaded_weight = orig_param_data, param.data
  99. param.ds_quantize_(loaded_weight.cuda())
  100. extra_weight_attrs["weight_loader"] = quant_weight_loader
  101. set_weight_attrs(weight, extra_weight_attrs)
  102. def apply(self,
  103. layer: torch.nn.Module,
  104. x: torch.Tensor,
  105. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  106. weight = layer.weight
  107. y = weight.ds_dequantize()
  108. return F.linear(x, y, bias)
  109. class DeepSpeedFPParameter(nn.Parameter):
  110. """
  111. DeepSpeedFP quantized parameter class that implements fp8/fp6
  112. quantization deepspeed. Weights are stored in quantized form on
  113. GPUs, and can be dequantized on-the-fly when needed by the model.
  114. """
  115. def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
  116. quant_config: DeepSpeedFPConfig):
  117. try:
  118. import deepspeed
  119. if deepspeed.__version__ < "0.14.2":
  120. raise ImportError("deepspeed version is wrong. Please "
  121. "install deepspeed>=0.14.2.")
  122. from deepspeed.ops.fp_quantizer import FP_Quantize
  123. except ImportError as err:
  124. raise ImportError("Please install deepspeed>=0.14.2 via "
  125. "`pip install deepspeed>=0.14.2` to use "
  126. "deepspeedfp quantizer.") from err
  127. data = torch.empty((
  128. orig_shape.numel() // quant_config.group_size,
  129. quant_config.group_size * quant_config.weight_bits // 8 + 4,
  130. ),
  131. dtype=torch.int8)
  132. self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
  133. self.orig_shape = orig_shape
  134. self.quant_config = quant_config
  135. self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
  136. self.fp_quantizer.orig_shape = orig_shape
  137. self.fp_quantizer.orig_dtype = params_dtype
  138. return self
  139. def ds_quantize_(self, tensor: torch.Tensor):
  140. assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
  141. return self.data.copy_(
  142. self.fp_quantizer.quantize(
  143. tensor.data,
  144. q_bits=self.quant_config.weight_bits,
  145. ))
  146. def ds_dequantize(self, fp_out=None) -> torch.Tensor:
  147. """
  148. Return a tensor containing the dequantized weights of this parameter.
  149. """
  150. assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
  151. return self.fp_quantizer.dequantize(
  152. self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)
  153. def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
  154. """
  155. Return a tensor where only the weights at `indices` are dequantized
  156. (to save HBM -> SRAM bandwidth).
  157. """
  158. assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
  159. return self.fp_quantizer.selective_dequantize(
  160. self.data,
  161. indices,
  162. fp_out=fp_out,
  163. q_bits=self.quant_config.weight_bits)