1
0

gguf.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from typing import Any, Dict, List, Optional
  2. import torch
  3. from torch.nn.parameter import Parameter
  4. from aphrodite._C import ops
  5. from aphrodite.modeling.layers.linear import (LinearMethodBase,
  6. set_weight_attrs)
  7. from aphrodite.modeling.layers.quantization.base_config import (
  8. QuantizationConfig)
  9. GGML_QUANT_SIZES = {
  10. 0: (1, 4),
  11. 1: (1, 2),
  12. 2: (32, 2 + 16),
  13. 3: (32, 2 + 2 + 16),
  14. 6: (32, 2 + 4 + 16),
  15. 7: (32, 2 + 2 + 4 + 16),
  16. 8: (32, 2 + 32),
  17. 9: (32, 4 + 4 + 32),
  18. 10: (256, 2 + 2 + 256 // 16 + 256 // 4),
  19. 11: (256, 2 + 256 // 4 + 256 // 8 + 12),
  20. 12: (256, 2 + 2 + 256 // 2 + 12),
  21. 13: (256, 2 + 2 + 256 // 2 + 256 // 8 + 12),
  22. 14: (256, 2 + 256 // 2 + 256 // 4 + 256 // 16),
  23. 15: (256, 4 + 256 + 256 // 8),
  24. 16: (256, 2 + 256 // 4),
  25. 17: (256, 2 + 256 // 4 + 256 // 32),
  26. }
  27. class GGUFConfig(QuantizationConfig):
  28. """Config class for GGUF"""
  29. def __repr__(self) -> str:
  30. return "GGUFConfig()"
  31. def get_name(self) -> str:
  32. return "gguf"
  33. def get_supported_act_dtypes(self) -> List[torch.dtype]:
  34. return [torch.half]
  35. def get_min_capability(self) -> int:
  36. return 70
  37. @staticmethod
  38. def get_config_filenames() -> List[str]:
  39. return []
  40. @classmethod
  41. def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
  42. return cls()
  43. def get_linear_method(self) -> "GGUFLinearMethod":
  44. return GGUFLinearMethod(self)
  45. def get_scaled_act_names(self) -> List[str]:
  46. return []
  47. def merge_weight(self) -> bool:
  48. return False
  49. def rope_style(self) -> Optional[bool]:
  50. return False
  51. def quant_vocab(self) -> Optional[bool]:
  52. return True
  53. class GGUFLinearMethod(LinearMethodBase):
  54. """Linear method for GGUF.
  55. Args:
  56. quant_config: The GGUF quantization config.
  57. """
  58. def __init__(self, quant_config: GGUFConfig):
  59. self.quant_config = quant_config
  60. def create_weights(self, input_size_per_partition: int,
  61. output_size_per_partition: int, input_size: int,
  62. output_size: int,
  63. params_dtype: torch.dtype) -> Dict[str, Any]:
  64. # The type of weight is unknown until load state dict
  65. weight = torch.nn.parameter.UninitializedParameter(requires_grad=False)
  66. # No need for pack_factor because we don't fuse qkv layers anyway.
  67. set_weight_attrs(weight, {
  68. "input_dim": 1,
  69. "output_dim": 0,
  70. })
  71. weight_type = Parameter(
  72. torch.tensor((1), dtype=torch.int, device="cuda"),
  73. requires_grad=False,
  74. )
  75. set_weight_attrs(weight_type, {"ignore_warning": True})
  76. return {"weight": weight, "weight_type": weight_type}
  77. def apply_weights(self,
  78. weights: Dict[str, Any],
  79. x: torch.Tensor,
  80. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  81. if isinstance(weights["weight_type"], torch.Tensor):
  82. weights["weight_type"] = int(weights["weight_type"])
  83. # Check tensor parallel shape here on first pass
  84. block_size = GGML_QUANT_SIZES[weights["weight_type"]][1]
  85. if weights["weight"].shape[1] % block_size != 0:
  86. raise ValueError("Size is not aligned with the quantized "
  87. "weight shape.")
  88. weight = weights["weight"]
  89. weight_type = weights["weight_type"]
  90. infeatures = x.shape[-1]
  91. outfeatures = weight.shape[0]
  92. out_shape = x.shape[:-1] + (weight.shape[0], )
  93. reshaped_x = x.reshape(-1, x.shape[-1])
  94. xshape = x.view(-1, x.shape[-1])
  95. if xshape.shape[0] == 1:
  96. out = ops.ggml_mul_mat_vec_a8(weight, reshaped_x, weight_type,
  97. outfeatures)
  98. elif xshape.shape[0] < 8 and weight_type < 16:
  99. out = ops.ggml_mul_mat_a8(weight, reshaped_x, weight_type,
  100. outfeatures)
  101. else:
  102. weight = ops.ggml_dequantize(weight, weight_type, outfeatures,
  103. infeatures)
  104. out = reshaped_x @ weight.T
  105. if bias is not None:
  106. out = out + bias
  107. return out.reshape(out_shape)
  108. def apply_embedding(self, weights: Dict[str, torch.Tensor],
  109. x: torch.Tensor) -> torch.Tensor:
  110. if isinstance(weights["weight_type"], torch.Tensor):
  111. weights["weight_type"] = int(weights["weight_type"])
  112. weight = weights["weight"]
  113. weight_type = weights["weight_type"]
  114. dim, block_size = GGML_QUANT_SIZES[weights["weight_type"]]
  115. vocab_size = weight.shape[0]
  116. hidden_size = weight.shape[1] // block_size * dim
  117. if weight_type < 2:
  118. return torch.embedding(weight.view(vocab_size, -1), x)
  119. x_flat = x.flatten()
  120. quant = torch.index_select(weight.view(vocab_size, -1),
  121. dim=0,
  122. index=x_flat)
  123. dequant = ops.ggml_dequantize(quant, weight_type, hidden_size,
  124. x_flat.shape[0])
  125. return dequant.view(*x.shape, hidden_size)