|
@@ -1,37 +1,17 @@
|
|
-from typing import Callable, List, Tuple, Union
|
|
|
|
|
|
+from typing import Callable, List
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from torch.nn import Parameter
|
|
from torch.nn import Parameter
|
|
|
|
|
|
from aphrodite import _custom_ops as custom_ops
|
|
from aphrodite import _custom_ops as custom_ops
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
from aphrodite.modeling.utils import set_weight_attrs
|
|
-from aphrodite.quantization.compressed_tensors.schemes import \
|
|
|
|
- CompressedTensorsScheme
|
|
|
|
|
|
+from aphrodite.quantization.compressed_tensors.schemes.compressed_tensors_w8a8 import \
|
|
|
|
+ CompressedTensorsW8A8 # noqa: E501
|
|
|
|
|
|
__all__ = ["CompressedTensorsW8A8StaticTensor"]
|
|
__all__ = ["CompressedTensorsW8A8StaticTensor"]
|
|
|
|
|
|
|
|
|
|
-class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
|
|
|
-
|
|
|
|
- def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
|
|
|
- if isinstance(shard_id, int):
|
|
|
|
- return shard_id
|
|
|
|
-
|
|
|
|
- assert isinstance(shard_id, str)
|
|
|
|
- qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
|
|
|
- assert shard_id in qkv_idxs
|
|
|
|
- return qkv_idxs[shard_id]
|
|
|
|
-
|
|
|
|
- def scales_shard_splitter(
|
|
|
|
- self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
|
|
|
- shard_id: Union[str, int],
|
|
|
|
- logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
- shard_id = self._shard_id_as_int(shard_id)
|
|
|
|
- offset = sum(logical_widths[:shard_id])
|
|
|
|
- size = logical_widths[shard_id]
|
|
|
|
- # update loaded weight with copies for broadcast.
|
|
|
|
- loaded_weight = loaded_weight.repeat(size)
|
|
|
|
- return param[offset:offset + size], loaded_weight
|
|
|
|
|
|
+class CompressedTensorsW8A8StaticTensor(CompressedTensorsW8A8):
|
|
|
|
|
|
def create_weights(self, layer: torch.nn.Module,
|
|
def create_weights(self, layer: torch.nn.Module,
|
|
output_partition_sizes: List[int],
|
|
output_partition_sizes: List[int],
|
|
@@ -39,41 +19,21 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
|
params_dtype: torch.dtype, weight_loader: Callable,
|
|
params_dtype: torch.dtype, weight_loader: Callable,
|
|
**kwargs):
|
|
**kwargs):
|
|
|
|
|
|
- is_tensor_partitioned = len(output_partition_sizes) != 1
|
|
|
|
- weight_scale_dim = sum(
|
|
|
|
- output_partition_sizes) if is_tensor_partitioned else 1
|
|
|
|
|
|
+ super().create_weights(
|
|
|
|
+ layer=layer,
|
|
|
|
+ output_partition_sizes=output_partition_sizes,
|
|
|
|
+ input_size_per_partition=input_size_per_partition,
|
|
|
|
+ params_dtype=params_dtype,
|
|
|
|
+ weight_loader=weight_loader)
|
|
|
|
|
|
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
|
input_scale = Parameter(torch.empty(1, dtype=torch.float32),
|
|
requires_grad=False)
|
|
requires_grad=False)
|
|
|
|
|
|
- weight_scale = Parameter(torch.empty(weight_scale_dim,
|
|
|
|
- dtype=torch.float32),
|
|
|
|
- requires_grad=False)
|
|
|
|
-
|
|
|
|
- weight = Parameter(torch.empty(sum(output_partition_sizes),
|
|
|
|
- input_size_per_partition,
|
|
|
|
- dtype=torch.int8),
|
|
|
|
- requires_grad=False)
|
|
|
|
-
|
|
|
|
- layer.register_parameter("weight", weight)
|
|
|
|
- set_weight_attrs(weight, {
|
|
|
|
- "weight_loader": weight_loader,
|
|
|
|
- "input_dim": 1,
|
|
|
|
- "output_dim": 0,
|
|
|
|
- })
|
|
|
|
layer.register_parameter("input_scale", input_scale)
|
|
layer.register_parameter("input_scale", input_scale)
|
|
set_weight_attrs(input_scale, {
|
|
set_weight_attrs(input_scale, {
|
|
"weight_loader": weight_loader,
|
|
"weight_loader": weight_loader,
|
|
"ignore_warning": True,
|
|
"ignore_warning": True,
|
|
})
|
|
})
|
|
- layer.register_parameter("weight_scale", weight_scale)
|
|
|
|
- set_weight_attrs(
|
|
|
|
- weight_scale, {
|
|
|
|
- "weight_loader": weight_loader,
|
|
|
|
- "shard_splitter": self.scales_shard_splitter,
|
|
|
|
- "logical_widths": output_partition_sizes,
|
|
|
|
- "ignore_warning": True,
|
|
|
|
- })
|
|
|
|
|
|
|
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
|
weight = layer.weight
|
|
weight = layer.weight
|
|
@@ -84,4 +44,4 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
|
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
|
|
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
|
|
|
|
|
|
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
|
|
return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
|
|
- weight_scale, x.dtype)
|
|
|
|
|
|
+ weight_scale, x.dtype)
|