linear.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. from abc import abstractmethod
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter
  7. from aphrodite.distributed import (
  8. divide, get_current_tp_rank_partition_offset,
  9. get_current_tp_rank_partition_size, get_tensor_model_parallel_rank,
  10. get_tensor_model_parallel_world_size, split_tensor_along_last_dim,
  11. tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
  12. from aphrodite.modeling.utils import set_weight_attrs
  13. from aphrodite.quantization.base_config import (QuantizationConfig,
  14. QuantizeMethodBase)
  15. def adjust_marlin_shard(param, shard_size, shard_offset):
  16. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  17. if marlin_tile_size is None:
  18. return shard_size, shard_offset
  19. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  20. def adjust_bitsandbytes_shard(param: Parameter,
  21. qkv_offsets: Dict[str, Tuple[int, int]],
  22. loaded_shard_id: str) -> Tuple[int, int]:
  23. """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
  24. total, _ = qkv_offsets["total"]
  25. orig_offset, orig_size = qkv_offsets[loaded_shard_id]
  26. quantized_total = param.data.shape[0]
  27. quantized_offset = orig_offset * quantized_total // total
  28. quantized_size = orig_size * quantized_total // total
  29. return quantized_size, quantized_offset
  30. def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
  31. """For fused modules (QKV and MLP) we have an array of length
  32. N that holds 1 scale for each "logical" matrix. So the param
  33. is an array of length N. The loaded_weight corresponds to
  34. one of the shards on disk. Here, we slice the param based on
  35. the shard_id for loading.
  36. """
  37. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  38. if isinstance(shard_id, str):
  39. shard_id = qkv_idxs[shard_id]
  40. elif not isinstance(shard_id, int):
  41. raise ValueError(f"Unknown Shard Id {shard_id}")
  42. # AutoFP8 scales do not have a shape
  43. # compressed-tensors scales do have a shape
  44. if len(loaded_weight.shape) != 0:
  45. assert loaded_weight.shape[0] == 1
  46. loaded_weight = loaded_weight[0]
  47. return param[shard_id], loaded_weight
  48. class LinearMethodBase(QuantizeMethodBase):
  49. """Base class for different (maybe quantized) linear methods."""
  50. @abstractmethod
  51. def create_weights(self, layer: torch.nn.Module,
  52. input_size_per_partition: int,
  53. output_partition_sizes: List[int], input_size: int,
  54. output_size: int, params_dtype: torch.dtype,
  55. **extra_weight_attrs):
  56. """Create weights for a linear layer.
  57. The weights will be set as attributes of the layer.
  58. Args:
  59. layer: The layer that is using the LinearMethodBase factory.
  60. input_size_per_partition: Size of the weight input dim on rank X.
  61. output_partition_sizes: Sizes of the output dim of each logical
  62. weight on rank X. E.g., output_partition_sizes for QKVLinear
  63. is a list contains the width of Wq, Wk, Wv on rank X.
  64. input_size: Size of the input dim of the weight across all ranks.
  65. output_size: Size of the output dim of the weight across all ranks.
  66. params_dtype: Datatype of the parameters.
  67. """
  68. raise NotImplementedError
  69. @abstractmethod
  70. def apply(self,
  71. layer: torch.nn.Module,
  72. x: torch.Tensor,
  73. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  74. """Apply the weights in layer to the input tensor.
  75. Expects create_weights to have been called before on the layer."""
  76. raise NotImplementedError
  77. class UnquantizedLinearMethod(LinearMethodBase):
  78. """Linear method without quantization."""
  79. def create_weights(self, layer: torch.nn.Module,
  80. input_size_per_partition: int,
  81. output_partition_sizes: List[int], input_size: int,
  82. output_size: int, params_dtype: torch.dtype,
  83. **extra_weight_attrs):
  84. weight = Parameter(torch.empty(sum(output_partition_sizes),
  85. input_size_per_partition,
  86. dtype=params_dtype),
  87. requires_grad=False)
  88. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  89. layer.register_parameter("weight", weight)
  90. set_weight_attrs(weight, extra_weight_attrs)
  91. def apply(self,
  92. layer: torch.nn.Module,
  93. x: torch.Tensor,
  94. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  95. return F.linear(x, layer.weight, bias)
  96. class LinearBase(torch.nn.Module):
  97. """Base linear layer.
  98. Args:
  99. input_size: input dimension of the linear layer.
  100. output_size: output dimension of the linear layer.
  101. bias: If true, add bias.
  102. skip_bias_add: If true, skip adding bias but instead return it.
  103. params_dtype: Data type for the parameters.
  104. quant_config: Quantization configure.
  105. """
  106. def __init__(
  107. self,
  108. input_size: int,
  109. output_size: int,
  110. skip_bias_add: bool = False,
  111. params_dtype: Optional[torch.dtype] = None,
  112. quant_config: Optional[QuantizationConfig] = None,
  113. prefix: str = "",
  114. ):
  115. super().__init__()
  116. # Keep input parameters
  117. self.input_size = input_size
  118. self.output_size = output_size
  119. self.skip_bias_add = skip_bias_add
  120. if params_dtype is None:
  121. params_dtype = torch.get_default_dtype()
  122. self.params_dtype = params_dtype
  123. if quant_config is None:
  124. self.quant_method: Optional[
  125. QuantizeMethodBase] = UnquantizedLinearMethod()
  126. else:
  127. self.quant_method = quant_config.get_quant_method(self,
  128. prefix=prefix)
  129. def forward(self, x: torch.Tensor) -> torch.Tensor:
  130. raise NotImplementedError
  131. class ReplicatedLinear(LinearBase):
  132. """Replicated linear layer.
  133. Args:
  134. input_size: input dimension of the linear layer.
  135. output_size: output dimension of the linear layer.
  136. bias: If true, add bias.
  137. skip_bias_add: If true, skip adding bias but instead return it.
  138. params_dtype: Data type for the parameters.
  139. quant_config: Quantization configure.
  140. prefix: The name of the layer in the state dict, including all parents
  141. (e.g. model.layers.0.qkv_proj)
  142. """
  143. def __init__(self,
  144. input_size: int,
  145. output_size: int,
  146. bias: bool = True,
  147. skip_bias_add: bool = False,
  148. params_dtype: Optional[torch.dtype] = None,
  149. quant_config: Optional[QuantizationConfig] = None,
  150. prefix: str = ""):
  151. super().__init__(input_size,
  152. output_size,
  153. skip_bias_add,
  154. params_dtype,
  155. quant_config,
  156. prefix=prefix)
  157. # All the linear layer supports quant method.
  158. assert self.quant_method is not None
  159. self.quant_method.create_weights(self,
  160. self.input_size, [self.output_size],
  161. self.input_size,
  162. self.output_size,
  163. self.params_dtype,
  164. prefix=prefix)
  165. if bias:
  166. self.bias = Parameter(
  167. torch.empty(self.output_size, dtype=self.params_dtype))
  168. set_weight_attrs(self.bias, {"output_dim": 0})
  169. else:
  170. self.register_parameter("bias", None)
  171. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  172. # If the weight on disk does not have a shape, give it one
  173. # (such scales for AutoFp8).
  174. if len(loaded_weight.shape) == 0:
  175. loaded_weight = loaded_weight.reshape(1)
  176. assert param.size() == loaded_weight.size()
  177. param.data.copy_(loaded_weight)
  178. def forward(self, x: torch.Tensor) -> torch.Tensor:
  179. bias = self.bias if not self.skip_bias_add else None
  180. assert self.quant_method is not None
  181. output = self.quant_method.apply(self, x, bias)
  182. output_bias = self.bias if self.skip_bias_add else None
  183. return output, output_bias
  184. def extra_repr(self) -> str:
  185. s = f"in_features={self.input_size}"
  186. s += f", output_features={self.output_size}"
  187. s += f", bias={self.bias is not None}"
  188. return s
  189. class ColumnParallelLinear(LinearBase):
  190. """Linear layer with column parallelism.
  191. The linear layer is defined as Y = XA + b. A is parallelized along
  192. its second dimension as A = [A_1, ..., A_p].
  193. Args:
  194. input_size: first dimension of matrix A.
  195. output_size: second dimension of matrix A.
  196. bias: If true, add bias.
  197. gather_output: If true, call all-gather on output and make Y available
  198. to all GPUs, otherwise, every GPU will have its output
  199. which is Y_i = XA_i
  200. skip_bias_add: This was added to enable performance optimizations where
  201. bias can be fused with other element-wise operations. we
  202. skip adding bias but instead return it.
  203. params_dtype: Data type for the parameters.
  204. quant_config: Quantization configure.
  205. output_sizes: list of output sizes packed into one output, like for QKV
  206. the list would be size 3.
  207. prefix: The name of the layer in the state dict, including all parents
  208. (e.g. model.layers.0.qkv_proj)
  209. """
  210. def __init__(self,
  211. input_size: int,
  212. output_size: int,
  213. bias: bool = True,
  214. gather_output: bool = False,
  215. skip_bias_add: bool = False,
  216. params_dtype: Optional[torch.dtype] = None,
  217. quant_config: Optional[QuantizationConfig] = None,
  218. output_sizes: Optional[List[int]] = None,
  219. prefix: str = ""):
  220. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  221. quant_config, prefix)
  222. self.gather_output = gather_output
  223. # Divide the weight matrix along the last dimension.
  224. tp_rank = get_tensor_model_parallel_rank()
  225. tp_size = get_tensor_model_parallel_world_size()
  226. assert self.quant_method is not None
  227. if quant_config is None:
  228. self.output_size_per_partition = get_current_tp_rank_partition_size(
  229. output_size, tp_rank, tp_size)
  230. else:
  231. self.output_size_per_partition = divide(self.output_size, tp_size)
  232. self.output_partition_sizes = [self.output_size_per_partition]
  233. # If QKV or MergedColumn, use output size of each partition.
  234. if hasattr(self, "output_sizes"):
  235. if quant_config is None:
  236. self.output_partition_sizes = [
  237. get_current_tp_rank_partition_size(output_size, tp_rank,
  238. tp_size)
  239. for output_size in self.output_sizes
  240. ]
  241. else:
  242. self.output_partition_sizes = [
  243. divide(output_size, tp_size)
  244. for output_size in self.output_sizes
  245. ]
  246. if output_sizes is None:
  247. output_sizes = [output_size]
  248. self.quant_method.create_weights(
  249. layer=self,
  250. input_size_per_partition=self.input_size,
  251. output_partition_sizes=self.output_partition_sizes,
  252. input_size=self.input_size,
  253. output_size=self.output_size,
  254. params_dtype=self.params_dtype,
  255. weight_loader=self.weight_loader,
  256. prefix=prefix)
  257. if bias:
  258. self.bias = Parameter(
  259. torch.empty(self.output_size_per_partition,
  260. dtype=params_dtype))
  261. set_weight_attrs(self.bias, {
  262. "output_dim": 0,
  263. "weight_loader": self.weight_loader,
  264. })
  265. else:
  266. self.register_parameter("bias", None)
  267. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  268. tp_rank = get_tensor_model_parallel_rank()
  269. output_dim = getattr(param, "output_dim", None)
  270. param_data = param.data
  271. if output_dim is not None:
  272. shard_size = param_data.shape[output_dim]
  273. start_idx = tp_rank * shard_size
  274. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  275. shard_size)
  276. # Special case for loading scales off disk, which often do not
  277. # have a shape (such as in the case of AutoFP8).
  278. if len(loaded_weight.shape) == 0:
  279. loaded_weight = loaded_weight.reshape(1)
  280. assert param_data.shape == loaded_weight.shape
  281. param_data.copy_(loaded_weight)
  282. def forward(self, input_):
  283. bias = self.bias if not self.skip_bias_add else None
  284. # Matrix multiply.
  285. assert self.quant_method is not None
  286. output_parallel = self.quant_method.apply(self, input_, bias)
  287. if self.gather_output:
  288. # All-gather across the partitions.
  289. output = tensor_model_parallel_all_gather(output_parallel)
  290. else:
  291. output = output_parallel
  292. output_bias = self.bias if self.skip_bias_add else None
  293. return output, output_bias
  294. def extra_repr(self) -> str:
  295. s = f"in_features={self.input_size}"
  296. s += f", output_features={self.output_size_per_partition}"
  297. s += f", bias={self.bias is not None}"
  298. s += f", tp_size={get_tensor_model_parallel_world_size()}"
  299. s += f", gather_output={self.gather_output}"
  300. return s
  301. class MergedColumnParallelLinear(ColumnParallelLinear):
  302. """Packed linear layers with column parallelism.
  303. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  304. along the output dimension. When the weight matrix is loaded, the
  305. different partitions are sharded separately.
  306. Args:
  307. input_size: input dimension of the linear layer.
  308. output_sizes: list of output dimensions of the linear layer.
  309. bias: If true, add bias.
  310. gather_output: If true, call all-gather on output and make the output
  311. available to all GPUs, otherwise, every GPU will have
  312. its own output.
  313. skip_bias_add: This was added to enable performance optimizations where
  314. bias can be fused with other element-wise operations. we
  315. skip adding bias but instead return it.
  316. params_dtype: Data type for the parameters.
  317. quant_config: Quantization configure.
  318. prefix: The name of the layer in the state dict, including all parents
  319. (e.g. model.layers.0.qkv_proj)
  320. """
  321. def __init__(self,
  322. input_size: int,
  323. output_sizes: List[int],
  324. bias: bool = True,
  325. gather_output: bool = False,
  326. skip_bias_add: bool = False,
  327. params_dtype: Optional[torch.dtype] = None,
  328. quant_config: Optional[QuantizationConfig] = None,
  329. prefix: str = ""):
  330. self.output_sizes = output_sizes
  331. self.quant_config = quant_config
  332. if quant_config is not None:
  333. tp_size = get_tensor_model_parallel_world_size()
  334. assert all(output_size % tp_size == 0
  335. for output_size in output_sizes)
  336. super().__init__(input_size=input_size,
  337. output_size=sum(output_sizes),
  338. bias=bias,
  339. gather_output=gather_output,
  340. skip_bias_add=skip_bias_add,
  341. params_dtype=params_dtype,
  342. quant_config=quant_config,
  343. prefix=prefix)
  344. def weight_loader(self,
  345. param: Parameter,
  346. loaded_weight: torch.Tensor,
  347. loaded_shard_id: Optional[int] = None):
  348. param_data = param.data
  349. output_dim = getattr(param, "output_dim", None)
  350. # Special case for AQLM codebooks.
  351. is_metadata = getattr(param, "is_metadata", False)
  352. # Special case for per-tensor scale to load scalar into fused array.
  353. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  354. if loaded_shard_id is None:
  355. # Loaded weight is already fused on disk (qkv/mlp).
  356. if output_dim is None:
  357. if needs_scalar_to_array:
  358. param_data, loaded_weight = adjust_scalar_to_fused_array(
  359. param_data, loaded_weight, 0)
  360. assert param_data.shape == loaded_weight.shape
  361. param_data.copy_(loaded_weight)
  362. return
  363. current_shard_offset = 0
  364. shard_offsets: List[Tuple[int, int, int]] = []
  365. for i, output_size in enumerate(self.output_sizes):
  366. shard_offsets.append((i, current_shard_offset, output_size))
  367. current_shard_offset += output_size
  368. packed_dim = getattr(param, "packed_dim", None)
  369. for shard_id, shard_offset, shard_size in shard_offsets:
  370. # Special case for Quantization.
  371. # If quantized, we need to adjust the offset and size to account
  372. # for the packing.
  373. if packed_dim == output_dim:
  374. shard_size = shard_size // param.pack_factor
  375. shard_offset = shard_offset // param.pack_factor
  376. # Special case for Marlin.
  377. shard_size, shard_offset = adjust_marlin_shard(
  378. param, shard_size, shard_offset)
  379. loaded_weight_shard = loaded_weight.narrow(
  380. output_dim, shard_offset, shard_size)
  381. self.weight_loader(param, loaded_weight_shard, shard_id)
  382. return
  383. assert loaded_shard_id < len(self.output_sizes)
  384. tp_rank = get_tensor_model_parallel_rank()
  385. tp_size = get_tensor_model_parallel_world_size()
  386. if output_dim is not None:
  387. if self.quant_config is None:
  388. shard_offset = sum(
  389. get_current_tp_rank_partition_size(output_size, tp_rank,
  390. tp_size)
  391. for output_size in self.output_sizes[:loaded_shard_id])
  392. shard_size = get_current_tp_rank_partition_size(
  393. self.output_sizes[loaded_shard_id], tp_rank, tp_size)
  394. else:
  395. shard_offset = sum(
  396. self.output_sizes[:loaded_shard_id]) // tp_size
  397. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  398. # Special case for quantization.
  399. # If quantized, we need to adjust the offset and size to account
  400. # for the packing.
  401. packed_dim = getattr(param, "packed_dim", None)
  402. if packed_dim == output_dim:
  403. shard_size = shard_size // param.pack_factor
  404. shard_offset = shard_offset // param.pack_factor
  405. # Special case for Marlin.
  406. shard_size, shard_offset = adjust_marlin_shard(
  407. param, shard_size, shard_offset)
  408. use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
  409. if use_bitsandbytes:
  410. shard_size = loaded_weight.shape[output_dim]
  411. shard_offset = loaded_weight.shape[output_dim] * \
  412. loaded_shard_id
  413. param_data = param_data.narrow(output_dim, shard_offset,
  414. shard_size)
  415. if self.quant_config is None:
  416. start_idx = get_current_tp_rank_partition_offset(
  417. loaded_weight.shape[output_dim], tp_rank, tp_size)
  418. else:
  419. start_idx = tp_rank * shard_size
  420. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  421. shard_size)
  422. # Special case for AQLM codebooks.
  423. elif is_metadata:
  424. # metadata indicates fixed size concatenated along dim 0
  425. shard_size = loaded_weight.shape[0]
  426. shard_offset = loaded_shard_id * shard_size
  427. param_data = param_data.narrow(0, shard_offset, shard_size)
  428. # Special case for per-tensor scales in fused case.
  429. elif needs_scalar_to_array:
  430. param_data, loaded_weight = adjust_scalar_to_fused_array(
  431. param_data, loaded_weight, loaded_shard_id)
  432. else:
  433. ignore_warning = getattr(param, "ignore_warning", False)
  434. if not ignore_warning:
  435. logger.warning(
  436. "Loading a weight without `output_dim` attribute in "
  437. "MergedColumnParallelLinear, assume the weight is "
  438. "the same for all partitions.")
  439. assert param_data.shape == loaded_weight.shape
  440. param_data.copy_(loaded_weight)
  441. class QKVParallelLinear(ColumnParallelLinear):
  442. """Linear layers for the attention's QKV transformation.
  443. Linear layers for the linear transformation of the query, key, and value
  444. vectors in the attention layer. The weight matrix is concatenated along
  445. the output dimension. The layer is parallelized along the head dimension.
  446. When the number of key/value heads is smaller than the number of query
  447. heads (e.g., multi-query/grouped-query attention), the key/value head may
  448. be replicated while the query heads are partitioned.
  449. Args:
  450. hidden_size: input hidden state size of the transformer.
  451. head_size: size of each attention head.
  452. total_num_heads: total number of attention query heads.
  453. total_num_kv_heads: total number of attention key/value heads. If
  454. None, assume total_num_kv_heads = total_num_heads.
  455. bias: If true, add bias.
  456. skip_bias_add: This was added to enable performance optimizations where
  457. bias can be fused with other element-wise operations. we
  458. skip adding bias but instead return it.
  459. params_dtype: Data type for the parameters.
  460. quant_config: Quantization configure.
  461. prefix: The name of the layer in the state dict, including all parents
  462. (e.g. model.layers.0.qkv_proj)
  463. """
  464. def __init__(self,
  465. hidden_size: int,
  466. head_size: int,
  467. total_num_heads: int,
  468. total_num_kv_heads: Optional[int] = None,
  469. bias: bool = True,
  470. skip_bias_add: bool = False,
  471. params_dtype: Optional[torch.dtype] = None,
  472. quant_config: Optional[QuantizationConfig] = None,
  473. prefix: str = ""):
  474. self.hidden_size = hidden_size
  475. self.head_size = head_size
  476. self.total_num_heads = total_num_heads
  477. self.quant_config = quant_config
  478. if total_num_kv_heads is None:
  479. total_num_kv_heads = total_num_heads
  480. self.total_num_kv_heads = total_num_kv_heads
  481. # Divide the weight matrix along the last dimension.
  482. tp_size = get_tensor_model_parallel_world_size()
  483. tp_rank = get_tensor_model_parallel_rank()
  484. if quant_config is None:
  485. self.num_heads_per_kv_head = (self.total_num_heads //
  486. self.total_num_kv_heads)
  487. self.num_kv_heads = get_current_tp_rank_partition_size(
  488. self.total_num_kv_heads, tp_rank, tp_size)
  489. self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head
  490. self.num_kv_head_replicas = 1
  491. else:
  492. self.num_heads = divide(self.total_num_heads, tp_size)
  493. if tp_size >= self.total_num_kv_heads:
  494. self.num_kv_heads = 1
  495. self.num_kv_head_replicas = divide(tp_size,
  496. self.total_num_kv_heads)
  497. elif tp_size < self.total_num_kv_heads and quant_config is not None:
  498. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  499. self.num_kv_head_replicas = 1
  500. input_size = self.hidden_size
  501. output_size = (self.num_heads +
  502. 2 * self.num_kv_heads) * tp_size * self.head_size
  503. self.output_sizes = [
  504. self.num_heads * self.head_size * tp_size, # q_proj
  505. self.num_kv_heads * self.head_size * tp_size, # k_proj
  506. self.num_kv_heads * self.head_size * tp_size, # v_proj
  507. ]
  508. super().__init__(input_size=input_size,
  509. output_size=output_size,
  510. bias=bias,
  511. gather_output=False,
  512. skip_bias_add=skip_bias_add,
  513. params_dtype=params_dtype,
  514. quant_config=quant_config,
  515. prefix=prefix)
  516. def weight_loader(self,
  517. param: Parameter,
  518. loaded_weight: torch.Tensor,
  519. loaded_shard_id: Optional[str] = None):
  520. param_data = param.data
  521. output_dim = getattr(param, "output_dim", None)
  522. # Special case for AQLM codebooks.
  523. is_metadata = getattr(param, "is_metadata", False)
  524. # Special case for per-tensor scales in fused case.
  525. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  526. if loaded_shard_id is None:
  527. # Loaded weight is already fused on disk (qkv/mlp).
  528. if output_dim is None:
  529. if needs_scalar_to_array:
  530. param_data, loaded_weight = adjust_scalar_to_fused_array(
  531. param_data, loaded_weight, 0)
  532. assert param_data.shape == loaded_weight.shape
  533. param_data.copy_(loaded_weight)
  534. return
  535. shard_offsets = [
  536. # (shard_id, shard_offset, shard_size)
  537. ("q", 0, self.total_num_heads * self.head_size),
  538. ("k", self.total_num_heads * self.head_size,
  539. self.total_num_kv_heads * self.head_size),
  540. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  541. self.head_size, self.total_num_kv_heads * self.head_size),
  542. ]
  543. packed_dim = getattr(param, "packed_dim", None)
  544. for shard_id, shard_offset, shard_size in shard_offsets:
  545. # Special case for Quantized Weights.
  546. # If quantized, we need to adjust the offset and size to account
  547. # for the packing.
  548. if packed_dim == output_dim:
  549. shard_size = shard_size // param.pack_factor
  550. shard_offset = shard_offset // param.pack_factor
  551. # Special case for Marlin.
  552. shard_size, shard_offset = adjust_marlin_shard(
  553. param, shard_size, shard_offset)
  554. loaded_weight_shard = loaded_weight.narrow(
  555. output_dim, shard_offset, shard_size)
  556. self.weight_loader(param, loaded_weight_shard, shard_id)
  557. return
  558. tp_rank = get_tensor_model_parallel_rank()
  559. assert loaded_shard_id in ["q", "k", "v"]
  560. # If output dim is defined, use the default loading process.
  561. if output_dim is not None:
  562. if loaded_shard_id == "q":
  563. shard_offset = 0
  564. shard_size = self.num_heads * self.head_size
  565. if self.quant_config is None:
  566. multiple_of = self.head_size * self.num_heads_per_kv_head
  567. elif loaded_shard_id == "k":
  568. shard_offset = self.num_heads * self.head_size
  569. shard_size = self.num_kv_heads * self.head_size
  570. if self.quant_config is None:
  571. multiple_of = self.head_size
  572. elif loaded_shard_id == "v":
  573. shard_offset = (self.num_heads +
  574. self.num_kv_heads) * self.head_size
  575. shard_size = self.num_kv_heads * self.head_size
  576. if self.quant_config is None:
  577. multiple_of = self.head_size
  578. # Special case for Quantized Weights.
  579. # If quantized, we need to adjust the offset and size to account
  580. # for the packing.
  581. packed_dim = getattr(param, "packed_dim", None)
  582. if packed_dim == output_dim:
  583. shard_size = shard_size // param.pack_factor
  584. shard_offset = shard_offset // param.pack_factor
  585. if self.quant_config is None:
  586. multiple_of = multiple_of // param.pack_factor
  587. # Special case for Marlin.
  588. shard_size, shard_offset = adjust_marlin_shard(
  589. param, shard_size, shard_offset)
  590. use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
  591. if use_bitsandbytes:
  592. orig_qkv_offsets = {
  593. "q": (0, self.num_heads * self.head_size),
  594. "k": (self.num_heads * self.head_size,
  595. self.num_kv_heads * self.head_size),
  596. "v":
  597. ((self.num_heads + self.num_kv_heads) * self.head_size,
  598. self.num_kv_heads * self.head_size),
  599. "total":
  600. ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
  601. 0)
  602. }
  603. shard_size, shard_offset = adjust_bitsandbytes_shard(
  604. param, orig_qkv_offsets, loaded_shard_id)
  605. param_data = param_data.narrow(output_dim, shard_offset,
  606. shard_size)
  607. if self.quant_config is None:
  608. tp_size = get_tensor_model_parallel_world_size()
  609. total_size = loaded_weight.shape[output_dim]
  610. start_idx = get_current_tp_rank_partition_offset(
  611. total_size, tp_rank, tp_size, multiple_of=multiple_of)
  612. else:
  613. if loaded_shard_id == "q":
  614. shard_id = tp_rank
  615. else:
  616. shard_id = tp_rank // self.num_kv_head_replicas
  617. start_idx = shard_id * shard_size
  618. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  619. shard_size)
  620. # Special case for for AQLM codebooks.
  621. elif is_metadata:
  622. # metadata indicates fixed size concatenated along dim 0
  623. shard_size = loaded_weight.shape[0]
  624. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  625. param_data = param_data.narrow(0, shard_index * shard_size,
  626. shard_size)
  627. # Special case for per-tensor scales in fused case.
  628. elif needs_scalar_to_array:
  629. param_data, loaded_weight = adjust_scalar_to_fused_array(
  630. param_data, loaded_weight, loaded_shard_id)
  631. else:
  632. ignore_warning = getattr(param, "ignore_warning", False)
  633. if not ignore_warning:
  634. logger.warning(
  635. "Loading a weight without `output_dim` attribute in "
  636. "QKVParallelLinear, assume the weight is the same "
  637. "for all partitions.")
  638. assert param_data.shape == loaded_weight.shape
  639. param_data.copy_(loaded_weight)
  640. class RowParallelLinear(LinearBase):
  641. """Linear layer with row parallelism.
  642. The linear layer is defined as Y = XA + b. A is parallelized along
  643. its first dimension and X along its second dimension as:
  644. - -
  645. | A_1 |
  646. | . |
  647. A = | . | X = [X_1, ..., X_p]
  648. | . |
  649. | A_p |
  650. - -
  651. Arguments:
  652. input_size: first dimension of matrix A.
  653. output_size: second dimension of matrix A.
  654. bias: If true, add bias. Note that bias is not parallelized.
  655. input_is_parallel: If true, we assume that the input is already
  656. split across the GPUs and we do not split
  657. again.
  658. skip_bias_add: This was added to enable performance optimization where
  659. bias can be fused with other element-wise operations.
  660. We skip adding bias but instead return it.
  661. params_dtype: Data type for the parameters.
  662. quant_config: Quantization configure.
  663. partition_multiple_of: Partitions will be divided,
  664. so each partition is a multiple of this number.
  665. """
  666. def __init__(self,
  667. input_size: int,
  668. output_size: int,
  669. bias: bool = True,
  670. input_is_parallel: bool = True,
  671. skip_bias_add: bool = False,
  672. params_dtype: Optional[torch.dtype] = None,
  673. reduce_results: bool = True,
  674. quant_config: Optional[QuantizationConfig] = None,
  675. partition_multiple_of: int = 1,
  676. prefix: str = ""):
  677. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  678. quant_config, prefix)
  679. self.input_is_parallel = input_is_parallel
  680. self.reduce_results = reduce_results
  681. self.quant_config = quant_config
  682. # Divide the weight matrix along the last dimension.
  683. self.tp_rank = get_tensor_model_parallel_rank()
  684. self.tp_size = get_tensor_model_parallel_world_size()
  685. self.tp_rank = get_tensor_model_parallel_rank()
  686. if quant_config is None:
  687. self.partition_multiple_of = partition_multiple_of
  688. self.input_size_per_partition = get_current_tp_rank_partition_size(
  689. input_size, self.tp_rank, self.tp_size, partition_multiple_of)
  690. else:
  691. self.input_size_per_partition = divide(input_size, self.tp_size)
  692. assert self.quant_method is not None
  693. self.quant_method.create_weights(
  694. layer=self,
  695. input_size_per_partition=self.input_size_per_partition,
  696. output_partition_sizes=[self.output_size],
  697. input_size=self.input_size,
  698. output_size=self.output_size,
  699. params_dtype=self.params_dtype,
  700. weight_loader=self.weight_loader,
  701. prefix=prefix)
  702. if not reduce_results and (bias and not skip_bias_add):
  703. raise ValueError("When not reduce the results, adding bias to the "
  704. "results can lead to incorrect results")
  705. if bias:
  706. self.bias = Parameter(
  707. torch.empty(self.output_size, dtype=params_dtype))
  708. set_weight_attrs(self.bias, {
  709. "output_dim": 0,
  710. "weight_loader": self.weight_loader,
  711. })
  712. else:
  713. self.register_parameter("bias", None)
  714. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  715. input_dim = getattr(param, "input_dim", None)
  716. param_data = param.data
  717. if input_dim is not None:
  718. shard_size = param_data.shape[input_dim]
  719. if self.quant_config is None:
  720. start_idx = get_current_tp_rank_partition_offset(
  721. self.input_size,
  722. self.tp_rank,
  723. self.tp_size,
  724. multiple_of=self.partition_multiple_of)
  725. else:
  726. start_idx = self.tp_rank * shard_size
  727. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  728. shard_size)
  729. # Special case for loading scales off disk, which often do not
  730. # have a shape (such as in the case of AutoFP8).
  731. if len(loaded_weight.shape) == 0:
  732. loaded_weight = loaded_weight.reshape(1)
  733. assert param_data.shape == loaded_weight.shape
  734. param_data.copy_(loaded_weight)
  735. def forward(self, input_):
  736. if self.input_is_parallel:
  737. input_parallel = input_
  738. else:
  739. tp_rank = get_tensor_model_parallel_rank()
  740. splitted_input = split_tensor_along_last_dim(
  741. input_, num_partitions=self.tp_size)
  742. input_parallel = splitted_input[tp_rank].contiguous()
  743. # Matrix multiply.
  744. assert self.quant_method is not None
  745. # Only fuse bias add into GEMM for rank 0 (this ensures that
  746. # bias will not get added more than once in TP>1 case)
  747. bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
  748. output_parallel = self.quant_method.apply(self,
  749. input_parallel,
  750. bias=bias_)
  751. if self.reduce_results and self.tp_size > 1:
  752. output = tensor_model_parallel_all_reduce(output_parallel)
  753. else:
  754. output = output_parallel
  755. output_bias = self.bias if self.skip_bias_add else None
  756. return output, output_bias
  757. def extra_repr(self) -> str:
  758. s = f"input_features={self.input_size_per_partition}"
  759. s += f", output_features={self.output_size}"
  760. s += f", bias={self.bias is not None}"
  761. s += f", tp_size={self.tp_size}"
  762. s += f", reduce_results={self.reduce_results}"
  763. return s