parameter.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. from typing import Callable, List, Optional, Tuple, Union
  2. import torch
  3. from torch.nn import Parameter
  4. from aphrodite.distributed import get_tensor_model_parallel_rank
  5. __all__ = [
  6. "BaseAphroditeParameter", "PackedAphroditeParameter",
  7. "PerTensorScaleParameter", "ModelWeightParameter",
  8. "ChannelQuantScaleParameter", "GroupQuantScaleParameter",
  9. "PackedColumnParameter", "RowAphroditeParameter"
  10. ]
  11. class BaseAphroditeParameter(Parameter):
  12. """
  13. Base parameter for Aphrodite linear layers. Extends the torch.nn.parameter
  14. by taking in a linear weight loader. Will copy the loaded weight
  15. into the parameter when the provided weight loader is called.
  16. """
  17. def __new__(cls, data: torch.Tensor, **kwargs):
  18. return super().__new__(cls, data=data, requires_grad=False)
  19. def __init__(self, data: torch.Tensor, weight_loader: Callable):
  20. """
  21. Initialize the BaseAphroditeParameter
  22. :param data: torch tensor with the parameter data
  23. :param weight_loader: weight loader callable
  24. :returns: a torch.nn.parameter
  25. """
  26. self._weight_loader = weight_loader
  27. @property
  28. def weight_loader(self):
  29. return self._weight_loader
  30. def _assert_and_load(self, loaded_weight: torch.Tensor):
  31. assert self.data.shape == loaded_weight.shape
  32. self.data.copy_(loaded_weight)
  33. def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
  34. self._assert_and_load(loaded_weight)
  35. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  36. self._assert_and_load(loaded_weight)
  37. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  38. self._assert_and_load(loaded_weight)
  39. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  40. self._assert_and_load(loaded_weight)
  41. class _ColumnAphroditeParameter(BaseAphroditeParameter):
  42. """
  43. Private class defining weight loading functionality
  44. (load_merged_column_weight, load_qkv_weight)
  45. for parameters being loaded into linear layers with column
  46. parallelism. This includes QKV and MLP layers which are
  47. not already fused on disk. Requires an output dimension
  48. to be defined. Called within the weight loader of
  49. each of the column parallel linear layers.
  50. """
  51. def __init__(self, output_dim: int, **kwargs):
  52. self._output_dim = output_dim
  53. super().__init__(**kwargs)
  54. @property
  55. def output_dim(self):
  56. return self._output_dim
  57. def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
  58. tp_rank = get_tensor_model_parallel_rank()
  59. shard_size = self.data.shape[self.output_dim]
  60. loaded_weight = loaded_weight.narrow(self.output_dim,
  61. tp_rank * shard_size, shard_size)
  62. assert self.data.shape == loaded_weight.shape
  63. self.data.copy_(loaded_weight)
  64. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  65. shard_offset = kwargs.get("shard_offset")
  66. shard_size = kwargs.get("shard_size")
  67. if isinstance(
  68. self,
  69. (PackedColumnParameter,
  70. PackedAphroditeParameter)) and self.packed_dim == self.output_dim:
  71. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  72. shard_offset=shard_offset, shard_size=shard_size)
  73. param_data = self.data
  74. tp_rank = get_tensor_model_parallel_rank()
  75. param_data = param_data.narrow(self.output_dim, shard_offset,
  76. shard_size)
  77. loaded_weight = loaded_weight.narrow(self.output_dim,
  78. tp_rank * shard_size, shard_size)
  79. assert param_data.shape == loaded_weight.shape
  80. param_data.copy_(loaded_weight)
  81. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  82. shard_offset = kwargs.get("shard_offset")
  83. shard_size = kwargs.get("shard_size")
  84. shard_id = kwargs.get("shard_id")
  85. num_heads = kwargs.get("num_heads")
  86. if isinstance(
  87. self,
  88. (PackedColumnParameter,
  89. PackedAphroditeParameter)) and self.output_dim == self.packed_dim:
  90. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  91. shard_offset=shard_offset, shard_size=shard_size)
  92. param_data = self.data
  93. tp_rank = get_tensor_model_parallel_rank()
  94. shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
  95. param_data = param_data.narrow(self.output_dim, shard_offset,
  96. shard_size)
  97. loaded_weight = loaded_weight.narrow(self.output_dim,
  98. shard_id * shard_size, shard_size)
  99. assert param_data.shape == loaded_weight.shape
  100. param_data.copy_(loaded_weight)
  101. class RowAphroditeParameter(BaseAphroditeParameter):
  102. """
  103. Parameter class defining weight_loading functionality
  104. (load_row_parallel_weight) for parameters being loaded
  105. into linear layers with row parallel functionality.
  106. Requires an input_dim to be defined.
  107. """
  108. def __init__(self, input_dim: int, **kwargs):
  109. self._input_dim = input_dim
  110. super().__init__(**kwargs)
  111. @property
  112. def input_dim(self):
  113. return self._input_dim
  114. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  115. tp_rank = get_tensor_model_parallel_rank()
  116. shard_size = self.data.shape[self.input_dim]
  117. loaded_weight = loaded_weight.narrow(self.input_dim,
  118. tp_rank * shard_size, shard_size)
  119. if len(loaded_weight.shape) == 0:
  120. loaded_weight = loaded_weight.reshape(1)
  121. assert self.data.shape == loaded_weight.shape
  122. self.data.copy_(loaded_weight)
  123. class ModelWeightParameter(_ColumnAphroditeParameter, RowAphroditeParameter):
  124. """
  125. Parameter class for linear layer weights. Uses both column and
  126. row parallelism.
  127. """
  128. pass
  129. class GroupQuantScaleParameter(_ColumnAphroditeParameter,
  130. RowAphroditeParameter):
  131. """
  132. Parameter class for weight scales loaded for weights with
  133. grouped quantization. Uses both column and row parallelism.
  134. """
  135. pass
  136. class ChannelQuantScaleParameter(_ColumnAphroditeParameter):
  137. """
  138. Parameter class for weight scales loaded for weights with
  139. channel-wise quantization. Equivalent to _ColumnAphroditeParameter.
  140. """
  141. pass
  142. class PerTensorScaleParameter(BaseAphroditeParameter):
  143. """
  144. Parameter class for scales where the number of scales is
  145. equivalent to the number of logical matrices in fused linear
  146. layers (e.g. for QKV, there are 3 scales loaded from disk).
  147. This is relevant to weights with per-tensor quantization.
  148. Adds functionality to map the scalers to a shard during
  149. weight loading.
  150. Note: additional parameter manipulation may be handled
  151. for each quantization config specifically, within
  152. process_weights_after_loading
  153. """
  154. def __init__(self, **kwargs):
  155. self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
  156. super().__init__(**kwargs)
  157. def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
  158. if isinstance(shard_id, int):
  159. return shard_id
  160. # if not int, assume shard_id for qkv
  161. # map to int and return
  162. assert isinstance(shard_id, str)
  163. assert shard_id in self.qkv_idxs
  164. return self.qkv_idxs[shard_id]
  165. # For row parallel layers, no sharding needed
  166. # load weight into parameter as is
  167. def load_row_parallel_weight(self, *args, **kwargs):
  168. super().load_row_parallel_weight(*args, **kwargs)
  169. def load_merged_column_weight(self, *args, **kwargs):
  170. self._load_into_shard_id(*args, **kwargs)
  171. def load_qkv_weight(self, *args, **kwargs):
  172. self._load_into_shard_id(*args, **kwargs)
  173. def load_column_parallel_weight(self, *args, **kwargs):
  174. super().load_row_parallel_weight(*args, **kwargs)
  175. def _load_into_shard_id(self, loaded_weight: torch.Tensor,
  176. shard_id: Union[str, int], **kwargs):
  177. """
  178. Slice the parameter data based on the shard id for
  179. loading.
  180. """
  181. param_data = self.data
  182. shard_id = self._shard_id_as_int(shard_id)
  183. # AutoFP8 scales do not have a shape
  184. # compressed-tensors scales do have a shape
  185. if len(loaded_weight.shape) != 0:
  186. assert loaded_weight.shape[0] == 1
  187. loaded_weight = loaded_weight[0]
  188. param_data = param_data[shard_id]
  189. assert param_data.shape == loaded_weight.shape
  190. param_data.copy_(loaded_weight)
  191. class PackedColumnParameter(_ColumnAphroditeParameter):
  192. """
  193. Parameter for model parameters which are packed on disk
  194. and support column parallelism only. See PackedAphroditeParameter
  195. for more details on the packed properties.
  196. """
  197. def __init__(self,
  198. packed_factor: int,
  199. packed_dim: int,
  200. marlin_tile_size: Optional[int] = None,
  201. **kwargs):
  202. self._packed_factor = packed_factor
  203. self._packed_dim = packed_dim
  204. self._marlin_tile_size = marlin_tile_size
  205. super().__init__(**kwargs)
  206. @property
  207. def packed_dim(self):
  208. return self._packed_dim
  209. @property
  210. def packed_factor(self):
  211. return self._packed_factor
  212. @property
  213. def marlin_tile_size(self):
  214. return self._marlin_tile_size
  215. def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
  216. return _adjust_shard_indexes_for_packing(
  217. shard_size=shard_size,
  218. shard_offset=shard_offset,
  219. packed_factor=self.packed_factor,
  220. marlin_tile_size=self.marlin_tile_size)
  221. class PackedAphroditeParameter(ModelWeightParameter):
  222. """
  223. Parameter for model weights which are packed on disk.
  224. Example: GPTQ Marlin weights are int4 or int8, packed into int32.
  225. Extends the ModelWeightParameter to take in the
  226. packed factor, the packed dimension, and optionally, marlin
  227. tile size for marlin kernels. Adjusts the shard_size and
  228. shard_offset for fused linear layers model weight loading
  229. by accounting for packing and optionally, marlin tile size.
  230. """
  231. def __init__(self,
  232. packed_factor: int,
  233. packed_dim: int,
  234. marlin_tile_size: Optional[int] = None,
  235. **kwargs):
  236. self._packed_factor = packed_factor
  237. self._packed_dim = packed_dim
  238. self._marlin_tile_size = marlin_tile_size
  239. super().__init__(**kwargs)
  240. @property
  241. def packed_dim(self):
  242. return self._packed_dim
  243. @property
  244. def packed_factor(self):
  245. return self._packed_factor
  246. @property
  247. def marlin_tile_size(self):
  248. return self._marlin_tile_size
  249. def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
  250. return _adjust_shard_indexes_for_packing(
  251. shard_size=shard_size,
  252. shard_offset=shard_offset,
  253. packed_factor=self.packed_factor,
  254. marlin_tile_size=self.marlin_tile_size)
  255. def permute_param_layout_(param: BaseAphroditeParameter, input_dim: int,
  256. output_dim: int, **kwargs) -> BaseAphroditeParameter:
  257. """
  258. Permute a parameter's layout to the specified input and output dimensions,
  259. useful for forcing the parameter into a known layout, for example, if I need
  260. a packed (quantized) weight matrix to be in the layout
  261. {input_dim = 0, output_dim = 1, packed_dim = 0}
  262. then I can call:
  263. permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
  264. to ensure x is in the correct layout (permuting it to the correct layout if
  265. required, asserting if it cannot get it to the correct layout)
  266. """
  267. curr_input_dim = getattr(param, "input_dim", None)
  268. curr_output_dim = getattr(param, "output_dim", None)
  269. if curr_input_dim is None or curr_output_dim is None:
  270. assert param.data.dim() == 2,\
  271. "permute_param_layout_ only supports 2D parameters when either "\
  272. "input_dim or output_dim is not set"
  273. # if one of the dimensions is not set, set it to the opposite of the other
  274. # we can only do this since we asserted the parameter is 2D above
  275. if curr_input_dim is None:
  276. assert curr_output_dim is not None,\
  277. "either input or output dim must be set"
  278. curr_input_dim = (curr_output_dim + 1) % 2
  279. if curr_output_dim is None:
  280. assert curr_input_dim is not None,\
  281. "either input or output dim must be set"
  282. curr_output_dim = (curr_input_dim + 1) % 2
  283. # create permutation from the current layout to the layout with
  284. # self.input_dim at input_dim and self.output_dim at output_dim preserving
  285. # other dimensions
  286. perm = [
  287. i for i in range(param.data.dim())
  288. if i not in [curr_input_dim, curr_output_dim]
  289. ]
  290. perm.insert(input_dim, curr_input_dim)
  291. perm.insert(output_dim, curr_output_dim)
  292. if "packed_dim" in kwargs:
  293. assert hasattr(param, "packed_dim") and\
  294. param.packed_dim == perm[kwargs["packed_dim"]],\
  295. "permute_param_layout_ currently doesn't support repacking"
  296. param.data = param.data.permute(*perm)
  297. if hasattr(param, "_input_dim"):
  298. param._input_dim = input_dim
  299. if hasattr(param, "_output_dim"):
  300. param._output_dim = output_dim
  301. if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
  302. param._packed_dim = kwargs["packed_dim"]
  303. return param
  304. def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
  305. marlin_tile_size):
  306. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  307. def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
  308. marlin_tile_size):
  309. shard_size = shard_size // packed_factor
  310. shard_offset = shard_offset // packed_factor
  311. if marlin_tile_size is not None:
  312. return _adjust_shard_indexes_for_marlin(
  313. shard_size=shard_size,
  314. shard_offset=shard_offset,
  315. marlin_tile_size=marlin_tile_size)
  316. return shard_size, shard_offset
  317. # Qweights in HQQ need to be reshaped such that the shape of the stored tensors
  318. # is the actual shape used in weight multiplication. This is needed to correctly
  319. # repack to Marlin. We also store shard size and offsets in order to be able to
  320. # correctly unpack (shard by shard) from 4-bit to 8-bit.
  321. class HQQQweightParameter(PackedAphroditeParameter):
  322. def __init__(self, packed_factor: int, packed_dim: int, **kwargs):
  323. super().__init__(packed_factor, packed_dim, None, **kwargs)
  324. self.shard_offsets: List[Tuple[int, int]] = []
  325. self.pack_factor = packed_factor
  326. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  327. shard_offset = kwargs.get("shard_offset")
  328. shard_size = kwargs.get("shard_size")
  329. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  330. shard_offset=shard_offset, shard_size=shard_size)
  331. self.shard_offsets.append((shard_offset, shard_size))
  332. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  333. super().load_merged_column_weight(loaded_weight, **kwargs)
  334. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  335. self.shard_offsets.append((0, self.shape[self.output_dim]))
  336. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  337. super().load_row_parallel_weight(loaded_weight)
  338. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  339. shard_offset = kwargs.get("shard_offset")
  340. shard_size = kwargs.get("shard_size")
  341. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  342. shard_offset=shard_offset, shard_size=shard_size)
  343. self.shard_offsets.append((shard_offset, shard_size))
  344. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  345. super().load_qkv_weight(loaded_weight, **kwargs)
  346. # Zero points and scales in HQQ must also be reshaped to their actual shapes.
  347. class HQQZeroScaleParameter(GroupQuantScaleParameter):
  348. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  349. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  350. super().load_merged_column_weight(loaded_weight, **kwargs)
  351. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  352. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  353. super().load_row_parallel_weight(loaded_weight)
  354. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  355. loaded_weight = loaded_weight.reshape(-1, self.shape[1])
  356. super().load_qkv_weight(loaded_weight, **kwargs)