gguf_reader.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # ruff: noqa
  2. #
  3. # GGUF file reading/modification support. For API usage information,
  4. # please see the files scripts/ for some fairly simple examples.
  5. #
  6. from __future__ import annotations
  7. import os
  8. from collections import OrderedDict
  9. from typing import Any, Literal, NamedTuple, TypeVar, Union
  10. import numpy as np
  11. import numpy.typing as npt
  12. if __name__ == "__main__":
  13. import sys
  14. from pathlib import Path
  15. # Allow running file in package as a script.
  16. sys.path.insert(0, str(Path(__file__).parent.parent))
  17. from .constants import (
  18. GGML_QUANT_SIZES,
  19. GGUF_DEFAULT_ALIGNMENT,
  20. GGUF_MAGIC,
  21. GGUF_VERSION,
  22. GGMLQuantizationType,
  23. GGUFValueType,
  24. )
  25. READER_SUPPORTED_VERSIONS = [2, GGUF_VERSION]
  26. class ReaderField(NamedTuple):
  27. # Offset to start of this field.
  28. offset: int
  29. # Name of the field (not necessarily from file data).
  30. name: str
  31. # Data parts. Some types have multiple components, such as strings
  32. # that consist of a length followed by the string data.
  33. parts: list[npt.NDArray[Any]] = []
  34. # Indexes into parts that we can call the actual data. For example
  35. # an array of strings will be populated with indexes to the actual
  36. # string data.
  37. data: list[int] = [-1]
  38. types: list[GGUFValueType] = []
  39. class ReaderTensor(NamedTuple):
  40. name: str
  41. tensor_type: GGMLQuantizationType
  42. shape: npt.NDArray[np.uint32]
  43. n_elements: int
  44. n_bytes: int
  45. data_offset: int
  46. data: npt.NDArray[Any]
  47. field: ReaderField
  48. class GGUFReader:
  49. # I - same as host, S - swapped
  50. byte_order: Literal['I' | 'S'] = 'I'
  51. alignment: int = GGUF_DEFAULT_ALIGNMENT
  52. # Note: Internal helper, API may change.
  53. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = {
  54. GGUFValueType.UINT8: np.uint8,
  55. GGUFValueType.INT8: np.int8,
  56. GGUFValueType.UINT16: np.uint16,
  57. GGUFValueType.INT16: np.int16,
  58. GGUFValueType.UINT32: np.uint32,
  59. GGUFValueType.INT32: np.int32,
  60. GGUFValueType.FLOAT32: np.float32,
  61. GGUFValueType.UINT64: np.uint64,
  62. GGUFValueType.INT64: np.int64,
  63. GGUFValueType.FLOAT64: np.float64,
  64. GGUFValueType.BOOL: np.bool_,
  65. }
  66. def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'):
  67. self.data = np.memmap(path, mode = mode)
  68. offs = 0
  69. if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC:
  70. raise ValueError('GGUF magic invalid')
  71. offs += 4
  72. temp_version = self._get(offs, np.uint32)
  73. if temp_version[0] & 65535 == 0:
  74. # If we get 0 here that means it's (probably) a GGUF file created for
  75. # the opposite byte order of the machine this script is running on.
  76. self.byte_order = 'S'
  77. temp_version = temp_version.newbyteorder(self.byte_order)
  78. version = temp_version[0]
  79. if version not in READER_SUPPORTED_VERSIONS:
  80. raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle')
  81. self.fields: OrderedDict[str, ReaderField] = OrderedDict()
  82. self.tensors: list[ReaderTensor] = []
  83. offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32]))
  84. temp_counts = self._get(offs, np.uint64, 2)
  85. offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64]))
  86. offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64]))
  87. tensor_count, kv_count = temp_counts
  88. offs = self._build_fields(offs, kv_count)
  89. offs, tensors_fields = self._build_tensors_fields(offs, tensor_count)
  90. new_align = self.fields.get('general.alignment')
  91. if new_align is not None:
  92. if new_align.types != [GGUFValueType.UINT32]:
  93. raise ValueError('Bad type for general.alignment field')
  94. self.alignment = new_align.parts[-1][0]
  95. padding = offs % self.alignment
  96. if padding != 0:
  97. offs += self.alignment - padding
  98. self._build_tensors(offs, tensors_fields)
  99. _DT = TypeVar('_DT', bound = npt.DTypeLike)
  100. # Fetch a key/value metadata field by key.
  101. def get_field(self, key: str) -> Union[ReaderField, None]:
  102. return self.fields.get(key, None)
  103. # Fetch a tensor from the list by index.
  104. def get_tensor(self, idx: int) -> ReaderTensor:
  105. return self.tensors[idx]
  106. def _get(
  107. self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None,
  108. ) -> npt.NDArray[Any]:
  109. count = int(count)
  110. itemsize = int(np.empty([], dtype = dtype).itemsize)
  111. end_offs = offset + itemsize * count
  112. return (
  113. self.data[offset:end_offs]
  114. .view(dtype = dtype)[:count]
  115. .newbyteorder(override_order or self.byte_order)
  116. )
  117. def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
  118. if field.name in self.fields:
  119. raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}')
  120. self.fields[field.name] = field
  121. return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts)
  122. def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]:
  123. slen = self._get(offset, np.uint64)
  124. return slen, self._get(offset + 8, np.uint8, slen[0])
  125. def _get_field_parts(
  126. self, orig_offs: int, raw_type: int,
  127. ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]:
  128. offs = orig_offs
  129. types: list[GGUFValueType] = []
  130. gtype = GGUFValueType(raw_type)
  131. types.append(gtype)
  132. # Handle strings.
  133. if gtype == GGUFValueType.STRING:
  134. sparts: list[npt.NDArray[Any]] = list(self._get_str(offs))
  135. size = sum(int(part.nbytes) for part in sparts)
  136. return size, sparts, [1], types
  137. # Check if it's a simple scalar type.
  138. nptype = self.gguf_scalar_to_np.get(gtype)
  139. if nptype is not None:
  140. val = self._get(offs, nptype)
  141. return int(val.nbytes), [val], [0], types
  142. # Handle arrays.
  143. if gtype == GGUFValueType.ARRAY:
  144. raw_itype = self._get(offs, np.uint32)
  145. offs += int(raw_itype.nbytes)
  146. alen = self._get(offs, np.uint64)
  147. offs += int(alen.nbytes)
  148. aparts: list[npt.NDArray[Any]] = [raw_itype, alen]
  149. data_idxs: list[int] = []
  150. for idx in range(alen[0]):
  151. curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0])
  152. if idx == 0:
  153. types += curr_types
  154. idxs_offs = len(aparts)
  155. aparts += curr_parts
  156. data_idxs += (idx + idxs_offs for idx in curr_idxs)
  157. offs += curr_size
  158. return offs - orig_offs, aparts, data_idxs, types
  159. # We can't deal with this one.
  160. raise ValueError('Unknown/unhandled field type {gtype}')
  161. def _get_tensor(self, orig_offs: int) -> ReaderField:
  162. offs = orig_offs
  163. name_len, name_data = self._get_str(offs)
  164. offs += int(name_len.nbytes + name_data.nbytes)
  165. n_dims = self._get(offs, np.uint32)
  166. offs += int(n_dims.nbytes)
  167. dims = self._get(offs, np.uint64, n_dims[0])
  168. offs += int(dims.nbytes)
  169. raw_dtype = self._get(offs, np.uint32)
  170. offs += int(raw_dtype.nbytes)
  171. offset_tensor = self._get(offs, np.uint64)
  172. offs += int(offset_tensor.nbytes)
  173. return ReaderField(
  174. orig_offs,
  175. str(bytes(name_data), encoding = 'utf-8'),
  176. [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor],
  177. [1, 3, 4, 5],
  178. )
  179. def _build_fields(self, offs: int, count: int) -> int:
  180. for _ in range(count):
  181. orig_offs = offs
  182. kv_klen, kv_kdata = self._get_str(offs)
  183. offs += int(kv_klen.nbytes + kv_kdata.nbytes)
  184. raw_kv_type = self._get(offs, np.uint32)
  185. offs += int(raw_kv_type.nbytes)
  186. parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type]
  187. idxs_offs = len(parts)
  188. field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0])
  189. parts += field_parts
  190. self._push_field(ReaderField(
  191. orig_offs,
  192. str(bytes(kv_kdata), encoding = 'utf-8'),
  193. parts,
  194. [idx + idxs_offs for idx in field_idxs],
  195. field_types,
  196. ), skip_sum = True)
  197. offs += field_size
  198. return offs
  199. def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]:
  200. tensor_fields = []
  201. for _ in range(count):
  202. field = self._get_tensor(offs)
  203. offs += sum(int(part.nbytes) for part in field.parts)
  204. tensor_fields.append(field)
  205. return offs, tensor_fields
  206. def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None:
  207. tensors = []
  208. for field in fields:
  209. _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts
  210. ggml_type = GGMLQuantizationType(raw_dtype[0])
  211. n_elems = np.prod(dims)
  212. block_size, type_size = GGML_QUANT_SIZES[ggml_type]
  213. n_bytes = n_elems * type_size // block_size
  214. data_offs = int(start_offs + offset_tensor[0])
  215. item_type: npt.DTypeLike
  216. if ggml_type == GGMLQuantizationType.F16:
  217. item_count = n_elems
  218. item_type = np.float16
  219. elif ggml_type == GGMLQuantizationType.F32:
  220. item_count = n_elems
  221. item_type = np.float32
  222. elif ggml_type == GGMLQuantizationType.F64:
  223. item_count = n_elems
  224. item_type = np.float64
  225. elif ggml_type == GGMLQuantizationType.I8:
  226. item_count = n_elems
  227. item_type = np.int8
  228. elif ggml_type == GGMLQuantizationType.I16:
  229. item_count = n_elems
  230. item_type = np.int16
  231. elif ggml_type == GGMLQuantizationType.I32:
  232. item_count = n_elems
  233. item_type = np.int32
  234. elif ggml_type == GGMLQuantizationType.I64:
  235. item_count = n_elems
  236. item_type = np.int64
  237. else:
  238. item_count = n_bytes
  239. item_type = np.uint8
  240. tensors.append(ReaderTensor(
  241. name = str(bytes(name_data), encoding = 'utf-8'),
  242. tensor_type = ggml_type,
  243. shape = dims,
  244. n_elements = n_elems,
  245. n_bytes = n_bytes,
  246. data_offset = data_offs,
  247. data = self._get(data_offs, item_type, item_count),
  248. field = field,
  249. ))
  250. self.tensors = tensors