parameter.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. from fractions import Fraction
  2. from typing import Callable, List, Optional, Tuple, Union
  3. import torch
  4. from torch.nn import Parameter
  5. from aphrodite.distributed import get_tensor_model_parallel_rank
  6. __all__ = [
  7. "BaseAphroditeParameter", "PackedAphroditeParameter",
  8. "PerTensorScaleParameter", "ModelWeightParameter",
  9. "ChannelQuantScaleParameter", "GroupQuantScaleParameter",
  10. "PackedColumnParameter", "RowAphroditeParameter"
  11. ]
  12. class BaseAphroditeParameter(Parameter):
  13. """
  14. Base parameter for Aphrodite linear layers. Extends the torch.nn.parameter
  15. by taking in a linear weight loader. Will copy the loaded weight
  16. into the parameter when the provided weight loader is called.
  17. """
  18. def __new__(cls, data: torch.Tensor, **kwargs):
  19. return super().__new__(cls, data=data, requires_grad=False)
  20. def __init__(self, data: torch.Tensor, weight_loader: Callable):
  21. """
  22. Initialize the BaseAphroditeParameter
  23. :param data: torch tensor with the parameter data
  24. :param weight_loader: weight loader callable
  25. :returns: a torch.nn.parameter
  26. """
  27. self._weight_loader = weight_loader
  28. @property
  29. def weight_loader(self):
  30. return self._weight_loader
  31. def _assert_and_load(self, loaded_weight: torch.Tensor):
  32. assert self.data.shape == loaded_weight.shape
  33. self.data.copy_(loaded_weight)
  34. def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
  35. self._assert_and_load(loaded_weight)
  36. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  37. self._assert_and_load(loaded_weight)
  38. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  39. self._assert_and_load(loaded_weight)
  40. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  41. self._assert_and_load(loaded_weight)
  42. class _ColumnAphroditeParameter(BaseAphroditeParameter):
  43. """
  44. Private class defining weight loading functionality
  45. (load_merged_column_weight, load_qkv_weight)
  46. for parameters being loaded into linear layers with column
  47. parallelism. This includes QKV and MLP layers which are
  48. not already fused on disk. Requires an output dimension
  49. to be defined. Called within the weight loader of
  50. each of the column parallel linear layers.
  51. """
  52. def __init__(self, output_dim: int, **kwargs):
  53. self._output_dim = output_dim
  54. super().__init__(**kwargs)
  55. @property
  56. def output_dim(self):
  57. return self._output_dim
  58. def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
  59. tp_rank = get_tensor_model_parallel_rank()
  60. shard_size = self.data.shape[self.output_dim]
  61. loaded_weight = loaded_weight.narrow(self.output_dim,
  62. tp_rank * shard_size, shard_size)
  63. assert self.data.shape == loaded_weight.shape
  64. self.data.copy_(loaded_weight)
  65. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  66. shard_offset = kwargs.get("shard_offset")
  67. shard_size = kwargs.get("shard_size")
  68. if isinstance(
  69. self,
  70. (PackedColumnParameter,
  71. PackedAphroditeParameter)) and self.packed_dim == self.output_dim:
  72. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  73. shard_offset=shard_offset, shard_size=shard_size)
  74. param_data = self.data
  75. tp_rank = get_tensor_model_parallel_rank()
  76. param_data = param_data.narrow(self.output_dim, shard_offset,
  77. shard_size)
  78. loaded_weight = loaded_weight.narrow(self.output_dim,
  79. tp_rank * shard_size, shard_size)
  80. assert param_data.shape == loaded_weight.shape
  81. param_data.copy_(loaded_weight)
  82. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  83. shard_offset = kwargs.get("shard_offset")
  84. shard_size = kwargs.get("shard_size")
  85. shard_id = kwargs.get("shard_id")
  86. num_heads = kwargs.get("num_heads")
  87. if isinstance(
  88. self,
  89. (PackedColumnParameter,
  90. PackedAphroditeParameter)) and self.output_dim == self.packed_dim:
  91. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  92. shard_offset=shard_offset, shard_size=shard_size)
  93. param_data = self.data
  94. tp_rank = get_tensor_model_parallel_rank()
  95. shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
  96. param_data = param_data.narrow(self.output_dim, shard_offset,
  97. shard_size)
  98. loaded_weight = loaded_weight.narrow(self.output_dim,
  99. shard_id * shard_size, shard_size)
  100. assert param_data.shape == loaded_weight.shape
  101. param_data.copy_(loaded_weight)
  102. class RowAphroditeParameter(BaseAphroditeParameter):
  103. """
  104. Parameter class defining weight_loading functionality
  105. (load_row_parallel_weight) for parameters being loaded
  106. into linear layers with row parallel functionality.
  107. Requires an input_dim to be defined.
  108. """
  109. def __init__(self, input_dim: int, **kwargs):
  110. self._input_dim = input_dim
  111. super().__init__(**kwargs)
  112. @property
  113. def input_dim(self):
  114. return self._input_dim
  115. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  116. tp_rank = get_tensor_model_parallel_rank()
  117. shard_size = self.data.shape[self.input_dim]
  118. loaded_weight = loaded_weight.narrow(self.input_dim,
  119. tp_rank * shard_size, shard_size)
  120. if len(loaded_weight.shape) == 0:
  121. loaded_weight = loaded_weight.reshape(1)
  122. assert self.data.shape == loaded_weight.shape
  123. self.data.copy_(loaded_weight)
  124. class ModelWeightParameter(_ColumnAphroditeParameter, RowAphroditeParameter):
  125. """
  126. Parameter class for linear layer weights. Uses both column and
  127. row parallelism.
  128. """
  129. pass
  130. class GroupQuantScaleParameter(_ColumnAphroditeParameter,
  131. RowAphroditeParameter):
  132. """
  133. Parameter class for weight scales loaded for weights with
  134. grouped quantization. Uses both column and row parallelism.
  135. """
  136. pass
  137. class ChannelQuantScaleParameter(_ColumnAphroditeParameter):
  138. """
  139. Parameter class for weight scales loaded for weights with
  140. channel-wise quantization. Equivalent to _ColumnAphroditeParameter.
  141. """
  142. pass
  143. class PerTensorScaleParameter(BaseAphroditeParameter):
  144. """
  145. Parameter class for scales where the number of scales is
  146. equivalent to the number of logical matrices in fused linear
  147. layers (e.g. for QKV, there are 3 scales loaded from disk).
  148. This is relevant to weights with per-tensor quantization.
  149. Adds functionality to map the scalers to a shard during
  150. weight loading.
  151. Note: additional parameter manipulation may be handled
  152. for each quantization config specifically, within
  153. process_weights_after_loading
  154. """
  155. def __init__(self, **kwargs):
  156. self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
  157. super().__init__(**kwargs)
  158. def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
  159. if isinstance(shard_id, int):
  160. return shard_id
  161. # if not int, assume shard_id for qkv
  162. # map to int and return
  163. assert isinstance(shard_id, str)
  164. assert shard_id in self.qkv_idxs
  165. return self.qkv_idxs[shard_id]
  166. # For row parallel layers, no sharding needed
  167. # load weight into parameter as is
  168. def load_row_parallel_weight(self, *args, **kwargs):
  169. super().load_row_parallel_weight(*args, **kwargs)
  170. def load_merged_column_weight(self, *args, **kwargs):
  171. self._load_into_shard_id(*args, **kwargs)
  172. def load_qkv_weight(self, *args, **kwargs):
  173. self._load_into_shard_id(*args, **kwargs)
  174. def load_column_parallel_weight(self, *args, **kwargs):
  175. super().load_row_parallel_weight(*args, **kwargs)
  176. def _load_into_shard_id(self, loaded_weight: torch.Tensor,
  177. shard_id: Union[str, int], **kwargs):
  178. """
  179. Slice the parameter data based on the shard id for
  180. loading.
  181. """
  182. param_data = self.data
  183. shard_id = self._shard_id_as_int(shard_id)
  184. # AutoFP8 scales do not have a shape
  185. # compressed-tensors scales do have a shape
  186. if len(loaded_weight.shape) != 0:
  187. assert loaded_weight.shape[0] == 1
  188. loaded_weight = loaded_weight[0]
  189. param_data = param_data[shard_id]
  190. assert param_data.shape == loaded_weight.shape
  191. param_data.copy_(loaded_weight)
  192. class PackedColumnParameter(_ColumnAphroditeParameter):
  193. """
  194. Parameter for model parameters which are packed on disk
  195. and support column parallelism only. See PackedAphroditeParameter
  196. for more details on the packed properties.
  197. """
  198. def __init__(self,
  199. packed_factor: Union[int, Fraction],
  200. packed_dim: int,
  201. marlin_tile_size: Optional[int] = None,
  202. **kwargs):
  203. self._packed_factor = packed_factor
  204. self._packed_dim = packed_dim
  205. self._marlin_tile_size = marlin_tile_size
  206. super().__init__(**kwargs)
  207. @property
  208. def packed_dim(self):
  209. return self._packed_dim
  210. @property
  211. def packed_factor(self):
  212. return self._packed_factor
  213. @property
  214. def marlin_tile_size(self):
  215. return self._marlin_tile_size
  216. def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
  217. return _adjust_shard_indexes_for_packing(
  218. shard_size=shard_size,
  219. shard_offset=shard_offset,
  220. packed_factor=self.packed_factor,
  221. marlin_tile_size=self.marlin_tile_size)
  222. class PackedAphroditeParameter(ModelWeightParameter):
  223. """
  224. Parameter for model weights which are packed on disk.
  225. Example: GPTQ Marlin weights are int4 or int8, packed into int32.
  226. Extends the ModelWeightParameter to take in the
  227. packed factor, the packed dimension, and optionally, marlin
  228. tile size for marlin kernels. Adjusts the shard_size and
  229. shard_offset for fused linear layers model weight loading
  230. by accounting for packing and optionally, marlin tile size.
  231. """
  232. def __init__(self,
  233. packed_factor: Union[int, Fraction],
  234. packed_dim: int,
  235. marlin_tile_size: Optional[int] = None,
  236. **kwargs):
  237. self._packed_factor = packed_factor
  238. self._packed_dim = packed_dim
  239. self._marlin_tile_size = marlin_tile_size
  240. super().__init__(**kwargs)
  241. @property
  242. def packed_dim(self):
  243. return self._packed_dim
  244. @property
  245. def packed_factor(self):
  246. return self._packed_factor
  247. @property
  248. def marlin_tile_size(self):
  249. return self._marlin_tile_size
  250. def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
  251. return _adjust_shard_indexes_for_packing(
  252. shard_size=shard_size,
  253. shard_offset=shard_offset,
  254. packed_factor=self.packed_factor,
  255. marlin_tile_size=self.marlin_tile_size)
  256. def permute_param_layout_(param: BaseAphroditeParameter, input_dim: int,
  257. output_dim: int, **kwargs) -> BaseAphroditeParameter:
  258. """
  259. Permute a parameter's layout to the specified input and output dimensions,
  260. useful for forcing the parameter into a known layout, for example, if I need
  261. a packed (quantized) weight matrix to be in the layout
  262. {input_dim = 0, output_dim = 1, packed_dim = 0}
  263. then I can call:
  264. permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
  265. to ensure x is in the correct layout (permuting it to the correct layout if
  266. required, asserting if it cannot get it to the correct layout)
  267. """
  268. curr_input_dim = getattr(param, "input_dim", None)
  269. curr_output_dim = getattr(param, "output_dim", None)
  270. if curr_input_dim is None or curr_output_dim is None:
  271. assert param.data.dim() == 2,\
  272. "permute_param_layout_ only supports 2D parameters when either "\
  273. "input_dim or output_dim is not set"
  274. # if one of the dimensions is not set, set it to the opposite of the other
  275. # we can only do this since we asserted the parameter is 2D above
  276. if curr_input_dim is None:
  277. assert curr_output_dim is not None,\
  278. "either input or output dim must be set"
  279. curr_input_dim = (curr_output_dim + 1) % 2
  280. if curr_output_dim is None:
  281. assert curr_input_dim is not None,\
  282. "either input or output dim must be set"
  283. curr_output_dim = (curr_input_dim + 1) % 2
  284. # create permutation from the current layout to the layout with
  285. # self.input_dim at input_dim and self.output_dim at output_dim preserving
  286. # other dimensions
  287. perm = [
  288. i for i in range(param.data.dim())
  289. if i not in [curr_input_dim, curr_output_dim]
  290. ]
  291. perm.insert(input_dim, curr_input_dim)
  292. perm.insert(output_dim, curr_output_dim)
  293. if "packed_dim" in kwargs:
  294. assert hasattr(param, "packed_dim") and\
  295. param.packed_dim == perm[kwargs["packed_dim"]],\
  296. "permute_param_layout_ currently doesn't support repacking"
  297. param.data = param.data.permute(*perm)
  298. if hasattr(param, "_input_dim"):
  299. param._input_dim = input_dim
  300. if hasattr(param, "_output_dim"):
  301. param._output_dim = output_dim
  302. if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
  303. param._packed_dim = kwargs["packed_dim"]
  304. return param
  305. def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
  306. marlin_tile_size):
  307. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  308. def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
  309. marlin_tile_size):
  310. shard_size = shard_size // packed_factor
  311. shard_offset = shard_offset // packed_factor
  312. if marlin_tile_size is not None:
  313. return _adjust_shard_indexes_for_marlin(
  314. shard_size=shard_size,
  315. shard_offset=shard_offset,
  316. marlin_tile_size=marlin_tile_size)
  317. return shard_size, shard_offset
  318. # Qweights in HQQ need to be reshaped such that the shape of the stored tensors
  319. # is the actual shape used in weight multiplication. This is needed to correctly
  320. # repack to Marlin. We also store shard size and offsets in order to be able to
  321. # correctly unpack (shard by shard) from 4-bit to 8-bit.
  322. class HQQQweightParameter(PackedAphroditeParameter):
  323. def __init__(self, packed_factor: int, packed_dim: int, **kwargs):
  324. super().__init__(packed_factor, packed_dim, None, **kwargs)
  325. self.shard_offsets: List[Tuple[int, int]] = []
  326. self.pack_factor = packed_factor
  327. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  328. shard_offset = kwargs.get("shard_offset")
  329. shard_size = kwargs.get("shard_size")
  330. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  331. shard_offset=shard_offset, shard_size=shard_size)
  332. self.shard_offsets.append((shard_offset, shard_size))
  333. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  334. super().load_merged_column_weight(loaded_weight, **kwargs)
  335. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  336. self.shard_offsets.append((0, self.shape[self.output_dim]))
  337. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  338. super().load_row_parallel_weight(loaded_weight)
  339. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  340. shard_offset = kwargs.get("shard_offset")
  341. shard_size = kwargs.get("shard_size")
  342. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  343. shard_offset=shard_offset, shard_size=shard_size)
  344. self.shard_offsets.append((shard_offset, shard_size))
  345. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  346. super().load_qkv_weight(loaded_weight, **kwargs)
  347. # Zero points and scales in HQQ must also be reshaped to their actual shapes.
  348. class HQQZeroScaleParameter(GroupQuantScaleParameter):
  349. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  350. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  351. super().load_merged_column_weight(loaded_weight, **kwargs)
  352. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  353. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  354. super().load_row_parallel_weight(loaded_weight)
  355. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  356. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  357. super().load_qkv_weight(loaded_weight, **kwargs)