linear.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164
  1. from abc import abstractmethod
  2. from typing import Dict, List, Optional, Tuple
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter, UninitializedParameter
  7. # yapf: disable
  8. from aphrodite.distributed import (divide,
  9. get_current_tp_rank_partition_offset,
  10. get_current_tp_rank_partition_size,
  11. get_tensor_model_parallel_rank,
  12. get_tensor_model_parallel_world_size,
  13. split_tensor_along_last_dim,
  14. tensor_model_parallel_all_gather,
  15. tensor_model_parallel_all_reduce)
  16. from aphrodite.modeling.parameter import (BaseAphroditeParameter,
  17. BlockQuantScaleParameter,
  18. PackedAphroditeParameter,
  19. PackedColumnParameter,
  20. PerTensorScaleParameter,
  21. RowAphroditeParameter)
  22. # yapf: enable
  23. from aphrodite.modeling.utils import set_weight_attrs
  24. from aphrodite.quantization.base_config import (QuantizationConfig,
  25. QuantizeMethodBase)
  26. WEIGHT_LOADER_V2_SUPPORTED = [
  27. "AWQLinearMethod",
  28. "AWQMarlinLinearMethod",
  29. "CompressedTensorsLinearMethod",
  30. "FBGEMMFp8LinearMethod",
  31. "Fp8LinearMethod",
  32. "GPTQLinearMethod",
  33. "GPTQMarlin24LinearMethod",
  34. "GPTQMarlinLinearMethod",
  35. "HQQMarlinMethod",
  36. "MarlinLinearMethod",
  37. "ModelOptFp8LinearMethod",
  38. "QQQLinearMethod",
  39. "TPUInt8LinearMethod",
  40. ]
  41. def adjust_marlin_shard(param, shard_size, shard_offset):
  42. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  43. if marlin_tile_size is None:
  44. return shard_size, shard_offset
  45. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  46. def adjust_bitsandbytes_4bit_shard(param: Parameter,
  47. qkv_offsets: Dict[str, Tuple[int, int]],
  48. loaded_shard_id: str) -> Tuple[int, int]:
  49. """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
  50. total, _ = qkv_offsets["total"]
  51. orig_offset, orig_size = qkv_offsets[loaded_shard_id]
  52. quantized_total = param.data.shape[0]
  53. quantized_offset = orig_offset * quantized_total // total
  54. quantized_size = orig_size * quantized_total // total
  55. return quantized_size, quantized_offset
  56. def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
  57. """For fused modules (QKV and MLP) we have an array of length
  58. N that holds 1 scale for each "logical" matrix. So the param
  59. is an array of length N. The loaded_weight corresponds to
  60. one of the shards on disk. Here, we slice the param based on
  61. the shard_id for loading.
  62. """
  63. qkv_idxs = {"q": 0, "k": 1, "v": 2}
  64. if isinstance(shard_id, str):
  65. shard_id = qkv_idxs[shard_id]
  66. elif not isinstance(shard_id, int):
  67. raise ValueError(f"Unknown Shard Id {shard_id}")
  68. # AutoFP8 scales do not have a shape
  69. # compressed-tensors scales do have a shape
  70. if len(loaded_weight.shape) != 0:
  71. assert loaded_weight.shape[0] == 1
  72. loaded_weight = loaded_weight[0]
  73. return param[shard_id], loaded_weight
  74. class LinearMethodBase(QuantizeMethodBase):
  75. """Base class for different (maybe quantized) linear methods."""
  76. @abstractmethod
  77. def create_weights(self, layer: torch.nn.Module,
  78. input_size_per_partition: int,
  79. output_partition_sizes: List[int], input_size: int,
  80. output_size: int, params_dtype: torch.dtype,
  81. **extra_weight_attrs):
  82. """Create weights for a linear layer.
  83. The weights will be set as attributes of the layer.
  84. Args:
  85. layer: The layer that is using the LinearMethodBase factory.
  86. input_size_per_partition: Size of the weight input dim on rank X.
  87. output_partition_sizes: Sizes of the output dim of each logical
  88. weight on rank X. E.g., output_partition_sizes for QKVLinear
  89. is a list contains the width of Wq, Wk, Wv on rank X.
  90. input_size: Size of the input dim of the weight across all ranks.
  91. output_size: Size of the output dim of the weight across all ranks.
  92. params_dtype: Datatype of the parameters.
  93. """
  94. raise NotImplementedError
  95. @abstractmethod
  96. def apply(self,
  97. layer: torch.nn.Module,
  98. x: torch.Tensor,
  99. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  100. """Apply the weights in layer to the input tensor.
  101. Expects create_weights to have been called before on the layer."""
  102. raise NotImplementedError
  103. class UnquantizedLinearMethod(LinearMethodBase):
  104. """Linear method without quantization."""
  105. def create_weights(self, layer: torch.nn.Module,
  106. input_size_per_partition: int,
  107. output_partition_sizes: List[int], input_size: int,
  108. output_size: int, params_dtype: torch.dtype,
  109. **extra_weight_attrs):
  110. weight = Parameter(torch.empty(sum(output_partition_sizes),
  111. input_size_per_partition,
  112. dtype=params_dtype),
  113. requires_grad=False)
  114. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  115. layer.register_parameter("weight", weight)
  116. set_weight_attrs(weight, extra_weight_attrs)
  117. def apply(self,
  118. layer: torch.nn.Module,
  119. x: torch.Tensor,
  120. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  121. return F.linear(x, layer.weight, bias)
  122. class LinearBase(torch.nn.Module):
  123. """Base linear layer.
  124. Args:
  125. input_size: input dimension of the linear layer.
  126. output_size: output dimension of the linear layer.
  127. bias: If true, add bias.
  128. skip_bias_add: If true, skip adding bias but instead return it.
  129. params_dtype: Data type for the parameters.
  130. quant_config: Quantization configure.
  131. """
  132. def __init__(
  133. self,
  134. input_size: int,
  135. output_size: int,
  136. skip_bias_add: bool = False,
  137. params_dtype: Optional[torch.dtype] = None,
  138. quant_config: Optional[QuantizationConfig] = None,
  139. prefix: str = "",
  140. ):
  141. super().__init__()
  142. # Keep input parameters
  143. self.input_size = input_size
  144. self.output_size = output_size
  145. self.skip_bias_add = skip_bias_add
  146. if params_dtype is None:
  147. params_dtype = torch.get_default_dtype()
  148. self.params_dtype = params_dtype
  149. if quant_config is None:
  150. self.quant_method: Optional[
  151. QuantizeMethodBase] = UnquantizedLinearMethod()
  152. else:
  153. self.quant_method = quant_config.get_quant_method(self,
  154. prefix=prefix)
  155. def forward(self, x: torch.Tensor) -> torch.Tensor:
  156. raise NotImplementedError
  157. class ReplicatedLinear(LinearBase):
  158. """Replicated linear layer.
  159. Args:
  160. input_size: input dimension of the linear layer.
  161. output_size: output dimension of the linear layer.
  162. bias: If true, add bias.
  163. skip_bias_add: If true, skip adding bias but instead return it.
  164. params_dtype: Data type for the parameters.
  165. quant_config: Quantization configure.
  166. prefix: The name of the layer in the state dict, including all parents
  167. (e.g. model.layers.0.qkv_proj)
  168. """
  169. def __init__(self,
  170. input_size: int,
  171. output_size: int,
  172. bias: bool = True,
  173. skip_bias_add: bool = False,
  174. params_dtype: Optional[torch.dtype] = None,
  175. quant_config: Optional[QuantizationConfig] = None,
  176. prefix: str = ""):
  177. super().__init__(input_size,
  178. output_size,
  179. skip_bias_add,
  180. params_dtype,
  181. quant_config,
  182. prefix=prefix)
  183. # All the linear layer supports quant method.
  184. assert self.quant_method is not None
  185. self.quant_method.create_weights(self,
  186. self.input_size, [self.output_size],
  187. self.input_size,
  188. self.output_size,
  189. self.params_dtype,
  190. weight_loader=self.weight_loader)
  191. if bias:
  192. self.bias = Parameter(
  193. torch.empty(self.output_size, dtype=self.params_dtype))
  194. set_weight_attrs(self.bias, {"output_dim": 0})
  195. else:
  196. self.register_parameter("bias", None)
  197. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  198. # If the weight on disk does not have a shape, give it one
  199. # (such scales for AutoFp8).
  200. if len(loaded_weight.shape) == 0:
  201. loaded_weight = loaded_weight.reshape(1)
  202. assert param.size() == loaded_weight.size()
  203. param.data.copy_(loaded_weight)
  204. def forward(self, x: torch.Tensor) -> torch.Tensor:
  205. bias = self.bias if not self.skip_bias_add else None
  206. assert self.quant_method is not None
  207. output = self.quant_method.apply(self, x, bias)
  208. output_bias = self.bias if self.skip_bias_add else None
  209. return output, output_bias
  210. def extra_repr(self) -> str:
  211. s = f"in_features={self.input_size}"
  212. s += f", output_features={self.output_size}"
  213. s += f", bias={self.bias is not None}"
  214. return s
  215. class ColumnParallelLinear(LinearBase):
  216. """Linear layer with column parallelism.
  217. The linear layer is defined as Y = XA + b. A is parallelized along
  218. its second dimension as A = [A_1, ..., A_p].
  219. Args:
  220. input_size: first dimension of matrix A.
  221. output_size: second dimension of matrix A.
  222. bias: If true, add bias.
  223. gather_output: If true, call all-gather on output and make Y available
  224. to all GPUs, otherwise, every GPU will have its output
  225. which is Y_i = XA_i
  226. skip_bias_add: This was added to enable performance optimizations where
  227. bias can be fused with other element-wise operations. we
  228. skip adding bias but instead return it.
  229. params_dtype: Data type for the parameters.
  230. quant_config: Quantization configure.
  231. output_sizes: list of output sizes packed into one output, like for QKV
  232. the list would be size 3.
  233. prefix: The name of the layer in the state dict, including all parents
  234. (e.g. model.layers.0.qkv_proj)
  235. """
  236. def __init__(self,
  237. input_size: int,
  238. output_size: int,
  239. bias: bool = True,
  240. gather_output: bool = False,
  241. skip_bias_add: bool = False,
  242. params_dtype: Optional[torch.dtype] = None,
  243. quant_config: Optional[QuantizationConfig] = None,
  244. output_sizes: Optional[List[int]] = None,
  245. prefix: str = ""):
  246. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  247. quant_config, prefix)
  248. self.gather_output = gather_output
  249. # Divide the weight matrix along the last dimension.
  250. tp_rank = get_tensor_model_parallel_rank()
  251. tp_size = get_tensor_model_parallel_world_size()
  252. assert self.quant_method is not None
  253. if quant_config is None:
  254. self.output_size_per_partition = get_current_tp_rank_partition_size(
  255. output_size, tp_rank, tp_size)
  256. else:
  257. self.output_size_per_partition = divide(self.output_size, tp_size)
  258. self.output_partition_sizes = [self.output_size_per_partition]
  259. # If QKV or MergedColumn, use output size of each partition.
  260. if hasattr(self, "output_sizes"):
  261. if quant_config is None:
  262. self.output_partition_sizes = [
  263. get_current_tp_rank_partition_size(output_size, tp_rank,
  264. tp_size)
  265. for output_size in self.output_sizes
  266. ]
  267. else:
  268. self.output_partition_sizes = [
  269. divide(output_size, tp_size)
  270. for output_size in self.output_sizes
  271. ]
  272. if output_sizes is None:
  273. output_sizes = [output_size]
  274. self.quant_method.create_weights(
  275. layer=self,
  276. input_size_per_partition=self.input_size,
  277. output_partition_sizes=self.output_partition_sizes,
  278. input_size=self.input_size,
  279. output_size=self.output_size,
  280. params_dtype=self.params_dtype,
  281. weight_loader=(
  282. self.weight_loader_v2 if self.quant_method.__class__.__name__
  283. in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
  284. if bias:
  285. self.bias = Parameter(
  286. torch.empty(self.output_size_per_partition,
  287. dtype=params_dtype))
  288. set_weight_attrs(self.bias, {
  289. "output_dim": 0,
  290. "weight_loader": self.weight_loader,
  291. })
  292. else:
  293. self.register_parameter("bias", None)
  294. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  295. tp_rank = get_tensor_model_parallel_rank()
  296. output_dim = getattr(param, "output_dim", None)
  297. # Special case for GGUF
  298. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  299. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  300. if is_gguf_weight_type:
  301. param.weight_type = loaded_weight.item()
  302. # Materialize GGUF UninitializedParameter
  303. if is_gguf_weight and isinstance(param, UninitializedParameter):
  304. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  305. param_data = param.data
  306. if output_dim is not None:
  307. shard_size = param_data.shape[output_dim]
  308. start_idx = tp_rank * shard_size
  309. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  310. shard_size)
  311. # Special case for loading scales off disk, which often do not
  312. # have a shape (such as in the case of AutoFP8).
  313. if len(loaded_weight.shape) == 0:
  314. loaded_weight = loaded_weight.reshape(1)
  315. assert param_data.shape == loaded_weight.shape
  316. param_data.copy_(loaded_weight)
  317. def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
  318. # Special case for loading scales off disk, which often do not
  319. # have a shape (such as in the case of AutoFP8).
  320. if len(loaded_weight.shape) == 0:
  321. assert loaded_weight.numel() == 1
  322. loaded_weight = loaded_weight.reshape(1)
  323. param.load_column_parallel_weight(loaded_weight=loaded_weight)
  324. def forward(self, input_):
  325. bias = self.bias if not self.skip_bias_add else None
  326. # Matrix multiply.
  327. assert self.quant_method is not None
  328. output_parallel = self.quant_method.apply(self, input_, bias)
  329. if self.gather_output:
  330. # All-gather across the partitions.
  331. output = tensor_model_parallel_all_gather(output_parallel)
  332. else:
  333. output = output_parallel
  334. output_bias = self.bias if self.skip_bias_add else None
  335. return output, output_bias
  336. def extra_repr(self) -> str:
  337. s = f"in_features={self.input_size}"
  338. s += f", output_features={self.output_size_per_partition}"
  339. s += f", bias={self.bias is not None}"
  340. s += f", tp_size={get_tensor_model_parallel_world_size()}"
  341. s += f", gather_output={self.gather_output}"
  342. return s
  343. class MergedColumnParallelLinear(ColumnParallelLinear):
  344. """Packed linear layers with column parallelism.
  345. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  346. along the output dimension. When the weight matrix is loaded, the
  347. different partitions are sharded separately.
  348. Args:
  349. input_size: input dimension of the linear layer.
  350. output_sizes: list of output dimensions of the linear layer.
  351. bias: If true, add bias.
  352. gather_output: If true, call all-gather on output and make the output
  353. available to all GPUs, otherwise, every GPU will have
  354. its own output.
  355. skip_bias_add: This was added to enable performance optimizations where
  356. bias can be fused with other element-wise operations. we
  357. skip adding bias but instead return it.
  358. params_dtype: Data type for the parameters.
  359. quant_config: Quantization configure.
  360. prefix: The name of the layer in the state dict, including all parents
  361. (e.g. model.layers.0.qkv_proj)
  362. """
  363. def __init__(self,
  364. input_size: int,
  365. output_sizes: List[int],
  366. bias: bool = True,
  367. gather_output: bool = False,
  368. skip_bias_add: bool = False,
  369. params_dtype: Optional[torch.dtype] = None,
  370. quant_config: Optional[QuantizationConfig] = None,
  371. prefix: str = ""):
  372. self.output_sizes = output_sizes
  373. self.quant_config = quant_config
  374. if quant_config is not None:
  375. tp_size = get_tensor_model_parallel_world_size()
  376. assert all(output_size % tp_size == 0
  377. for output_size in output_sizes)
  378. super().__init__(input_size=input_size,
  379. output_size=sum(output_sizes),
  380. bias=bias,
  381. gather_output=gather_output,
  382. skip_bias_add=skip_bias_add,
  383. params_dtype=params_dtype,
  384. quant_config=quant_config,
  385. prefix=prefix)
  386. def weight_loader(self,
  387. param: Parameter,
  388. loaded_weight: torch.Tensor,
  389. loaded_shard_id: Optional[int] = None):
  390. # Special case for GGUF
  391. # initialize GGUF param after we know the quantize type
  392. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  393. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  394. if is_gguf_weight_type:
  395. param.data[loaded_shard_id].copy_(loaded_weight)
  396. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  397. return
  398. if is_gguf_weight and isinstance(param, UninitializedParameter):
  399. from gguf.constants import GGML_QUANT_SIZES
  400. ori_shape = param.tensor_shape
  401. weight_types = self.qweight_type.shard_weight_type.values()
  402. row_size = []
  403. for weight_type in weight_types:
  404. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  405. row_size.append(ori_shape[1] // block_size * type_size)
  406. q_shape = (ori_shape[0], max(row_size))
  407. param.materialize(q_shape, dtype=loaded_weight.dtype)
  408. param_data = param.data
  409. output_dim = getattr(param, "output_dim", None)
  410. # Special case for AQLM codebooks.
  411. is_metadata = getattr(param, "is_metadata", False)
  412. # Special case for per-tensor scale to load scalar into fused array.
  413. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  414. if loaded_shard_id is None:
  415. # Loaded weight is already fused on disk (qkv/mlp).
  416. if output_dim is None:
  417. if needs_scalar_to_array:
  418. param_data, loaded_weight = adjust_scalar_to_fused_array(
  419. param_data, loaded_weight, 0)
  420. assert param_data.shape == loaded_weight.shape
  421. param_data.copy_(loaded_weight)
  422. return
  423. current_shard_offset = 0
  424. shard_offsets: List[Tuple[int, int, int]] = []
  425. for i, output_size in enumerate(self.output_sizes):
  426. shard_offsets.append((i, current_shard_offset, output_size))
  427. current_shard_offset += output_size
  428. packed_dim = getattr(param, "packed_dim", None)
  429. for shard_id, shard_offset, shard_size in shard_offsets:
  430. # Special case for Quantization.
  431. # If quantized, we need to adjust the offset and size to account
  432. # for the packing.
  433. if packed_dim == output_dim:
  434. shard_size = shard_size // param.pack_factor
  435. shard_offset = shard_offset // param.pack_factor
  436. # Special case for Marlin.
  437. shard_size, shard_offset = adjust_marlin_shard(
  438. param, shard_size, shard_offset)
  439. loaded_weight_shard = loaded_weight.narrow(
  440. output_dim, shard_offset, shard_size)
  441. self.weight_loader(param, loaded_weight_shard, shard_id)
  442. return
  443. assert loaded_shard_id < len(self.output_sizes)
  444. tp_rank = get_tensor_model_parallel_rank()
  445. tp_size = get_tensor_model_parallel_world_size()
  446. if output_dim is not None:
  447. if self.quant_config is None:
  448. shard_offset = sum(
  449. get_current_tp_rank_partition_size(output_size, tp_rank,
  450. tp_size)
  451. for output_size in self.output_sizes[:loaded_shard_id])
  452. shard_size = get_current_tp_rank_partition_size(
  453. self.output_sizes[loaded_shard_id], tp_rank, tp_size)
  454. else:
  455. shard_offset = sum(
  456. self.output_sizes[:loaded_shard_id]) // tp_size
  457. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  458. # Special case for quantization.
  459. # If quantized, we need to adjust the offset and size to account
  460. # for the packing.
  461. packed_dim = getattr(param, "packed_dim", None)
  462. if packed_dim == output_dim:
  463. shard_size = shard_size // param.pack_factor
  464. shard_offset = shard_offset // param.pack_factor
  465. # Special case for Marlin.
  466. shard_size, shard_offset = adjust_marlin_shard(
  467. param, shard_size, shard_offset)
  468. use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
  469. False)
  470. if use_bitsandbytes_4bit:
  471. shard_size = loaded_weight.shape[output_dim]
  472. shard_offset = loaded_weight.shape[output_dim] * \
  473. loaded_shard_id
  474. if is_gguf_weight:
  475. tp_size = get_tensor_model_parallel_world_size()
  476. output_dim = getattr(param, "output_dim", None)
  477. shard_shape = list(loaded_weight.shape)
  478. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  479. param.shard_id.append(loaded_shard_id)
  480. param.shard_size[loaded_shard_id] = shard_shape
  481. input_dim = getattr(param, "input_dim", None)
  482. input_size = loaded_weight.shape[input_dim]
  483. param_data = param_data.narrow(input_dim, 0, input_size)
  484. param_data = param_data.narrow(output_dim, shard_offset,
  485. shard_size)
  486. if self.quant_config is None:
  487. start_idx = get_current_tp_rank_partition_offset(
  488. loaded_weight.shape[output_dim], tp_rank, tp_size)
  489. else:
  490. start_idx = tp_rank * shard_size
  491. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  492. shard_size)
  493. # Special case for AQLM codebooks.
  494. elif is_metadata:
  495. # metadata indicates fixed size concatenated along dim 0
  496. shard_size = loaded_weight.shape[0]
  497. shard_offset = loaded_shard_id * shard_size
  498. param_data = param_data.narrow(0, shard_offset, shard_size)
  499. # Special case for per-tensor scales in fused case.
  500. elif needs_scalar_to_array:
  501. param_data, loaded_weight = adjust_scalar_to_fused_array(
  502. param_data, loaded_weight, loaded_shard_id)
  503. else:
  504. ignore_warning = getattr(param, "ignore_warning", False)
  505. if not ignore_warning:
  506. logger.warning(
  507. "Loading a weight without `output_dim` attribute in "
  508. "MergedColumnParallelLinear, assume the weight is "
  509. "the same for all partitions.")
  510. assert param_data.shape == loaded_weight.shape
  511. param_data.copy_(loaded_weight)
  512. def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
  513. loaded_weight: torch.Tensor):
  514. """
  515. Handle special case for models where MLP layers are already
  516. fused on disk. In this case, we have no shard id. This function
  517. determines the shard id by splitting these layers and then calls
  518. the weight loader using the shard id.
  519. An example of a model with these fused layers:
  520. https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
  521. """
  522. current_shard_offset = 0
  523. shard_offsets: List[Tuple[int, int, int]] = []
  524. for i, output_size in enumerate(self.output_sizes):
  525. shard_offsets.append((i, current_shard_offset, output_size))
  526. current_shard_offset += output_size
  527. for shard_id, shard_offset, shard_size in shard_offsets:
  528. # Special case for Quantization.
  529. # If quantized, we need to adjust the offset and size to account
  530. # for the packing.
  531. if isinstance(param, (PackedColumnParameter,
  532. PackedAphroditeParameter,
  533. )) and param.packed_dim == param.output_dim:
  534. shard_size, shard_offset = \
  535. param.adjust_shard_indexes_for_packing(
  536. shard_size=shard_size, shard_offset=shard_offset)
  537. loaded_weight_shard = loaded_weight.narrow(param.output_dim,
  538. shard_offset,
  539. shard_size)
  540. self.weight_loader_v2(param, loaded_weight_shard, shard_id)
  541. def weight_loader_v2(self,
  542. param: BaseAphroditeParameter,
  543. loaded_weight: torch.Tensor,
  544. loaded_shard_id: Optional[int] = None):
  545. if loaded_shard_id is None:
  546. if isinstance(param, PerTensorScaleParameter):
  547. param.load_merged_column_weight(loaded_weight=loaded_weight,
  548. shard_id=0)
  549. return
  550. elif type(param) in (RowAphroditeParameter, BaseAphroditeParameter):
  551. param.load_merged_column_weight(loaded_weight=loaded_weight)
  552. return
  553. self._load_fused_module_from_checkpoint(param, loaded_weight)
  554. return
  555. assert loaded_shard_id < len(self.output_sizes)
  556. tp_size = get_tensor_model_parallel_world_size()
  557. if isinstance(param, BlockQuantScaleParameter):
  558. from aphrodite.quantization.fp8 import (Fp8LinearMethod,
  559. Fp8MoEMethod)
  560. assert self.quant_method is not None
  561. assert isinstance(self.quant_method,
  562. (Fp8LinearMethod, Fp8MoEMethod))
  563. weight_block_size = self.quant_method.quant_config.weight_block_size
  564. assert weight_block_size is not None
  565. block_n, _ = weight_block_size[0], weight_block_size[1]
  566. shard_offset = (
  567. (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
  568. block_n) // tp_size
  569. shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
  570. block_n // tp_size)
  571. else:
  572. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  573. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  574. param.load_merged_column_weight(loaded_weight=loaded_weight,
  575. shard_id=loaded_shard_id,
  576. shard_offset=shard_offset,
  577. shard_size=shard_size)
  578. class QKVParallelLinear(ColumnParallelLinear):
  579. """Linear layers for the attention's QKV transformation.
  580. Linear layers for the linear transformation of the query, key, and value
  581. vectors in the attention layer. The weight matrix is concatenated along
  582. the output dimension. The layer is parallelized along the head dimension.
  583. When the number of key/value heads is smaller than the number of query
  584. heads (e.g., multi-query/grouped-query attention), the key/value head may
  585. be replicated while the query heads are partitioned.
  586. Args:
  587. hidden_size: input hidden state size of the transformer.
  588. head_size: size of each attention head.
  589. total_num_heads: total number of attention query heads.
  590. total_num_kv_heads: total number of attention key/value heads. If
  591. None, assume total_num_kv_heads = total_num_heads.
  592. bias: If true, add bias.
  593. skip_bias_add: This was added to enable performance optimizations where
  594. bias can be fused with other element-wise operations. we
  595. skip adding bias but instead return it.
  596. params_dtype: Data type for the parameters.
  597. quant_config: Quantization configure.
  598. prefix: The name of the layer in the state dict, including all parents
  599. (e.g. model.layers.0.qkv_proj)
  600. """
  601. def __init__(self,
  602. hidden_size: int,
  603. head_size: int,
  604. total_num_heads: int,
  605. total_num_kv_heads: Optional[int] = None,
  606. bias: bool = True,
  607. skip_bias_add: bool = False,
  608. params_dtype: Optional[torch.dtype] = None,
  609. quant_config: Optional[QuantizationConfig] = None,
  610. prefix: str = ""):
  611. self.hidden_size = hidden_size
  612. self.head_size = head_size
  613. self.total_num_heads = total_num_heads
  614. self.quant_config = quant_config
  615. if total_num_kv_heads is None:
  616. total_num_kv_heads = total_num_heads
  617. self.total_num_kv_heads = total_num_kv_heads
  618. # Divide the weight matrix along the last dimension.
  619. tp_size = get_tensor_model_parallel_world_size()
  620. tp_rank = get_tensor_model_parallel_rank()
  621. if quant_config is None:
  622. self.num_heads_per_kv_head = (self.total_num_heads //
  623. self.total_num_kv_heads)
  624. self.num_kv_heads = get_current_tp_rank_partition_size(
  625. self.total_num_kv_heads, tp_rank, tp_size)
  626. self.num_heads = self.num_kv_heads * self.num_heads_per_kv_head
  627. self.num_kv_head_replicas = 1
  628. else:
  629. self.num_heads = divide(self.total_num_heads, tp_size)
  630. if tp_size >= self.total_num_kv_heads:
  631. self.num_kv_heads = 1
  632. self.num_kv_head_replicas = divide(tp_size,
  633. self.total_num_kv_heads)
  634. elif tp_size < self.total_num_kv_heads and quant_config is not None:
  635. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  636. self.num_kv_head_replicas = 1
  637. input_size = self.hidden_size
  638. output_size = (self.num_heads +
  639. 2 * self.num_kv_heads) * tp_size * self.head_size
  640. self.output_sizes = [
  641. self.num_heads * self.head_size * tp_size, # q_proj
  642. self.num_kv_heads * self.head_size * tp_size, # k_proj
  643. self.num_kv_heads * self.head_size * tp_size, # v_proj
  644. ]
  645. super().__init__(input_size=input_size,
  646. output_size=output_size,
  647. bias=bias,
  648. gather_output=False,
  649. skip_bias_add=skip_bias_add,
  650. params_dtype=params_dtype,
  651. quant_config=quant_config,
  652. prefix=prefix)
  653. def _get_shard_offset_mapping(self, loaded_shard_id: str):
  654. shard_offset_mapping = {
  655. "q": 0,
  656. "k": self.num_heads * self.head_size,
  657. "v": (self.num_heads + self.num_kv_heads) * self.head_size,
  658. "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size
  659. }
  660. return shard_offset_mapping.get(loaded_shard_id)
  661. def _get_shard_size_mapping(self, loaded_shard_id: str):
  662. shard_size_mapping = {
  663. "q": self.num_heads * self.head_size,
  664. "k": self.num_kv_heads * self.head_size,
  665. "v": self.num_kv_heads * self.head_size,
  666. }
  667. return shard_size_mapping.get(loaded_shard_id)
  668. def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
  669. loaded_weight: torch.Tensor):
  670. """
  671. Handle special case for models where QKV layers are already
  672. fused on disk. In this case, we have no shard id. This function
  673. determmines the shard id by splitting these layers and then calls
  674. the weight loader using the shard id.
  675. An example of a model with these fused layers:
  676. https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
  677. """
  678. shard_offsets = [
  679. # (shard_id, shard_offset, shard_size)
  680. ("q", 0, self.total_num_heads * self.head_size),
  681. ("k", self.total_num_heads * self.head_size,
  682. self.total_num_kv_heads * self.head_size),
  683. ("v",
  684. (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
  685. self.total_num_kv_heads * self.head_size),
  686. ]
  687. for shard_id, shard_offset, shard_size in shard_offsets:
  688. # Special case for Quantization.
  689. # If quantized, we need to adjust the offset and size to account
  690. # for the packing.
  691. if isinstance(param, (PackedColumnParameter,
  692. PackedAphroditeParameter,
  693. )) and param.packed_dim == param.output_dim:
  694. shard_size, shard_offset = \
  695. param.adjust_shard_indexes_for_packing(
  696. shard_size=shard_size, shard_offset=shard_offset)
  697. loaded_weight_shard = loaded_weight.narrow(param.output_dim,
  698. shard_offset,
  699. shard_size)
  700. self.weight_loader_v2(param, loaded_weight_shard, shard_id)
  701. def weight_loader_v2(self,
  702. param: BaseAphroditeParameter,
  703. loaded_weight: torch.Tensor,
  704. loaded_shard_id: Optional[str] = None):
  705. if loaded_shard_id is None: # special case for certain models
  706. if isinstance(param, PerTensorScaleParameter):
  707. param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
  708. return
  709. elif type(param) in (RowAphroditeParameter, BaseAphroditeParameter):
  710. param.load_qkv_weight(loaded_weight=loaded_weight)
  711. return
  712. self._load_fused_module_from_checkpoint(param, loaded_weight)
  713. return
  714. assert loaded_shard_id in ["q", "k", "v"]
  715. shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
  716. shard_size = self._get_shard_size_mapping(loaded_shard_id)
  717. param.load_qkv_weight(loaded_weight=loaded_weight,
  718. num_heads=self.num_kv_head_replicas,
  719. shard_id=loaded_shard_id,
  720. shard_offset=shard_offset,
  721. shard_size=shard_size)
  722. def weight_loader(self,
  723. param: Parameter,
  724. loaded_weight: torch.Tensor,
  725. loaded_shard_id: Optional[str] = None):
  726. # Special case for GGUF
  727. # initialize GGUF param after we know the quantize type
  728. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  729. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  730. if is_gguf_weight_type and loaded_shard_id is not None:
  731. idx_map = {"q": 0, "k": 1, "v": 2}
  732. param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
  733. param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
  734. return
  735. if is_gguf_weight and isinstance(param, UninitializedParameter):
  736. from gguf.constants import GGML_QUANT_SIZES
  737. ori_shape = param.tensor_shape
  738. weight_types = self.qweight_type.shard_weight_type.values()
  739. row_size = []
  740. for weight_type in weight_types:
  741. block_size, type_size = GGML_QUANT_SIZES[weight_type]
  742. row_size.append(ori_shape[1] // block_size * type_size)
  743. q_shape = (ori_shape[0], max(row_size))
  744. param.materialize(q_shape, dtype=loaded_weight.dtype)
  745. param_data = param.data
  746. output_dim = getattr(param, "output_dim", None)
  747. # Special case for AQLM codebooks.
  748. is_metadata = getattr(param, "is_metadata", False)
  749. # Special case for per-tensor scales in fused case.
  750. needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
  751. if loaded_shard_id is None:
  752. # Loaded weight is already fused on disk (qkv/mlp).
  753. if output_dim is None:
  754. if needs_scalar_to_array:
  755. param_data, loaded_weight = adjust_scalar_to_fused_array(
  756. param_data, loaded_weight, 0)
  757. assert param_data.shape == loaded_weight.shape
  758. param_data.copy_(loaded_weight)
  759. return
  760. shard_offsets = [
  761. # (shard_id, shard_offset, shard_size)
  762. ("q", 0, self.total_num_heads * self.head_size),
  763. ("k", self.total_num_heads * self.head_size,
  764. self.total_num_kv_heads * self.head_size),
  765. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  766. self.head_size, self.total_num_kv_heads * self.head_size),
  767. ]
  768. packed_dim = getattr(param, "packed_dim", None)
  769. for shard_id, shard_offset, shard_size in shard_offsets:
  770. # Special case for Quantized Weights.
  771. # If quantized, we need to adjust the offset and size to account
  772. # for the packing.
  773. if packed_dim == output_dim:
  774. shard_size = shard_size // param.pack_factor
  775. shard_offset = shard_offset // param.pack_factor
  776. # Special case for Marlin.
  777. shard_size, shard_offset = adjust_marlin_shard(
  778. param, shard_size, shard_offset)
  779. loaded_weight_shard = loaded_weight.narrow(
  780. output_dim, shard_offset, shard_size)
  781. self.weight_loader(param, loaded_weight_shard, shard_id)
  782. return
  783. tp_rank = get_tensor_model_parallel_rank()
  784. assert loaded_shard_id in ["q", "k", "v"]
  785. # If output dim is defined, use the default loading process.
  786. if output_dim is not None:
  787. if loaded_shard_id == "q":
  788. shard_offset = 0
  789. shard_size = self.num_heads * self.head_size
  790. if self.quant_config is None:
  791. multiple_of = self.head_size * self.num_heads_per_kv_head
  792. elif loaded_shard_id == "k":
  793. shard_offset = self.num_heads * self.head_size
  794. shard_size = self.num_kv_heads * self.head_size
  795. if self.quant_config is None:
  796. multiple_of = self.head_size
  797. elif loaded_shard_id == "v":
  798. shard_offset = (self.num_heads +
  799. self.num_kv_heads) * self.head_size
  800. shard_size = self.num_kv_heads * self.head_size
  801. if self.quant_config is None:
  802. multiple_of = self.head_size
  803. # Special case for Quantized Weights.
  804. # If quantized, we need to adjust the offset and size to account
  805. # for the packing.
  806. packed_dim = getattr(param, "packed_dim", None)
  807. if packed_dim == output_dim:
  808. shard_size = shard_size // param.pack_factor
  809. shard_offset = shard_offset // param.pack_factor
  810. if self.quant_config is None:
  811. multiple_of = multiple_of // param.pack_factor
  812. # Special case for Marlin.
  813. shard_size, shard_offset = adjust_marlin_shard(
  814. param, shard_size, shard_offset)
  815. use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
  816. False)
  817. if use_bitsandbytes_4bit:
  818. orig_qkv_offsets = {
  819. "q": (0, self.num_heads * self.head_size),
  820. "k": (self.num_heads * self.head_size,
  821. self.num_kv_heads * self.head_size),
  822. "v":
  823. ((self.num_heads + self.num_kv_heads) * self.head_size,
  824. self.num_kv_heads * self.head_size),
  825. "total":
  826. ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
  827. 0)
  828. }
  829. shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
  830. param, orig_qkv_offsets, loaded_shard_id)
  831. if is_gguf_weight:
  832. tp_size = get_tensor_model_parallel_world_size()
  833. output_dim = getattr(param, "output_dim", None)
  834. shard_shape = list(loaded_weight.shape)
  835. shard_shape[output_dim] = shard_shape[output_dim] // tp_size
  836. param.shard_id.append(loaded_shard_id)
  837. param.shard_size[loaded_shard_id] = shard_shape
  838. input_dim = getattr(param, "input_dim", None)
  839. input_size = loaded_weight.shape[input_dim]
  840. param_data = param_data.narrow(input_dim, 0, input_size)
  841. param_data = param_data.narrow(output_dim, shard_offset,
  842. shard_size)
  843. if self.quant_config is None:
  844. tp_size = get_tensor_model_parallel_world_size()
  845. total_size = loaded_weight.shape[output_dim]
  846. start_idx = get_current_tp_rank_partition_offset(
  847. total_size, tp_rank, tp_size, multiple_of=multiple_of)
  848. else:
  849. if loaded_shard_id == "q":
  850. shard_id = tp_rank
  851. else:
  852. shard_id = tp_rank // self.num_kv_head_replicas
  853. start_idx = shard_id * shard_size
  854. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  855. shard_size)
  856. # Special case for for AQLM codebooks.
  857. elif is_metadata:
  858. # metadata indicates fixed size concatenated along dim 0
  859. shard_size = loaded_weight.shape[0]
  860. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  861. param_data = param_data.narrow(0, shard_index * shard_size,
  862. shard_size)
  863. # Special case for per-tensor scales in fused case.
  864. elif needs_scalar_to_array:
  865. param_data, loaded_weight = adjust_scalar_to_fused_array(
  866. param_data, loaded_weight, loaded_shard_id)
  867. else:
  868. ignore_warning = getattr(param, "ignore_warning", False)
  869. if not ignore_warning:
  870. logger.warning(
  871. "Loading a weight without `output_dim` attribute in "
  872. "QKVParallelLinear, assume the weight is the same "
  873. "for all partitions.")
  874. assert param_data.shape == loaded_weight.shape
  875. param_data.copy_(loaded_weight)
  876. class RowParallelLinear(LinearBase):
  877. """Linear layer with row parallelism.
  878. The linear layer is defined as Y = XA + b. A is parallelized along
  879. its first dimension and X along its second dimension as:
  880. - -
  881. | A_1 |
  882. | . |
  883. A = | . | X = [X_1, ..., X_p]
  884. | . |
  885. | A_p |
  886. - -
  887. Arguments:
  888. input_size: first dimension of matrix A.
  889. output_size: second dimension of matrix A.
  890. bias: If true, add bias. Note that bias is not parallelized.
  891. input_is_parallel: If true, we assume that the input is already
  892. split across the GPUs and we do not split
  893. again.
  894. skip_bias_add: This was added to enable performance optimization where
  895. bias can be fused with other element-wise operations.
  896. We skip adding bias but instead return it.
  897. params_dtype: Data type for the parameters.
  898. quant_config: Quantization configure.
  899. partition_multiple_of: Partitions will be divided,
  900. so each partition is a multiple of this number.
  901. """
  902. def __init__(self,
  903. input_size: int,
  904. output_size: int,
  905. bias: bool = True,
  906. input_is_parallel: bool = True,
  907. skip_bias_add: bool = False,
  908. params_dtype: Optional[torch.dtype] = None,
  909. reduce_results: bool = True,
  910. quant_config: Optional[QuantizationConfig] = None,
  911. partition_multiple_of: int = 1,
  912. prefix: str = ""):
  913. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  914. quant_config, prefix)
  915. self.input_is_parallel = input_is_parallel
  916. self.reduce_results = reduce_results
  917. self.quant_config = quant_config
  918. # Divide the weight matrix along the last dimension.
  919. self.tp_rank = get_tensor_model_parallel_rank()
  920. self.tp_size = get_tensor_model_parallel_world_size()
  921. self.tp_rank = get_tensor_model_parallel_rank()
  922. if quant_config is None:
  923. self.partition_multiple_of = partition_multiple_of
  924. self.input_size_per_partition = get_current_tp_rank_partition_size(
  925. input_size, self.tp_rank, self.tp_size, partition_multiple_of)
  926. else:
  927. self.input_size_per_partition = divide(input_size, self.tp_size)
  928. assert self.quant_method is not None
  929. self.quant_method.create_weights(
  930. layer=self,
  931. input_size_per_partition=self.input_size_per_partition,
  932. output_partition_sizes=[self.output_size],
  933. input_size=self.input_size,
  934. output_size=self.output_size,
  935. params_dtype=self.params_dtype,
  936. weight_loader=(
  937. self.weight_loader_v2 if self.quant_method.__class__.__name__
  938. in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
  939. if not reduce_results and (bias and not skip_bias_add):
  940. raise ValueError("When not reduce the results, adding bias to the "
  941. "results can lead to incorrect results")
  942. if bias:
  943. self.bias = Parameter(
  944. torch.empty(self.output_size, dtype=params_dtype))
  945. set_weight_attrs(self.bias, {
  946. "output_dim": 0,
  947. "weight_loader": self.weight_loader,
  948. })
  949. else:
  950. self.register_parameter("bias", None)
  951. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  952. tp_size = get_tensor_model_parallel_world_size()
  953. input_dim = getattr(param, "input_dim", None)
  954. # Special case for GGUF
  955. is_gguf_weight = getattr(param, "is_gguf_weight", False)
  956. is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
  957. if is_gguf_weight_type:
  958. param.weight_type = loaded_weight.item()
  959. # Materialize GGUF UninitializedParameter
  960. if is_gguf_weight and isinstance(param, UninitializedParameter):
  961. weight_shape = list(loaded_weight.shape)
  962. if input_dim:
  963. weight_shape[input_dim] = weight_shape[input_dim] // tp_size
  964. param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
  965. param_data = param.data
  966. if input_dim is not None:
  967. shard_size = param_data.shape[input_dim]
  968. if self.quant_config is None:
  969. start_idx = get_current_tp_rank_partition_offset(
  970. self.input_size,
  971. self.tp_rank,
  972. self.tp_size,
  973. multiple_of=self.partition_multiple_of)
  974. else:
  975. start_idx = self.tp_rank * shard_size
  976. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  977. shard_size)
  978. # Special case for loading scales off disk, which often do not
  979. # have a shape (such as in the case of AutoFP8).
  980. if len(loaded_weight.shape) == 0:
  981. loaded_weight = loaded_weight.reshape(1)
  982. assert param_data.shape == loaded_weight.shape
  983. param_data.copy_(loaded_weight)
  984. def weight_loader_v2(self, param: BaseAphroditeParameter,
  985. loaded_weight: torch.Tensor):
  986. # Special case for loading scales off disk, which often do not
  987. # have a shape (such as in the case of AutoFP8).
  988. if len(loaded_weight.shape) == 0:
  989. assert loaded_weight.numel() == 1
  990. loaded_weight = loaded_weight.reshape(1)
  991. param.load_row_parallel_weight(loaded_weight=loaded_weight)
  992. def forward(self, input_):
  993. if self.input_is_parallel:
  994. input_parallel = input_
  995. else:
  996. tp_rank = get_tensor_model_parallel_rank()
  997. splitted_input = split_tensor_along_last_dim(
  998. input_, num_partitions=self.tp_size)
  999. input_parallel = splitted_input[tp_rank].contiguous()
  1000. # Matrix multiply.
  1001. assert self.quant_method is not None
  1002. # Only fuse bias add into GEMM for rank 0 (this ensures that
  1003. # bias will not get added more than once in TP>1 case)
  1004. bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
  1005. output_parallel = self.quant_method.apply(self,
  1006. input_parallel,
  1007. bias=bias_)
  1008. if self.reduce_results and self.tp_size > 1:
  1009. output = tensor_model_parallel_all_reduce(output_parallel)
  1010. else:
  1011. output = output_parallel
  1012. output_bias = self.bias if self.skip_bias_add else None
  1013. return output, output_bias
  1014. def extra_repr(self) -> str:
  1015. s = f"input_features={self.input_size_per_partition}"
  1016. s += f", output_features={self.output_size}"
  1017. s += f", bias={self.bias is not None}"
  1018. s += f", tp_size={self.tp_size}"
  1019. s += f", reduce_results={self.reduce_results}"
  1020. return s