compressed_tensors_w8a8.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from typing import Callable, List, Tuple, Union
  2. import torch
  3. from torch.nn import Parameter
  4. from aphrodite.modeling.utils import set_weight_attrs
  5. from aphrodite.quantization.compressed_tensors.schemes import \
  6. CompressedTensorsScheme
  7. from aphrodite.quantization.compressed_tensors.utils import \
  8. QuantizationStrategy
  9. class CompressedTensorsW8A8(CompressedTensorsScheme):
  10. def __init__(self, strategy: str):
  11. self.strategy = strategy
  12. def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
  13. if isinstance(shard_id, int):
  14. return shard_id
  15. assert isinstance(shard_id, str)
  16. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  17. assert shard_id in qkv_idxs
  18. return qkv_idxs[shard_id]
  19. def scales_shard_splitter(
  20. self, param: torch.Tensor, loaded_weight: torch.Tensor,
  21. shard_id: Union[str, int],
  22. logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  23. shard_id = self._shard_id_as_int(shard_id)
  24. offset = sum(logical_widths[:shard_id])
  25. size = logical_widths[shard_id]
  26. # update loaded weight with copies for broadcast.
  27. loaded_weight = loaded_weight.repeat(size)
  28. return param[offset:offset + size], loaded_weight
  29. def create_weights(self, layer: torch.nn.Module,
  30. output_partition_sizes: List[int],
  31. input_size_per_partition: int,
  32. params_dtype: torch.dtype, weight_loader: Callable,
  33. **kwargs):
  34. is_tensor_partitioned = len(output_partition_sizes) != 1
  35. weight_scale_dim = sum(output_partition_sizes) if (
  36. is_tensor_partitioned
  37. or self.strategy == QuantizationStrategy.CHANNEL) else 1
  38. shape: Union[Tuple[int], Tuple[int, int]] = (weight_scale_dim, )
  39. if self.strategy == QuantizationStrategy.CHANNEL:
  40. shape = (weight_scale_dim, 1)
  41. weight_scale = Parameter(torch.empty(*shape, dtype=torch.float32),
  42. requires_grad=False)
  43. layer.register_parameter("weight_scale", weight_scale)
  44. set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
  45. weight = Parameter(torch.empty(sum(output_partition_sizes),
  46. input_size_per_partition,
  47. dtype=torch.int8),
  48. requires_grad=False)
  49. layer.register_parameter("weight", weight)
  50. set_weight_attrs(
  51. weight, {
  52. "input_dim": 1,
  53. "output_dim": 0,
  54. "weight_loader": weight_loader,
  55. "logical_widths": output_partition_sizes
  56. })
  57. # Don't need a shard_splitter for channel-wise quantization
  58. # Use the default loading method
  59. if self.strategy == QuantizationStrategy.CHANNEL:
  60. set_weight_attrs(weight_scale, {
  61. "output_dim": 0,
  62. })
  63. else:
  64. set_weight_attrs(
  65. weight_scale, {
  66. "logical_widths": output_partition_sizes,
  67. "shard_splitter": self.scales_shard_splitter,
  68. })