marlin.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. from typing import Optional, Tuple
  2. import torch
  3. from aphrodite import _custom_ops as ops
  4. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  5. permute_param_layout_)
  6. from aphrodite.quantization.utils.marlin_utils import (
  7. MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
  8. check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
  9. marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
  10. query_marlin_supported_quant_types)
  11. from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
  12. class MarlinLinearKernel(MPLinearKernel):
  13. @classmethod
  14. def get_min_capability(cls) -> int:
  15. return 80
  16. @classmethod
  17. def can_implement(cls,
  18. c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
  19. if c.zero_points:
  20. return False, "Zero points currently not supported by "\
  21. " MarlinLinearKernel. Will be added when AWQMarlin "\
  22. "is migrated over to using MPLinearKernel backend"
  23. quant_types = query_marlin_supported_quant_types(c.zero_points)
  24. if c.weight_type not in quant_types:
  25. return False, f"Quant type ({c.weight_type}) not supported by"\
  26. f" Marlin, supported types are: {quant_types}"
  27. if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
  28. return False, f"Group size ({c.group_size}) not supported by "\
  29. "Marlin, supported group sizes are: "\
  30. f"{MARLIN_SUPPORTED_GROUP_SIZES}"
  31. return check_marlin_supports_shape(c.partition_weight_shape[0],
  32. c.partition_weight_shape[1],
  33. c.full_weight_shape[1],
  34. c.group_size)
  35. # note assumes that
  36. # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
  37. # `weight_scale` is: {input_dim = 0, output_dim = 1}
  38. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  39. device = getattr(layer, self.w_q_name).device
  40. c = self.config
  41. row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
  42. self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
  43. # Allocate marlin workspace.
  44. self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
  45. device)
  46. # Default names since marlin requires empty parameters for these,
  47. # TODO: remove this requirement from marlin (allow optional tensors)
  48. if self.w_gidx_name is None:
  49. self.w_gidx_name = "g_idx"
  50. if self.w_zp_name is None:
  51. self.w_zp_name = "w_zp"
  52. if c.has_g_idx:
  53. g_idx, g_idx_sort_indices = marlin_sort_g_idx(
  54. getattr(layer, self.w_gidx_name))
  55. self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
  56. layer.g_idx_sort_indices = g_idx_sort_indices
  57. else:
  58. setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
  59. layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
  60. if c.zero_points:
  61. pass
  62. # TODO (lucas): add the following when AWQMarlin is migrated over to
  63. # using MPLinearKernel backend
  64. # self._transform_param(layer, self.w_zp_name, lambda x: \
  65. # marlin_zero_points(
  66. # x,
  67. # size_k=c.partition_weight_shape[0],
  68. # size_n=c.partition_weight_shape[1],
  69. # num_bits=c.weight_type.size_bits))
  70. else:
  71. setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
  72. def transform_w_q(x):
  73. assert isinstance(x, BaseAphroditeParameter)
  74. permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
  75. x.data = ops.gptq_marlin_repack(x.data.contiguous(),
  76. perm=layer.g_idx_sort_indices,
  77. size_k=c.partition_weight_shape[0],
  78. size_n=c.partition_weight_shape[1],
  79. num_bits=c.weight_type.size_bits)
  80. return x
  81. def transform_w_s(x):
  82. assert isinstance(x, BaseAphroditeParameter)
  83. permute_param_layout_(x, input_dim=0, output_dim=1)
  84. x.data = marlin_permute_scales(x.data.contiguous(),
  85. size_k=c.partition_weight_shape[0],
  86. size_n=c.partition_weight_shape[1],
  87. group_size=c.group_size)
  88. return x
  89. self._transform_param(layer, self.w_q_name, transform_w_q)
  90. self._transform_param(layer, self.w_s_name, transform_w_s)
  91. def apply_weights(self,
  92. layer: torch.nn.Module,
  93. x: torch.Tensor,
  94. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  95. c = self.config
  96. w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
  97. # `process_weights_after_loading` will ensure w_zp and w_gidx are not
  98. # None for marlin
  99. return apply_gptq_marlin_linear(
  100. input=x,
  101. weight=w_q,
  102. weight_scale=w_s,
  103. weight_zp=w_zp, # type: ignore
  104. g_idx=w_gidx, # type: ignore
  105. g_idx_sort_indices=layer.g_idx_sort_indices,
  106. workspace=self.workspace,
  107. wtype=c.weight_type,
  108. input_size_per_partition=c.partition_weight_shape[0],
  109. output_size_per_partition=c.partition_weight_shape[1],
  110. is_k_full=self.is_k_full,
  111. bias=bias)