parameter.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. from typing import Callable, Optional, Union
  2. import torch
  3. from torch.nn import Parameter
  4. from aphrodite.distributed import get_tensor_model_parallel_rank
  5. __all__ = [
  6. "BaseAphroditeParameter", "PackedAphroditeParameter",
  7. "PerTensorScaleParameter", "ModelWeightParameter",
  8. "ChannelQuantScaleParameter", "GroupQuantScaleParameter",
  9. "PackedColumnParameter", "RowAphroditeParameter"
  10. ]
  11. class BaseAphroditeParameter(Parameter):
  12. """
  13. Base parameter for Aphrodite linear layers. Extends the torch.nn.parameter
  14. by taking in a linear weight loader. Will copy the loaded weight
  15. into the parameter when the provided weight loader is called.
  16. """
  17. def __new__(cls, data: torch.Tensor, **kwargs):
  18. return super().__new__(cls, data=data, requires_grad=False)
  19. def __init__(self, data: torch.Tensor, weight_loader: Callable):
  20. """
  21. Initialize the BaseAphroditeParameter
  22. :param data: torch tensor with the parameter data
  23. :param weight_loader: weight loader callable
  24. :returns: a torch.nn.parameter
  25. """
  26. self._weight_loader = weight_loader
  27. @property
  28. def weight_loader(self):
  29. return self._weight_loader
  30. def _assert_and_load(self, loaded_weight: torch.Tensor):
  31. assert self.data.shape == loaded_weight.shape
  32. self.data.copy_(loaded_weight)
  33. def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
  34. self._assert_and_load(loaded_weight)
  35. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  36. self._assert_and_load(loaded_weight)
  37. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  38. self._assert_and_load(loaded_weight)
  39. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  40. self._assert_and_load(loaded_weight)
  41. class _ColumnAphroditeParameter(BaseAphroditeParameter):
  42. """
  43. Private class defining weight loading functionality
  44. (load_merged_column_weight, load_qkv_weight)
  45. for parameters being loaded into linear layers with column
  46. parallelism. This includes QKV and MLP layers which are
  47. not already fused on disk. Requires an output dimension
  48. to be defined. Called within the weight loader of
  49. each of the column parallel linear layers.
  50. """
  51. def __init__(self, output_dim: int, **kwargs):
  52. self._output_dim = output_dim
  53. super().__init__(**kwargs)
  54. @property
  55. def output_dim(self):
  56. return self._output_dim
  57. def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
  58. tp_rank = get_tensor_model_parallel_rank()
  59. shard_size = self.data.shape[self.output_dim]
  60. loaded_weight = loaded_weight.narrow(self.output_dim,
  61. tp_rank * shard_size, shard_size)
  62. assert self.data.shape == loaded_weight.shape
  63. self.data.copy_(loaded_weight)
  64. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  65. shard_offset = kwargs.get("shard_offset")
  66. shard_size = kwargs.get("shard_size")
  67. if isinstance(
  68. self,
  69. (PackedColumnParameter,
  70. PackedAphroditeParameter)) and self.packed_dim == self.output_dim:
  71. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  72. shard_offset=shard_offset, shard_size=shard_size)
  73. param_data = self.data
  74. tp_rank = get_tensor_model_parallel_rank()
  75. param_data = param_data.narrow(self.output_dim, shard_offset,
  76. shard_size)
  77. loaded_weight = loaded_weight.narrow(self.output_dim,
  78. tp_rank * shard_size, shard_size)
  79. assert param_data.shape == loaded_weight.shape
  80. param_data.copy_(loaded_weight)
  81. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  82. shard_offset = kwargs.get("shard_offset")
  83. shard_size = kwargs.get("shard_size")
  84. shard_id = kwargs.get("shard_id")
  85. num_heads = kwargs.get("num_heads")
  86. if isinstance(
  87. self,
  88. (PackedColumnParameter,
  89. PackedAphroditeParameter)) and self.output_dim == self.packed_dim:
  90. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  91. shard_offset=shard_offset, shard_size=shard_size)
  92. param_data = self.data
  93. tp_rank = get_tensor_model_parallel_rank()
  94. shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
  95. param_data = param_data.narrow(self.output_dim, shard_offset,
  96. shard_size)
  97. loaded_weight = loaded_weight.narrow(self.output_dim,
  98. shard_id * shard_size, shard_size)
  99. assert param_data.shape == loaded_weight.shape
  100. param_data.copy_(loaded_weight)
  101. class RowAphroditeParameter(BaseAphroditeParameter):
  102. """
  103. Parameter class defining weight_loading functionality
  104. (load_row_parallel_weight) for parameters being loaded
  105. into linear layers with row parallel functionality.
  106. Requires an input_dim to be defined.
  107. """
  108. def __init__(self, input_dim: int, **kwargs):
  109. self._input_dim = input_dim
  110. super().__init__(**kwargs)
  111. @property
  112. def input_dim(self):
  113. return self._input_dim
  114. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  115. tp_rank = get_tensor_model_parallel_rank()
  116. shard_size = self.data.shape[self.input_dim]
  117. loaded_weight = loaded_weight.narrow(self.input_dim,
  118. tp_rank * shard_size, shard_size)
  119. if len(loaded_weight.shape) == 0:
  120. loaded_weight = loaded_weight.reshape(1)
  121. assert self.data.shape == loaded_weight.shape
  122. self.data.copy_(loaded_weight)
  123. class ModelWeightParameter(_ColumnAphroditeParameter, RowAphroditeParameter):
  124. """
  125. Parameter class for linear layer weights. Uses both column and
  126. row parallelism.
  127. """
  128. pass
  129. class GroupQuantScaleParameter(_ColumnAphroditeParameter,
  130. RowAphroditeParameter):
  131. """
  132. Parameter class for weight scales loaded for weights with
  133. grouped quantization. Uses both column and row parallelism.
  134. """
  135. pass
  136. class ChannelQuantScaleParameter(_ColumnAphroditeParameter):
  137. """
  138. Parameter class for weight scales loaded for weights with
  139. channel-wise quantization. Equivalent to _ColumnAphroditeParameter.
  140. """
  141. pass
  142. class PerTensorScaleParameter(BaseAphroditeParameter):
  143. """
  144. Parameter class for scales where the number of scales is
  145. equivalent to the number of logical matrices in fused linear
  146. layers (e.g. for QKV, there are 3 scales loaded from disk).
  147. This is relevant to weights with per-tensor quantization.
  148. Adds functionality to map the scalers to a shard during
  149. weight loading.
  150. Note: additional parameter manipulation may be handled
  151. for each quantization config specifically, within
  152. process_weights_after_loading
  153. """
  154. def __init__(self, **kwargs):
  155. self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
  156. super().__init__(**kwargs)
  157. def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
  158. if isinstance(shard_id, int):
  159. return shard_id
  160. assert isinstance(shard_id, str)
  161. assert shard_id in self.qkv_idxs
  162. return self.qkv_idxs[shard_id]
  163. def load_merged_column_weight(self, *args, **kwargs):
  164. self._load_into_shard_id(*args, **kwargs)
  165. def load_qkv_weight(self, *args, **kwargs):
  166. self._load_into_shard_id(*args, **kwargs)
  167. def load_column_parallel_weight(self, *args, **kwargs):
  168. self._load_into_shard_id(*args, **kwargs)
  169. def _load_into_shard_id(self, loaded_weight: torch.Tensor,
  170. shard_id: Union[str, int], **kwargs):
  171. """
  172. Slice the parameter data based on the shard id for
  173. loading.
  174. """
  175. param_data = self.data
  176. shard_id = self._shard_id_as_int(shard_id)
  177. # AutoFP8 scales do not have a shape
  178. # compressed-tensors scales do have a shape
  179. if len(loaded_weight.shape) != 0:
  180. assert loaded_weight.shape[0] == 1
  181. loaded_weight = loaded_weight[0]
  182. param_data = param_data[shard_id]
  183. assert param_data.shape == loaded_weight.shape
  184. param_data.copy_(loaded_weight)
  185. class PackedColumnParameter(_ColumnAphroditeParameter):
  186. """
  187. Parameter for model parameters which are packed on disk
  188. and support column parallelism only. See PackedAphroditeParameter
  189. for more details on the packed properties.
  190. """
  191. def __init__(self,
  192. packed_factor: int,
  193. packed_dim: int,
  194. marlin_tile_size: Optional[int] = None,
  195. **kwargs):
  196. self._packed_factor = packed_factor
  197. self._packed_dim = packed_dim
  198. self._marlin_tile_size = marlin_tile_size
  199. super().__init__(**kwargs)
  200. @property
  201. def packed_dim(self):
  202. return self._packed_dim
  203. @property
  204. def packed_factor(self):
  205. return self._packed_factor
  206. @property
  207. def marlin_tile_size(self):
  208. return self._marlin_tile_size
  209. def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
  210. return _adjust_shard_indexes_for_packing(
  211. shard_size=shard_size,
  212. shard_offset=shard_offset,
  213. packed_factor=self.packed_factor,
  214. marlin_tile_size=self.marlin_tile_size)
  215. class PackedAphroditeParameter(ModelWeightParameter):
  216. """
  217. Parameter for model weights which are packed on disk.
  218. Example: GPTQ Marlin weights are int4 or int8, packed into int32.
  219. Extends the ModelWeightParameter to take in the
  220. packed factor, the packed dimension, and optionally, marlin
  221. tile size for marlin kernels. Adjusts the shard_size and
  222. shard_offset for fused linear layers model weight loading
  223. by accounting for packing and optionally, marlin tile size.
  224. """
  225. def __init__(self,
  226. packed_factor: int,
  227. packed_dim: int,
  228. marlin_tile_size: Optional[int] = None,
  229. **kwargs):
  230. self._packed_factor = packed_factor
  231. self._packed_dim = packed_dim
  232. self._marlin_tile_size = marlin_tile_size
  233. super().__init__(**kwargs)
  234. @property
  235. def packed_dim(self):
  236. return self._packed_dim
  237. @property
  238. def packed_factor(self):
  239. return self._packed_factor
  240. @property
  241. def marlin_tile_size(self):
  242. return self._marlin_tile_size
  243. def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
  244. return _adjust_shard_indexes_for_packing(
  245. shard_size=shard_size,
  246. shard_offset=shard_offset,
  247. packed_factor=self.packed_factor,
  248. marlin_tile_size=self.marlin_tile_size)
  249. def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
  250. marlin_tile_size):
  251. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  252. def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
  253. marlin_tile_size):
  254. shard_size = shard_size // packed_factor
  255. shard_offset = shard_offset // packed_factor
  256. if marlin_tile_size is not None:
  257. return _adjust_shard_indexes_for_marlin(
  258. shard_size=shard_size,
  259. shard_offset=shard_offset,
  260. marlin_tile_size=marlin_tile_size)
  261. return shard_size, shard_offset