linear.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784
  1. from abc import abstractmethod
  2. from typing import List, Optional
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter
  7. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  8. get_tensor_model_parallel_world_size,
  9. split_tensor_along_last_dim,
  10. tensor_model_parallel_all_gather,
  11. tensor_model_parallel_all_reduce)
  12. from aphrodite.modeling.utils import set_weight_attrs
  13. from aphrodite.quantization.base_config import (QuantizationConfig,
  14. QuantizeMethodBase)
  15. def adjust_marlin_shard(param, shard_size, shard_offset):
  16. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  17. if marlin_tile_size is None:
  18. return shard_size, shard_offset
  19. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  20. def adjust_bitblas_shard(param, shard_size, shard_offset):
  21. bitblas_tile_size = getattr(param, "bitblas_tile_size", None)
  22. weight_propagation = getattr(param, "weight_propagation", None)
  23. if weight_propagation and bitblas_tile_size is not None:
  24. return (shard_size // bitblas_tile_size,
  25. shard_offset // bitblas_tile_size)
  26. return shard_size, shard_offset
  27. class LinearMethodBase(QuantizeMethodBase):
  28. """Base class for different (maybe quantized) linear methods."""
  29. @abstractmethod
  30. def create_weights(self, layer: torch.nn.Module,
  31. input_size_per_partition: int,
  32. output_partition_sizes: List[int], input_size: int,
  33. output_size: int, params_dtype: torch.dtype,
  34. **extra_weight_attrs):
  35. """Create weights for a linear layer.
  36. The weights will be set as attributes of the layer.
  37. Args:
  38. layer: The layer that is using the LinearMethodBase factory.
  39. input_size_per_partition: Size of the weight input dim on rank X.
  40. output_partition_sizes: Sizes of the output dim of each logical
  41. weight on rank X. E.g., output_partition_sizes for QKVLinear
  42. is a list contains the width of Wq, Wk, Wv on rank X.
  43. input_size: Size of the input dim of the weight across all ranks.
  44. output_size: Size of the output dim of the weight across all ranks.
  45. params_dtype: Datatype of the parameters.
  46. """
  47. raise NotImplementedError
  48. @abstractmethod
  49. def apply(self,
  50. layer: torch.nn.Module,
  51. x: torch.Tensor,
  52. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  53. """Apply the weights in layer to the input tensor.
  54. Expects create_weights to have been called before on the layer."""
  55. raise NotImplementedError
  56. class UnquantizedLinearMethod(LinearMethodBase):
  57. """Linear method without quantization.
  58. Args:
  59. separate_bias_add: If true, add bias separately after matrix
  60. multiplication.
  61. """
  62. def __init__(self, separate_bias_add: bool = False):
  63. self.separate_bias_add = separate_bias_add
  64. def create_weights(self, layer: torch.nn.Module,
  65. input_size_per_partition: int,
  66. output_partition_sizes: List[int], input_size: int,
  67. output_size: int, params_dtype: torch.dtype,
  68. **extra_weight_attrs):
  69. weight = Parameter(torch.empty(sum(output_partition_sizes),
  70. input_size_per_partition,
  71. dtype=params_dtype),
  72. requires_grad=False)
  73. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  74. layer.register_parameter("weight", weight)
  75. set_weight_attrs(weight, extra_weight_attrs)
  76. def apply(self,
  77. layer: torch.nn.Module,
  78. x: torch.Tensor,
  79. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  80. weight = layer.weight
  81. if self.separate_bias_add:
  82. if bias is not None:
  83. return F.linear(x, weight) + bias
  84. return F.linear(x, weight)
  85. return F.linear(x, weight, bias)
  86. class LinearBase(torch.nn.Module):
  87. """Base linear layer.
  88. Args:
  89. input_size: input dimension of the linear layer.
  90. output_size: output dimension of the linear layer.
  91. bias: If true, add bias.
  92. skip_bias_add: If true, skip adding bias but instead return it.
  93. params_dtype: Data type for the parameters.
  94. quant_config: Quantization config..
  95. """
  96. def __init__(
  97. self,
  98. input_size: int,
  99. output_size: int,
  100. skip_bias_add: bool = False,
  101. params_dtype: Optional[torch.dtype] = None,
  102. quant_config: Optional[QuantizationConfig] = None,
  103. ):
  104. super().__init__()
  105. # Keep input parameters
  106. self.input_size = input_size
  107. self.output_size = output_size
  108. self.skip_bias_add = skip_bias_add
  109. if params_dtype is None:
  110. params_dtype = torch.get_default_dtype()
  111. self.params_dtype = params_dtype
  112. if quant_config is None:
  113. self.quant_method = UnquantizedLinearMethod()
  114. else:
  115. self.quant_method = quant_config.get_quant_method(self)
  116. def forward(self, x: torch.Tensor) -> torch.Tensor:
  117. raise NotImplementedError
  118. class ReplicatedLinear(LinearBase):
  119. """Replicated linear layer.
  120. Args:
  121. input_size: input dimension of the linear layer.
  122. output_size: output dimension of the linear layer.
  123. bias: If true, add bias.
  124. skip_bias_add: If true, skip adding bias but instead return it.
  125. params_dtype: Data type for the parameters.
  126. quant_config: Quantization configure.
  127. """
  128. def __init__(self,
  129. input_size: int,
  130. output_size: int,
  131. bias: bool = True,
  132. skip_bias_add: bool = False,
  133. params_dtype: Optional[torch.dtype] = None,
  134. quant_config: Optional[QuantizationConfig] = None):
  135. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  136. quant_config)
  137. self.quant_method.create_weights(self, self.input_size,
  138. [self.output_size], self.input_size,
  139. self.output_size, self.params_dtype)
  140. if bias:
  141. self.bias = Parameter(
  142. torch.empty(self.output_size, dtype=self.params_dtype))
  143. set_weight_attrs(self.bias, {"output_dim": 0})
  144. else:
  145. self.register_parameter("bias", None)
  146. def forward(self, x: torch.Tensor) -> torch.Tensor:
  147. bias = self.bias if not self.skip_bias_add else None
  148. output = self.quant_method.apply(self, x, bias)
  149. output_bias = self.bias if self.skip_bias_add else None
  150. return output, output_bias
  151. def extra_repr(self) -> str:
  152. s = f"in_features={self.input_size}"
  153. s += f", output_features={self.output_size}"
  154. s += f", bias={self.bias is not None}"
  155. return s
  156. class ColumnParallelLinear(LinearBase):
  157. """Linear layer with column parallelism.
  158. The linear layer is defined as Y = XA + b. A is parallelized along
  159. its second dimension as A = [A_1, ..., A_p].
  160. Args:
  161. input_size: first dimension of matrix A.
  162. output_size: second dimension of matrix A.
  163. bias: If true, add bias.
  164. gather_output: If true, call all-gather on output and make Y available
  165. to all GPUs, otherwise, every GPU will have its output
  166. which is Y_i = XA_i
  167. skip_bias_add: This was added to enable performance optimizations where
  168. bias can be fused with other element-wise operations. we
  169. skip adding bias but instead return it.
  170. params_dtype: Data type for the parameters.
  171. quant_config: Quantization configure.
  172. output_sizes: list of output sizes packed into one output, like for QKV
  173. the list would be size 3.
  174. """
  175. def __init__(self,
  176. input_size: int,
  177. output_size: int,
  178. bias: bool = True,
  179. gather_output: bool = False,
  180. skip_bias_add: bool = False,
  181. params_dtype: Optional[torch.dtype] = None,
  182. quant_config: Optional[QuantizationConfig] = None,
  183. output_sizes: Optional[List[int]] = None):
  184. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  185. quant_config)
  186. self.gather_output = gather_output
  187. # Divide the weight matrix along the last dimension.
  188. tp_size = get_tensor_model_parallel_world_size()
  189. assert self.quant_method is not None
  190. self.output_size_per_partition = divide(self.output_size, tp_size)
  191. self.output_partition_sizes = [self.output_size_per_partition]
  192. # If QKV or MergedColumn, use output size of each partition.
  193. if hasattr(self, "output_sizes"):
  194. self.output_partition_sizes = [
  195. divide(output_size, tp_size)
  196. for output_size in self.output_sizes
  197. ]
  198. if output_sizes is None:
  199. output_sizes = [output_size]
  200. self.quant_method.create_weights(
  201. layer=self,
  202. input_size_per_partition=self.input_size,
  203. output_partition_sizes=self.output_partition_sizes,
  204. input_size=self.input_size,
  205. output_size=self.output_size,
  206. params_dtype=self.params_dtype,
  207. weight_loader=self.weight_loader)
  208. if bias:
  209. self.bias = Parameter(
  210. torch.empty(self.output_size_per_partition,
  211. dtype=params_dtype))
  212. set_weight_attrs(self.bias, {
  213. "output_dim": 0,
  214. "weight_loader": self.weight_loader,
  215. })
  216. else:
  217. self.register_parameter("bias", None)
  218. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  219. # Special case for Fp8 scales.
  220. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
  221. None)
  222. tp_rank = get_tensor_model_parallel_rank()
  223. output_dim = getattr(param, "output_dim", None)
  224. param_data = param.data
  225. if output_dim is not None:
  226. shard_size = param_data.shape[output_dim]
  227. start_idx = tp_rank * shard_size
  228. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  229. shard_size)
  230. # Special case for Fp8 scales.
  231. elif fp8_scales_shard_indexer is not None:
  232. param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
  233. loaded_weight,
  234. shard_id=0)
  235. assert param_data.dtype == loaded_weight.dtype, (
  236. f"{param_data.dtype} != {loaded_weight.dtype}")
  237. assert param_data.shape == loaded_weight.shape, (
  238. f"{param_data.shape} != {loaded_weight.shape}")
  239. param_data.copy_(loaded_weight)
  240. def forward(self, input_):
  241. bias = self.bias if not self.skip_bias_add else None
  242. # Matrix multiply.
  243. output_parallel = self.quant_method.apply(self, input_, bias)
  244. if self.gather_output:
  245. # All-gather across the partitions.
  246. output = tensor_model_parallel_all_gather(output_parallel)
  247. else:
  248. output = output_parallel
  249. output_bias = self.bias if self.skip_bias_add else None
  250. return output, output_bias
  251. def extra_repr(self) -> str:
  252. s = f"in_features={self.input_size}"
  253. s += f", output_features={self.output_size_per_partition}"
  254. s += f", bias={self.bias is not None}"
  255. s += f", tp_size={get_tensor_model_parallel_world_size()}"
  256. s += f", gather_output={self.gather_output}"
  257. return s
  258. class MergedColumnParallelLinear(ColumnParallelLinear):
  259. """Packed linear layers with column parallelism.
  260. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  261. along the output dimension. When the weight matrix is loaded, the
  262. different partitions are sharded separately.
  263. Args:
  264. input_size: input dimension of the linear layer.
  265. output_sizes: list of output dimensions of the linear layer.
  266. bias: If true, add bias.
  267. gather_output: If true, call all-gather on output and make the output
  268. available to all GPUs, otherwise, every GPU will have
  269. its own output.
  270. skip_bias_add: This was added to enable performance optimizations where
  271. bias can be fused with other element-wise operations. we
  272. skip adding bias but instead return it.
  273. params_dtype: Data type for the parameters.
  274. quant_config: Quantization configure.
  275. """
  276. def __init__(self,
  277. input_size: int,
  278. output_sizes: List[int],
  279. bias: bool = True,
  280. gather_output: bool = False,
  281. skip_bias_add: bool = False,
  282. params_dtype: Optional[torch.dtype] = None,
  283. quant_config: Optional[QuantizationConfig] = None):
  284. self.output_sizes = output_sizes
  285. tp_size = get_tensor_model_parallel_world_size()
  286. assert all(output_size % tp_size == 0 for output_size in output_sizes)
  287. super().__init__(input_size=input_size,
  288. output_size=sum(output_sizes),
  289. bias=bias,
  290. gather_output=gather_output,
  291. skip_bias_add=skip_bias_add,
  292. params_dtype=params_dtype,
  293. quant_config=quant_config)
  294. def weight_loader(self,
  295. param: Parameter,
  296. loaded_weight: torch.Tensor,
  297. loaded_shard_id: Optional[int] = None):
  298. param_data = param.data
  299. output_dim = getattr(param, "output_dim", None)
  300. # Special case for AQLM codebooks.
  301. is_metadata = getattr(param, "is_metadata", False)
  302. param_shard_splitter = getattr(param, "shard_splitter", None)
  303. if output_dim is not None and param_shard_splitter is not None:
  304. raise NotImplementedError(
  305. "We do not currently support output_dim != None and "
  306. "shard_splitter != None for a parameter. Please open an issue."
  307. )
  308. # If a parameter has defined a shard_splitter to be used for
  309. # the weight, it should be applied before the weight is
  310. # loaded/copied to the parameter. The shard_splitter applies
  311. # logic by using the loaded_shard_id to ensure that the loaded
  312. # param is loaded to the correct location
  313. # within the parameter defined by the linear method.
  314. if loaded_shard_id is None and param_shard_splitter is not None:
  315. raise NotImplementedError(
  316. "We do not currently support loaded_shard_id == None and "
  317. "shard_splitter != None for a parameter. Please open an issue."
  318. )
  319. # Special case for Fp8 scales.
  320. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
  321. None)
  322. if loaded_shard_id is None:
  323. # Loaded weight is already packed.
  324. if output_dim is None:
  325. assert param_data.shape == loaded_weight.shape
  326. param_data.copy_(loaded_weight)
  327. return
  328. current_shard_offset = 0
  329. shard_offsets = []
  330. for i, output_size in enumerate(self.output_sizes):
  331. shard_offsets.append((i, current_shard_offset, output_size))
  332. current_shard_offset += output_size
  333. packed_dim = getattr(param, "packed_dim", None)
  334. for shard_id, shard_offset, shard_size in shard_offsets:
  335. # If quantized, we need to adjust the offset and size to account
  336. # for the packing.
  337. if packed_dim == output_dim:
  338. shard_size = shard_size // param.pack_factor
  339. shard_offset = shard_offset // param.pack_factor
  340. # If marlin, we need to adjust the offset and size to
  341. # account for the tiling.
  342. shard_size, shard_offset = adjust_marlin_shard(
  343. param, shard_size, shard_offset)
  344. shard_size, shard_offset = adjust_bitblas_shard(
  345. param, shard_size, shard_offset)
  346. loaded_weight_shard = loaded_weight.narrow(
  347. output_dim, shard_offset, shard_size)
  348. self.weight_loader(param, loaded_weight_shard, shard_id)
  349. return
  350. assert loaded_shard_id < len(self.output_sizes)
  351. tp_rank = get_tensor_model_parallel_rank()
  352. tp_size = get_tensor_model_parallel_world_size()
  353. if output_dim is not None:
  354. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  355. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  356. # If quantized, we need to adjust the offset and size to account
  357. # for the packing.
  358. packed_dim = getattr(param, "packed_dim", None)
  359. if packed_dim == output_dim:
  360. shard_size = shard_size // param.pack_factor
  361. shard_offset = shard_offset // param.pack_factor
  362. # If marlin, we need to adjust the offset and size to
  363. # account for the tiling.
  364. shard_size, shard_offset = adjust_marlin_shard(
  365. param, shard_size, shard_offset)
  366. shard_size, shard_offset = adjust_bitblas_shard(
  367. param, shard_size, shard_offset)
  368. param_data = param_data.narrow(output_dim, shard_offset,
  369. shard_size)
  370. start_idx = tp_rank * shard_size
  371. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  372. shard_size)
  373. elif is_metadata:
  374. # metadata indicates fixed size concatenated along dim 0
  375. shard_size = loaded_weight.shape[0]
  376. shard_offset = loaded_shard_id * shard_size
  377. param_data = param_data.narrow(0, shard_offset, shard_size)
  378. # If a param_shard_splitter is defined by the LinearMethod, use it.
  379. elif param_shard_splitter is not None:
  380. logical_widths = getattr(param, "logical_widths", None)
  381. param_data, loaded_weight = param_shard_splitter(
  382. param_data, loaded_weight, loaded_shard_id, logical_widths)
  383. # Special case for Fp8 scales.
  384. elif fp8_scales_shard_indexer is not None:
  385. param_data, loaded_weight = fp8_scales_shard_indexer(
  386. param_data, loaded_weight, loaded_shard_id)
  387. else:
  388. ignore_warning = getattr(param, "ignore_warning", False)
  389. if not ignore_warning:
  390. logger.warning(
  391. "Loading a weight without `output_dim` attribute in "
  392. "MergedColumnParallelLinear, assume the weight is "
  393. "the same for all partitions.")
  394. if fp8_scales_shard_indexer is None:
  395. if len(param_data.shape) == 0:
  396. param_data = param_data.reshape(1)
  397. if len(loaded_weight.shape) == 0:
  398. loaded_weight = loaded_weight.reshape(1)
  399. assert param_data.shape == loaded_weight.shape
  400. param_data.copy_(loaded_weight)
  401. class QKVParallelLinear(ColumnParallelLinear):
  402. """Linear layers for the attention's QKV transformation.
  403. Linear layers for the linear transformation of the query, key, and value
  404. vectors in the attention layer. The weight matrix is concatenated along
  405. the output dimension. The layer is parallelized along the head dimension.
  406. When the number of key/value heads is smaller than the number of query
  407. heads (e.g., multi-query/grouped-query attention), the key/value head may
  408. be replicated while the query heads are partitioned.
  409. Args:
  410. hidden_size: input hidden state size of the transformer.
  411. head_size: size of each attention head.
  412. total_num_heads: total number of attention query heads.
  413. total_num_kv_heads: total number of attention key/value heads. If
  414. None, assume total_num_kv_heads = total_num_heads.
  415. bias: If true, add bias.
  416. skip_bias_add: This was added to enable performance optimizations where
  417. bias can be fused with other element-wise operations. we
  418. skip adding bias but instead return it.
  419. params_dtype: Data type for the parameters.
  420. quant_config: Quantization configure.
  421. """
  422. def __init__(self,
  423. hidden_size: int,
  424. head_size: int,
  425. total_num_heads: int,
  426. total_num_kv_heads: Optional[int] = None,
  427. bias: bool = True,
  428. skip_bias_add: bool = False,
  429. params_dtype: Optional[torch.dtype] = None,
  430. quant_config: Optional[QuantizationConfig] = None):
  431. self.hidden_size = hidden_size
  432. self.head_size = head_size
  433. self.total_num_heads = total_num_heads
  434. if total_num_kv_heads is None:
  435. total_num_kv_heads = total_num_heads
  436. self.total_num_kv_heads = total_num_kv_heads
  437. # Divide the weight matrix along the last dimension.
  438. tp_size = get_tensor_model_parallel_world_size()
  439. self.num_heads = divide(self.total_num_heads, tp_size)
  440. if tp_size >= self.total_num_kv_heads:
  441. self.num_kv_heads = 1
  442. self.num_kv_head_replicas = divide(tp_size,
  443. self.total_num_kv_heads)
  444. else:
  445. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  446. self.num_kv_head_replicas = 1
  447. input_size = self.hidden_size
  448. output_size = (self.num_heads +
  449. 2 * self.num_kv_heads) * tp_size * self.head_size
  450. self.output_sizes = [
  451. self.num_heads * self.head_size * tp_size, # q_proj
  452. self.num_kv_heads * self.head_size * tp_size, # k_proj
  453. self.num_kv_heads * self.head_size * tp_size, # v_proj
  454. ]
  455. super().__init__(input_size=input_size,
  456. output_size=output_size,
  457. bias=bias,
  458. gather_output=False,
  459. skip_bias_add=skip_bias_add,
  460. params_dtype=params_dtype,
  461. quant_config=quant_config)
  462. def weight_loader(self,
  463. param: Parameter,
  464. loaded_weight: torch.Tensor,
  465. loaded_shard_id: Optional[str] = None):
  466. param_data = param.data
  467. output_dim = getattr(param, "output_dim", None)
  468. is_metadata = getattr(param, "is_metadata", False)
  469. param_shard_splitter = getattr(param, "shard_splitter", None)
  470. if output_dim is not None and param_shard_splitter is not None:
  471. raise NotImplementedError(
  472. "We do not currently support output_dim != None and "
  473. "shard_splitter != None for a parameter. Please open an issue."
  474. )
  475. # If a parameter has defined a shard_splitter to be used for
  476. # the weight, it should be applied before the weight is
  477. # loaded/copied to the parameter. The shard_splitter applies
  478. # logic by using the loaded_shard_id to ensure that the loaded
  479. # param is loaded to the correct location
  480. # within the parameter defined by the linear method.
  481. if loaded_shard_id is None and param_shard_splitter is not None:
  482. raise NotImplementedError(
  483. "We do not currently support loaded_shard_id == None and "
  484. "shard_splitter != None for a parameter. Please open an issue."
  485. )
  486. # Special case for Fp8 scales.
  487. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
  488. None)
  489. if loaded_shard_id is None:
  490. # Loaded weight is already packed.
  491. if output_dim is None:
  492. assert param_data.shape == loaded_weight.shape
  493. param_data.copy_(loaded_weight)
  494. return
  495. shard_offsets = [
  496. # (shard_id, shard_offset, shard_size)
  497. ("q", 0, self.total_num_heads * self.head_size),
  498. ("k", self.total_num_heads * self.head_size,
  499. self.total_num_kv_heads * self.head_size),
  500. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  501. self.head_size, self.total_num_kv_heads * self.head_size),
  502. ]
  503. packed_dim = getattr(param, "packed_dim", None)
  504. for shard_id, shard_offset, shard_size in shard_offsets:
  505. # If quantized, we need to adjust the offset and size to account
  506. # for the packing.
  507. if packed_dim == output_dim:
  508. shard_size = shard_size // param.pack_factor
  509. shard_offset = shard_offset // param.pack_factor
  510. # If marlin, we need to adjust the offset and size to
  511. # account for the tiling.
  512. shard_size, shard_offset = adjust_marlin_shard(
  513. param, shard_size, shard_offset)
  514. shard_size, shard_offset = adjust_bitblas_shard(
  515. param, shard_size, shard_offset)
  516. loaded_weight_shard = loaded_weight.narrow(
  517. output_dim, shard_offset, shard_size)
  518. self.weight_loader(param, loaded_weight_shard, shard_id)
  519. return
  520. tp_rank = get_tensor_model_parallel_rank()
  521. assert loaded_shard_id in ["q", "k", "v"]
  522. # If output dim is defined, use the default loading process.
  523. if output_dim is not None:
  524. if loaded_shard_id == "q":
  525. shard_offset = 0
  526. shard_size = self.num_heads * self.head_size
  527. elif loaded_shard_id == "k":
  528. shard_offset = self.num_heads * self.head_size
  529. shard_size = self.num_kv_heads * self.head_size
  530. elif loaded_shard_id == "v":
  531. shard_offset = (self.num_heads +
  532. self.num_kv_heads) * self.head_size
  533. shard_size = self.num_kv_heads * self.head_size
  534. # If quantized, we need to adjust the offset and size to account
  535. # for the packing.
  536. packed_dim = getattr(param, "packed_dim", None)
  537. if packed_dim == output_dim:
  538. shard_size = shard_size // param.pack_factor
  539. shard_offset = shard_offset // param.pack_factor
  540. # If marlin, we need to adjust the offset and size to
  541. # account for the tiling.
  542. shard_size, shard_offset = adjust_marlin_shard(
  543. param, shard_size, shard_offset)
  544. shard_size, shard_offset = adjust_bitblas_shard(
  545. param, shard_size, shard_offset)
  546. param_data = param_data.narrow(output_dim, shard_offset,
  547. shard_size)
  548. if loaded_shard_id == "q":
  549. shard_id = tp_rank
  550. else:
  551. shard_id = tp_rank // self.num_kv_head_replicas
  552. start_idx = shard_id * shard_size
  553. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  554. shard_size)
  555. elif is_metadata:
  556. # metadata indicates fixed size concatenated along dim 0
  557. shard_size = loaded_weight.shape[0]
  558. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  559. param_data = param_data.narrow(0, shard_index * shard_size,
  560. shard_size)
  561. # If a param_shard_splitter is defined by the LinearMethod, use it.
  562. elif param_shard_splitter is not None:
  563. logical_widths = getattr(param, "logical_widths", None)
  564. param_data, loaded_weight = param_shard_splitter(
  565. param_data, loaded_weight, loaded_shard_id, logical_widths)
  566. # Special case for Fp8 scales.
  567. elif fp8_scales_shard_indexer is not None:
  568. param_data, loaded_weight = fp8_scales_shard_indexer(
  569. param_data, loaded_weight, loaded_shard_id)
  570. else:
  571. ignore_warning = getattr(param, "ignore_warning", False)
  572. if not ignore_warning:
  573. logger.warning(
  574. "Loading a weight without `output_dim` attribute in "
  575. "QKVParallelLinear, assume the weight is the same "
  576. "for all partitions.")
  577. if len(param_data.shape) == 0:
  578. param_data = param_data.reshape(1)
  579. if len(loaded_weight.shape) == 0:
  580. loaded_weight = loaded_weight.reshape(1)
  581. assert param_data.shape == loaded_weight.shape
  582. param_data.copy_(loaded_weight)
  583. class RowParallelLinear(LinearBase):
  584. """Linear layer with row parallelism.
  585. The linear layer is defined as Y = XA + b. A is parallelized along
  586. its first dimension and X along its second dimension as:
  587. - -
  588. | A_1 |
  589. | . |
  590. A = | . | X = [X_1, ..., X_p]
  591. | . |
  592. | A_p |
  593. - -
  594. Arguments:
  595. input_size: first dimension of matrix A.
  596. output_size: second dimension of matrix A.
  597. bias: If true, add bias. Note that bias is not parallelized.
  598. input_is_parallel: If true, we assume that the input is already
  599. split across the GPUs and we do not split
  600. again.
  601. skip_bias_add: This was added to enable performance optimization where
  602. bias can be fused with other element-wise operations.
  603. We skip adding bias but instead return it.
  604. params_dtype: Data type for the parameters.
  605. quant_config: Quantization configure.
  606. """
  607. def __init__(self,
  608. input_size: int,
  609. output_size: int,
  610. bias: bool = True,
  611. input_is_parallel: bool = True,
  612. skip_bias_add: bool = False,
  613. params_dtype: Optional[torch.dtype] = None,
  614. reduce_results: bool = True,
  615. quant_config: Optional[QuantizationConfig] = None):
  616. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  617. quant_config)
  618. self.input_is_parallel = input_is_parallel
  619. self.reduce_results = reduce_results
  620. # Divide the weight matrix along the last dimension.
  621. self.tp_size = get_tensor_model_parallel_world_size()
  622. self.input_size_per_partition = divide(input_size, self.tp_size)
  623. assert self.quant_method is not None
  624. self.quant_method.create_weights(
  625. layer=self,
  626. input_size_per_partition=self.input_size_per_partition,
  627. output_partition_sizes=[self.output_size],
  628. input_size=self.input_size,
  629. output_size=self.output_size,
  630. params_dtype=self.params_dtype,
  631. weight_loader=self.weight_loader)
  632. if not reduce_results and (bias and not skip_bias_add):
  633. raise ValueError("When not reduce the results, adding bias to the "
  634. "results can lead to incorrect results")
  635. if bias:
  636. self.bias = Parameter(
  637. torch.empty(self.output_size, dtype=params_dtype))
  638. set_weight_attrs(self.bias, {
  639. "output_dim": 0,
  640. "weight_loader": self.weight_loader,
  641. })
  642. else:
  643. self.register_parameter("bias", None)
  644. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  645. # Special case for Fp8 scales.
  646. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
  647. None)
  648. tp_rank = get_tensor_model_parallel_rank()
  649. input_dim = getattr(param, "input_dim", None)
  650. param_data = param.data
  651. if input_dim is not None:
  652. shard_size = param_data.shape[input_dim]
  653. start_idx = tp_rank * shard_size
  654. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  655. shard_size)
  656. # Special case for Fp8 scales.
  657. elif fp8_scales_shard_indexer is not None:
  658. param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
  659. loaded_weight,
  660. shard_id=0)
  661. if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
  662. loaded_weight = loaded_weight.reshape(1)
  663. assert param_data.shape == loaded_weight.shape
  664. param_data.copy_(loaded_weight)
  665. def forward(self, input_):
  666. # Set up backprop all-reduce.
  667. if self.input_is_parallel:
  668. input_parallel = input_
  669. else:
  670. tp_rank = get_tensor_model_parallel_rank()
  671. splitted_input = split_tensor_along_last_dim(
  672. input_, num_partitions=self.tp_size)
  673. input_parallel = splitted_input[tp_rank].contiguous()
  674. # Matrix multiply.
  675. output_parallel = self.quant_method.apply(self, input_parallel)
  676. if self.reduce_results and self.tp_size > 1:
  677. output_ = tensor_model_parallel_all_reduce(output_parallel)
  678. else:
  679. output_ = output_parallel
  680. if not self.skip_bias_add:
  681. output = output_ + self.bias if self.bias is not None else output_
  682. output_bias = None
  683. else:
  684. output = output_
  685. output_bias = self.bias
  686. return output, output_bias
  687. def extra_repr(self) -> str:
  688. s = f"input_features={self.input_size_per_partition}"
  689. s += f", output_features={self.output_size}"
  690. s += f", bias={self.bias is not None}"
  691. s += f", tp_size={self.tp_size}"
  692. s += f", reduce_results={self.reduce_results}"
  693. return s