gguf.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from typing import Any, Dict, List, Optional
  2. from contextlib import suppress
  3. import torch
  4. from torch.nn.parameter import Parameter
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.quantization.base_config import (QuantizationConfig)
  8. HAS_QUANTS = False
  9. with suppress(ImportError):
  10. from aphrodite._quant_C import quant_ops as ops
  11. HAS_QUANTS = True
  12. GGML_QUANT_SIZES = {
  13. 0: (1, 4), # F32
  14. 1: (1, 2), # F16
  15. 2: (32, 2 + 16), # Q4_0
  16. 3: (32, 2 + 2 + 16), # Q4_1
  17. 6: (32, 2 + 4 + 16), # Q5_0
  18. 7: (32, 2 + 2 + 4 + 16), # Q5_1
  19. 8: (32, 2 + 32), # Q8_0
  20. 9: (32, 4 + 4 + 32), # Q8_1
  21. 10: (256, 2 + 2 + 256 // 16 + 256 // 4), # Q2_K
  22. 11: (256, 2 + 256 // 4 + 256 // 8 + 12), # Q3_K
  23. 12: (256, 2 + 2 + 256 // 2 + 12), # Q4_K
  24. 13: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12), # Q5_K
  25. 14: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16), # Q6_K
  26. 15: (256, 4 + 256 + 256 // 8), # Q8_K
  27. 16: (256, 2 + 256 // 4), # IQ2_XXS
  28. 17: (256, 2 + 256 // 4 + 256 // 32), # IQ2_XS
  29. 18: (256, 2 + 3 * 256 // 8), # IQ3_XXS
  30. 19: (256, 2 + 256 // 8 + 256 // 16), # IQ1_S
  31. 20: (32, 2 + 32 // 2), # IQ4_NL
  32. 21: (256, 2 + 256 // 4 + 256 // 32 + 256 // 8 + 256 // 64), # IQ3_S
  33. 22: (256, 2 + 256 // 4 + 256 // 32 + 256 // 32), # IQ2_S
  34. 23: (256, 2 + 2 + 256 // 64 + 256 // 2), # IQ4_XS
  35. }
  36. class GGUFConfig(QuantizationConfig):
  37. """Config class for GGUF"""
  38. def __repr__(self) -> str:
  39. return ("GGUFConfig()")
  40. def get_name(self) -> str:
  41. return "gguf"
  42. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  43. return [torch.half]
  44. def get_min_capability(self) -> int:
  45. return 61
  46. @staticmethod
  47. def get_config_filenames() -> List[str]:
  48. return []
  49. @classmethod
  50. def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
  51. return cls()
  52. def get_linear_method(self) -> "GGUFLinearMethod":
  53. return GGUFLinearMethod(self)
  54. def get_scaled_act_names(self) -> List[str]:
  55. return []
  56. def merge_weight(self) -> bool:
  57. return False
  58. def rope_style(self) -> Optional[bool]:
  59. return False
  60. def quant_vocab(self) -> List[bool]:
  61. return [True, True]
  62. def support_fused_moe(self) -> bool:
  63. return False
  64. class GGUFLinearMethod(LinearMethodBase):
  65. """Linear method for GGUF.
  66. Args:
  67. quant_config: The GGUF quantization config.
  68. """
  69. def __init__(self, quant_config: GGUFConfig):
  70. if not HAS_QUANTS:
  71. raise ImportError("Could not find the quantization kernels.")
  72. self.quant_config = quant_config
  73. def create_weights(
  74. self,
  75. layer: torch.nn.Module,
  76. input_size_per_partition: int,
  77. output_partition_sizes: List[int],
  78. input_size: int,
  79. output_size: int,
  80. params_dtype: torch.dtype,
  81. **extra_weight_attrs,
  82. ):
  83. # The type of weight is unknown until load state dict
  84. weight = torch.nn.parameter.UninitializedParameter(requires_grad=False)
  85. # No need for pack_factor because we don't fuse qkv layers anyway.
  86. set_weight_attrs(weight, {
  87. "input_dim": 1,
  88. "output_dim": 0,
  89. })
  90. weight_type = Parameter(
  91. torch.tensor((1), dtype=torch.int, device="cuda"),
  92. requires_grad=False,
  93. )
  94. set_weight_attrs(weight_type, {"ignore_warning": True})
  95. layer.register_parameter("weight", weight)
  96. set_weight_attrs(weight, extra_weight_attrs)
  97. layer.register_parameter("weight_type", weight_type)
  98. set_weight_attrs(weight_type, extra_weight_attrs)
  99. def apply_weights(self,
  100. layer: torch.nn.Module,
  101. x: torch.Tensor,
  102. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  103. if not hasattr(layer, "weight_type_int"):
  104. layer.weight_type_int = int(layer.weight_type)
  105. # Check tensor parallel shape here on first pass
  106. block_size = GGML_QUANT_SIZES[layer.weight_type_int][1]
  107. if layer.weight.shape[1] % block_size != 0:
  108. raise ValueError("Size is not aligned with the quantized "
  109. "weight shape.")
  110. weight = layer.weight
  111. weight_type = layer.weight_type_int
  112. infeatures = x.shape[-1]
  113. outfeatures = weight.shape[0]
  114. out_shape = x.shape[:-1] + (weight.shape[0], )
  115. reshaped_x = x.reshape(-1, x.shape[-1])
  116. xshape = x.view(-1, x.shape[-1])
  117. if xshape.shape[0] == 1:
  118. out = ops.ggml_mul_mat_vec_a8(weight, reshaped_x, weight_type,
  119. outfeatures)
  120. elif xshape.shape[0] < 8 and weight_type < 16:
  121. out = ops.ggml_mul_mat_a8(weight, reshaped_x, weight_type,
  122. outfeatures)
  123. else:
  124. weight = ops.ggml_dequantize(weight, weight_type, outfeatures,
  125. infeatures)
  126. out = reshaped_x @ weight.T
  127. if bias is not None:
  128. out.add_(bias)
  129. return out.reshape(out_shape)
  130. def apply_embedding(self, layer: torch.nn.Module,
  131. x: torch.Tensor) -> torch.Tensor:
  132. if not hasattr(layer, "weight_type_int"):
  133. layer.weight_type_int = int(layer.weight_type)
  134. weight = layer.weight
  135. weight_type = layer.weight_type_int
  136. dim, block_size = GGML_QUANT_SIZES[weight_type]
  137. vocab_size = weight.shape[0]
  138. hidden_size = weight.shape[1] // block_size * dim
  139. if weight_type < 2:
  140. return torch.embedding(weight.view(vocab_size, -1), x)
  141. x_flat = x.flatten()
  142. quant = torch.index_select(weight.view(vocab_size, -1),
  143. dim=0,
  144. index=x_flat)
  145. dequant = ops.ggml_dequantize(quant, weight_type, hidden_size,
  146. x_flat.shape[0])
  147. return dequant.view(*x.shape, hidden_size)
  148. def apply_moe_weights(self, w1: Dict[str,
  149. torch.Tensor], w2: Dict[str,
  150. torch.Tensor],
  151. x: torch.Tensor, gating_output: torch.Tensor,
  152. topk: int, renormalize: bool) -> torch.Tensor:
  153. raise NotImplementedError