machete.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from functools import partial
  2. from typing import Optional, Tuple
  3. import torch
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  6. permute_param_layout_)
  7. from aphrodite.quantization.utils.machete_utils import (
  8. MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
  9. query_machete_supported_quant_types)
  10. from aphrodite.quantization.utils.quant_utils import (
  11. pack_weights_into_int32, unpack_weights_into_int32)
  12. from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
  13. class MacheteLinearKernel(MPLinearKernel):
  14. @classmethod
  15. def get_min_capability(cls) -> int:
  16. return 90
  17. @classmethod
  18. def can_implement(cls,
  19. c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
  20. if c.has_g_idx and\
  21. c.partition_weight_shape[0] != c.full_weight_shape[0]:
  22. return False, "Act reordering currently not supported by Machete, "\
  23. "when the input features are partitioned across "\
  24. "devices"
  25. if c.zero_points:
  26. return False, "Zero points currently not supported by "\
  27. " Compressed Tensors + Machete. (Kernel supports it"\
  28. " but CompressedTensorsWNA16 does not so support has"\
  29. " not been added to MacheteWNA16Kernel yet"
  30. if c.weight_type not in query_machete_supported_quant_types(
  31. c.zero_points):
  32. return False, f"Quant type ({c.weight_type}) not supported by "\
  33. "Machete, supported types are: "\
  34. f"{query_machete_supported_quant_types(c.zero_points)}"
  35. if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
  36. return False, f"Group size ({c.group_size}) not supported by "\
  37. "Machete, supported group sizes are: "\
  38. f"{MACHETE_SUPPORTED_GROUP_SIZES}"
  39. return check_machete_supports_shape(c.partition_weight_shape[0],
  40. c.partition_weight_shape[1])
  41. # note assumes that
  42. # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
  43. # `weight_scale` is: {input_dim = 0, output_dim = 1}
  44. def process_weights_after_loading(self, layer: torch.nn.Module):
  45. c = self.config
  46. if c.has_g_idx:
  47. assert self.w_gidx_name is not None
  48. perm = torch.argsort(getattr(layer, self.w_gidx_name))\
  49. .to(torch.int)
  50. self.act_perm = lambda x: x[:, perm]
  51. # use `ops.permute_cols` if possible
  52. if c.act_type in [torch.float16, torch.bfloat16] \
  53. and c.partition_weight_shape[0] % 8 == 0:
  54. self.act_perm = partial(ops.permute_cols, perm=perm)
  55. def transform_w_q(x):
  56. assert isinstance(x, BaseAphroditeParameter)
  57. permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
  58. if c.has_g_idx:
  59. x_unpacked = unpack_weights_into_int32(x.data,
  60. c.weight_type,
  61. packed_dim=0)
  62. x_perm = x_unpacked[perm, :]
  63. x.data = pack_weights_into_int32(x_perm,
  64. c.weight_type,
  65. packed_dim=0)
  66. x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
  67. self.config.weight_type)
  68. return x
  69. def transform_w_s(x):
  70. assert isinstance(x, BaseAphroditeParameter)
  71. permute_param_layout_(x, input_dim=0, output_dim=1)
  72. x.data = x.data.contiguous()
  73. return x
  74. # Repack weights and scales for Machete
  75. self._transform_param(layer, self.w_q_name, transform_w_q)
  76. self._transform_param(layer, self.w_s_name, transform_w_s)
  77. def apply_weights(self,
  78. layer: torch.nn.Module,
  79. x: torch.Tensor,
  80. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  81. c = self.config
  82. w_q, w_s, _, _ = self._get_weight_params(layer)
  83. x_2d = x.reshape(-1, x.shape[-1])
  84. out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
  85. if c.has_g_idx:
  86. x_2d = self.act_perm(x_2d)
  87. output = ops.machete_gemm(a=x_2d,
  88. b_q=w_q,
  89. b_type=c.weight_type,
  90. b_zeros=None,
  91. b_scales=w_s,
  92. b_group_size=c.group_size)
  93. if bias is not None:
  94. output.add_(bias) # In-place add
  95. return output.reshape(out_shape)