1
0

gguf.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from contextlib import suppress
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  6. from aphrodite.modeling.utils import set_weight_attrs
  7. from aphrodite.quantization.base_config import QuantizationConfig
  8. HAS_QUANTS = False
  9. with suppress(ImportError):
  10. from aphrodite._quant_C import quant_ops as ops
  11. HAS_QUANTS = True
  12. GGML_QUANT_SIZES = {
  13. 0: (1, 4), # F32
  14. 1: (1, 2), # F16
  15. 2: (32, 2 + 16), # Q4_0
  16. 3: (32, 2 + 2 + 16), # Q4_1
  17. 6: (32, 2 + 4 + 16), # Q5_0
  18. 7: (32, 2 + 2 + 4 + 16), # Q5_1
  19. 8: (32, 2 + 32), # Q8_0
  20. 9: (32, 4 + 4 + 32), # Q8_1
  21. 10: (256, 2 + 2 + 256 // 16 + 256 // 4), # Q2_K
  22. 11: (256, 2 + 256 // 4 + 256 // 8 + 12), # Q3_K
  23. 12: (256, 2 + 2 + 256 // 2 + 12), # Q4_K
  24. 13: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12), # Q5_K
  25. 14: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16), # Q6_K
  26. 15: (256, 4 + 256 + 256 // 8), # Q8_K
  27. 16: (256, 2 + 256 // 4), # IQ2_XXS
  28. 17: (256, 2 + 256 // 4 + 256 // 32), # IQ2_XS
  29. 18: (256, 2 + 3 * 256 // 8), # IQ3_XXS
  30. 19: (256, 2 + 256 // 8 + 256 // 16), # IQ1_S
  31. 20: (32, 2 + 32 // 2), # IQ4_NL
  32. 21: (256, 2 + 256 // 4 + 256 // 32 + 256 // 8 + 256 // 64), # IQ3_S
  33. 22: (256, 2 + 256 // 4 + 256 // 32 + 256 // 32), # IQ2_S
  34. 23: (256, 2 + 2 + 256 // 64 + 256 // 2), # IQ4_XS
  35. }
  36. class GGUFConfig(QuantizationConfig):
  37. """Config class for GGUF"""
  38. def __repr__(self) -> str:
  39. return ("GGUFConfig()")
  40. def get_name(self) -> str:
  41. return "gguf"
  42. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  43. return [torch.half]
  44. def get_min_capability(self) -> int:
  45. return 61
  46. @staticmethod
  47. def get_config_filenames() -> List[str]:
  48. return []
  49. @classmethod
  50. def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
  51. return cls()
  52. def get_quant_method(
  53. self, layer: torch.nn.Module) -> Optional["GGUFLinearMethod"]:
  54. if isinstance(layer, LinearBase):
  55. return GGUFLinearMethod(self)
  56. return None
  57. def get_scaled_act_names(self) -> List[str]:
  58. return []
  59. def merge_weight(self) -> bool:
  60. return False
  61. def rope_style(self) -> Optional[bool]:
  62. return False
  63. def quant_vocab(self) -> List[bool]:
  64. return [True, True]
  65. def support_fused_moe(self) -> bool:
  66. return False
  67. class GGUFLinearMethod(LinearMethodBase):
  68. """Linear method for GGUF.
  69. Args:
  70. quant_config: The GGUF quantization config.
  71. """
  72. def __init__(self, quant_config: GGUFConfig):
  73. if not HAS_QUANTS:
  74. raise ImportError("Could not find the quantization kernels.")
  75. self.quant_config = quant_config
  76. def create_weights(self, layer: torch.nn.Module,
  77. input_size_per_partition: int,
  78. output_partition_sizes: List[int], input_size: int,
  79. output_size: int, params_dtype: torch.dtype,
  80. **extra_weight_attrs):
  81. # The type of weight is unknown until load state dict
  82. weight = torch.nn.parameter.UninitializedParameter(requires_grad=False)
  83. # No need for pack_factor because we don't fuse qkv layers anyway.
  84. set_weight_attrs(weight, {
  85. "input_dim": 1,
  86. "output_dim": 0,
  87. })
  88. layer.register_parameter("weight", weight)
  89. weight_type = Parameter(
  90. torch.tensor((1), dtype=torch.int, device="cuda"),
  91. requires_grad=False,
  92. )
  93. set_weight_attrs(weight_type, {"ignore_warning": True})
  94. layer.register_parameter("weight_type", weight_type)
  95. def apply(self,
  96. layer: torch.nn.Module,
  97. x: torch.Tensor,
  98. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  99. if isinstance(layer.weight_type, torch.Tensor):
  100. layer.weight_type = int(layer.weight_type)
  101. # Check tensor parallel shape here on first pass
  102. block_size = GGML_QUANT_SIZES[layer.weight_type][1]
  103. if layer.weight.shape[1] % block_size != 0:
  104. raise ValueError("Size is not aligned with the quantized "
  105. "weight shape.")
  106. weight = layer.weight
  107. weight_type = layer.weight_type
  108. infeatures = x.shape[-1]
  109. outfeatures = weight.shape[0]
  110. out_shape = x.shape[:-1] + (weight.shape[0], )
  111. reshaped_x = x.reshape(-1, x.shape[-1])
  112. xshape = x.view(-1, x.shape[-1])
  113. if xshape.shape[0] == 1:
  114. out = ops.ggml_mul_mat_vec_a8(weight, reshaped_x, weight_type,
  115. outfeatures)
  116. elif xshape.shape[0] < 8 and weight_type < 16:
  117. out = ops.ggml_mul_mat_a8(weight, reshaped_x, weight_type,
  118. outfeatures)
  119. else:
  120. weight = ops.ggml_dequantize(weight, weight_type, outfeatures,
  121. infeatures)
  122. out = reshaped_x @ weight.T
  123. if bias is not None:
  124. out = out + bias
  125. return out.reshape(out_shape)
  126. def apply_embedding(self, layer: torch.nn.Module,
  127. x: torch.Tensor) -> torch.Tensor:
  128. if isinstance(layer.weight_type, torch.Tensor):
  129. layer.weight_type = int(layer.weight_type)
  130. weight = layer.weight
  131. weight_type = layer.weight_type
  132. dim, block_size = GGML_QUANT_SIZES[weight_type]
  133. vocab_size = weight.shape[0]
  134. hidden_size = weight.shape[1] // block_size * dim
  135. if weight_type < 2:
  136. return torch.embedding(weight.view(vocab_size, -1), x)
  137. x_flat = x.flatten()
  138. quant = torch.index_select(weight.view(vocab_size, -1),
  139. dim=0,
  140. index=x_flat)
  141. dequant = ops.ggml_dequantize(quant, weight_type, hidden_size,
  142. x_flat.shape[0])
  143. return dequant.view(*x.shape, hidden_size)
  144. def apply_moe_weights(self, w1: Dict[str,
  145. torch.Tensor], w2: Dict[str,
  146. torch.Tensor],
  147. x: torch.Tensor, gating_output: torch.Tensor,
  148. topk: int, renormalize: bool) -> torch.Tensor:
  149. raise NotImplementedError