gguf.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. from __future__ import annotations
  2. import os
  3. from enum import IntEnum
  4. from collections import OrderedDict
  5. from typing import Any, Literal, NamedTuple, TypeVar, Union
  6. import numpy as np
  7. import numpy.typing as npt
  8. GGUF_MAGIC = 0x46554747 # "GGUF"
  9. GGUF_VERSION = 3
  10. GGUF_DEFAULT_ALIGNMENT = 32
  11. READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
  12. class GGMLQuantizationType(IntEnum):
  13. F32 = 0
  14. F16 = 1
  15. Q4_0 = 2
  16. Q4_1 = 3
  17. Q5_0 = 6
  18. Q5_1 = 7
  19. Q8_0 = 8
  20. Q8_1 = 9
  21. Q2_K = 10
  22. Q3_K = 11
  23. Q4_K = 12
  24. Q5_K = 13
  25. Q6_K = 14
  26. Q8_K = 15
  27. IQ2_XXS = 16
  28. IQ2_XS = 17
  29. IQ3_XXS = 18
  30. IQ1_S = 19
  31. IQ4_NL = 20
  32. IQ3_S = 21
  33. IQ2_S = 22
  34. IQ4_XS = 23
  35. QK_K = 256
  36. # Items here are (block size, type size)
  37. GGML_QUANT_SIZES = {
  38. GGMLQuantizationType.F32: (1, 4),
  39. GGMLQuantizationType.F16: (1, 2),
  40. GGMLQuantizationType.Q4_0: (32, 2 + 16),
  41. GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
  42. GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
  43. GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
  44. GGMLQuantizationType.Q8_0: (32, 2 + 32),
  45. GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
  46. GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4),
  47. GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12),
  48. GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12),
  49. GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
  50. GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
  51. GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8),
  52. GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4),
  53. GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32),
  54. GGMLQuantizationType.IQ3_XXS: (256, 2 + 3 * QK_K // 8),
  55. GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16),
  56. GGMLQuantizationType.IQ4_NL: (32, 2 + 32 // 2),
  57. GGMLQuantizationType.IQ3_S:
  58. (256, 2 + QK_K // 4 + QK_K // 32 + QK_K // 8 + QK_K // 64),
  59. GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 32 + QK_K // 32),
  60. GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 64 + QK_K // 2),
  61. }
  62. class GGUFValueType(IntEnum):
  63. UINT8 = 0
  64. INT8 = 1
  65. UINT16 = 2
  66. INT16 = 3
  67. UINT32 = 4
  68. INT32 = 5
  69. FLOAT32 = 6
  70. BOOL = 7
  71. STRING = 8
  72. ARRAY = 9
  73. UINT64 = 10
  74. INT64 = 11
  75. FLOAT64 = 12
  76. @staticmethod
  77. def get_type(val: Any) -> GGUFValueType:
  78. if isinstance(val, (str, bytes, bytearray)):
  79. return GGUFValueType.STRING
  80. elif isinstance(val, list):
  81. return GGUFValueType.ARRAY
  82. elif isinstance(val, float):
  83. return GGUFValueType.FLOAT32
  84. elif isinstance(val, bool):
  85. return GGUFValueType.BOOL
  86. elif isinstance(val, int):
  87. return GGUFValueType.INT32
  88. class ReaderField(NamedTuple):
  89. # Offset to start of this field.
  90. offset: int
  91. # Name of the field (not necessarily from file data).
  92. name: str
  93. # Data parts. Some types have multiple components, such as strings
  94. # that consist of a length followed by the string data.
  95. parts: list[npt.NDArray[Any]] = []
  96. # Indexes into parts that we can call the actual data. For example
  97. # an array of strings will be populated with indexes to the actual
  98. # string data.
  99. data: list[int] = [-1]
  100. types: list[GGUFValueType] = []
  101. class ReaderTensor(NamedTuple):
  102. name: str
  103. tensor_type: GGMLQuantizationType
  104. shape: npt.NDArray[np.uint32]
  105. n_elements: int
  106. n_bytes: int
  107. data_offset: int
  108. data: npt.NDArray[Any]
  109. field: ReaderField
  110. class GGUFReader:
  111. # I - same as host, S - swapped
  112. byte_order: Literal['I' | 'S'] = 'I'
  113. alignment: int = GGUF_DEFAULT_ALIGNMENT
  114. # Note: Internal helper, API may change.
  115. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
  116. GGUFValueType.UINT8: np.uint8,
  117. GGUFValueType.INT8: np.int8,
  118. GGUFValueType.UINT16: np.uint16,
  119. GGUFValueType.INT16: np.int16,
  120. GGUFValueType.UINT32: np.uint32,
  121. GGUFValueType.INT32: np.int32,
  122. GGUFValueType.FLOAT32: np.float32,
  123. GGUFValueType.UINT64: np.uint64,
  124. GGUFValueType.INT64: np.int64,
  125. GGUFValueType.FLOAT64: np.float64,
  126. GGUFValueType.BOOL: np.bool_,
  127. }
  128. def __init__(self,
  129. path: os.PathLike[str] | str,
  130. mode: Literal['r' | 'r+' | 'c'] = 'r'):
  131. self.data = np.memmap(path, mode=mode)
  132. offs = 0
  133. if self._get(offs, np.uint32, override_order='<')[0] != GGUF_MAGIC:
  134. raise ValueError('GGUF magic invalid')
  135. offs += 4
  136. temp_version = self._get(offs, np.uint32)
  137. if temp_version[0] & 65535 == 0:
  138. # If we get 0 here that means it's (probably) a GGUF file created
  139. # for the opposite byte order of the machine this script is
  140. # running on.
  141. self.byte_order = 'S'
  142. temp_version = temp_version.newbyteorder(self.byte_order)
  143. version = temp_version[0]
  144. if version not in READER_SUPPORTED_VERSIONS:
  145. raise ValueError(
  146. f'Sorry, file appears to be version {version} which we cannot '
  147. 'handle')
  148. self.fields: OrderedDict[str, ReaderField] = OrderedDict()
  149. self.tensors: list[ReaderTensor] = []
  150. offs += self._push_field(
  151. ReaderField(offs, 'GGUF.version', [temp_version], [0],
  152. [GGUFValueType.UINT32]))
  153. temp_counts = self._get(offs, np.uint64, 2)
  154. offs += self._push_field(
  155. ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0],
  156. [GGUFValueType.UINT64]))
  157. offs += self._push_field(
  158. ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0],
  159. [GGUFValueType.UINT64]))
  160. tensor_count, kv_count = temp_counts
  161. offs = self._build_fields(offs, kv_count)
  162. offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
  163. new_align = self.fields.get('general.alignment')
  164. if new_align is not None:
  165. if new_align.types != [GGUFValueType.UINT64]:
  166. raise ValueError('Bad type for general.alignment field')
  167. self.alignment = new_align.parts[-1][0]
  168. padding = offs % self.alignment
  169. if padding != 0:
  170. offs += self.alignment - padding
  171. self._build_tensors(offs, tensors_fields)
  172. _DT = TypeVar('_DT', bound=npt.DTypeLike)
  173. # Fetch a key/value metadata field by key.
  174. def get_field(self, key: str) -> Union[ReaderField, None]:
  175. return self.fields.get(key, None)
  176. # Fetch a tensor from the list by index.
  177. def get_tensor(self, idx: int) -> ReaderTensor:
  178. return self.tensors[idx]
  179. def _get(
  180. self,
  181. offset: int,
  182. dtype: npt.DTypeLike,
  183. count: int = 1,
  184. override_order: None | Literal['I' | 'S' | '<'] = None,
  185. ) -> npt.NDArray[Any]:
  186. count = int(count)
  187. itemsize = int(np.empty([], dtype=dtype).itemsize)
  188. end_offs = offset + itemsize * count
  189. return (self.data[offset:end_offs].view(
  190. dtype=dtype)[:count].newbyteorder(override_order
  191. or self.byte_order))
  192. def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
  193. if field.name in self.fields:
  194. raise KeyError(f'Duplicate {field.name} already in list at offset '
  195. f'{field.offset}')
  196. self.fields[field.name] = field
  197. return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
  198. def _get_str(
  199. self, offset: int
  200. ) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
  201. slen = self._get(offset, np.uint64)
  202. return slen, self._get(offset + 8, np.uint8, slen[0])
  203. def _get_field_parts(
  204. self,
  205. orig_offs: int,
  206. raw_type: int,
  207. ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
  208. offs = orig_offs
  209. types: list[GGUFValueType] = []
  210. gtype = GGUFValueType(raw_type)
  211. types.append(gtype)
  212. # Handle strings.
  213. if gtype == GGUFValueType.STRING:
  214. sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
  215. size = sum(int(part.nbytes) for part in sparts)
  216. return size, sparts, [1], types
  217. # Check if it's a simple scalar type.
  218. nptype = self.gguf_scalar_to_np.get(gtype)
  219. if nptype is not None:
  220. val = self._get(offs, nptype)
  221. return int(val.nbytes), [val], [0], types
  222. # Handle arrays.
  223. if gtype == GGUFValueType.ARRAY:
  224. raw_itype = self._get(offs, np.uint32)
  225. offs += int(raw_itype.nbytes)
  226. alen = self._get(offs, np.uint64)
  227. offs += int(alen.nbytes)
  228. aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
  229. data_idxs: list[int] = []
  230. for idx in range(alen[0]):
  231. curr_size, curr_parts, curr_idxs, curr_types = (
  232. self._get_field_parts(offs, raw_itype[0]))
  233. if idx == 0:
  234. types += curr_types
  235. idxs_offs = len(aparts)
  236. aparts += curr_parts
  237. data_idxs += (idx + idxs_offs for idx in curr_idxs)
  238. offs += curr_size
  239. return offs - orig_offs, aparts, data_idxs, types
  240. # We can't deal with this one.
  241. raise ValueError('Unknown/unhandled field type {gtype}')
  242. def _get_tensor(self, orig_offs: int) -> ReaderField:
  243. offs = orig_offs
  244. name_len, name_data = self._get_str(offs)
  245. offs += int(name_len.nbytes + name_data.nbytes)
  246. n_dims = self._get(offs, np.uint32)
  247. offs += int(n_dims.nbytes)
  248. dims = self._get(offs, np.uint64, n_dims[0])
  249. offs += int(dims.nbytes)
  250. raw_dtype = self._get(offs, np.uint32)
  251. offs += int(raw_dtype.nbytes)
  252. offset_tensor = self._get(offs, np.uint64)
  253. offs += int(offset_tensor.nbytes)
  254. return ReaderField(
  255. orig_offs,
  256. str(bytes(name_data), encoding='utf-8'),
  257. [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
  258. [1, 3, 4, 5],
  259. )
  260. def _build_fields(self, offs: int, count: int) -> int:
  261. for _ in range(count):
  262. orig_offs = offs
  263. kv_klen, kv_kdata = self._get_str(offs)
  264. offs += int(kv_klen.nbytes + kv_kdata.nbytes)
  265. raw_kv_type = self._get(offs, np.uint32)
  266. offs += int(raw_kv_type.nbytes)
  267. parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
  268. idxs_offs = len(parts)
  269. field_size, field_parts, field_idxs, field_types = (
  270. self._get_field_parts(offs, raw_kv_type[0]))
  271. parts += field_parts
  272. self._push_field(ReaderField(
  273. orig_offs,
  274. str(bytes(kv_kdata), encoding='utf-8'),
  275. parts,
  276. [idx + idxs_offs for idx in field_idxs],
  277. field_types,
  278. ),
  279. skip_sum=True)
  280. offs += field_size
  281. return offs
  282. def _build_tensors_fields(self, offs: int,
  283. count: int) -> tuple[int, list[ReaderField]]:
  284. tensor_fields = []
  285. for _ in range(count):
  286. field = self._get_tensor(offs)
  287. offs += sum(int(part.nbytes) for part in field.parts)
  288. tensor_fields.append(field)
  289. return offs, tensor_fields
  290. def _build_tensors(self, start_offs: int,
  291. fields: list[ReaderField]) -> None:
  292. tensors = []
  293. for field in fields:
  294. # pylint: disable=unused-variable
  295. (_name_len, name_data, _n_dims, dims, raw_dtype,
  296. offset_tensor) = field.parts
  297. ggml_type = GGMLQuantizationType(raw_dtype[0])
  298. n_elems = np.prod(dims)
  299. block_size, type_size = GGML_QUANT_SIZES[ggml_type]
  300. n_bytes = n_elems * type_size // block_size
  301. data_offs = int(start_offs + offset_tensor[0])
  302. item_type: npt.DTypeLike
  303. if ggml_type == GGMLQuantizationType.F32:
  304. item_count = n_elems
  305. item_type = np.float32
  306. elif ggml_type == GGMLQuantizationType.F16:
  307. item_count = n_elems
  308. item_type = np.float16
  309. else:
  310. item_count = n_bytes
  311. item_type = np.uint8
  312. tensors.append(
  313. ReaderTensor(
  314. name=str(bytes(name_data), encoding='utf-8'),
  315. tensor_type=ggml_type,
  316. shape=dims,
  317. n_elements=n_elems,
  318. n_bytes=n_bytes,
  319. data_offset=data_offs,
  320. data=self._get(data_offs, item_type, item_count),
  321. field=field,
  322. ))
  323. self.tensors = tensors