compressed_tensors_moe.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. import enum
  2. from enum import Enum
  3. from typing import Callable, List, Optional
  4. import torch
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.fused_moe import FusedMoEMethodBase
  7. from aphrodite.modeling.utils import set_weight_attrs
  8. from aphrodite.quantization.compressed_tensors.schemes import (
  9. WNA16_SUPPORTED_BITS)
  10. from aphrodite.quantization.compressed_tensors.utils import CompressionFormat
  11. class GPTQMarlinState(Enum):
  12. REPACK = enum.auto()
  13. READY = enum.auto()
  14. __all__ = ["CompressedTensorsMoEMethod"]
  15. class CompressedTensorsMoEMethod(FusedMoEMethodBase):
  16. def __init__(
  17. self,
  18. quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
  19. ):
  20. self.quant_config = quant_config
  21. # TODO: refactor this to use schemes as other kernels
  22. # are supported + check if the layer is being ignored.
  23. config = self.quant_config.target_scheme_map["Linear"].get("weights")
  24. self.num_bits = config.num_bits
  25. self.packed_factor = 32 // config.num_bits
  26. self.strategy = config.strategy.value
  27. self.group_size = config.group_size
  28. assert config.symmetric, (
  29. "Only symmetric quantization is supported for MoE")
  30. if not (self.quant_config.quant_format
  31. == CompressionFormat.pack_quantized.value
  32. and self.num_bits in WNA16_SUPPORTED_BITS):
  33. raise ValueError("For Fused MoE layers, only ",
  34. f"{CompressionFormat.pack_quantized.value} ",
  35. "is supported for the following bits: ",
  36. f"{WNA16_SUPPORTED_BITS}")
  37. def create_weights(self, layer: torch.nn.Module, num_experts: int,
  38. hidden_size: int, intermediate_size: int,
  39. params_dtype: torch.dtype, **extra_weight_attrs):
  40. # Will transpose the loaded weight along the
  41. # intermediate and hidden dim sizes. Will
  42. # shard for TP along the transposed dims
  43. extra_weight_attrs.update({
  44. "is_transposed": True,
  45. "quant_method": self.strategy
  46. })
  47. w13_weight = torch.nn.Parameter(torch.empty(num_experts,
  48. hidden_size //
  49. self.packed_factor,
  50. 2 * intermediate_size,
  51. dtype=torch.int32),
  52. requires_grad=False)
  53. layer.register_parameter("w13_weight_packed", w13_weight)
  54. set_weight_attrs(w13_weight, extra_weight_attrs)
  55. w2_weight = torch.nn.Parameter(torch.empty(num_experts,
  56. intermediate_size //
  57. self.packed_factor,
  58. hidden_size,
  59. dtype=torch.int32),
  60. requires_grad=False)
  61. layer.register_parameter("w2_weight_packed", w2_weight)
  62. set_weight_attrs(w2_weight, extra_weight_attrs)
  63. if self.strategy == "channel":
  64. num_groups_w2 = num_groups_w13 = 1
  65. self.group_size = -1
  66. else:
  67. num_groups_w2 = intermediate_size // self.group_size
  68. num_groups_w13 = hidden_size // self.group_size
  69. w13_scale = torch.nn.Parameter(torch.ones(num_experts,
  70. num_groups_w13,
  71. 2 * intermediate_size,
  72. dtype=params_dtype),
  73. requires_grad=False)
  74. layer.register_parameter("w13_weight_scale", w13_scale)
  75. set_weight_attrs(w13_scale, extra_weight_attrs)
  76. w2_scale = torch.nn.Parameter(torch.ones(num_experts,
  77. num_groups_w2,
  78. hidden_size,
  79. dtype=params_dtype),
  80. requires_grad=False)
  81. layer.register_parameter("w2_weight_scale", w2_scale)
  82. set_weight_attrs(w2_scale, extra_weight_attrs)
  83. w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
  84. requires_grad=False)
  85. layer.register_parameter("w2_weight_shape", w2_weight_shape)
  86. set_weight_attrs(w2_weight_shape, extra_weight_attrs)
  87. w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
  88. requires_grad=False)
  89. layer.register_parameter("w13_weight_shape", w13_weight_shape)
  90. set_weight_attrs(w13_weight_shape, extra_weight_attrs)
  91. w13_g_idx = torch.nn.Parameter(
  92. torch.empty(
  93. num_experts,
  94. hidden_size,
  95. dtype=torch.int32,
  96. ),
  97. requires_grad=False,
  98. )
  99. layer.register_parameter("w13_g_idx", w13_g_idx)
  100. set_weight_attrs(w13_g_idx, extra_weight_attrs)
  101. w2_g_idx = torch.nn.Parameter(
  102. torch.empty(
  103. num_experts,
  104. intermediate_size,
  105. dtype=torch.int32,
  106. ),
  107. requires_grad=False,
  108. )
  109. layer.register_parameter("w2_g_idx", w2_g_idx)
  110. set_weight_attrs(w2_g_idx, extra_weight_attrs)
  111. w13_g_idx_sort_indices = torch.nn.Parameter(
  112. torch.empty(
  113. num_experts,
  114. hidden_size,
  115. dtype=torch.int32,
  116. ),
  117. requires_grad=False,
  118. )
  119. layer.register_parameter("w13_g_idx_sort_indices",
  120. w13_g_idx_sort_indices)
  121. set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
  122. w2_g_idx_sort_indices = torch.nn.Parameter(
  123. torch.empty(
  124. num_experts,
  125. intermediate_size,
  126. dtype=torch.int32,
  127. ),
  128. requires_grad=False,
  129. )
  130. layer.register_parameter("w2_g_idx_sort_indices",
  131. w2_g_idx_sort_indices)
  132. set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
  133. layer.a13_scale = None
  134. layer.a2_scale = None
  135. layer.marlin_state = GPTQMarlinState.REPACK
  136. def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
  137. def replace_tensor(name, new_t):
  138. # It is important to use resize_() here since it ensures
  139. # the same buffer is reused
  140. getattr(layer, name).resize_(new_t.shape)
  141. getattr(layer, name).copy_(new_t)
  142. del new_t
  143. def get_scale_perms(num_bits: int):
  144. scale_perm: List[int] = []
  145. for i in range(8):
  146. scale_perm.extend([i + 8 * j for j in range(8)])
  147. scale_perm_single: List[int] = []
  148. for i in range(4):
  149. scale_perm_single.extend(
  150. [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
  151. return scale_perm, scale_perm_single
  152. def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
  153. group_size: int, num_bits: int):
  154. scale_perm, scale_perm_single = get_scale_perms(num_bits)
  155. if group_size < size_k and group_size != -1:
  156. s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
  157. else:
  158. s = s.reshape((-1, len(scale_perm_single)))[:,
  159. scale_perm_single]
  160. s = s.reshape((-1, size_n)).contiguous()
  161. return s
  162. def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
  163. size_n: int, group_size: int,
  164. num_bits: int):
  165. num_experts = s.shape[0]
  166. output = torch.empty((num_experts, s.shape[1], s.shape[2]),
  167. device=s.device,
  168. dtype=s.dtype)
  169. for e in range(num_experts):
  170. output[e] = marlin_permute_scales(s[e], size_k, size_n,
  171. group_size, num_bits)
  172. return output
  173. size_k2 = layer.w2_weight_packed.shape[2]
  174. size_k13 = layer.w13_weight_packed.shape[2]
  175. num_experts = layer.w13_g_idx.shape[0]
  176. device = layer.w13_g_idx.device
  177. layer.w13_g_idx = torch.nn.Parameter(
  178. torch.empty((num_experts, 0), dtype=torch.int32, device=device),
  179. requires_grad=False,
  180. )
  181. layer.w2_g_idx = torch.nn.Parameter(
  182. torch.empty((num_experts, 0), dtype=torch.int32, device=device),
  183. requires_grad=False,
  184. )
  185. layer.w13_g_idx_sort_indices = torch.nn.Parameter(
  186. torch.empty((num_experts, 0), dtype=torch.int32, device=device),
  187. requires_grad=False,
  188. )
  189. layer.w2_g_idx_sort_indices = torch.nn.Parameter(
  190. torch.empty((num_experts, 0), dtype=torch.int32, device=device),
  191. requires_grad=False,
  192. )
  193. marlin_w13_qweight = ops.gptq_marlin_moe_repack(
  194. layer.w13_weight_packed,
  195. layer.w13_g_idx_sort_indices,
  196. layer.w13_weight_packed.shape[1] * self.packed_factor,
  197. layer.w13_weight_packed.shape[2],
  198. self.num_bits,
  199. )
  200. replace_tensor("w13_weight_packed", marlin_w13_qweight)
  201. marlin_w2_qweight = ops.gptq_marlin_moe_repack(
  202. layer.w2_weight_packed,
  203. layer.w2_g_idx_sort_indices,
  204. layer.w2_weight_packed.shape[1] * self.packed_factor,
  205. layer.w2_weight_packed.shape[2],
  206. self.num_bits,
  207. )
  208. replace_tensor("w2_weight_packed", marlin_w2_qweight)
  209. # Repack scales
  210. marlin_w13_scales = marlin_moe_permute_scales(
  211. layer.w13_weight_scale,
  212. size_k13,
  213. layer.w13_weight_scale.shape[2],
  214. self.group_size,
  215. self.num_bits,
  216. )
  217. replace_tensor("w13_weight_scale", marlin_w13_scales)
  218. marlin_w2_scales = marlin_moe_permute_scales(
  219. layer.w2_weight_scale,
  220. layer.w2_weight_scale.shape[1] * self.packed_factor,
  221. size_k2,
  222. self.group_size,
  223. self.num_bits,
  224. )
  225. replace_tensor("w2_weight_scale", marlin_w2_scales)
  226. def apply(
  227. self,
  228. layer: torch.nn.Module,
  229. x: torch.Tensor,
  230. router_logits: torch.Tensor,
  231. top_k: int,
  232. renormalize: bool = True,
  233. use_grouped_topk: bool = False,
  234. num_expert_group: Optional[int] = None,
  235. topk_group: Optional[int] = None,
  236. custom_routing_function: Optional[Callable] = None,
  237. ) -> torch.Tensor:
  238. from aphrodite.modeling.layers.fused_moe.fused_moe import (
  239. fused_marlin_moe)
  240. return fused_marlin_moe(x,
  241. layer.w13_weight_packed,
  242. layer.w2_weight_packed,
  243. router_logits,
  244. layer.w13_g_idx,
  245. layer.w2_g_idx,
  246. layer.w13_g_idx_sort_indices,
  247. layer.w2_g_idx_sort_indices,
  248. top_k,
  249. custom_routing_function,
  250. renormalize=renormalize,
  251. w1_scale=layer.w13_weight_scale,
  252. w2_scale=layer.w2_weight_scale)