1
0

linear.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157
  1. from abc import abstractmethod
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter, UninitializedParameter
  7. # yapf: disable
  8. from aphrodite.distributed import (divide,
  9. get_current_tp_rank_partition_offset,
  10. get_current_tp_rank_partition_size,
  11. get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce)
  16. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  17. PackedAphroditeParameter,
  18. PackedColumnParameter,
  19. PerTensorScaleParameter,
  20. RowAphroditeParameter)
  21. # yapf: enable
  22. from aphrodite.modeling.utils import set_weight_attrs
  23. from aphrodite.quantization.base_config import (QuantizationConfig,
  24. QuantizeMethodBase)
  25. WEIGHT_LOADER_V2_SUPPORTED = [
  26. "AWQLinearMethod",
  27. "AWQMarlinLinearMethod",
  28. "CompressedTensorsLinearMethod",
  29. "FBGEMMFp8LinearMethod",
  30. "Fp8LinearMethod",
  31. "GPTQLinearMethod",
  32. "GPTQMarlin24LinearMethod",
  33. "GPTQMarlinLinearMethod",
  34. "HQQMarlinMethod",
  35. "MarlinLinearMethod",
  36. "ModelOptFp8LinearMethod",
  37. "QQQLinearMethod",
  38. "TPUInt8LinearMethod",
  39. ]
  40. def adjust_marlin_shard(param, shard_size, shard_offset):
  41. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  42. if marlin_tile_size is None:
  43. return shard_size, shard_offset
  44. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  45. def adjust_bitsandbytes_4bit_shard(param: Parameter,
  46. qkv_offsets: Dict[str, Tuple[int, int]],
  47. loaded_shard_id: str) -> Tuple[int, int]:
  48. """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
  49. total, _ = qkv_offsets["total"]
  50. orig_offset, orig_size = qkv_offsets[loaded_shard_id]
  51. quantized_total = param.data.shape[0]
  52. quantized_offset = orig_offset * quantized_total // total
  53. quantized_size = orig_size * quantized_total // total
  54. return quantized_size, quantized_offset
  55. def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
  56. """For fused modules (QKV and MLP) we have an array of length
  57. N that holds 1 scale for each "logical" matrix. So the param
  58. is an array of length N. The loaded_weight corresponds to
  59. one of the shards on disk. Here, we slice the param based on
  60. the shard_id for loading.
  61. """
  62. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  63. if isinstance(shard_id, str):
  64. shard_id = qkv_idxs[shard_id]
  65. elif not isinstance(shard_id, int):
  66. raise ValueError(f"Unknown Shard Id {shard_id}")
  67. # AutoFP8 scales do not have a shape
  68. # compressed-tensors scales do have a shape
  69. if len(loaded_weight.shape) != 0:
  70. assert loaded_weight.shape[0] == 1
  71. loaded_weight = loaded_weight[0]
  72. return param[shard_id], loaded_weight
  73. class LinearMethodBase(QuantizeMethodBase):
  74. """Base class for different (maybe quantized) linear methods."""
  75. @abstractmethod
  76. def create_weights(self, layer: torch.nn.Module,
  77. input_size_per_partition: int,
  78. output_partition_sizes: List[int], input_size: int,
  79. output_size: int, params_dtype: torch.dtype,
  80. **extra_weight_attrs):
  81. """Create weights for a linear layer.
  82. The weights will be set as attributes of the layer.
  83. Args:
  84. layer: The layer that is using the LinearMethodBase factory.
  85. input_size_per_partition: Size of the weight input dim on rank X.
  86. output_partition_sizes: Sizes of the output dim of each logical
  87. weight on rank X. E.g., output_partition_sizes for QKVLinear
  88. is a list contains the width of Wq, Wk, Wv on rank X.
  89. input_size: Size of the input dim of the weight across all ranks.
  90. output_size: Size of the output dim of the weight across all ranks.
  91. params_dtype: Datatype of the parameters.
  92. """
  93. raise NotImplementedError
  94. @abstractmethod
  95. def apply(self,
  96. layer: torch.nn.Module,
  97. x: torch.Tensor,
  98. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  99. """Apply the weights in layer to the input tensor.
  100. Expects create_weights to have been called before on the layer."""
  101. raise NotImplementedError
  102. class UnquantizedLinearMethod(LinearMethodBase):
  103. """Linear method without quantization."""
  104. def create_weights(self, layer: torch.nn.Module,
  105. input_size_per_partition: int,
  106. output_partition_sizes: List[int], input_size: int,
  107. output_size: int, params_dtype: torch.dtype,
  108. **extra_weight_attrs):
  109. weight = Parameter(torch.empty(sum(output_partition_sizes),
  110. input_size_per_partition,
  111. dtype=params_dtype),
  112. requires_grad=False)
  113. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  114. layer.register_parameter("weight", weight)
  115. set_weight_attrs(weight, extra_weight_attrs)
  116. def apply(self,
  117. layer: torch.nn.Module,
  118. x: torch.Tensor,
  119. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  120. return F.linear(x, layer.weight, bias)
  121. class LinearBase(torch.nn.Module):
  122. """Base linear layer.
  123. Args:
  124. input_size: input dimension of the linear layer.
  125. output_size: output dimension of the linear layer.
  126. bias: If true, add bias.
  127. skip_bias_add: If true, skip adding bias but instead return it.
  128. params_dtype: Data type for the parameters.
  129. quant_config: Quantization configure.
  130. """
  131. def __init__(
  132. self,
  133. input_size: int,
  134. output_size: int,
  135. skip_bias_add: bool = False,
  136. params_dtype: Optional[torch.dtype] = None,
  137. quant_config: Optional[QuantizationConfig] = None,
  138. prefix: str = "",
  139. ):
  140. super().__init__()
  141. # Keep input parameters
  142. self.input_size = input_size
  143. self.output_size = output_size
  144. self.skip_bias_add = skip_bias_add
  145. if params_dtype is None:
  146. params_dtype = torch.get_default_dtype()
  147. self.params_dtype = params_dtype
  148. if quant_config is None:
  149. self.quant_method: Optional[
  150. QuantizeMethodBase] = UnquantizedLinearMethod()
  151. else:
  152. self.quant_method = quant_config.get_quant_method(self,
  153. prefix=prefix)
  154. def forward(self, x: torch.Tensor) -> torch.Tensor:
  155. raise NotImplementedError
  156. class ReplicatedLinear(LinearBase):
  157. """Replicated linear layer.
  158. Args:
  159. input_size: input dimension of the linear layer.
  160. output_size: output dimension of the linear layer.
  161. bias: If true, add bias.
  162. skip_bias_add: If true, skip adding bias but instead return it.
  163. params_dtype: Data type for the parameters.
  164. quant_config: Quantization configure.
  165. prefix: The name of the layer in the state dict, including all parents
  166. (e.g. model.layers.0.qkv_proj)
  167. """
  168. def __init__(self,
  169. input_size: int,
  170. output_size: int,
  171. bias: bool = True,
  172. skip_bias_add: bool = False,
  173. params_dtype: Optional[torch.dtype] = None,
  174. quant_config: Optional[QuantizationConfig] = None,
  175. prefix: str = ""):
  176. super().__init__(input_size,
  177. output_size,
  178. skip_bias_add,
  179. params_dtype,
  180. quant_config,
  181. prefix=prefix)
  182. # All the linear layer supports quant method.
  183. assert self.quant_method is not None
  184. self.quant_method.create_weights(self,
  185. self.input_size, [self.output_size],
  186. self.input_size,
  187. self.output_size,
  188. self.params_dtype,
  189. weight_loader=self.weight_loader)
  190. if bias:
  191. self.bias = Parameter(
  192. torch.empty(self.output_size, dtype=self.params_dtype))
  193. set_weight_attrs(self.bias, {"output_dim": 0})
  194. else:
  195. self.register_parameter("bias", None)
  196. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  197. # If the weight on disk does not have a shape, give it one
  198. # (such scales for AutoFp8).
  199. if len(loaded_weight.shape) == 0:
  200. loaded_weight = loaded_weight.reshape(1)
  201. assert param.size() == loaded_weight.size()
  202. param.data.copy_(loaded_weight)
  203. def forward(self, x: torch.Tensor) -> torch.Tensor:
  204. bias = self.bias if not self.skip_bias_add else None
  205. assert self.quant_method is not None
  206. output = self.quant_method.apply(self, x, bias)
  207. output_bias = self.bias if self.skip_bias_add else None
  208. return output, output_bias
  209. def extra_repr(self) -> str:
  210. s = f"in_features={self.input_size}"
  211. s += f", output_features={self.output_size}"
  212. s += f", bias={self.bias is not None}"
  213. return s
  214. class ColumnParallelLinear(LinearBase):
  215. """Linear layer with column parallelism.
  216. The linear layer is defined as Y = XA + b. A is parallelized along
  217. its second dimension as A = [A_1, ..., A_p].
  218. Args:
  219. input_size: first dimension of matrix A.
  220. output_size: second dimension of matrix A.
  221. bias: If true, add bias.
  222. gather_output: If true, call all-gather on output and make Y available
  223. to all GPUs, otherwise, every GPU will have its output
  224. which is Y_i = XA_i
  225. skip_bias_add: This was added to enable performance optimizations where
  226. bias can be fused with other element-wise operations. we
  227. skip adding bias but instead return it.
  228. params_dtype: Data type for the parameters.
  229. quant_config: Quantization configure.
  230. output_sizes: list of output sizes packed into one output, like for QKV
  231. the list would be size 3.
  232. prefix: The name of the layer in the state dict, including all parents
  233. (e.g. model.layers.0.qkv_proj)
  234. """
  235. def __init__(self,
  236. input_size: int,
  237. output_size: int,
  238. bias: bool = True,
  239. gather_output: bool = False,
  240. skip_bias_add: bool = False,
  241. params_dtype: Optional[torch.dtype] = None,
  242. quant_config: Optional[QuantizationConfig] = None,
  243. output_sizes: Optional[List[int]] = None,
  244. prefix: str = ""):
  245. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  246. quant_config, prefix)
  247. self.gather_output = gather_output
  248. # Divide the weight matrix along the last dimension.
  249. tp_rank = get_tensor_model_parallel_rank()
  250. tp_size = get_tensor_model_parallel_world_size()
  251. assert self.quant_method is not None
  252. if quant_config is None:
  253. self.output_size_per_partition = get_current_tp_rank_partition_size(
  254. output_size, tp_rank, tp_size)
  255. else:
  256. self.output_size_per_partition = divide(self.output_size, tp_size)
  257. self.output_partition_sizes = [self.output_size_per_partition]
  258. # If QKV or MergedColumn, use output size of each partition.
  259. if hasattr(self, "output_sizes"):
  260. if quant_config is None:
  261. self.output_partition_sizes = [
  262. get_current_tp_rank_partition_size(output_size, tp_rank,
  263. tp_size)
  264. for output_size in self.output_sizes
  265. ]
  266. else:
  267. self.output_partition_sizes = [
  268. divide(output_size, tp_size)
  269. for output_size in self.output_sizes
  270. ]
  271. if output_sizes is None:
  272. output_sizes = [output_size]
  273. self.quant_method.create_weights(
  274. layer=self,
  275. input_size_per_partition=self.input_size,
  276. output_partition_sizes=self.output_partition_sizes,
  277. input_size=self.input_size,
  278. output_size=self.output_size,
  279. params_dtype=self.params_dtype,
  280. weight_loader=(
  281. self.weight_loader_v2 if self.quant_method.__class__.__name__
  282. in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
  283. if bias:
  284. self.bias = Parameter(
  285. torch.empty(self.output_size_per_partition,
  286. dtype=params_dtype))
  287. set_weight_attrs(self.bias, {
  288. "output_dim": 0,
  289. "weight_loader": self.weight_loader,
  290. })
  291. else:
  292. self.register_parameter("bias", None)
  293. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  294. tp_rank = get_tensor_model_parallel_rank()
  295. output_dim = getattr(param, "output_dim", None)
  296. # Special case for GGUF
  297. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  298. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  299. if is_gguf_weight_type:
  300. param.weight_type = loaded_weight.item()
  301. # Materialize GGUF UninitializedParameter
  302. if is_gguf_weight and isinstance(param, UninitializedParameter):
  303. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  304. param_data = param.data
  305. if output_dim is not None:
  306. shard_size = param_data.shape[output_dim]
  307. start_idx = tp_rank * shard_size
  308. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  309. shard_size)
  310. # Special case for loading scales off disk, which often do not
  311. # have a shape (such as in the case of AutoFP8).
  312. if len(loaded_weight.shape) == 0:
  313. loaded_weight = loaded_weight.reshape(1)
  314. assert param_data.shape == loaded_weight.shape
  315. param_data.copy_(loaded_weight)
  316. def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
  317. # Special case for loading scales off disk, which often do not
  318. # have a shape (such as in the case of AutoFP8).
  319. if len(loaded_weight.shape) == 0:
  320. assert loaded_weight.numel() == 1
  321. loaded_weight = loaded_weight.reshape(1)
  322. param.load_column_parallel_weight(loaded_weight=loaded_weight)
  323. def forward(self, input_):
  324. bias = self.bias if not self.skip_bias_add else None
  325. # Matrix multiply.
  326. assert self.quant_method is not None
  327. output_parallel = self.quant_method.apply(self, input_, bias)
  328. if self.gather_output:
  329. # All-gather across the partitions.
  330. output = tensor_model_parallel_all_gather(output_parallel)
  331. else:
  332. output = output_parallel
  333. output_bias = self.bias if self.skip_bias_add else None
  334. return output, output_bias
  335. def extra_repr(self) -> str:
  336. s = f"in_features={self.input_size}"
  337. s += f", output_features={self.output_size_per_partition}"
  338. s += f", bias={self.bias is not None}"
  339. s += f", tp_size={get_tensor_model_parallel_world_size()}"
  340. s += f", gather_output={self.gather_output}"
  341. return s
  342. class MergedColumnParallelLinear(ColumnParallelLinear):
  343. """Packed linear layers with column parallelism.
  344. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  345. along the output dimension. When the weight matrix is loaded, the
  346. different partitions are sharded separately.
  347. Args:
  348. input_size: input dimension of the linear layer.
  349. output_sizes: list of output dimensions of the linear layer.
  350. bias: If true, add bias.
  351. gather_output: If true, call all-gather on output and make the output
  352. available to all GPUs, otherwise, every GPU will have
  353. its own output.
  354. skip_bias_add: This was added to enable performance optimizations where
  355. bias can be fused with other element-wise operations. we
  356. skip adding bias but instead return it.
  357. params_dtype: Data type for the parameters.
  358. quant_config: Quantization configure.
  359. prefix: The name of the layer in the state dict, including all parents
  360. (e.g. model.layers.0.qkv_proj)
  361. """
  362. def __init__(self,
  363. input_size: int,
  364. output_sizes: List[int],
  365. bias: bool = True,
  366. gather_output: bool = False,
  367. skip_bias_add: bool = False,
  368. params_dtype: Optional[torch.dtype] = None,
  369. quant_config: Optional[QuantizationConfig] = None,
  370. prefix: str = ""):
  371. self.output_sizes = output_sizes
  372. self.quant_config = quant_config
  373. if quant_config is not None:
  374. tp_size = get_tensor_model_parallel_world_size()
  375. assert all(output_size % tp_size == 0
  376. for output_size in output_sizes)
  377. super().__init__(input_size=input_size,
  378. output_size=sum(output_sizes),
  379. bias=bias,
  380. gather_output=gather_output,
  381. skip_bias_add=skip_bias_add,
  382. params_dtype=params_dtype,
  383. quant_config=quant_config,
  384. prefix=prefix)
  385. def weight_loader(self,
  386. param: Parameter,
  387. loaded_weight: torch.Tensor,
  388. loaded_shard_id: Optional[int] = None):
  389. # Special case for GGUF
  390. # initialize GGUF param after we know the quantize type
  391. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  392. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  393. if is_gguf_weight_type:
  394. param.data[loaded_shard_id].copy_(loaded_weight)
  395. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  396. return
  397. if is_gguf_weight and isinstance(param, UninitializedParameter):
  398. from gguf.constants import GGML_QUANT_SIZES
  399. ori_shape = param.tensor_shape
  400. weight_types = self.qweight_type.shard_weight_type.values()
  401. row_size = []
  402. for weight_type in weight_types:
  403. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  404. row_size.append(ori_shape[1] // block_size * type_size)
  405. q_shape = (ori_shape[0], max(row_size))
  406. param.materialize(q_shape, dtype=loaded_weight.dtype)
  407. param_data = param.data
  408. output_dim = getattr(param, "output_dim", None)
  409. # Special case for AQLM codebooks.
  410. is_metadata = getattr(param, "is_metadata", False)
  411. # Special case for per-tensor scale to load scalar into fused array.
  412. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  413. if loaded_shard_id is None:
  414. # Loaded weight is already fused on disk (qkv/mlp).
  415. if output_dim is None:
  416. if needs_scalar_to_array:
  417. param_data, loaded_weight = adjust_scalar_to_fused_array(
  418. param_data, loaded_weight, 0)
  419. assert param_data.shape == loaded_weight.shape
  420. param_data.copy_(loaded_weight)
  421. return
  422. current_shard_offset = 0
  423. shard_offsets: List[Tuple[int, int, int]] = []
  424. for i, output_size in enumerate(self.output_sizes):
  425. shard_offsets.append((i, current_shard_offset, output_size))
  426. current_shard_offset += output_size
  427. packed_dim = getattr(param, "packed_dim", None)
  428. for shard_id, shard_offset, shard_size in shard_offsets:
  429. # Special case for Quantization.
  430. # If quantized, we need to adjust the offset and size to account
  431. # for the packing.
  432. if packed_dim == output_dim:
  433. shard_size = shard_size // param.pack_factor
  434. shard_offset = shard_offset // param.pack_factor
  435. # Special case for Marlin.
  436. shard_size, shard_offset = adjust_marlin_shard(
  437. param, shard_size, shard_offset)
  438. loaded_weight_shard = loaded_weight.narrow(
  439. output_dim, shard_offset, shard_size)
  440. self.weight_loader(param, loaded_weight_shard, shard_id)
  441. return
  442. assert loaded_shard_id < len(self.output_sizes)
  443. tp_rank = get_tensor_model_parallel_rank()
  444. tp_size = get_tensor_model_parallel_world_size()
  445. if output_dim is not None:
  446. if self.quant_config is None:
  447. shard_offset = sum(
  448. get_current_tp_rank_partition_size(output_size, tp_rank,
  449. tp_size)
  450. for output_size in self.output_sizes[:loaded_shard_id])
  451. shard_size = get_current_tp_rank_partition_size(
  452. self.output_sizes[loaded_shard_id], tp_rank, tp_size)
  453. else:
  454. shard_offset = sum(
  455. self.output_sizes[:loaded_shard_id]) // tp_size
  456. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  457. # Special case for quantization.
  458. # If quantized, we need to adjust the offset and size to account
  459. # for the packing.
  460. packed_dim = getattr(param, "packed_dim", None)
  461. if packed_dim == output_dim:
  462. shard_size = shard_size // param.pack_factor
  463. shard_offset = shard_offset // param.pack_factor
  464. # Special case for Marlin.
  465. shard_size, shard_offset = adjust_marlin_shard(
  466. param, shard_size, shard_offset)
  467. use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
  468. False)
  469. if use_bitsandbytes_4bit:
  470. shard_size = loaded_weight.shape[output_dim]
  471. shard_offset = loaded_weight.shape[output_dim] * \
  472. loaded_shard_id
  473. if is_gguf_weight:
  474. tp_size = get_tensor_model_parallel_world_size()
  475. output_dim = getattr(param, "output_dim", None)
  476. shard_shape = list(loaded_weight.shape)
  477. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  478. param.shard_id.append(loaded_shard_id)
  479. param.shard_size[loaded_shard_id] = shard_shape
  480. input_dim = getattr(param, "input_dim", None)
  481. input_size = loaded_weight.shape[input_dim]
  482. param_data = param_data.narrow(input_dim, 0, input_size)
  483. param_data = param_data.narrow(output_dim, shard_offset,
  484. shard_size)
  485. if self.quant_config is None:
  486. start_idx = get_current_tp_rank_partition_offset(
  487. loaded_weight.shape[output_dim], tp_rank, tp_size)
  488. else:
  489. start_idx = tp_rank * shard_size
  490. # bitsandbytes loads the weights of the specific portion
  491. # no need to narrow here
  492. if not use_bitsandbytes_4bit:
  493. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  494. shard_size)
  495. # Special case for AQLM codebooks.
  496. elif is_metadata:
  497. # metadata indicates fixed size concatenated along dim 0
  498. shard_size = loaded_weight.shape[0]
  499. shard_offset = loaded_shard_id * shard_size
  500. param_data = param_data.narrow(0, shard_offset, shard_size)
  501. # Special case for per-tensor scales in fused case.
  502. elif needs_scalar_to_array:
  503. param_data, loaded_weight = adjust_scalar_to_fused_array(
  504. param_data, loaded_weight, loaded_shard_id)
  505. else:
  506. ignore_warning = getattr(param, "ignore_warning", False)
  507. if not ignore_warning:
  508. logger.warning(
  509. "Loading a weight without `output_dim` attribute in "
  510. "MergedColumnParallelLinear, assume the weight is "
  511. "the same for all partitions.")
  512. assert param_data.shape == loaded_weight.shape
  513. param_data.copy_(loaded_weight)
  514. def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
  515. loaded_weight: torch.Tensor):
  516. """
  517. Handle special case for models where MLP layers are already
  518. fused on disk. In this case, we have no shard id. This function
  519. determines the shard id by splitting these layers and then calls
  520. the weight loader using the shard id.
  521. An example of a model with these fused layers:
  522. https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
  523. """
  524. current_shard_offset = 0
  525. shard_offsets: List[Tuple[int, int, int]] = []
  526. for i, output_size in enumerate(self.output_sizes):
  527. shard_offsets.append((i, current_shard_offset, output_size))
  528. current_shard_offset += output_size
  529. for shard_id, shard_offset, shard_size in shard_offsets:
  530. # Special case for Quantization.
  531. # If quantized, we need to adjust the offset and size to account
  532. # for the packing.
  533. if isinstance(param, (PackedColumnParameter,
  534. PackedAphroditeParameter,
  535. )) and param.packed_dim == param.output_dim:
  536. shard_size, shard_offset = \
  537. param.adjust_shard_indexes_for_packing(
  538. shard_size=shard_size, shard_offset=shard_offset)
  539. loaded_weight_shard = loaded_weight.narrow(param.output_dim,
  540. shard_offset,
  541. shard_size)
  542. self.weight_loader_v2(param, loaded_weight_shard, shard_id)
  543. def weight_loader_v2(self,
  544. param: BaseAphroditeParameter,
  545. loaded_weight: torch.Tensor,
  546. loaded_shard_id: Optional[int] = None):
  547. if loaded_shard_id is None:
  548. if isinstance(param, PerTensorScaleParameter):
  549. param.load_merged_column_weight(loaded_weight=loaded_weight,
  550. shard_id=0)
  551. return
  552. elif type(param) in (RowAphroditeParameter, BaseAphroditeParameter):
  553. param.load_merged_column_weight(loaded_weight=loaded_weight)
  554. return
  555. self._load_fused_module_from_checkpoint(param, loaded_weight)
  556. return
  557. assert loaded_shard_id < len(self.output_sizes)
  558. tp_size = get_tensor_model_parallel_world_size()
  559. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  560. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  561. param.load_merged_column_weight(loaded_weight=loaded_weight,
  562. shard_id=loaded_shard_id,
  563. shard_offset=shard_offset,
  564. shard_size=shard_size)
  565. class QKVParallelLinear(ColumnParallelLinear):
  566. """Linear layers for the attention's QKV transformation.
  567. Linear layers for the linear transformation of the query, key, and value
  568. vectors in the attention layer. The weight matrix is concatenated along
  569. the output dimension. The layer is parallelized along the head dimension.
  570. When the number of key/value heads is smaller than the number of query
  571. heads (e.g., multi-query/grouped-query attention), the key/value head may
  572. be replicated while the query heads are partitioned.
  573. Args:
  574. hidden_size: input hidden state size of the transformer.
  575. head_size: size of each attention head.
  576. total_num_heads: total number of attention query heads.
  577. total_num_kv_heads: total number of attention key/value heads. If
  578. None, assume total_num_kv_heads = total_num_heads.
  579. bias: If true, add bias.
  580. skip_bias_add: This was added to enable performance optimizations where
  581. bias can be fused with other element-wise operations. we
  582. skip adding bias but instead return it.
  583. params_dtype: Data type for the parameters.
  584. quant_config: Quantization configure.
  585. prefix: The name of the layer in the state dict, including all parents
  586. (e.g. model.layers.0.qkv_proj)
  587. """
  588. def __init__(self,
  589. hidden_size: int,
  590. head_size: int,
  591. total_num_heads: int,
  592. total_num_kv_heads: Optional[int] = None,
  593. bias: bool = True,
  594. skip_bias_add: bool = False,
  595. params_dtype: Optional[torch.dtype] = None,
  596. quant_config: Optional[QuantizationConfig] = None,
  597. prefix: str = ""):
  598. self.hidden_size = hidden_size
  599. self.head_size = head_size
  600. self.total_num_heads = total_num_heads
  601. self.quant_config = quant_config
  602. if total_num_kv_heads is None:
  603. total_num_kv_heads = total_num_heads
  604. self.total_num_kv_heads = total_num_kv_heads
  605. # Divide the weight matrix along the last dimension.
  606. tp_size = get_tensor_model_parallel_world_size()
  607. tp_rank = get_tensor_model_parallel_rank()
  608. if quant_config is None:
  609. self.num_heads_per_kv_head = (self.total_num_heads //
  610. self.total_num_kv_heads)
  611. self.num_kv_heads = get_current_tp_rank_partition_size(
  612. self.total_num_kv_heads, tp_rank, tp_size)
  613. self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head
  614. self.num_kv_head_replicas = 1
  615. else:
  616. self.num_heads = divide(self.total_num_heads, tp_size)
  617. if tp_size >= self.total_num_kv_heads:
  618. self.num_kv_heads = 1
  619. self.num_kv_head_replicas = divide(tp_size,
  620. self.total_num_kv_heads)
  621. elif tp_size < self.total_num_kv_heads and quant_config is not None:
  622. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  623. self.num_kv_head_replicas = 1
  624. input_size = self.hidden_size
  625. output_size = (self.num_heads +
  626. 2 * self.num_kv_heads) * tp_size * self.head_size
  627. self.output_sizes = [
  628. self.num_heads * self.head_size * tp_size, # q_proj
  629. self.num_kv_heads * self.head_size * tp_size, # k_proj
  630. self.num_kv_heads * self.head_size * tp_size, # v_proj
  631. ]
  632. super().__init__(input_size=input_size,
  633. output_size=output_size,
  634. bias=bias,
  635. gather_output=False,
  636. skip_bias_add=skip_bias_add,
  637. params_dtype=params_dtype,
  638. quant_config=quant_config,
  639. prefix=prefix)
  640. def _get_shard_offset_mapping(self, loaded_shard_id: str):
  641. shard_offset_mapping = {
  642. "q": 0,
  643. "k": self.num_heads * self.head_size,
  644. "v": (self.num_heads + self.num_kv_heads) * self.head_size,
  645. "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
  646. }
  647. return shard_offset_mapping.get(loaded_shard_id)
  648. def _get_shard_size_mapping(self, loaded_shard_id: str):
  649. shard_size_mapping = {
  650. "q": self.num_heads * self.head_size,
  651. "k": self.num_kv_heads * self.head_size,
  652. "v": self.num_kv_heads * self.head_size,
  653. }
  654. return shard_size_mapping.get(loaded_shard_id)
  655. def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
  656. loaded_weight: torch.Tensor):
  657. """
  658. Handle special case for models where QKV layers are already
  659. fused on disk. In this case, we have no shard id. This function
  660. determmines the shard id by splitting these layers and then calls
  661. the weight loader using the shard id.
  662. An example of a model with these fused layers:
  663. https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
  664. """
  665. shard_offsets = [
  666. # (shard_id, shard_offset, shard_size)
  667. ("q", 0, self.total_num_heads * self.head_size),
  668. ("k", self.total_num_heads * self.head_size,
  669. self.total_num_kv_heads * self.head_size),
  670. ("v",
  671. (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
  672. self.total_num_kv_heads * self.head_size),
  673. ]
  674. for shard_id, shard_offset, shard_size in shard_offsets:
  675. # Special case for Quantization.
  676. # If quantized, we need to adjust the offset and size to account
  677. # for the packing.
  678. if isinstance(param, (PackedColumnParameter,
  679. PackedAphroditeParameter,
  680. )) and param.packed_dim == param.output_dim:
  681. shard_size, shard_offset = \
  682. param.adjust_shard_indexes_for_packing(
  683. shard_size=shard_size, shard_offset=shard_offset)
  684. loaded_weight_shard = loaded_weight.narrow(param.output_dim,
  685. shard_offset,
  686. shard_size)
  687. self.weight_loader_v2(param, loaded_weight_shard, shard_id)
  688. def weight_loader_v2(self,
  689. param: BaseAphroditeParameter,
  690. loaded_weight: torch.Tensor,
  691. loaded_shard_id: Optional[str] = None):
  692. if loaded_shard_id is None: # special case for certain models
  693. if isinstance(param, PerTensorScaleParameter):
  694. param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
  695. return
  696. elif type(param) in (RowAphroditeParameter, BaseAphroditeParameter):
  697. param.load_qkv_weight(loaded_weight=loaded_weight)
  698. return
  699. self._load_fused_module_from_checkpoint(param, loaded_weight)
  700. return
  701. assert loaded_shard_id in ["q", "k", "v"]
  702. shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
  703. shard_size = self._get_shard_size_mapping(loaded_shard_id)
  704. param.load_qkv_weight(loaded_weight=loaded_weight,
  705. num_heads=self.num_kv_head_replicas,
  706. shard_id=loaded_shard_id,
  707. shard_offset=shard_offset,
  708. shard_size=shard_size)
  709. def weight_loader(self,
  710. param: Parameter,
  711. loaded_weight: torch.Tensor,
  712. loaded_shard_id: Optional[str] = None):
  713. # Special case for GGUF
  714. # initialize GGUF param after we know the quantize type
  715. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  716. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  717. if is_gguf_weight_type and loaded_shard_id is not None:
  718. idx_map = {"q": 0, "k": 1, "v": 2}
  719. param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
  720. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  721. return
  722. if is_gguf_weight and isinstance(param, UninitializedParameter):
  723. from gguf.constants import GGML_QUANT_SIZES
  724. ori_shape = param.tensor_shape
  725. weight_types = self.qweight_type.shard_weight_type.values()
  726. row_size = []
  727. for weight_type in weight_types:
  728. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  729. row_size.append(ori_shape[1] // block_size * type_size)
  730. q_shape = (ori_shape[0], max(row_size))
  731. param.materialize(q_shape, dtype=loaded_weight.dtype)
  732. param_data = param.data
  733. output_dim = getattr(param, "output_dim", None)
  734. # Special case for AQLM codebooks.
  735. is_metadata = getattr(param, "is_metadata", False)
  736. # Special case for per-tensor scales in fused case.
  737. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  738. if loaded_shard_id is None:
  739. # Loaded weight is already fused on disk (qkv/mlp).
  740. if output_dim is None:
  741. if needs_scalar_to_array:
  742. param_data, loaded_weight = adjust_scalar_to_fused_array(
  743. param_data, loaded_weight, 0)
  744. assert param_data.shape == loaded_weight.shape
  745. param_data.copy_(loaded_weight)
  746. return
  747. shard_offsets = [
  748. # (shard_id, shard_offset, shard_size)
  749. ("q", 0, self.total_num_heads * self.head_size),
  750. ("k", self.total_num_heads * self.head_size,
  751. self.total_num_kv_heads * self.head_size),
  752. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  753. self.head_size, self.total_num_kv_heads * self.head_size),
  754. ]
  755. packed_dim = getattr(param, "packed_dim", None)
  756. for shard_id, shard_offset, shard_size in shard_offsets:
  757. # Special case for Quantized Weights.
  758. # If quantized, we need to adjust the offset and size to account
  759. # for the packing.
  760. if packed_dim == output_dim:
  761. shard_size = shard_size // param.pack_factor
  762. shard_offset = shard_offset // param.pack_factor
  763. # Special case for Marlin.
  764. shard_size, shard_offset = adjust_marlin_shard(
  765. param, shard_size, shard_offset)
  766. loaded_weight_shard = loaded_weight.narrow(
  767. output_dim, shard_offset, shard_size)
  768. self.weight_loader(param, loaded_weight_shard, shard_id)
  769. return
  770. tp_rank = get_tensor_model_parallel_rank()
  771. assert loaded_shard_id in ["q", "k", "v"]
  772. # If output dim is defined, use the default loading process.
  773. if output_dim is not None:
  774. if loaded_shard_id == "q":
  775. shard_offset = 0
  776. shard_size = self.num_heads * self.head_size
  777. if self.quant_config is None:
  778. multiple_of = self.head_size * self.num_heads_per_kv_head
  779. elif loaded_shard_id == "k":
  780. shard_offset = self.num_heads * self.head_size
  781. shard_size = self.num_kv_heads * self.head_size
  782. if self.quant_config is None:
  783. multiple_of = self.head_size
  784. elif loaded_shard_id == "v":
  785. shard_offset = (self.num_heads +
  786. self.num_kv_heads) * self.head_size
  787. shard_size = self.num_kv_heads * self.head_size
  788. if self.quant_config is None:
  789. multiple_of = self.head_size
  790. # Special case for Quantized Weights.
  791. # If quantized, we need to adjust the offset and size to account
  792. # for the packing.
  793. packed_dim = getattr(param, "packed_dim", None)
  794. if packed_dim == output_dim:
  795. shard_size = shard_size // param.pack_factor
  796. shard_offset = shard_offset // param.pack_factor
  797. if self.quant_config is None:
  798. multiple_of = multiple_of // param.pack_factor
  799. # Special case for Marlin.
  800. shard_size, shard_offset = adjust_marlin_shard(
  801. param, shard_size, shard_offset)
  802. use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
  803. False)
  804. if use_bitsandbytes_4bit:
  805. orig_qkv_offsets = {
  806. "q": (0, self.num_heads * self.head_size),
  807. "k": (self.num_heads * self.head_size,
  808. self.num_kv_heads * self.head_size),
  809. "v":
  810. ((self.num_heads + self.num_kv_heads) * self.head_size,
  811. self.num_kv_heads * self.head_size),
  812. "total":
  813. ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
  814. 0)
  815. }
  816. shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
  817. param, orig_qkv_offsets, loaded_shard_id)
  818. if is_gguf_weight:
  819. tp_size = get_tensor_model_parallel_world_size()
  820. output_dim = getattr(param, "output_dim", None)
  821. shard_shape = list(loaded_weight.shape)
  822. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  823. param.shard_id.append(loaded_shard_id)
  824. param.shard_size[loaded_shard_id] = shard_shape
  825. input_dim = getattr(param, "input_dim", None)
  826. input_size = loaded_weight.shape[input_dim]
  827. param_data = param_data.narrow(input_dim, 0, input_size)
  828. param_data = param_data.narrow(output_dim, shard_offset,
  829. shard_size)
  830. if self.quant_config is None:
  831. tp_size = get_tensor_model_parallel_world_size()
  832. total_size = loaded_weight.shape[output_dim]
  833. start_idx = get_current_tp_rank_partition_offset(
  834. total_size, tp_rank, tp_size, multiple_of=multiple_of)
  835. else:
  836. if loaded_shard_id == "q":
  837. shard_id = tp_rank
  838. else:
  839. shard_id = tp_rank // self.num_kv_head_replicas
  840. start_idx = shard_id * shard_size
  841. # bitsandbytes loads the weights of the specific portion
  842. # no need to narrow here
  843. if not use_bitsandbytes_4bit:
  844. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  845. shard_size)
  846. # Special case for for AQLM codebooks.
  847. elif is_metadata:
  848. # metadata indicates fixed size concatenated along dim 0
  849. shard_size = loaded_weight.shape[0]
  850. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  851. param_data = param_data.narrow(0, shard_index * shard_size,
  852. shard_size)
  853. # Special case for per-tensor scales in fused case.
  854. elif needs_scalar_to_array:
  855. param_data, loaded_weight = adjust_scalar_to_fused_array(
  856. param_data, loaded_weight, loaded_shard_id)
  857. else:
  858. ignore_warning = getattr(param, "ignore_warning", False)
  859. if not ignore_warning:
  860. logger.warning(
  861. "Loading a weight without `output_dim` attribute in "
  862. "QKVParallelLinear, assume the weight is the same "
  863. "for all partitions.")
  864. assert param_data.shape == loaded_weight.shape
  865. param_data.copy_(loaded_weight)
  866. class RowParallelLinear(LinearBase):
  867. """Linear layer with row parallelism.
  868. The linear layer is defined as Y = XA + b. A is parallelized along
  869. its first dimension and X along its second dimension as:
  870. - -
  871. | A_1 |
  872. | . |
  873. A = | . | X = [X_1, ..., X_p]
  874. | . |
  875. | A_p |
  876. - -
  877. Arguments:
  878. input_size: first dimension of matrix A.
  879. output_size: second dimension of matrix A.
  880. bias: If true, add bias. Note that bias is not parallelized.
  881. input_is_parallel: If true, we assume that the input is already
  882. split across the GPUs and we do not split
  883. again.
  884. skip_bias_add: This was added to enable performance optimization where
  885. bias can be fused with other element-wise operations.
  886. We skip adding bias but instead return it.
  887. params_dtype: Data type for the parameters.
  888. quant_config: Quantization configure.
  889. partition_multiple_of: Partitions will be divided,
  890. so each partition is a multiple of this number.
  891. """
  892. def __init__(self,
  893. input_size: int,
  894. output_size: int,
  895. bias: bool = True,
  896. input_is_parallel: bool = True,
  897. skip_bias_add: bool = False,
  898. params_dtype: Optional[torch.dtype] = None,
  899. reduce_results: bool = True,
  900. quant_config: Optional[QuantizationConfig] = None,
  901. partition_multiple_of: int = 1,
  902. prefix: str = ""):
  903. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  904. quant_config, prefix)
  905. self.input_is_parallel = input_is_parallel
  906. self.reduce_results = reduce_results
  907. self.quant_config = quant_config
  908. # Divide the weight matrix along the last dimension.
  909. self.tp_rank = get_tensor_model_parallel_rank()
  910. self.tp_size = get_tensor_model_parallel_world_size()
  911. self.tp_rank = get_tensor_model_parallel_rank()
  912. if quant_config is None:
  913. self.partition_multiple_of = partition_multiple_of
  914. self.input_size_per_partition = get_current_tp_rank_partition_size(
  915. input_size, self.tp_rank, self.tp_size, partition_multiple_of)
  916. else:
  917. self.input_size_per_partition = divide(input_size, self.tp_size)
  918. assert self.quant_method is not None
  919. self.quant_method.create_weights(
  920. layer=self,
  921. input_size_per_partition=self.input_size_per_partition,
  922. output_partition_sizes=[self.output_size],
  923. input_size=self.input_size,
  924. output_size=self.output_size,
  925. params_dtype=self.params_dtype,
  926. weight_loader=(
  927. self.weight_loader_v2 if self.quant_method.__class__.__name__
  928. in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
  929. if not reduce_results and (bias and not skip_bias_add):
  930. raise ValueError("When not reduce the results, adding bias to the "
  931. "results can lead to incorrect results")
  932. if bias:
  933. self.bias = Parameter(
  934. torch.empty(self.output_size, dtype=params_dtype))
  935. set_weight_attrs(self.bias, {
  936. "output_dim": 0,
  937. "weight_loader": self.weight_loader,
  938. })
  939. else:
  940. self.register_parameter("bias", None)
  941. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  942. tp_size = get_tensor_model_parallel_world_size()
  943. input_dim = getattr(param, "input_dim", None)
  944. use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
  945. # Special case for GGUF
  946. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  947. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  948. if is_gguf_weight_type:
  949. param.weight_type = loaded_weight.item()
  950. # Materialize GGUF UninitializedParameter
  951. if is_gguf_weight and isinstance(param, UninitializedParameter):
  952. weight_shape = list(loaded_weight.shape)
  953. if input_dim:
  954. weight_shape[input_dim] = weight_shape[input_dim] // tp_size
  955. param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
  956. param_data = param.data
  957. # bitsandbytes loads the weights of the specific portion
  958. # no need to narrow here
  959. if input_dim is not None and not use_bitsandbytes_4bit:
  960. shard_size = param_data.shape[input_dim]
  961. if self.quant_config is None:
  962. start_idx = get_current_tp_rank_partition_offset(
  963. self.input_size,
  964. self.tp_rank,
  965. self.tp_size,
  966. multiple_of=self.partition_multiple_of)
  967. else:
  968. start_idx = self.tp_rank * shard_size
  969. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  970. shard_size)
  971. # Special case for loading scales off disk, which often do not
  972. # have a shape (such as in the case of AutoFP8).
  973. if len(loaded_weight.shape) == 0:
  974. loaded_weight = loaded_weight.reshape(1)
  975. assert param_data.shape == loaded_weight.shape
  976. param_data.copy_(loaded_weight)
  977. def weight_loader_v2(self, param: BaseAphroditeParameter,
  978. loaded_weight: torch.Tensor):
  979. # Special case for loading scales off disk, which often do not
  980. # have a shape (such as in the case of AutoFP8).
  981. if len(loaded_weight.shape) == 0:
  982. assert loaded_weight.numel() == 1
  983. loaded_weight = loaded_weight.reshape(1)
  984. param.load_row_parallel_weight(loaded_weight=loaded_weight)
  985. def forward(self, input_):
  986. if self.input_is_parallel:
  987. input_parallel = input_
  988. else:
  989. tp_rank = get_tensor_model_parallel_rank()
  990. splitted_input = split_tensor_along_last_dim(
  991. input_, num_partitions=self.tp_size)
  992. input_parallel = splitted_input[tp_rank].contiguous()
  993. # Matrix multiply.
  994. assert self.quant_method is not None
  995. # Only fuse bias add into GEMM for rank 0 (this ensures that
  996. # bias will not get added more than once in TP>1 case)
  997. bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
  998. output_parallel = self.quant_method.apply(self,
  999. input_parallel,
  1000. bias=bias_)
  1001. if self.reduce_results and self.tp_size > 1:
  1002. output = tensor_model_parallel_all_reduce(output_parallel)
  1003. else:
  1004. output = output_parallel
  1005. output_bias = self.bias if self.skip_bias_add else None
  1006. return output, output_bias
  1007. def extra_repr(self) -> str:
  1008. s = f"input_features={self.input_size_per_partition}"
  1009. s += f", output_features={self.output_size}"
  1010. s += f", bias={self.bias is not None}"
  1011. s += f", tp_size={self.tp_size}"
  1012. s += f", reduce_results={self.reduce_results}"
  1013. return s