1
0

deepspeedfp.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.quantization.base_config import (QuantizationConfig)
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. class DeepSpeedFPConfig(QuantizationConfig):
  9. """Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
  10. Args:
  11. weight_bits: the target quantization bits, 6 or 8.
  12. group_size: group size for quantizaiton, default to 128.
  13. """
  14. def __init__(
  15. self,
  16. weight_bits: int = 8,
  17. group_size: int = 512,
  18. ) -> None:
  19. self.weight_bits = weight_bits
  20. self.group_size = group_size
  21. self.valid_types = [torch.bfloat16, torch.float16]
  22. if self.weight_bits not in (4, 6, 8, 12):
  23. raise ValueError(
  24. "Currently, only 4-bit, 6-bit, 8-bit, and 12-bit weight"
  25. " quantization are "
  26. f"supported for DeepSpeed FP quantizaiton, but got "
  27. f"{self.weight_bits} bits.")
  28. def __repr__(self) -> str:
  29. return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), "
  30. f"group_size={self.group_size}")
  31. @classmethod
  32. def get_name(cls) -> str:
  33. return "DeepSpeedFP"
  34. @classmethod
  35. def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
  36. weight_bits = cls.get_from_keys(config, ["bits"])
  37. group_size = cls.get_from_keys(config, ["group_size"])
  38. return cls(weight_bits=weight_bits, group_size=group_size)
  39. def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
  40. return DeepSpeedFPLinearMethod(self)
  41. def get_scaled_act_names(self) -> List[str]:
  42. return []
  43. @classmethod
  44. def get_supported_act_dtypes(cls) -> List[torch.dtype]:
  45. return [torch.half, torch.bfloat16]
  46. @classmethod
  47. # Need to figure it out
  48. def get_min_capability(cls) -> int:
  49. return 60
  50. @staticmethod
  51. def get_config_filenames() -> List[str]:
  52. return [
  53. "quant_config.json",
  54. "quantize_config.json",
  55. ]
  56. def get_quant_method(
  57. self,
  58. layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]:
  59. if isinstance(layer, LinearBase):
  60. return DeepSpeedFPLinearMethod(self)
  61. return None
  62. class DeepSpeedFPLinearMethod(LinearMethodBase):
  63. """Linear method for DeepSpeedFP quantizer.
  64. Args:
  65. quant_config: the DeepSpeedFP quantization config.
  66. """
  67. def __init__(self, quant_config: DeepSpeedFPConfig):
  68. self.quant_config = quant_config
  69. self.weight = None
  70. def create_weights(self,
  71. layer: torch.nn.Module,
  72. input_size_per_partition: int,
  73. output_partition_sizes: List[int],
  74. input_size: int,
  75. output_size: int,
  76. params_dtype: torch.dtype,
  77. weight_loader=None,
  78. **extra_weight_attrs):
  79. del output_size
  80. del input_size
  81. output_size_per_partition = sum(output_partition_sizes)
  82. weight = DeepSpeedFPParameter(
  83. torch.Size((output_size_per_partition, input_size_per_partition)),
  84. params_dtype=params_dtype,
  85. quant_config=self.quant_config,
  86. )
  87. set_weight_attrs(weight, {
  88. "input_dim": 1,
  89. "output_dim": 0,
  90. })
  91. layer.register_parameter("weight", weight)
  92. def quant_weight_loader(param, loaded_weight, *args, **kwargs):
  93. # Calls the original weight loader (if any), quantizes the result,
  94. # and then loads the quantized parameter.
  95. if weight_loader is not None:
  96. orig_param_data = param.data
  97. param.data = param.ds_dequantize()
  98. weight_loader(param, loaded_weight, *args, **kwargs)
  99. param.data, loaded_weight = orig_param_data, param.data
  100. param.ds_quantize_(loaded_weight.cuda())
  101. extra_weight_attrs["weight_loader"] = quant_weight_loader
  102. set_weight_attrs(weight, extra_weight_attrs)
  103. def apply(self,
  104. layer: torch.nn.Module,
  105. x: torch.Tensor,
  106. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  107. weight = layer.weight
  108. y = weight.ds_dequantize()
  109. return F.linear(x, y, bias)
  110. class DeepSpeedFPParameter(nn.Parameter):
  111. """
  112. DeepSpeedFP quantized parameter class that implements fp8/fp6
  113. quantization deepspeed. Weights are stored in quantized form on
  114. GPUs, and can be dequantized on-the-fly when needed by the model.
  115. """
  116. def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
  117. quant_config: DeepSpeedFPConfig):
  118. try:
  119. import deepspeed
  120. if deepspeed.__version__ < "0.14.2":
  121. raise ImportError("deepspeed version is wrong. Please "
  122. "install deepspeed>=0.14.2.")
  123. from deepspeed.ops.fp_quantizer import FP_Quantize
  124. except ImportError as err:
  125. raise ImportError("Please install deepspeed>=0.14.2 via "
  126. "`pip install deepspeed>=0.14.2` to use "
  127. "deepspeedfp quantizer.") from err
  128. data = torch.empty((
  129. orig_shape.numel() // quant_config.group_size,
  130. quant_config.group_size * quant_config.weight_bits // 8 + 4,
  131. ),
  132. dtype=torch.int8)
  133. self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
  134. self.orig_shape = orig_shape
  135. self.quant_config = quant_config
  136. self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
  137. self.fp_quantizer.orig_shape = orig_shape
  138. self.fp_quantizer.orig_dtype = params_dtype
  139. return self
  140. def ds_quantize_(self, tensor: torch.Tensor):
  141. assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
  142. return self.data.copy_(
  143. self.fp_quantizer.quantize(
  144. tensor.data,
  145. q_bits=self.quant_config.weight_bits,
  146. ))
  147. def ds_dequantize(self, fp_out=None) -> torch.Tensor:
  148. """
  149. Return a tensor containing the dequantized weights of this parameter.
  150. """
  151. assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
  152. return self.fp_quantizer.dequantize(
  153. self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)
  154. def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
  155. """
  156. Return a tensor where only the weights at `indices` are dequantized
  157. (to save HBM -> SRAM bandwidth).
  158. """
  159. assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
  160. return self.fp_quantizer.selective_dequantize(
  161. self.data,
  162. indices,
  163. fp_out=fp_out,
  164. q_bits=self.quant_config.weight_bits)