1
0

gguf_reader.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # ruff: noqa
  2. #
  3. # GGUF file reading/modification support. For API usage information,
  4. # please see the files scripts/ for some fairly simple examples.
  5. #
  6. from __future__ import annotations
  7. import os
  8. from collections import OrderedDict
  9. from typing import Any, Literal, NamedTuple, TypeVar, Union
  10. import numpy as np
  11. import numpy.typing as npt
  12. if __name__ == "__main__":
  13. import sys
  14. from pathlib import Path
  15. # Allow running file in package as a script.
  16. sys.path.insert(0, str(Path(__file__).parent.parent))
  17. from .constants import (GGML_QUANT_SIZES, GGUF_DEFAULT_ALIGNMENT, GGUF_MAGIC,
  18. GGUF_VERSION, GGMLQuantizationType, GGUFValueType)
  19. READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
  20. class ReaderField(NamedTuple):
  21. # Offset to start of this field.
  22. offset: int
  23. # Name of the field (not necessarily from file data).
  24. name: str
  25. # Data parts. Some types have multiple components, such as strings
  26. # that consist of a length followed by the string data.
  27. parts: list[npt.NDArray[Any]] = []
  28. # Indexes into parts that we can call the actual data. For example
  29. # an array of strings will be populated with indexes to the actual
  30. # string data.
  31. data: list[int] = [-1]
  32. types: list[GGUFValueType] = []
  33. class ReaderTensor(NamedTuple):
  34. name: str
  35. tensor_type: GGMLQuantizationType
  36. shape: npt.NDArray[np.uint32]
  37. n_elements: int
  38. n_bytes: int
  39. data_offset: int
  40. data: npt.NDArray[Any]
  41. field: ReaderField
  42. class GGUFReader:
  43. # I - same as host, S - swapped
  44. byte_order: Literal['I' | 'S'] = 'I'
  45. alignment: int = GGUF_DEFAULT_ALIGNMENT
  46. # Note: Internal helper, API may change.
  47. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
  48. GGUFValueType.UINT8: np.uint8,
  49. GGUFValueType.INT8: np.int8,
  50. GGUFValueType.UINT16: np.uint16,
  51. GGUFValueType.INT16: np.int16,
  52. GGUFValueType.UINT32: np.uint32,
  53. GGUFValueType.INT32: np.int32,
  54. GGUFValueType.FLOAT32: np.float32,
  55. GGUFValueType.UINT64: np.uint64,
  56. GGUFValueType.INT64: np.int64,
  57. GGUFValueType.FLOAT64: np.float64,
  58. GGUFValueType.BOOL: np.bool_,
  59. }
  60. def __init__(self,
  61. path: os.PathLike[str] | str,
  62. mode: Literal['r' | 'r+' | 'c'] = 'r'):
  63. self.data = np.memmap(path, mode=mode)
  64. offs = 0
  65. if self._get(offs, np.uint32, override_order='<')[0] != GGUF_MAGIC:
  66. raise ValueError('GGUF magic invalid')
  67. offs += 4
  68. temp_version = self._get(offs, np.uint32)
  69. if temp_version[0] & 65535 == 0:
  70. # If we get 0 here that means it's (probably) a GGUF file created for
  71. # the opposite byte order of the machine this script is running on.
  72. self.byte_order = 'S'
  73. temp_version = temp_version.newbyteorder(self.byte_order)
  74. version = temp_version[0]
  75. if version not in READER_SUPPORTED_VERSIONS:
  76. raise ValueError(
  77. f'Sorry, file appears to be version {version} which we cannot handle'
  78. )
  79. self.fields: OrderedDict[str, ReaderField] = OrderedDict()
  80. self.tensors: list[ReaderTensor] = []
  81. offs += self._push_field(
  82. ReaderField(offs, 'GGUF.version', [temp_version], [0],
  83. [GGUFValueType.UINT32]))
  84. temp_counts = self._get(offs, np.uint64, 2)
  85. offs += self._push_field(
  86. ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0],
  87. [GGUFValueType.UINT64]))
  88. offs += self._push_field(
  89. ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0],
  90. [GGUFValueType.UINT64]))
  91. tensor_count, kv_count = temp_counts
  92. offs = self._build_fields(offs, kv_count)
  93. offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
  94. new_align = self.fields.get('general.alignment')
  95. if new_align is not None:
  96. if new_align.types != [GGUFValueType.UINT32]:
  97. raise ValueError('Bad type for general.alignment field')
  98. self.alignment = new_align.parts[-1][0]
  99. padding = offs % self.alignment
  100. if padding != 0:
  101. offs += self.alignment - padding
  102. self._build_tensors(offs, tensors_fields)
  103. _DT = TypeVar('_DT', bound=npt.DTypeLike)
  104. # Fetch a key/value metadata field by key.
  105. def get_field(self, key: str) -> Union[ReaderField, None]:
  106. return self.fields.get(key, None)
  107. # Fetch a tensor from the list by index.
  108. def get_tensor(self, idx: int) -> ReaderTensor:
  109. return self.tensors[idx]
  110. def _get(
  111. self,
  112. offset: int,
  113. dtype: npt.DTypeLike,
  114. count: int = 1,
  115. override_order: None | Literal['I' | 'S' | '<'] = None,
  116. ) -> npt.NDArray[Any]:
  117. count = int(count)
  118. itemsize = int(np.empty([], dtype=dtype).itemsize)
  119. end_offs = offset + itemsize * count
  120. return (self.data[offset:end_offs].view(
  121. dtype=dtype)[:count].newbyteorder(override_order
  122. or self.byte_order))
  123. def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
  124. if field.name in self.fields:
  125. raise KeyError(
  126. f'Duplicate {field.name} already in list at offset {field.offset}'
  127. )
  128. self.fields[field.name] = field
  129. return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
  130. def _get_str(
  131. self, offset: int
  132. ) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
  133. slen = self._get(offset, np.uint64)
  134. return slen, self._get(offset + 8, np.uint8, slen[0])
  135. def _get_field_parts(
  136. self,
  137. orig_offs: int,
  138. raw_type: int,
  139. ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
  140. offs = orig_offs
  141. types: list[GGUFValueType] = []
  142. gtype = GGUFValueType(raw_type)
  143. types.append(gtype)
  144. # Handle strings.
  145. if gtype == GGUFValueType.STRING:
  146. sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
  147. size = sum(int(part.nbytes) for part in sparts)
  148. return size, sparts, [1], types
  149. # Check if it's a simple scalar type.
  150. nptype = self.gguf_scalar_to_np.get(gtype)
  151. if nptype is not None:
  152. val = self._get(offs, nptype)
  153. return int(val.nbytes), [val], [0], types
  154. # Handle arrays.
  155. if gtype == GGUFValueType.ARRAY:
  156. raw_itype = self._get(offs, np.uint32)
  157. offs += int(raw_itype.nbytes)
  158. alen = self._get(offs, np.uint64)
  159. offs += int(alen.nbytes)
  160. aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
  161. data_idxs: list[int] = []
  162. for idx in range(alen[0]):
  163. curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(
  164. offs, raw_itype[0])
  165. if idx == 0:
  166. types += curr_types
  167. idxs_offs = len(aparts)
  168. aparts += curr_parts
  169. data_idxs += (idx + idxs_offs for idx in curr_idxs)
  170. offs += curr_size
  171. return offs - orig_offs, aparts, data_idxs, types
  172. # We can't deal with this one.
  173. raise ValueError('Unknown/unhandled field type {gtype}')
  174. def _get_tensor(self, orig_offs: int) -> ReaderField:
  175. offs = orig_offs
  176. name_len, name_data = self._get_str(offs)
  177. offs += int(name_len.nbytes + name_data.nbytes)
  178. n_dims = self._get(offs, np.uint32)
  179. offs += int(n_dims.nbytes)
  180. dims = self._get(offs, np.uint64, n_dims[0])
  181. offs += int(dims.nbytes)
  182. raw_dtype = self._get(offs, np.uint32)
  183. offs += int(raw_dtype.nbytes)
  184. offset_tensor = self._get(offs, np.uint64)
  185. offs += int(offset_tensor.nbytes)
  186. return ReaderField(
  187. orig_offs,
  188. str(bytes(name_data), encoding='utf-8'),
  189. [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
  190. [1, 3, 4, 5],
  191. )
  192. def _build_fields(self, offs: int, count: int) -> int:
  193. for _ in range(count):
  194. orig_offs = offs
  195. kv_klen, kv_kdata = self._get_str(offs)
  196. offs += int(kv_klen.nbytes + kv_kdata.nbytes)
  197. raw_kv_type = self._get(offs, np.uint32)
  198. offs += int(raw_kv_type.nbytes)
  199. parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
  200. idxs_offs = len(parts)
  201. field_size, field_parts, field_idxs, field_types = self._get_field_parts(
  202. offs, raw_kv_type[0])
  203. parts += field_parts
  204. self._push_field(ReaderField(
  205. orig_offs,
  206. str(bytes(kv_kdata), encoding='utf-8'),
  207. parts,
  208. [idx + idxs_offs for idx in field_idxs],
  209. field_types,
  210. ),
  211. skip_sum=True)
  212. offs += field_size
  213. return offs
  214. def _build_tensors_fields(self, offs: int,
  215. count: int) -> tuple[int, list[ReaderField]]:
  216. tensor_fields = []
  217. for _ in range(count):
  218. field = self._get_tensor(offs)
  219. offs += sum(int(part.nbytes) for part in field.parts)
  220. tensor_fields.append(field)
  221. return offs, tensor_fields
  222. def _build_tensors(self, start_offs: int,
  223. fields: list[ReaderField]) -> None:
  224. tensors = []
  225. for field in fields:
  226. _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
  227. ggml_type = GGMLQuantizationType(raw_dtype[0])
  228. n_elems = np.prod(dims)
  229. block_size, type_size = GGML_QUANT_SIZES[ggml_type]
  230. n_bytes = n_elems * type_size // block_size
  231. data_offs = int(start_offs + offset_tensor[0])
  232. item_type: npt.DTypeLike
  233. if ggml_type == GGMLQuantizationType.F16:
  234. item_count = n_elems
  235. item_type = np.float16
  236. elif ggml_type == GGMLQuantizationType.F32:
  237. item_count = n_elems
  238. item_type = np.float32
  239. elif ggml_type == GGMLQuantizationType.F64:
  240. item_count = n_elems
  241. item_type = np.float64
  242. elif ggml_type == GGMLQuantizationType.I8:
  243. item_count = n_elems
  244. item_type = np.int8
  245. elif ggml_type == GGMLQuantizationType.I16:
  246. item_count = n_elems
  247. item_type = np.int16
  248. elif ggml_type == GGMLQuantizationType.I32:
  249. item_count = n_elems
  250. item_type = np.int32
  251. elif ggml_type == GGMLQuantizationType.I64:
  252. item_count = n_elems
  253. item_type = np.int64
  254. else:
  255. item_count = n_bytes
  256. item_type = np.uint8
  257. tensors.append(
  258. ReaderTensor(
  259. name=str(bytes(name_data), encoding='utf-8'),
  260. tensor_type=ggml_type,
  261. shape=dims,
  262. n_elements=n_elems,
  263. n_bytes=n_bytes,
  264. data_offset=data_offs,
  265. data=self._get(data_offs, item_type, item_count),
  266. field=field,
  267. ))
  268. self.tensors = tensors