gguf.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from typing import Any, Dict, List, Optional
  2. import gguf
  3. import torch
  4. from torch.nn.parameter import Parameter, UninitializedParameter
  5. from aphrodite import _custom_ops as ops
  6. from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
  7. from aphrodite.modeling.layers.vocab_parallel_embedding import (
  8. VocabParallelEmbedding)
  9. from aphrodite.modeling.utils import set_weight_attrs
  10. from aphrodite.quantization.base_config import (QuantizationConfig,
  11. QuantizeMethodBase)
  12. class GGUFConfig(QuantizationConfig):
  13. """Config class for GGUF."""
  14. def __init__(self, ) -> None:
  15. pass
  16. def __repr__(self) -> str:
  17. return ("GGUFConfig()")
  18. def get_name(self) -> str:
  19. return "gguf"
  20. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  21. return [torch.half, torch.bfloat16]
  22. @classmethod
  23. def get_min_capability(cls) -> int:
  24. return 60
  25. @classmethod
  26. def get_config_filenames(cls) -> List[str]:
  27. return [] # no extra configs.
  28. @classmethod
  29. def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
  30. return cls()
  31. def get_quant_method(self, layer: torch.nn.Module,
  32. prefix: str) -> Optional["QuantizeMethodBase"]:
  33. if isinstance(layer, LinearBase):
  34. return GGUFLinearMethod(self)
  35. elif isinstance(layer, VocabParallelEmbedding):
  36. return GGUFEmbeddingMethod(self)
  37. return None
  38. def get_scaled_act_names(self) -> List[str]:
  39. return []
  40. def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
  41. qweight_type: int) -> torch.Tensor:
  42. # use dequantize mulmat for IQmatrix, mmq for k-quants
  43. if qweight_type >= 16:
  44. block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
  45. shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
  46. weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
  47. y = x @ weight.T
  48. else:
  49. y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
  50. return y
  51. class GGUFLinearMethod(LinearMethodBase):
  52. """Linear method for GGUF.
  53. Args:
  54. quant_config: The GGUF quantization config.
  55. """
  56. def __init__(self, quant_config: GGUFConfig):
  57. self.quant_config = quant_config
  58. def create_weights(self, layer: torch.nn.Module,
  59. input_size_per_partition: int,
  60. output_partition_sizes: List[int], input_size: int,
  61. output_size: int, params_dtype: torch.dtype,
  62. **extra_weight_attrs):
  63. output_size_per_partition = sum(output_partition_sizes)
  64. tensor_shape = (output_size_per_partition, input_size_per_partition)
  65. qweight = UninitializedParameter(requires_grad=False)
  66. set_weight_attrs(
  67. qweight, {
  68. "input_dim": 1,
  69. "output_dim": 0,
  70. "tensor_shape": tensor_shape,
  71. "is_gguf_weight": True,
  72. "shard_size": {},
  73. "shard_id": [],
  74. })
  75. set_weight_attrs(qweight, extra_weight_attrs)
  76. layer.register_parameter("qweight", qweight)
  77. qweight_type = Parameter(torch.empty(len(output_partition_sizes),
  78. dtype=torch.uint8),
  79. requires_grad=False)
  80. set_weight_attrs(
  81. qweight_type, {
  82. "is_gguf_weight_type": True,
  83. "weight_type": 0,
  84. "shard_weight_type": {},
  85. "ignore_warning": True
  86. })
  87. set_weight_attrs(qweight_type, extra_weight_attrs)
  88. layer.register_parameter("qweight_type", qweight_type)
  89. def apply(self,
  90. layer: torch.nn.Module,
  91. x: torch.Tensor,
  92. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  93. shard_size = getattr(layer.qweight, "shard_size", None)
  94. shard_id = getattr(layer.qweight, "shard_id", None)
  95. if shard_id and shard_size:
  96. result = []
  97. offset = 0
  98. # dequantize shard weights respectively
  99. shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
  100. for id in shard_id:
  101. shard_weight = layer.qweight[
  102. offset:offset +
  103. shard_size[id][0], :shard_size[id][1]].contiguous()
  104. qweight_type = layer.qweight_type.shard_weight_type[id]
  105. result.append(_fuse_mul_mat(x, shard_weight,
  106. qweight_type).contiguous())
  107. offset += shard_size[id][0]
  108. out = torch.cat(result, dim=-1)
  109. else:
  110. qweight = layer.qweight
  111. qweight_type = layer.qweight_type.weight_type
  112. out = _fuse_mul_mat(x, qweight, qweight_type)
  113. if bias is not None:
  114. out.add_(bias)
  115. return out
  116. class GGUFEmbeddingMethod(GGUFLinearMethod):
  117. """Embedding method for GGUF.
  118. Args:
  119. quant_config: The GGUF quantization config.
  120. """
  121. def embedding(self, layer: torch.nn.Module,
  122. x: torch.Tensor) -> torch.Tensor:
  123. qweight = layer.qweight
  124. qweight_type = layer.qweight_type.weight_type
  125. block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
  126. hidden_size = qweight.shape[1] // type_size * block_size
  127. if qweight_type < 2:
  128. return torch.embedding(qweight, x)
  129. x_flat = x.flatten()
  130. quant = torch.index_select(qweight, dim=0, index=x_flat)
  131. dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
  132. x_flat.shape[0])
  133. return dequant.view(*x.shape, hidden_size)