parameter.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. from typing import Callable, Optional, Union
  2. import torch
  3. from torch.nn import Parameter
  4. from aphrodite.distributed import get_tensor_model_parallel_rank
  5. __all__ = [
  6. "BaseAphroditeParameter", "PackedAphroditeParameter",
  7. "PerTensorScaleParameter", "ModelWeightParameter",
  8. "ChannelQuantScaleParameter", "GroupQuantScaleParameter"
  9. ]
  10. class BaseAphroditeParameter(Parameter):
  11. """
  12. Base parameter for Aphrodite linear layers. Extends the torch.nn.parameter
  13. by taking in a linear weight loader. Will copy the loaded weight
  14. into the parameter when the provided weight loader is called.
  15. """
  16. def __new__(cls, data: torch.Tensor, **kwargs):
  17. return super().__new__(cls, data=data, requires_grad=False)
  18. def __init__(self, data: torch.Tensor, weight_loader: Callable):
  19. """
  20. Initialize the BaseAphroditeParameter
  21. :param data: torch tensor with the parameter data
  22. :param weight_loader: weight loader callable
  23. :returns: a torch.nn.parameter
  24. """
  25. self._weight_loader = weight_loader
  26. @property
  27. def weight_loader(self):
  28. return self._weight_loader
  29. def _assert_and_load(self, loaded_weight: torch.Tensor):
  30. assert self.data.shape == loaded_weight.shape
  31. self.data.copy_(loaded_weight)
  32. def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
  33. self._assert_and_load(loaded_weight)
  34. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  35. self._assert_and_load(loaded_weight)
  36. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  37. self._assert_and_load(loaded_weight)
  38. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  39. self._assert_and_load(loaded_weight)
  40. class _ColumnAphroditeParameter(BaseAphroditeParameter):
  41. """
  42. Private class defining weight loading functionality
  43. (load_merged_column_weight, load_qkv_weight)
  44. for parameters being loaded into linear layers with column
  45. parallelism. This includes QKV and MLP layers which are
  46. not already fused on disk. Requires an output dimension
  47. to be defined. Called within the weight loader of
  48. each of the column parallel linear layers.
  49. """
  50. def __init__(self, output_dim: int, **kwargs):
  51. self._output_dim = output_dim
  52. super().__init__(**kwargs)
  53. @property
  54. def output_dim(self):
  55. return self._output_dim
  56. def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
  57. tp_rank = get_tensor_model_parallel_rank()
  58. shard_size = self.data.shape[self.output_dim]
  59. loaded_weight = loaded_weight.narrow(self.output_dim,
  60. tp_rank * shard_size, shard_size)
  61. assert self.data.shape == loaded_weight.shape
  62. self.data.copy_(loaded_weight)
  63. def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
  64. shard_offset = kwargs.get("shard_offset")
  65. shard_size = kwargs.get("shard_size")
  66. if isinstance(
  67. self,
  68. PackedAphroditeParameter
  69. ) and self.packed_dim == self.output_dim:
  70. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  71. shard_offset=shard_offset, shard_size=shard_size)
  72. param_data = self.data
  73. tp_rank = get_tensor_model_parallel_rank()
  74. param_data = param_data.narrow(self.output_dim, shard_offset,
  75. shard_size)
  76. loaded_weight = loaded_weight.narrow(self.output_dim,
  77. tp_rank * shard_size, shard_size)
  78. assert param_data.shape == loaded_weight.shape
  79. param_data.copy_(loaded_weight)
  80. def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
  81. shard_offset = kwargs.get("shard_offset")
  82. shard_size = kwargs.get("shard_size")
  83. shard_id = kwargs.get("shard_id")
  84. num_heads = kwargs.get("num_heads")
  85. if isinstance(
  86. self,
  87. PackedAphroditeParameter
  88. ) and self.output_dim == self.packed_dim:
  89. shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
  90. shard_offset=shard_offset, shard_size=shard_size)
  91. param_data = self.data
  92. tp_rank = get_tensor_model_parallel_rank()
  93. shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
  94. param_data = param_data.narrow(self.output_dim, shard_offset,
  95. shard_size)
  96. loaded_weight = loaded_weight.narrow(self.output_dim,
  97. shard_id * shard_size, shard_size)
  98. assert param_data.shape == loaded_weight.shape
  99. param_data.copy_(loaded_weight)
  100. class ModelWeightParameter(_ColumnAphroditeParameter):
  101. """
  102. Parameter class for linear layer weights. Extends the
  103. _ColumnAphroditeParameter by adding loading functionality
  104. for linear layers with row parallel functionality.
  105. Requires an input dimension to be defined.
  106. """
  107. def __init__(self, input_dim: int, **kwargs):
  108. self._input_dim = input_dim
  109. super().__init__(**kwargs)
  110. @property
  111. def input_dim(self):
  112. return self._input_dim
  113. def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
  114. tp_rank = get_tensor_model_parallel_rank()
  115. shard_size = self.data.shape[self.input_dim]
  116. loaded_weight = loaded_weight.narrow(self.input_dim,
  117. tp_rank * shard_size, shard_size)
  118. if len(loaded_weight.shape) == 0:
  119. loaded_weight = loaded_weight.reshape(1)
  120. assert self.data.shape == loaded_weight.shape
  121. self.data.copy_(loaded_weight)
  122. class GroupQuantScaleParameter(ModelWeightParameter):
  123. """
  124. Parameter class for weight scales loaded for weights with
  125. grouped quantization. Equivalent to ModelWeightParameter.
  126. """
  127. pass
  128. class ChannelQuantScaleParameter(_ColumnAphroditeParameter):
  129. """
  130. Parameter class for weight scales loaded for weights with
  131. channel-wise quantization. Equivalent to _ColumnAphroditeParameter.
  132. """
  133. pass
  134. class PerTensorScaleParameter(BaseAphroditeParameter):
  135. """
  136. Parameter class for scales where the number of scales is
  137. equivalent to the number of logical matrices in fused linear
  138. layers (e.g. for QKV, there are 3 scales loaded from disk).
  139. This is relevant to weights with per-tensor quantization.
  140. Adds functionality to map the scalers to a shard during
  141. weight loading.
  142. Note: additional parameter manipulation may be handled
  143. for each quantization config specifically, within
  144. process_weights_after_loading
  145. """
  146. def __init__(self, **kwargs):
  147. self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
  148. super().__init__(**kwargs)
  149. def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
  150. if isinstance(shard_id, int):
  151. return shard_id
  152. assert isinstance(shard_id, str)
  153. assert shard_id in self.qkv_idxs
  154. return self.qkv_idxs[shard_id]
  155. def load_merged_column_weight(self, *args, **kwargs):
  156. self._load_into_shard_id(*args, **kwargs)
  157. def load_qkv_weight(self, *args, **kwargs):
  158. self._load_into_shard_id(*args, **kwargs)
  159. def load_column_parallel_weight(self, *args, **kwargs):
  160. self._load_into_shard_id(*args, **kwargs)
  161. def _load_into_shard_id(self, loaded_weight: torch.Tensor,
  162. shard_id: Union[str, int], **kwargs):
  163. """
  164. Slice the parameter data based on the shard id for
  165. loading.
  166. """
  167. param_data = self.data
  168. shard_id = self._shard_id_as_int(shard_id)
  169. # AutoFP8 scales do not have a shape
  170. # compressed-tensors scales do have a shape
  171. if len(loaded_weight.shape) != 0:
  172. assert loaded_weight.shape[0] == 1
  173. loaded_weight = loaded_weight[0]
  174. param_data = param_data[shard_id]
  175. assert param_data.shape == loaded_weight.shape
  176. param_data.copy_(loaded_weight)
  177. class PackedAphroditeParameter(ModelWeightParameter):
  178. """
  179. Parameter for model weights which are packed on disk.
  180. Example: GPTQ Marlin weights are int4 or int8, packed into int32.
  181. Extends the ModelWeightParameter to take in the
  182. packed factor, the packed dimension, and optionally, marlin
  183. tile size for marlin kernels. Adjusts the shard_size and
  184. shard_offset for fused linear layers model weight loading
  185. by accounting for packing and optionally, marlin tile size.
  186. """
  187. def __init__(self,
  188. packed_factor: int,
  189. packed_dim: int,
  190. marlin_tile_size: Optional[int] = None,
  191. **kwargs):
  192. self._packed_factor = packed_factor
  193. self._packed_dim = packed_dim
  194. self._marlin_tile = marlin_tile_size
  195. super().__init__(**kwargs)
  196. @property
  197. def packed_dim(self):
  198. return self._packed_dim
  199. @property
  200. def packed_factor(self):
  201. return self._packed_factor
  202. @property
  203. def marlin_tile(self):
  204. return self._marlin_tile
  205. def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset):
  206. return shard_size * self.marlin_tile, shard_offset * self.marlin_tile
  207. def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
  208. shard_size = shard_size // self.packed_factor
  209. shard_offset = shard_offset // self.packed_factor
  210. if self.marlin_tile is not None:
  211. return self._adjust_shard_indexes_for_marlin(
  212. shard_size, shard_offset)
  213. return shard_size, shard_offset