linear.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762
  1. from abc import abstractmethod
  2. from typing import List, Optional
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter
  7. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  8. get_tensor_model_parallel_world_size,
  9. split_tensor_along_last_dim,
  10. tensor_model_parallel_all_gather,
  11. tensor_model_parallel_all_reduce)
  12. from aphrodite.modeling.utils import set_weight_attrs
  13. from aphrodite.quantization.base_config import (QuantizationConfig,
  14. QuantizeMethodBase)
  15. def adjust_marlin_shard(param, shard_size, shard_offset):
  16. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  17. if marlin_tile_size is None:
  18. return shard_size, shard_offset
  19. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  20. class LinearMethodBase(QuantizeMethodBase):
  21. """Base class for different (maybe quantized) linear methods."""
  22. @abstractmethod
  23. def create_weights(self, layer: torch.nn.Module,
  24. input_size_per_partition: int,
  25. output_partition_sizes: List[int], input_size: int,
  26. output_size: int, params_dtype: torch.dtype,
  27. **extra_weight_attrs):
  28. """Create weights for a linear layer.
  29. The weights will be set as attributes of the layer.
  30. Args:
  31. layer: The layer that is using the LinearMethodBase factory.
  32. input_size_per_partition: Size of the weight input dim on rank X.
  33. output_partition_sizes: Sizes of the output dim of each logical
  34. weight on rank X. E.g., output_partition_sizes for QKVLinear
  35. is a list contains the width of Wq, Wk, Wv on rank X.
  36. input_size: Size of the input dim of the weight across all ranks.
  37. output_size: Size of the output dim of the weight across all ranks.
  38. params_dtype: Datatype of the parameters.
  39. """
  40. raise NotImplementedError
  41. @abstractmethod
  42. def apply(self,
  43. layer: torch.nn.Module,
  44. x: torch.Tensor,
  45. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  46. """Apply the weights in layer to the input tensor.
  47. Expects create_weights to have been called before on the layer."""
  48. raise NotImplementedError
  49. class UnquantizedLinearMethod(LinearMethodBase):
  50. """Linear method without quantization.
  51. Args:
  52. separate_bias_add: If true, add bias separately after matrix
  53. multiplication.
  54. """
  55. def __init__(self, separate_bias_add: bool = False):
  56. self.separate_bias_add = separate_bias_add
  57. def create_weights(self, layer: torch.nn.Module,
  58. input_size_per_partition: int,
  59. output_partition_sizes: List[int], input_size: int,
  60. output_size: int, params_dtype: torch.dtype,
  61. **extra_weight_attrs):
  62. weight = Parameter(torch.empty(sum(output_partition_sizes),
  63. input_size_per_partition,
  64. dtype=params_dtype),
  65. requires_grad=False)
  66. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  67. layer.register_parameter("weight", weight)
  68. set_weight_attrs(weight, extra_weight_attrs)
  69. def apply(self,
  70. layer: torch.nn.Module,
  71. x: torch.Tensor,
  72. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  73. weight = layer.weight
  74. if self.separate_bias_add:
  75. if bias is not None:
  76. return F.linear(x, weight) + bias
  77. return F.linear(x, weight)
  78. return F.linear(x, weight, bias)
  79. class LinearBase(torch.nn.Module):
  80. """Base linear layer.
  81. Args:
  82. input_size: input dimension of the linear layer.
  83. output_size: output dimension of the linear layer.
  84. bias: If true, add bias.
  85. skip_bias_add: If true, skip adding bias but instead return it.
  86. params_dtype: Data type for the parameters.
  87. quant_config: Quantization config..
  88. """
  89. def __init__(
  90. self,
  91. input_size: int,
  92. output_size: int,
  93. skip_bias_add: bool = False,
  94. params_dtype: Optional[torch.dtype] = None,
  95. quant_config: Optional[QuantizationConfig] = None,
  96. ):
  97. super().__init__()
  98. # Keep input parameters
  99. self.input_size = input_size
  100. self.output_size = output_size
  101. self.skip_bias_add = skip_bias_add
  102. if params_dtype is None:
  103. params_dtype = torch.get_default_dtype()
  104. self.params_dtype = params_dtype
  105. if quant_config is None:
  106. self.quant_method = UnquantizedLinearMethod()
  107. else:
  108. self.quant_method = quant_config.get_quant_method(self)
  109. def forward(self, x: torch.Tensor) -> torch.Tensor:
  110. raise NotImplementedError
  111. class ReplicatedLinear(LinearBase):
  112. """Replicated linear layer.
  113. Args:
  114. input_size: input dimension of the linear layer.
  115. output_size: output dimension of the linear layer.
  116. bias: If true, add bias.
  117. skip_bias_add: If true, skip adding bias but instead return it.
  118. params_dtype: Data type for the parameters.
  119. quant_config: Quantization configure.
  120. """
  121. def __init__(self,
  122. input_size: int,
  123. output_size: int,
  124. bias: bool = True,
  125. skip_bias_add: bool = False,
  126. params_dtype: Optional[torch.dtype] = None,
  127. quant_config: Optional[QuantizationConfig] = None):
  128. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  129. quant_config)
  130. self.quant_method.create_weights(self, self.input_size,
  131. [self.output_size], self.input_size,
  132. self.output_size, self.params_dtype)
  133. if bias:
  134. self.bias = Parameter(
  135. torch.empty(self.output_size, dtype=self.params_dtype))
  136. set_weight_attrs(self.bias, {"output_dim": 0})
  137. else:
  138. self.register_parameter("bias", None)
  139. def forward(self, x: torch.Tensor) -> torch.Tensor:
  140. bias = self.bias if not self.skip_bias_add else None
  141. output = self.quant_method.apply(self, x, bias)
  142. output_bias = self.bias if self.skip_bias_add else None
  143. return output, output_bias
  144. def extra_repr(self) -> str:
  145. s = f"in_features={self.input_size}"
  146. s += f", output_features={self.output_size}"
  147. s += f", bias={self.bias is not None}"
  148. return s
  149. class ColumnParallelLinear(LinearBase):
  150. """Linear layer with column parallelism.
  151. The linear layer is defined as Y = XA + b. A is parallelized along
  152. its second dimension as A = [A_1, ..., A_p].
  153. Args:
  154. input_size: first dimension of matrix A.
  155. output_size: second dimension of matrix A.
  156. bias: If true, add bias.
  157. gather_output: If true, call all-gather on output and make Y available
  158. to all GPUs, otherwise, every GPU will have its output
  159. which is Y_i = XA_i
  160. skip_bias_add: This was added to enable performance optimizations where
  161. bias can be fused with other element-wise operations. we
  162. skip adding bias but instead return it.
  163. params_dtype: Data type for the parameters.
  164. quant_config: Quantization configure.
  165. output_sizes: list of output sizes packed into one output, like for QKV
  166. the list would be size 3.
  167. """
  168. def __init__(self,
  169. input_size: int,
  170. output_size: int,
  171. bias: bool = True,
  172. gather_output: bool = False,
  173. skip_bias_add: bool = False,
  174. params_dtype: Optional[torch.dtype] = None,
  175. quant_config: Optional[QuantizationConfig] = None,
  176. output_sizes: Optional[List[int]] = None):
  177. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  178. quant_config)
  179. self.gather_output = gather_output
  180. # Divide the weight matrix along the last dimension.
  181. tp_size = get_tensor_model_parallel_world_size()
  182. assert self.quant_method is not None
  183. self.output_size_per_partition = divide(self.output_size, tp_size)
  184. self.output_partition_sizes = [self.output_size_per_partition]
  185. # If QKV or MergedColumn, use output size of each partition.
  186. if hasattr(self, "output_sizes"):
  187. self.output_partition_sizes = [
  188. divide(output_size, tp_size)
  189. for output_size in self.output_sizes
  190. ]
  191. if output_sizes is None:
  192. output_sizes = [output_size]
  193. self.quant_method.create_weights(
  194. layer=self,
  195. input_size_per_partition=self.input_size,
  196. output_partition_sizes=self.output_partition_sizes,
  197. input_size=self.input_size,
  198. output_size=self.output_size,
  199. params_dtype=self.params_dtype,
  200. weight_loader=self.weight_loader)
  201. if bias:
  202. self.bias = Parameter(
  203. torch.empty(self.output_size_per_partition,
  204. dtype=params_dtype))
  205. set_weight_attrs(self.bias, {
  206. "output_dim": 0,
  207. "weight_loader": self.weight_loader,
  208. })
  209. else:
  210. self.register_parameter("bias", None)
  211. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  212. # Special case for Fp8 scales.
  213. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
  214. None)
  215. tp_rank = get_tensor_model_parallel_rank()
  216. output_dim = getattr(param, "output_dim", None)
  217. param_data = param.data
  218. if output_dim is not None:
  219. shard_size = param_data.shape[output_dim]
  220. start_idx = tp_rank * shard_size
  221. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  222. shard_size)
  223. # Special case for Fp8 scales.
  224. elif fp8_scales_shard_indexer is not None:
  225. param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
  226. loaded_weight,
  227. shard_id=0)
  228. assert param_data.shape == loaded_weight.shape
  229. param_data.copy_(loaded_weight)
  230. def forward(self, input_):
  231. bias = self.bias if not self.skip_bias_add else None
  232. # Matrix multiply.
  233. output_parallel = self.quant_method.apply(self, input_, bias)
  234. if self.gather_output:
  235. # All-gather across the partitions.
  236. output = tensor_model_parallel_all_gather(output_parallel)
  237. else:
  238. output = output_parallel
  239. output_bias = self.bias if self.skip_bias_add else None
  240. return output, output_bias
  241. def extra_repr(self) -> str:
  242. s = f"in_features={self.input_size}"
  243. s += f", output_features={self.output_size_per_partition}"
  244. s += f", bias={self.bias is not None}"
  245. s += f", tp_size={get_tensor_model_parallel_world_size()}"
  246. s += f", gather_output={self.gather_output}"
  247. return s
  248. class MergedColumnParallelLinear(ColumnParallelLinear):
  249. """Packed linear layers with column parallelism.
  250. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  251. along the output dimension. When the weight matrix is loaded, the
  252. different partitions are sharded separately.
  253. Args:
  254. input_size: input dimension of the linear layer.
  255. output_sizes: list of output dimensions of the linear layer.
  256. bias: If true, add bias.
  257. gather_output: If true, call all-gather on output and make the output
  258. available to all GPUs, otherwise, every GPU will have
  259. its own output.
  260. skip_bias_add: This was added to enable performance optimizations where
  261. bias can be fused with other element-wise operations. we
  262. skip adding bias but instead return it.
  263. params_dtype: Data type for the parameters.
  264. quant_config: Quantization configure.
  265. """
  266. def __init__(self,
  267. input_size: int,
  268. output_sizes: List[int],
  269. bias: bool = True,
  270. gather_output: bool = False,
  271. skip_bias_add: bool = False,
  272. params_dtype: Optional[torch.dtype] = None,
  273. quant_config: Optional[QuantizationConfig] = None):
  274. self.output_sizes = output_sizes
  275. tp_size = get_tensor_model_parallel_world_size()
  276. assert all(output_size % tp_size == 0 for output_size in output_sizes)
  277. super().__init__(input_size=input_size,
  278. output_size=sum(output_sizes),
  279. bias=bias,
  280. gather_output=gather_output,
  281. skip_bias_add=skip_bias_add,
  282. params_dtype=params_dtype,
  283. quant_config=quant_config)
  284. def weight_loader(self,
  285. param: Parameter,
  286. loaded_weight: torch.Tensor,
  287. loaded_shard_id: Optional[int] = None):
  288. param_data = param.data
  289. output_dim = getattr(param, "output_dim", None)
  290. # Special case for AQLM codebooks.
  291. is_metadata = getattr(param, "is_metadata", False)
  292. param_shard_splitter = getattr(param, "shard_splitter", None)
  293. if output_dim is not None and param_shard_splitter is not None:
  294. raise NotImplementedError(
  295. "We do not currently support output_dim != None and "
  296. "shard_splitter != None for a parameter. Please open an issue."
  297. )
  298. # If a parameter has defined a shard_splitter to be used for
  299. # the weight, it should be applied before the weight is
  300. # loaded/copied to the parameter. The shard_splitter applies
  301. # logic by using the loaded_shard_id to ensure that the loaded
  302. # param is loaded to the correct location
  303. # within the parameter defined by the linear method.
  304. if loaded_shard_id is None and param_shard_splitter is not None:
  305. raise NotImplementedError(
  306. "We do not currently support loaded_shard_id == None and "
  307. "shard_splitter != None for a parameter. Please open an issue."
  308. )
  309. # Special case for Fp8 scales.
  310. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
  311. None)
  312. if loaded_shard_id is None:
  313. # Loaded weight is already packed.
  314. if output_dim is None:
  315. assert param_data.shape == loaded_weight.shape
  316. param_data.copy_(loaded_weight)
  317. return
  318. current_shard_offset = 0
  319. shard_offsets = []
  320. for i, output_size in enumerate(self.output_sizes):
  321. shard_offsets.append((i, current_shard_offset, output_size))
  322. current_shard_offset += output_size
  323. packed_dim = getattr(param, "packed_dim", None)
  324. for shard_id, shard_offset, shard_size in shard_offsets:
  325. # If quantized, we need to adjust the offset and size to account
  326. # for the packing.
  327. if packed_dim == output_dim:
  328. shard_size = shard_size // param.pack_factor
  329. shard_offset = shard_offset // param.pack_factor
  330. # If marlin, we need to adjust the offset and size to
  331. # account for the tiling.
  332. shard_size, shard_offset = adjust_marlin_shard(
  333. param, shard_size, shard_offset)
  334. loaded_weight_shard = loaded_weight.narrow(
  335. output_dim, shard_offset, shard_size)
  336. self.weight_loader(param, loaded_weight_shard, shard_id)
  337. return
  338. assert loaded_shard_id < len(self.output_sizes)
  339. tp_rank = get_tensor_model_parallel_rank()
  340. tp_size = get_tensor_model_parallel_world_size()
  341. if output_dim is not None:
  342. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  343. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  344. # If quantized, we need to adjust the offset and size to account
  345. # for the packing.
  346. packed_dim = getattr(param, "packed_dim", None)
  347. if packed_dim == output_dim:
  348. shard_size = shard_size // param.pack_factor
  349. shard_offset = shard_offset // param.pack_factor
  350. # If marlin, we need to adjust the offset and size to
  351. # account for the tiling.
  352. shard_size, shard_offset = adjust_marlin_shard(
  353. param, shard_size, shard_offset)
  354. param_data = param_data.narrow(output_dim, shard_offset,
  355. shard_size)
  356. start_idx = tp_rank * shard_size
  357. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  358. shard_size)
  359. elif is_metadata:
  360. # metadata indicates fixed size concatenated along dim 0
  361. shard_size = loaded_weight.shape[0]
  362. shard_offset = loaded_shard_id * shard_size
  363. param_data = param_data.narrow(0, shard_offset, shard_size)
  364. # If a param_shard_splitter is defined by the LinearMethod, use it.
  365. elif param_shard_splitter is not None:
  366. logical_widths = getattr(param, "logical_widths", None)
  367. param_data, loaded_weight = param_shard_splitter(
  368. param_data, loaded_weight, loaded_shard_id, logical_widths)
  369. # Special case for Fp8 scales.
  370. elif fp8_scales_shard_indexer is not None:
  371. param_data, loaded_weight = fp8_scales_shard_indexer(
  372. param_data, loaded_weight, loaded_shard_id)
  373. else:
  374. ignore_warning = getattr(param, "ignore_warning", False)
  375. if not ignore_warning:
  376. logger.warning(
  377. "Loading a weight without `output_dim` attribute in "
  378. "MergedColumnParallelLinear, assume the weight is "
  379. "the same for all partitions.")
  380. if fp8_scales_shard_indexer is None:
  381. if len(param_data.shape) == 0:
  382. param_data = param_data.reshape(1)
  383. if len(loaded_weight.shape) == 0:
  384. loaded_weight = loaded_weight.reshape(1)
  385. assert param_data.shape == loaded_weight.shape
  386. param_data.copy_(loaded_weight)
  387. class QKVParallelLinear(ColumnParallelLinear):
  388. """Linear layers for the attention's QKV transformation.
  389. Linear layers for the linear transformation of the query, key, and value
  390. vectors in the attention layer. The weight matrix is concatenated along
  391. the output dimension. The layer is parallelized along the head dimension.
  392. When the number of key/value heads is smaller than the number of query
  393. heads (e.g., multi-query/grouped-query attention), the key/value head may
  394. be replicated while the query heads are partitioned.
  395. Args:
  396. hidden_size: input hidden state size of the transformer.
  397. head_size: size of each attention head.
  398. total_num_heads: total number of attention query heads.
  399. total_num_kv_heads: total number of attention key/value heads. If
  400. None, assume total_num_kv_heads = total_num_heads.
  401. bias: If true, add bias.
  402. skip_bias_add: This was added to enable performance optimizations where
  403. bias can be fused with other element-wise operations. we
  404. skip adding bias but instead return it.
  405. params_dtype: Data type for the parameters.
  406. quant_config: Quantization configure.
  407. """
  408. def __init__(self,
  409. hidden_size: int,
  410. head_size: int,
  411. total_num_heads: int,
  412. total_num_kv_heads: Optional[int] = None,
  413. bias: bool = True,
  414. skip_bias_add: bool = False,
  415. params_dtype: Optional[torch.dtype] = None,
  416. quant_config: Optional[QuantizationConfig] = None):
  417. self.hidden_size = hidden_size
  418. self.head_size = head_size
  419. self.total_num_heads = total_num_heads
  420. if total_num_kv_heads is None:
  421. total_num_kv_heads = total_num_heads
  422. self.total_num_kv_heads = total_num_kv_heads
  423. # Divide the weight matrix along the last dimension.
  424. tp_size = get_tensor_model_parallel_world_size()
  425. self.num_heads = divide(self.total_num_heads, tp_size)
  426. if tp_size >= self.total_num_kv_heads:
  427. self.num_kv_heads = 1
  428. self.num_kv_head_replicas = divide(tp_size,
  429. self.total_num_kv_heads)
  430. else:
  431. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  432. self.num_kv_head_replicas = 1
  433. input_size = self.hidden_size
  434. output_size = (self.num_heads +
  435. 2 * self.num_kv_heads) * tp_size * self.head_size
  436. self.output_sizes = [
  437. self.num_heads * self.head_size * tp_size, # q_proj
  438. self.num_kv_heads * self.head_size * tp_size, # k_proj
  439. self.num_kv_heads * self.head_size * tp_size, # v_proj
  440. ]
  441. super().__init__(input_size=input_size,
  442. output_size=output_size,
  443. bias=bias,
  444. gather_output=False,
  445. skip_bias_add=skip_bias_add,
  446. params_dtype=params_dtype,
  447. quant_config=quant_config)
  448. def weight_loader(self,
  449. param: Parameter,
  450. loaded_weight: torch.Tensor,
  451. loaded_shard_id: Optional[str] = None):
  452. param_data = param.data
  453. output_dim = getattr(param, "output_dim", None)
  454. is_metadata = getattr(param, "is_metadata", False)
  455. param_shard_splitter = getattr(param, "shard_splitter", None)
  456. if output_dim is not None and param_shard_splitter is not None:
  457. raise NotImplementedError(
  458. "We do not currently support output_dim != None and "
  459. "shard_splitter != None for a parameter. Please open an issue."
  460. )
  461. # If a parameter has defined a shard_splitter to be used for
  462. # the weight, it should be applied before the weight is
  463. # loaded/copied to the parameter. The shard_splitter applies
  464. # logic by using the loaded_shard_id to ensure that the loaded
  465. # param is loaded to the correct location
  466. # within the parameter defined by the linear method.
  467. if loaded_shard_id is None and param_shard_splitter is not None:
  468. raise NotImplementedError(
  469. "We do not currently support loaded_shard_id == None and "
  470. "shard_splitter != None for a parameter. Please open an issue."
  471. )
  472. # Special case for Fp8 scales.
  473. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
  474. None)
  475. if loaded_shard_id is None:
  476. # Loaded weight is already packed.
  477. if output_dim is None:
  478. assert param_data.shape == loaded_weight.shape
  479. param_data.copy_(loaded_weight)
  480. return
  481. shard_offsets = [
  482. # (shard_id, shard_offset, shard_size)
  483. ("q", 0, self.total_num_heads * self.head_size),
  484. ("k", self.total_num_heads * self.head_size,
  485. self.total_num_kv_heads * self.head_size),
  486. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  487. self.head_size, self.total_num_kv_heads * self.head_size),
  488. ]
  489. packed_dim = getattr(param, "packed_dim", None)
  490. for shard_id, shard_offset, shard_size in shard_offsets:
  491. # If quantized, we need to adjust the offset and size to account
  492. # for the packing.
  493. if packed_dim == output_dim:
  494. shard_size = shard_size // param.pack_factor
  495. shard_offset = shard_offset // param.pack_factor
  496. # If marlin, we need to adjust the offset and size to
  497. # account for the tiling.
  498. shard_size, shard_offset = adjust_marlin_shard(
  499. param, shard_size, shard_offset)
  500. loaded_weight_shard = loaded_weight.narrow(
  501. output_dim, shard_offset, shard_size)
  502. self.weight_loader(param, loaded_weight_shard, shard_id)
  503. return
  504. tp_rank = get_tensor_model_parallel_rank()
  505. assert loaded_shard_id in ["q", "k", "v"]
  506. # If output dim is defined, use the default loading process.
  507. if output_dim is not None:
  508. if loaded_shard_id == "q":
  509. shard_offset = 0
  510. shard_size = self.num_heads * self.head_size
  511. elif loaded_shard_id == "k":
  512. shard_offset = self.num_heads * self.head_size
  513. shard_size = self.num_kv_heads * self.head_size
  514. elif loaded_shard_id == "v":
  515. shard_offset = (self.num_heads +
  516. self.num_kv_heads) * self.head_size
  517. shard_size = self.num_kv_heads * self.head_size
  518. # If quantized, we need to adjust the offset and size to account
  519. # for the packing.
  520. packed_dim = getattr(param, "packed_dim", None)
  521. if packed_dim == output_dim:
  522. shard_size = shard_size // param.pack_factor
  523. shard_offset = shard_offset // param.pack_factor
  524. # If marlin, we need to adjust the offset and size to
  525. # account for the tiling.
  526. shard_size, shard_offset = adjust_marlin_shard(
  527. param, shard_size, shard_offset)
  528. param_data = param_data.narrow(output_dim, shard_offset,
  529. shard_size)
  530. if loaded_shard_id == "q":
  531. shard_id = tp_rank
  532. else:
  533. shard_id = tp_rank // self.num_kv_head_replicas
  534. start_idx = shard_id * shard_size
  535. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  536. shard_size)
  537. elif is_metadata:
  538. # metadata indicates fixed size concatenated along dim 0
  539. shard_size = loaded_weight.shape[0]
  540. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  541. param_data = param_data.narrow(0, shard_index * shard_size,
  542. shard_size)
  543. # If a param_shard_splitter is defined by the LinearMethod, use it.
  544. elif param_shard_splitter is not None:
  545. logical_widths = getattr(param, "logical_widths", None)
  546. param_data, loaded_weight = param_shard_splitter(
  547. param_data, loaded_weight, loaded_shard_id, logical_widths)
  548. # Special case for Fp8 scales.
  549. elif fp8_scales_shard_indexer is not None:
  550. param_data, loaded_weight = fp8_scales_shard_indexer(
  551. param_data, loaded_weight, loaded_shard_id)
  552. else:
  553. ignore_warning = getattr(param, "ignore_warning", False)
  554. if not ignore_warning:
  555. logger.warning(
  556. "Loading a weight without `output_dim` attribute in "
  557. "QKVParallelLinear, assume the weight is the same "
  558. "for all partitions.")
  559. if len(param_data.shape) == 0:
  560. param_data = param_data.reshape(1)
  561. if len(loaded_weight.shape) == 0:
  562. loaded_weight = loaded_weight.reshape(1)
  563. assert param_data.shape == loaded_weight.shape
  564. param_data.copy_(loaded_weight)
  565. class RowParallelLinear(LinearBase):
  566. """Linear layer with row parallelism.
  567. The linear layer is defined as Y = XA + b. A is parallelized along
  568. its first dimension and X along its second dimension as:
  569. - -
  570. | A_1 |
  571. | . |
  572. A = | . | X = [X_1, ..., X_p]
  573. | . |
  574. | A_p |
  575. - -
  576. Arguments:
  577. input_size: first dimension of matrix A.
  578. output_size: second dimension of matrix A.
  579. bias: If true, add bias. Note that bias is not parallelized.
  580. input_is_parallel: If true, we assume that the input is already
  581. split across the GPUs and we do not split
  582. again.
  583. skip_bias_add: This was added to enable performance optimization where
  584. bias can be fused with other element-wise operations.
  585. We skip adding bias but instead return it.
  586. params_dtype: Data type for the parameters.
  587. quant_config: Quantization configure.
  588. """
  589. def __init__(self,
  590. input_size: int,
  591. output_size: int,
  592. bias: bool = True,
  593. input_is_parallel: bool = True,
  594. skip_bias_add: bool = False,
  595. params_dtype: Optional[torch.dtype] = None,
  596. reduce_results: bool = True,
  597. quant_config: Optional[QuantizationConfig] = None):
  598. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  599. quant_config)
  600. self.input_is_parallel = input_is_parallel
  601. self.reduce_results = reduce_results
  602. # Divide the weight matrix along the last dimension.
  603. self.tp_size = get_tensor_model_parallel_world_size()
  604. self.input_size_per_partition = divide(input_size, self.tp_size)
  605. assert self.quant_method is not None
  606. self.quant_method.create_weights(
  607. layer=self,
  608. input_size_per_partition=self.input_size_per_partition,
  609. output_partition_sizes=[self.output_size],
  610. input_size=self.input_size,
  611. output_size=self.output_size,
  612. params_dtype=self.params_dtype,
  613. weight_loader=self.weight_loader)
  614. if not reduce_results and (bias and not skip_bias_add):
  615. raise ValueError("When not reduce the results, adding bias to the "
  616. "results can lead to incorrect results")
  617. if bias:
  618. self.bias = Parameter(
  619. torch.empty(self.output_size, dtype=params_dtype))
  620. set_weight_attrs(self.bias, {
  621. "output_dim": 0,
  622. "weight_loader": self.weight_loader,
  623. })
  624. else:
  625. self.register_parameter("bias", None)
  626. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  627. # Special case for Fp8 scales.
  628. fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
  629. None)
  630. tp_rank = get_tensor_model_parallel_rank()
  631. input_dim = getattr(param, "input_dim", None)
  632. param_data = param.data
  633. if input_dim is not None:
  634. shard_size = param_data.shape[input_dim]
  635. start_idx = tp_rank * shard_size
  636. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  637. shard_size)
  638. # Special case for Fp8 scales.
  639. elif fp8_scales_shard_indexer is not None:
  640. param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
  641. loaded_weight,
  642. shard_id=0)
  643. if fp8_scales_shard_indexer is None and len(loaded_weight.shape) == 0:
  644. loaded_weight = loaded_weight.reshape(1)
  645. assert param_data.shape == loaded_weight.shape
  646. param_data.copy_(loaded_weight)
  647. def forward(self, input_):
  648. # Set up backprop all-reduce.
  649. if self.input_is_parallel:
  650. input_parallel = input_
  651. else:
  652. tp_rank = get_tensor_model_parallel_rank()
  653. splitted_input = split_tensor_along_last_dim(
  654. input_, num_partitions=self.tp_size)
  655. input_parallel = splitted_input[tp_rank].contiguous()
  656. # Matrix multiply.
  657. output_parallel = self.quant_method.apply(self, input_parallel)
  658. if self.reduce_results and self.tp_size > 1:
  659. output_ = tensor_model_parallel_all_reduce(output_parallel)
  660. else:
  661. output_ = output_parallel
  662. if not self.skip_bias_add:
  663. output = output_ + self.bias if self.bias is not None else output_
  664. output_bias = None
  665. else:
  666. output = output_
  667. output_bias = self.bias
  668. return output, output_bias
  669. def extra_repr(self) -> str:
  670. s = f"input_features={self.input_size_per_partition}"
  671. s += f", output_features={self.output_size}"
  672. s += f", bias={self.bias is not None}"
  673. s += f", tp_size={self.tp_size}"
  674. s += f", reduce_results={self.reduce_results}"
  675. return s