linear.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965
  1. from abc import abstractmethod
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter, UninitializedParameter
  7. # yapf: disable
  8. from aphrodite.distributed import (divide,
  9. get_current_tp_rank_partition_offset,
  10. get_current_tp_rank_partition_size,
  11. get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce)
  16. # yapf: enable
  17. from aphrodite.modeling.utils import set_weight_attrs
  18. from aphrodite.quantization.base_config import (QuantizationConfig,
  19. QuantizeMethodBase)
  20. def adjust_marlin_shard(param, shard_size, shard_offset):
  21. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  22. if marlin_tile_size is None:
  23. return shard_size, shard_offset
  24. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  25. def adjust_bitsandbytes_shard(param: Parameter,
  26. qkv_offsets: Dict[str, Tuple[int, int]],
  27. loaded_shard_id: str) -> Tuple[int, int]:
  28. """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
  29. total, _ = qkv_offsets["total"]
  30. orig_offset, orig_size = qkv_offsets[loaded_shard_id]
  31. quantized_total = param.data.shape[0]
  32. quantized_offset = orig_offset * quantized_total // total
  33. quantized_size = orig_size * quantized_total // total
  34. return quantized_size, quantized_offset
  35. def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
  36. """For fused modules (QKV and MLP) we have an array of length
  37. N that holds 1 scale for each "logical" matrix. So the param
  38. is an array of length N. The loaded_weight corresponds to
  39. one of the shards on disk. Here, we slice the param based on
  40. the shard_id for loading.
  41. """
  42. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  43. if isinstance(shard_id, str):
  44. shard_id = qkv_idxs[shard_id]
  45. elif not isinstance(shard_id, int):
  46. raise ValueError(f"Unknown Shard Id {shard_id}")
  47. # AutoFP8 scales do not have a shape
  48. # compressed-tensors scales do have a shape
  49. if len(loaded_weight.shape) != 0:
  50. assert loaded_weight.shape[0] == 1
  51. loaded_weight = loaded_weight[0]
  52. return param[shard_id], loaded_weight
  53. class LinearMethodBase(QuantizeMethodBase):
  54. """Base class for different (maybe quantized) linear methods."""
  55. @abstractmethod
  56. def create_weights(self, layer: torch.nn.Module,
  57. input_size_per_partition: int,
  58. output_partition_sizes: List[int], input_size: int,
  59. output_size: int, params_dtype: torch.dtype,
  60. **extra_weight_attrs):
  61. """Create weights for a linear layer.
  62. The weights will be set as attributes of the layer.
  63. Args:
  64. layer: The layer that is using the LinearMethodBase factory.
  65. input_size_per_partition: Size of the weight input dim on rank X.
  66. output_partition_sizes: Sizes of the output dim of each logical
  67. weight on rank X. E.g., output_partition_sizes for QKVLinear
  68. is a list contains the width of Wq, Wk, Wv on rank X.
  69. input_size: Size of the input dim of the weight across all ranks.
  70. output_size: Size of the output dim of the weight across all ranks.
  71. params_dtype: Datatype of the parameters.
  72. """
  73. raise NotImplementedError
  74. @abstractmethod
  75. def apply(self,
  76. layer: torch.nn.Module,
  77. x: torch.Tensor,
  78. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  79. """Apply the weights in layer to the input tensor.
  80. Expects create_weights to have been called before on the layer."""
  81. raise NotImplementedError
  82. class UnquantizedLinearMethod(LinearMethodBase):
  83. """Linear method without quantization."""
  84. def create_weights(self, layer: torch.nn.Module,
  85. input_size_per_partition: int,
  86. output_partition_sizes: List[int], input_size: int,
  87. output_size: int, params_dtype: torch.dtype,
  88. **extra_weight_attrs):
  89. weight = Parameter(torch.empty(sum(output_partition_sizes),
  90. input_size_per_partition,
  91. dtype=params_dtype),
  92. requires_grad=False)
  93. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  94. layer.register_parameter("weight", weight)
  95. set_weight_attrs(weight, extra_weight_attrs)
  96. def apply(self,
  97. layer: torch.nn.Module,
  98. x: torch.Tensor,
  99. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  100. return F.linear(x, layer.weight, bias)
  101. class LinearBase(torch.nn.Module):
  102. """Base linear layer.
  103. Args:
  104. input_size: input dimension of the linear layer.
  105. output_size: output dimension of the linear layer.
  106. bias: If true, add bias.
  107. skip_bias_add: If true, skip adding bias but instead return it.
  108. params_dtype: Data type for the parameters.
  109. quant_config: Quantization configure.
  110. """
  111. def __init__(
  112. self,
  113. input_size: int,
  114. output_size: int,
  115. skip_bias_add: bool = False,
  116. params_dtype: Optional[torch.dtype] = None,
  117. quant_config: Optional[QuantizationConfig] = None,
  118. prefix: str = "",
  119. ):
  120. super().__init__()
  121. # Keep input parameters
  122. self.input_size = input_size
  123. self.output_size = output_size
  124. self.skip_bias_add = skip_bias_add
  125. if params_dtype is None:
  126. params_dtype = torch.get_default_dtype()
  127. self.params_dtype = params_dtype
  128. if quant_config is None:
  129. self.quant_method: Optional[
  130. QuantizeMethodBase] = UnquantizedLinearMethod()
  131. else:
  132. self.quant_method = quant_config.get_quant_method(self,
  133. prefix=prefix)
  134. def forward(self, x: torch.Tensor) -> torch.Tensor:
  135. raise NotImplementedError
  136. class ReplicatedLinear(LinearBase):
  137. """Replicated linear layer.
  138. Args:
  139. input_size: input dimension of the linear layer.
  140. output_size: output dimension of the linear layer.
  141. bias: If true, add bias.
  142. skip_bias_add: If true, skip adding bias but instead return it.
  143. params_dtype: Data type for the parameters.
  144. quant_config: Quantization configure.
  145. prefix: The name of the layer in the state dict, including all parents
  146. (e.g. model.layers.0.qkv_proj)
  147. """
  148. def __init__(self,
  149. input_size: int,
  150. output_size: int,
  151. bias: bool = True,
  152. skip_bias_add: bool = False,
  153. params_dtype: Optional[torch.dtype] = None,
  154. quant_config: Optional[QuantizationConfig] = None,
  155. prefix: str = ""):
  156. super().__init__(input_size,
  157. output_size,
  158. skip_bias_add,
  159. params_dtype,
  160. quant_config,
  161. prefix=prefix)
  162. # All the linear layer supports quant method.
  163. assert self.quant_method is not None
  164. self.quant_method.create_weights(self,
  165. self.input_size, [self.output_size],
  166. self.input_size,
  167. self.output_size,
  168. self.params_dtype,
  169. prefix=prefix)
  170. if bias:
  171. self.bias = Parameter(
  172. torch.empty(self.output_size, dtype=self.params_dtype))
  173. set_weight_attrs(self.bias, {"output_dim": 0})
  174. else:
  175. self.register_parameter("bias", None)
  176. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  177. # If the weight on disk does not have a shape, give it one
  178. # (such scales for AutoFp8).
  179. if len(loaded_weight.shape) == 0:
  180. loaded_weight = loaded_weight.reshape(1)
  181. assert param.size() == loaded_weight.size()
  182. param.data.copy_(loaded_weight)
  183. def forward(self, x: torch.Tensor) -> torch.Tensor:
  184. bias = self.bias if not self.skip_bias_add else None
  185. assert self.quant_method is not None
  186. output = self.quant_method.apply(self, x, bias)
  187. output_bias = self.bias if self.skip_bias_add else None
  188. return output, output_bias
  189. def extra_repr(self) -> str:
  190. s = f"in_features={self.input_size}"
  191. s += f", output_features={self.output_size}"
  192. s += f", bias={self.bias is not None}"
  193. return s
  194. class ColumnParallelLinear(LinearBase):
  195. """Linear layer with column parallelism.
  196. The linear layer is defined as Y = XA + b. A is parallelized along
  197. its second dimension as A = [A_1, ..., A_p].
  198. Args:
  199. input_size: first dimension of matrix A.
  200. output_size: second dimension of matrix A.
  201. bias: If true, add bias.
  202. gather_output: If true, call all-gather on output and make Y available
  203. to all GPUs, otherwise, every GPU will have its output
  204. which is Y_i = XA_i
  205. skip_bias_add: This was added to enable performance optimizations where
  206. bias can be fused with other element-wise operations. we
  207. skip adding bias but instead return it.
  208. params_dtype: Data type for the parameters.
  209. quant_config: Quantization configure.
  210. output_sizes: list of output sizes packed into one output, like for QKV
  211. the list would be size 3.
  212. prefix: The name of the layer in the state dict, including all parents
  213. (e.g. model.layers.0.qkv_proj)
  214. """
  215. def __init__(self,
  216. input_size: int,
  217. output_size: int,
  218. bias: bool = True,
  219. gather_output: bool = False,
  220. skip_bias_add: bool = False,
  221. params_dtype: Optional[torch.dtype] = None,
  222. quant_config: Optional[QuantizationConfig] = None,
  223. output_sizes: Optional[List[int]] = None,
  224. prefix: str = ""):
  225. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  226. quant_config, prefix)
  227. self.gather_output = gather_output
  228. # Divide the weight matrix along the last dimension.
  229. tp_rank = get_tensor_model_parallel_rank()
  230. tp_size = get_tensor_model_parallel_world_size()
  231. assert self.quant_method is not None
  232. if quant_config is None:
  233. self.output_size_per_partition = get_current_tp_rank_partition_size(
  234. output_size, tp_rank, tp_size)
  235. else:
  236. self.output_size_per_partition = divide(self.output_size, tp_size)
  237. self.output_partition_sizes = [self.output_size_per_partition]
  238. # If QKV or MergedColumn, use output size of each partition.
  239. if hasattr(self, "output_sizes"):
  240. if quant_config is None:
  241. self.output_partition_sizes = [
  242. get_current_tp_rank_partition_size(output_size, tp_rank,
  243. tp_size)
  244. for output_size in self.output_sizes
  245. ]
  246. else:
  247. self.output_partition_sizes = [
  248. divide(output_size, tp_size)
  249. for output_size in self.output_sizes
  250. ]
  251. if output_sizes is None:
  252. output_sizes = [output_size]
  253. self.quant_method.create_weights(
  254. layer=self,
  255. input_size_per_partition=self.input_size,
  256. output_partition_sizes=self.output_partition_sizes,
  257. input_size=self.input_size,
  258. output_size=self.output_size,
  259. params_dtype=self.params_dtype,
  260. weight_loader=self.weight_loader,
  261. prefix=prefix)
  262. if bias:
  263. self.bias = Parameter(
  264. torch.empty(self.output_size_per_partition,
  265. dtype=params_dtype))
  266. set_weight_attrs(self.bias, {
  267. "output_dim": 0,
  268. "weight_loader": self.weight_loader,
  269. })
  270. else:
  271. self.register_parameter("bias", None)
  272. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  273. tp_rank = get_tensor_model_parallel_rank()
  274. output_dim = getattr(param, "output_dim", None)
  275. # Special case for GGUF
  276. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  277. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  278. if is_gguf_weight_type:
  279. param.weight_type = loaded_weight.item()
  280. # Materialize GGUF UninitializedParameter
  281. if is_gguf_weight and isinstance(param, UninitializedParameter):
  282. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  283. param_data = param.data
  284. if output_dim is not None:
  285. shard_size = param_data.shape[output_dim]
  286. start_idx = tp_rank * shard_size
  287. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  288. shard_size)
  289. # Special case for loading scales off disk, which often do not
  290. # have a shape (such as in the case of AutoFP8).
  291. if len(loaded_weight.shape) == 0:
  292. loaded_weight = loaded_weight.reshape(1)
  293. assert param_data.shape == loaded_weight.shape
  294. param_data.copy_(loaded_weight)
  295. def forward(self, input_):
  296. bias = self.bias if not self.skip_bias_add else None
  297. # Matrix multiply.
  298. assert self.quant_method is not None
  299. output_parallel = self.quant_method.apply(self, input_, bias)
  300. if self.gather_output:
  301. # All-gather across the partitions.
  302. output = tensor_model_parallel_all_gather(output_parallel)
  303. else:
  304. output = output_parallel
  305. output_bias = self.bias if self.skip_bias_add else None
  306. return output, output_bias
  307. def extra_repr(self) -> str:
  308. s = f"in_features={self.input_size}"
  309. s += f", output_features={self.output_size_per_partition}"
  310. s += f", bias={self.bias is not None}"
  311. s += f", tp_size={get_tensor_model_parallel_world_size()}"
  312. s += f", gather_output={self.gather_output}"
  313. return s
  314. class MergedColumnParallelLinear(ColumnParallelLinear):
  315. """Packed linear layers with column parallelism.
  316. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  317. along the output dimension. When the weight matrix is loaded, the
  318. different partitions are sharded separately.
  319. Args:
  320. input_size: input dimension of the linear layer.
  321. output_sizes: list of output dimensions of the linear layer.
  322. bias: If true, add bias.
  323. gather_output: If true, call all-gather on output and make the output
  324. available to all GPUs, otherwise, every GPU will have
  325. its own output.
  326. skip_bias_add: This was added to enable performance optimizations where
  327. bias can be fused with other element-wise operations. we
  328. skip adding bias but instead return it.
  329. params_dtype: Data type for the parameters.
  330. quant_config: Quantization configure.
  331. prefix: The name of the layer in the state dict, including all parents
  332. (e.g. model.layers.0.qkv_proj)
  333. """
  334. def __init__(self,
  335. input_size: int,
  336. output_sizes: List[int],
  337. bias: bool = True,
  338. gather_output: bool = False,
  339. skip_bias_add: bool = False,
  340. params_dtype: Optional[torch.dtype] = None,
  341. quant_config: Optional[QuantizationConfig] = None,
  342. prefix: str = ""):
  343. self.output_sizes = output_sizes
  344. self.quant_config = quant_config
  345. if quant_config is not None:
  346. tp_size = get_tensor_model_parallel_world_size()
  347. assert all(output_size % tp_size == 0
  348. for output_size in output_sizes)
  349. super().__init__(input_size=input_size,
  350. output_size=sum(output_sizes),
  351. bias=bias,
  352. gather_output=gather_output,
  353. skip_bias_add=skip_bias_add,
  354. params_dtype=params_dtype,
  355. quant_config=quant_config,
  356. prefix=prefix)
  357. def weight_loader(self,
  358. param: Parameter,
  359. loaded_weight: torch.Tensor,
  360. loaded_shard_id: Optional[int] = None):
  361. # Special case for GGUF
  362. # initialize GGUF param after we know the quantize type
  363. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  364. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  365. if is_gguf_weight_type:
  366. param.data[loaded_shard_id].copy_(loaded_weight)
  367. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  368. return
  369. if is_gguf_weight and isinstance(param, UninitializedParameter):
  370. from gguf.constants import GGML_QUANT_SIZES
  371. ori_shape = param.tensor_shape
  372. weight_types = self.qweight_type.shard_weight_type.values()
  373. row_size = []
  374. for weight_type in weight_types:
  375. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  376. row_size.append(ori_shape[1] // block_size * type_size)
  377. q_shape = (ori_shape[0], max(row_size))
  378. param.materialize(q_shape, dtype=loaded_weight.dtype)
  379. param_data = param.data
  380. output_dim = getattr(param, "output_dim", None)
  381. # Special case for AQLM codebooks.
  382. is_metadata = getattr(param, "is_metadata", False)
  383. # Special case for per-tensor scale to load scalar into fused array.
  384. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  385. if loaded_shard_id is None:
  386. # Loaded weight is already fused on disk (qkv/mlp).
  387. if output_dim is None:
  388. if needs_scalar_to_array:
  389. param_data, loaded_weight = adjust_scalar_to_fused_array(
  390. param_data, loaded_weight, 0)
  391. assert param_data.shape == loaded_weight.shape
  392. param_data.copy_(loaded_weight)
  393. return
  394. current_shard_offset = 0
  395. shard_offsets: List[Tuple[int, int, int]] = []
  396. for i, output_size in enumerate(self.output_sizes):
  397. shard_offsets.append((i, current_shard_offset, output_size))
  398. current_shard_offset += output_size
  399. packed_dim = getattr(param, "packed_dim", None)
  400. for shard_id, shard_offset, shard_size in shard_offsets:
  401. # Special case for Quantization.
  402. # If quantized, we need to adjust the offset and size to account
  403. # for the packing.
  404. if packed_dim == output_dim:
  405. shard_size = shard_size // param.pack_factor
  406. shard_offset = shard_offset // param.pack_factor
  407. # Special case for Marlin.
  408. shard_size, shard_offset = adjust_marlin_shard(
  409. param, shard_size, shard_offset)
  410. loaded_weight_shard = loaded_weight.narrow(
  411. output_dim, shard_offset, shard_size)
  412. self.weight_loader(param, loaded_weight_shard, shard_id)
  413. return
  414. assert loaded_shard_id < len(self.output_sizes)
  415. tp_rank = get_tensor_model_parallel_rank()
  416. tp_size = get_tensor_model_parallel_world_size()
  417. if output_dim is not None:
  418. if self.quant_config is None:
  419. shard_offset = sum(
  420. get_current_tp_rank_partition_size(output_size, tp_rank,
  421. tp_size)
  422. for output_size in self.output_sizes[:loaded_shard_id])
  423. shard_size = get_current_tp_rank_partition_size(
  424. self.output_sizes[loaded_shard_id], tp_rank, tp_size)
  425. else:
  426. shard_offset = sum(
  427. self.output_sizes[:loaded_shard_id]) // tp_size
  428. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  429. # Special case for quantization.
  430. # If quantized, we need to adjust the offset and size to account
  431. # for the packing.
  432. packed_dim = getattr(param, "packed_dim", None)
  433. if packed_dim == output_dim:
  434. shard_size = shard_size // param.pack_factor
  435. shard_offset = shard_offset // param.pack_factor
  436. # Special case for Marlin.
  437. shard_size, shard_offset = adjust_marlin_shard(
  438. param, shard_size, shard_offset)
  439. use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
  440. if use_bitsandbytes:
  441. shard_size = loaded_weight.shape[output_dim]
  442. shard_offset = loaded_weight.shape[output_dim] * \
  443. loaded_shard_id
  444. if is_gguf_weight:
  445. tp_size = get_tensor_model_parallel_world_size()
  446. output_dim = getattr(param, "output_dim", None)
  447. shard_shape = list(loaded_weight.shape)
  448. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  449. param.shard_id.append(loaded_shard_id)
  450. param.shard_size[loaded_shard_id] = shard_shape
  451. input_dim = getattr(param, "input_dim", None)
  452. input_size = loaded_weight.shape[input_dim]
  453. param_data = param_data.narrow(input_dim, 0, input_size)
  454. param_data = param_data.narrow(output_dim, shard_offset,
  455. shard_size)
  456. if self.quant_config is None:
  457. start_idx = get_current_tp_rank_partition_offset(
  458. loaded_weight.shape[output_dim], tp_rank, tp_size)
  459. else:
  460. start_idx = tp_rank * shard_size
  461. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  462. shard_size)
  463. # Special case for AQLM codebooks.
  464. elif is_metadata:
  465. # metadata indicates fixed size concatenated along dim 0
  466. shard_size = loaded_weight.shape[0]
  467. shard_offset = loaded_shard_id * shard_size
  468. param_data = param_data.narrow(0, shard_offset, shard_size)
  469. # Special case for per-tensor scales in fused case.
  470. elif needs_scalar_to_array:
  471. param_data, loaded_weight = adjust_scalar_to_fused_array(
  472. param_data, loaded_weight, loaded_shard_id)
  473. else:
  474. ignore_warning = getattr(param, "ignore_warning", False)
  475. if not ignore_warning:
  476. logger.warning(
  477. "Loading a weight without `output_dim` attribute in "
  478. "MergedColumnParallelLinear, assume the weight is "
  479. "the same for all partitions.")
  480. assert param_data.shape == loaded_weight.shape
  481. param_data.copy_(loaded_weight)
  482. class QKVParallelLinear(ColumnParallelLinear):
  483. """Linear layers for the attention's QKV transformation.
  484. Linear layers for the linear transformation of the query, key, and value
  485. vectors in the attention layer. The weight matrix is concatenated along
  486. the output dimension. The layer is parallelized along the head dimension.
  487. When the number of key/value heads is smaller than the number of query
  488. heads (e.g., multi-query/grouped-query attention), the key/value head may
  489. be replicated while the query heads are partitioned.
  490. Args:
  491. hidden_size: input hidden state size of the transformer.
  492. head_size: size of each attention head.
  493. total_num_heads: total number of attention query heads.
  494. total_num_kv_heads: total number of attention key/value heads. If
  495. None, assume total_num_kv_heads = total_num_heads.
  496. bias: If true, add bias.
  497. skip_bias_add: This was added to enable performance optimizations where
  498. bias can be fused with other element-wise operations. we
  499. skip adding bias but instead return it.
  500. params_dtype: Data type for the parameters.
  501. quant_config: Quantization configure.
  502. prefix: The name of the layer in the state dict, including all parents
  503. (e.g. model.layers.0.qkv_proj)
  504. """
  505. def __init__(self,
  506. hidden_size: int,
  507. head_size: int,
  508. total_num_heads: int,
  509. total_num_kv_heads: Optional[int] = None,
  510. bias: bool = True,
  511. skip_bias_add: bool = False,
  512. params_dtype: Optional[torch.dtype] = None,
  513. quant_config: Optional[QuantizationConfig] = None,
  514. prefix: str = ""):
  515. self.hidden_size = hidden_size
  516. self.head_size = head_size
  517. self.total_num_heads = total_num_heads
  518. self.quant_config = quant_config
  519. if total_num_kv_heads is None:
  520. total_num_kv_heads = total_num_heads
  521. self.total_num_kv_heads = total_num_kv_heads
  522. # Divide the weight matrix along the last dimension.
  523. tp_size = get_tensor_model_parallel_world_size()
  524. tp_rank = get_tensor_model_parallel_rank()
  525. if quant_config is None:
  526. self.num_heads_per_kv_head = (self.total_num_heads //
  527. self.total_num_kv_heads)
  528. self.num_kv_heads = get_current_tp_rank_partition_size(
  529. self.total_num_kv_heads, tp_rank, tp_size)
  530. self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head
  531. self.num_kv_head_replicas = 1
  532. else:
  533. self.num_heads = divide(self.total_num_heads, tp_size)
  534. if tp_size >= self.total_num_kv_heads:
  535. self.num_kv_heads = 1
  536. self.num_kv_head_replicas = divide(tp_size,
  537. self.total_num_kv_heads)
  538. elif tp_size < self.total_num_kv_heads and quant_config is not None:
  539. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  540. self.num_kv_head_replicas = 1
  541. input_size = self.hidden_size
  542. output_size = (self.num_heads +
  543. 2 * self.num_kv_heads) * tp_size * self.head_size
  544. self.output_sizes = [
  545. self.num_heads * self.head_size * tp_size, # q_proj
  546. self.num_kv_heads * self.head_size * tp_size, # k_proj
  547. self.num_kv_heads * self.head_size * tp_size, # v_proj
  548. ]
  549. super().__init__(input_size=input_size,
  550. output_size=output_size,
  551. bias=bias,
  552. gather_output=False,
  553. skip_bias_add=skip_bias_add,
  554. params_dtype=params_dtype,
  555. quant_config=quant_config,
  556. prefix=prefix)
  557. def weight_loader(self,
  558. param: Parameter,
  559. loaded_weight: torch.Tensor,
  560. loaded_shard_id: Optional[str] = None):
  561. # Special case for GGUF
  562. # initialize GGUF param after we know the quantize type
  563. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  564. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  565. if is_gguf_weight_type and loaded_shard_id is not None:
  566. idx_map = {"q": 0, "k": 1, "v": 2}
  567. param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
  568. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  569. return
  570. if is_gguf_weight and isinstance(param, UninitializedParameter):
  571. from gguf.constants import GGML_QUANT_SIZES
  572. ori_shape = param.tensor_shape
  573. weight_types = self.qweight_type.shard_weight_type.values()
  574. row_size = []
  575. for weight_type in weight_types:
  576. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  577. row_size.append(ori_shape[1] // block_size * type_size)
  578. q_shape = (ori_shape[0], max(row_size))
  579. param.materialize(q_shape, dtype=loaded_weight.dtype)
  580. param_data = param.data
  581. output_dim = getattr(param, "output_dim", None)
  582. # Special case for AQLM codebooks.
  583. is_metadata = getattr(param, "is_metadata", False)
  584. # Special case for per-tensor scales in fused case.
  585. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  586. if loaded_shard_id is None:
  587. # Loaded weight is already fused on disk (qkv/mlp).
  588. if output_dim is None:
  589. if needs_scalar_to_array:
  590. param_data, loaded_weight = adjust_scalar_to_fused_array(
  591. param_data, loaded_weight, 0)
  592. assert param_data.shape == loaded_weight.shape
  593. param_data.copy_(loaded_weight)
  594. return
  595. shard_offsets = [
  596. # (shard_id, shard_offset, shard_size)
  597. ("q", 0, self.total_num_heads * self.head_size),
  598. ("k", self.total_num_heads * self.head_size,
  599. self.total_num_kv_heads * self.head_size),
  600. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  601. self.head_size, self.total_num_kv_heads * self.head_size),
  602. ]
  603. packed_dim = getattr(param, "packed_dim", None)
  604. for shard_id, shard_offset, shard_size in shard_offsets:
  605. # Special case for Quantized Weights.
  606. # If quantized, we need to adjust the offset and size to account
  607. # for the packing.
  608. if packed_dim == output_dim:
  609. shard_size = shard_size // param.pack_factor
  610. shard_offset = shard_offset // param.pack_factor
  611. # Special case for Marlin.
  612. shard_size, shard_offset = adjust_marlin_shard(
  613. param, shard_size, shard_offset)
  614. loaded_weight_shard = loaded_weight.narrow(
  615. output_dim, shard_offset, shard_size)
  616. self.weight_loader(param, loaded_weight_shard, shard_id)
  617. return
  618. tp_rank = get_tensor_model_parallel_rank()
  619. assert loaded_shard_id in ["q", "k", "v"]
  620. # If output dim is defined, use the default loading process.
  621. if output_dim is not None:
  622. if loaded_shard_id == "q":
  623. shard_offset = 0
  624. shard_size = self.num_heads * self.head_size
  625. if self.quant_config is None:
  626. multiple_of = self.head_size * self.num_heads_per_kv_head
  627. elif loaded_shard_id == "k":
  628. shard_offset = self.num_heads * self.head_size
  629. shard_size = self.num_kv_heads * self.head_size
  630. if self.quant_config is None:
  631. multiple_of = self.head_size
  632. elif loaded_shard_id == "v":
  633. shard_offset = (self.num_heads +
  634. self.num_kv_heads) * self.head_size
  635. shard_size = self.num_kv_heads * self.head_size
  636. if self.quant_config is None:
  637. multiple_of = self.head_size
  638. # Special case for Quantized Weights.
  639. # If quantized, we need to adjust the offset and size to account
  640. # for the packing.
  641. packed_dim = getattr(param, "packed_dim", None)
  642. if packed_dim == output_dim:
  643. shard_size = shard_size // param.pack_factor
  644. shard_offset = shard_offset // param.pack_factor
  645. if self.quant_config is None:
  646. multiple_of = multiple_of // param.pack_factor
  647. # Special case for Marlin.
  648. shard_size, shard_offset = adjust_marlin_shard(
  649. param, shard_size, shard_offset)
  650. use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
  651. if use_bitsandbytes:
  652. orig_qkv_offsets = {
  653. "q": (0, self.num_heads * self.head_size),
  654. "k": (self.num_heads * self.head_size,
  655. self.num_kv_heads * self.head_size),
  656. "v":
  657. ((self.num_heads + self.num_kv_heads) * self.head_size,
  658. self.num_kv_heads * self.head_size),
  659. "total":
  660. ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
  661. 0)
  662. }
  663. shard_size, shard_offset = adjust_bitsandbytes_shard(
  664. param, orig_qkv_offsets, loaded_shard_id)
  665. if is_gguf_weight:
  666. tp_size = get_tensor_model_parallel_world_size()
  667. output_dim = getattr(param, "output_dim", None)
  668. shard_shape = list(loaded_weight.shape)
  669. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  670. param.shard_id.append(loaded_shard_id)
  671. param.shard_size[loaded_shard_id] = shard_shape
  672. input_dim = getattr(param, "input_dim", None)
  673. input_size = loaded_weight.shape[input_dim]
  674. param_data = param_data.narrow(input_dim, 0, input_size)
  675. param_data = param_data.narrow(output_dim, shard_offset,
  676. shard_size)
  677. if self.quant_config is None:
  678. tp_size = get_tensor_model_parallel_world_size()
  679. total_size = loaded_weight.shape[output_dim]
  680. start_idx = get_current_tp_rank_partition_offset(
  681. total_size, tp_rank, tp_size, multiple_of=multiple_of)
  682. else:
  683. if loaded_shard_id == "q":
  684. shard_id = tp_rank
  685. else:
  686. shard_id = tp_rank // self.num_kv_head_replicas
  687. start_idx = shard_id * shard_size
  688. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  689. shard_size)
  690. # Special case for for AQLM codebooks.
  691. elif is_metadata:
  692. # metadata indicates fixed size concatenated along dim 0
  693. shard_size = loaded_weight.shape[0]
  694. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  695. param_data = param_data.narrow(0, shard_index * shard_size,
  696. shard_size)
  697. # Special case for per-tensor scales in fused case.
  698. elif needs_scalar_to_array:
  699. param_data, loaded_weight = adjust_scalar_to_fused_array(
  700. param_data, loaded_weight, loaded_shard_id)
  701. else:
  702. ignore_warning = getattr(param, "ignore_warning", False)
  703. if not ignore_warning:
  704. logger.warning(
  705. "Loading a weight without `output_dim` attribute in "
  706. "QKVParallelLinear, assume the weight is the same "
  707. "for all partitions.")
  708. assert param_data.shape == loaded_weight.shape
  709. param_data.copy_(loaded_weight)
  710. class RowParallelLinear(LinearBase):
  711. """Linear layer with row parallelism.
  712. The linear layer is defined as Y = XA + b. A is parallelized along
  713. its first dimension and X along its second dimension as:
  714. - -
  715. | A_1 |
  716. | . |
  717. A = | . | X = [X_1, ..., X_p]
  718. | . |
  719. | A_p |
  720. - -
  721. Arguments:
  722. input_size: first dimension of matrix A.
  723. output_size: second dimension of matrix A.
  724. bias: If true, add bias. Note that bias is not parallelized.
  725. input_is_parallel: If true, we assume that the input is already
  726. split across the GPUs and we do not split
  727. again.
  728. skip_bias_add: This was added to enable performance optimization where
  729. bias can be fused with other element-wise operations.
  730. We skip adding bias but instead return it.
  731. params_dtype: Data type for the parameters.
  732. quant_config: Quantization configure.
  733. partition_multiple_of: Partitions will be divided,
  734. so each partition is a multiple of this number.
  735. """
  736. def __init__(self,
  737. input_size: int,
  738. output_size: int,
  739. bias: bool = True,
  740. input_is_parallel: bool = True,
  741. skip_bias_add: bool = False,
  742. params_dtype: Optional[torch.dtype] = None,
  743. reduce_results: bool = True,
  744. quant_config: Optional[QuantizationConfig] = None,
  745. partition_multiple_of: int = 1,
  746. prefix: str = ""):
  747. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  748. quant_config, prefix)
  749. self.input_is_parallel = input_is_parallel
  750. self.reduce_results = reduce_results
  751. self.quant_config = quant_config
  752. # Divide the weight matrix along the last dimension.
  753. self.tp_rank = get_tensor_model_parallel_rank()
  754. self.tp_size = get_tensor_model_parallel_world_size()
  755. self.tp_rank = get_tensor_model_parallel_rank()
  756. if quant_config is None:
  757. self.partition_multiple_of = partition_multiple_of
  758. self.input_size_per_partition = get_current_tp_rank_partition_size(
  759. input_size, self.tp_rank, self.tp_size, partition_multiple_of)
  760. else:
  761. self.input_size_per_partition = divide(input_size, self.tp_size)
  762. assert self.quant_method is not None
  763. self.quant_method.create_weights(
  764. layer=self,
  765. input_size_per_partition=self.input_size_per_partition,
  766. output_partition_sizes=[self.output_size],
  767. input_size=self.input_size,
  768. output_size=self.output_size,
  769. params_dtype=self.params_dtype,
  770. weight_loader=self.weight_loader,
  771. prefix=prefix)
  772. if not reduce_results and (bias and not skip_bias_add):
  773. raise ValueError("When not reduce the results, adding bias to the "
  774. "results can lead to incorrect results")
  775. if bias:
  776. self.bias = Parameter(
  777. torch.empty(self.output_size, dtype=params_dtype))
  778. set_weight_attrs(self.bias, {
  779. "output_dim": 0,
  780. "weight_loader": self.weight_loader,
  781. })
  782. else:
  783. self.register_parameter("bias", None)
  784. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  785. tp_size = get_tensor_model_parallel_world_size()
  786. input_dim = getattr(param, "input_dim", None)
  787. # Special case for GGUF
  788. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  789. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  790. if is_gguf_weight_type:
  791. param.weight_type = loaded_weight.item()
  792. # Materialize GGUF UninitializedParameter
  793. if is_gguf_weight and isinstance(param, UninitializedParameter):
  794. weight_shape = list(loaded_weight.shape)
  795. if input_dim:
  796. weight_shape[input_dim] = weight_shape[input_dim] // tp_size
  797. param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
  798. param_data = param.data
  799. if input_dim is not None:
  800. shard_size = param_data.shape[input_dim]
  801. if self.quant_config is None:
  802. start_idx = get_current_tp_rank_partition_offset(
  803. self.input_size,
  804. self.tp_rank,
  805. self.tp_size,
  806. multiple_of=self.partition_multiple_of)
  807. else:
  808. start_idx = self.tp_rank * shard_size
  809. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  810. shard_size)
  811. # Special case for loading scales off disk, which often do not
  812. # have a shape (such as in the case of AutoFP8).
  813. if len(loaded_weight.shape) == 0:
  814. loaded_weight = loaded_weight.reshape(1)
  815. assert param_data.shape == loaded_weight.shape
  816. param_data.copy_(loaded_weight)
  817. def forward(self, input_):
  818. if self.input_is_parallel:
  819. input_parallel = input_
  820. else:
  821. tp_rank = get_tensor_model_parallel_rank()
  822. splitted_input = split_tensor_along_last_dim(
  823. input_, num_partitions=self.tp_size)
  824. input_parallel = splitted_input[tp_rank].contiguous()
  825. # Matrix multiply.
  826. assert self.quant_method is not None
  827. # Only fuse bias add into GEMM for rank 0 (this ensures that
  828. # bias will not get added more than once in TP>1 case)
  829. bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
  830. output_parallel = self.quant_method.apply(self,
  831. input_parallel,
  832. bias=bias_)
  833. if self.reduce_results and self.tp_size > 1:
  834. output = tensor_model_parallel_all_reduce(output_parallel)
  835. else:
  836. output = output_parallel
  837. output_bias = self.bias if self.skip_bias_add else None
  838. return output, output_bias
  839. def extra_repr(self) -> str:
  840. s = f"input_features={self.input_size_per_partition}"
  841. s += f", output_features={self.output_size}"
  842. s += f", bias={self.bias is not None}"
  843. s += f", tp_size={self.tp_size}"
  844. s += f", reduce_results={self.reduce_results}"
  845. return s