1
0

tpu_int8.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from typing import Any, Dict, List, Optional, Tuple
  2. import torch
  3. from torch.nn import Module
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. ACTIVATION_SCHEMES = ["none"]
  9. class Int8TpuConfig(QuantizationConfig):
  10. """Int8 Quantization Config class for TPU Backend."""
  11. def __init__(
  12. self,
  13. activation_scheme: str = "none",
  14. ) -> None:
  15. if activation_scheme not in ACTIVATION_SCHEMES:
  16. raise ValueError(
  17. f"Unsupported activation scheme {activation_scheme}")
  18. self.activation_scheme = activation_scheme
  19. def get_name(self) -> str:
  20. return "tpu_int8"
  21. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  22. return [torch.float16, torch.bfloat16]
  23. @classmethod
  24. def get_min_capability(cls) -> int:
  25. raise NotImplementedError(
  26. "This function should not be called with TPU Backend")
  27. @staticmethod
  28. def get_config_filenames() -> List[str]:
  29. return []
  30. @classmethod
  31. def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig":
  32. activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
  33. return cls(activation_scheme=activation_scheme)
  34. def get_quant_method(self, layer: Module,
  35. prefix: str) -> Optional["TPUInt8LinearMethod"]:
  36. if isinstance(layer, LinearBase):
  37. return TPUInt8LinearMethod(self)
  38. return None
  39. def get_scaled_act_names(self) -> List[str]:
  40. return []
  41. class TPUInt8LinearMethod(LinearMethodBase):
  42. """Int8 Linear method for TPU Quant. """
  43. def __init__(self, quant_config: Int8TpuConfig):
  44. self.quant_config = quant_config
  45. def create_weights(self, layer: Module, input_size_per_partition: int,
  46. output_partition_sizes: List[int], input_size: int,
  47. output_size: int, params_dtype: torch.dtype,
  48. **extra_weight_attrs):
  49. weight = Parameter(torch.empty(sum(output_partition_sizes),
  50. input_size_per_partition,
  51. dtype=params_dtype),
  52. requires_grad=False)
  53. layer.register_parameter("weight", weight)
  54. set_weight_attrs(weight, {
  55. **extra_weight_attrs,
  56. "input_dim": 1,
  57. "output_dim": 0,
  58. })
  59. def _quantize_weight(
  60. self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  61. weight_dtype = weight.dtype
  62. weight = weight.cpu().to(torch.float32)
  63. n_bit = 8
  64. eps = 1e-5
  65. max_int = 2**(n_bit - 1) - 1
  66. min_int = -(2**(n_bit - 1))
  67. max_val = weight.abs().amax(dim=-1, keepdim=True)
  68. max_val = max_val.clamp(min=eps)
  69. qscale = max_val / max_int
  70. qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int,
  71. max_int).to(torch.int8)
  72. qscale = qscale.squeeze().to(weight_dtype)
  73. return qweight, qscale
  74. def process_weights_after_loading(self, layer: Module) -> None:
  75. device = layer.weight.device
  76. qweight, qscale = self._quantize_weight(layer.weight)
  77. qweight = qweight.to(device)
  78. qscale = qscale.to(device)
  79. layer.weight = Parameter(qweight, requires_grad=False)
  80. layer.scale = Parameter(qscale, requires_grad=False)
  81. def apply(self,
  82. layer: torch.nn.Module,
  83. x: torch.Tensor,
  84. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  85. try:
  86. import torch_xla.experimental.xla_quantized_matmul # noqa: F401
  87. except ImportError as err:
  88. raise ImportError(
  89. "Please install torch_xla by following the instructions at "
  90. "https://aphrodite.pygmalion.chat/pages/installation/installation-tpu.html " # noqa: E501
  91. "to run Aphrodite on TPU.") from err
  92. weight = layer.weight
  93. scale = layer.scale
  94. out = torch.ops.xla.quantized_matmul(x, weight, scale)
  95. if bias is not None:
  96. out = out + bias
  97. return out