linear.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. from abc import abstractmethod
  2. from typing import List, Optional
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter
  7. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  8. get_tensor_model_parallel_world_size,
  9. split_tensor_along_last_dim,
  10. tensor_model_parallel_all_gather,
  11. tensor_model_parallel_all_reduce)
  12. from aphrodite.modeling.utils import set_weight_attrs
  13. from aphrodite.quantization.base_config import (QuantizationConfig,
  14. QuantizeMethodBase)
  15. def adjust_marlin_shard(param, shard_size, shard_offset):
  16. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  17. if marlin_tile_size is None:
  18. return shard_size, shard_offset
  19. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  20. class LinearMethodBase(QuantizeMethodBase):
  21. """Base class for different (maybe quantized) linear methods."""
  22. @abstractmethod
  23. def create_weights(self, layer: torch.nn.Module,
  24. input_size_per_partition: int,
  25. output_partition_sizes: List[int], input_size: int,
  26. output_size: int, params_dtype: torch.dtype,
  27. **extra_weight_attrs):
  28. """Create weights for a linear layer.
  29. The weights will be set as attributes of the layer.
  30. Args:
  31. layer: The layer that is using the LinearMethodBase factory.
  32. input_size_per_partition: Size of the weight input dim on rank X.
  33. output_partition_sizes: Sizes of the output dim of each logical
  34. weight on rank X. E.g., output_partition_sizes for QKVLinear
  35. is a list contains the width of Wq, Wk, Wv on rank X.
  36. input_size: Size of the input dim of the weight across all ranks.
  37. output_size: Size of the output dim of the weight across all ranks.
  38. params_dtype: Datatype of the parameters.
  39. """
  40. raise NotImplementedError
  41. @abstractmethod
  42. def apply(self,
  43. layer: torch.nn.Module,
  44. x: torch.Tensor,
  45. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  46. """Apply the weights in layer to the input tensor.
  47. Expects create_weights to have been called before on the layer."""
  48. raise NotImplementedError
  49. class UnquantizedLinearMethod(LinearMethodBase):
  50. """Linear method without quantization.
  51. Args:
  52. separate_bias_add: If true, add bias separately after matrix
  53. multiplication.
  54. """
  55. def __init__(self, separate_bias_add: bool = False):
  56. self.separate_bias_add = separate_bias_add
  57. def create_weights(self, layer: torch.nn.Module,
  58. input_size_per_partition: int,
  59. output_partition_sizes: List[int], input_size: int,
  60. output_size: int, params_dtype: torch.dtype,
  61. **extra_weight_attrs):
  62. output_size_per_partition = sum(output_partition_sizes)
  63. weight = Parameter(torch.empty(output_size_per_partition,
  64. input_size_per_partition,
  65. dtype=params_dtype),
  66. requires_grad=False)
  67. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  68. layer.register_parameter("weight", weight)
  69. set_weight_attrs(weight, extra_weight_attrs)
  70. def apply(self,
  71. layer: torch.nn.Module,
  72. x: torch.Tensor,
  73. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  74. weight = layer.weight
  75. if self.separate_bias_add:
  76. if bias is not None:
  77. return F.linear(x, weight) + bias
  78. return F.linear(x, weight)
  79. return F.linear(x, weight, bias)
  80. class LinearBase(torch.nn.Module):
  81. """Base linear layer.
  82. Args:
  83. input_size: input dimension of the linear layer.
  84. output_size: output dimension of the linear layer.
  85. bias: If true, add bias.
  86. skip_bias_add: If true, skip adding bias but instead return it.
  87. params_dtype: Data type for the parameters.
  88. quant_config: Quantization config..
  89. """
  90. def __init__(
  91. self,
  92. input_size: int,
  93. output_size: int,
  94. skip_bias_add: bool = False,
  95. params_dtype: Optional[torch.dtype] = None,
  96. quant_config: Optional[QuantizationConfig] = None,
  97. ):
  98. super().__init__()
  99. # Keep input parameters
  100. self.input_size = input_size
  101. self.output_size = output_size
  102. self.skip_bias_add = skip_bias_add
  103. if params_dtype is None:
  104. params_dtype = torch.get_default_dtype()
  105. self.params_dtype = params_dtype
  106. if quant_config is None:
  107. self.quant_method = UnquantizedLinearMethod()
  108. else:
  109. self.quant_method = quant_config.get_quant_method(self)
  110. def forward(self, x: torch.Tensor) -> torch.Tensor:
  111. raise NotImplementedError
  112. class ReplicatedLinear(LinearBase):
  113. """Replicated linear layer.
  114. Args:
  115. input_size: input dimension of the linear layer.
  116. output_size: output dimension of the linear layer.
  117. bias: If true, add bias.
  118. skip_bias_add: If true, skip adding bias but instead return it.
  119. params_dtype: Data type for the parameters.
  120. quant_config: Quantization configure.
  121. """
  122. def __init__(
  123. self,
  124. input_size: int,
  125. output_size: int,
  126. bias: bool = True,
  127. skip_bias_add: bool = False,
  128. params_dtype: Optional[torch.dtype] = None,
  129. quant_config: Optional[QuantizationConfig] = None,
  130. ):
  131. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  132. quant_config)
  133. self.quant_method.create_weights(self, self.input_size,
  134. [self.output_size], self.input_size,
  135. self.output_size, self.params_dtype)
  136. if bias:
  137. self.bias = Parameter(
  138. torch.empty(self.output_size, dtype=self.params_dtype))
  139. set_weight_attrs(self.bias, {"output_dim": 0})
  140. else:
  141. self.register_parameter("bias", None)
  142. def forward(self, x: torch.Tensor) -> torch.Tensor:
  143. bias = self.bias if not self.skip_bias_add else None
  144. output = self.quant_method.apply(self, x, bias)
  145. output_bias = self.bias if self.skip_bias_add else None
  146. return output, output_bias
  147. class ColumnParallelLinear(LinearBase):
  148. """Linear layer with column parallelism.
  149. The linear layer is defined as Y = XA + b. A is parallelized along
  150. its second dimension as A = [A_1, ..., A_p].
  151. Args:
  152. input_size: first dimension of matrix A.
  153. output_size: second dimension of matrix A.
  154. bias: If true, add bias.
  155. gather_output: If true, call all-gather on output and make Y available
  156. to all GPUs, otherwise, every GPU will have its output
  157. which is Y_i = XA_i
  158. skip_bias_add: This was added to enable performance optimizations where
  159. bias can be fused with other element-wise operations. we
  160. skip adding bias but instead return it.
  161. params_dtype: Data type for the parameters.
  162. quant_config: Quantization configure.
  163. output_sizes: list of output sizes packed into one output, like for QKV
  164. the list would be size 3.
  165. """
  166. def __init__(
  167. self,
  168. input_size: int,
  169. output_size: int,
  170. bias: bool = True,
  171. gather_output: bool = False,
  172. skip_bias_add: bool = False,
  173. params_dtype: Optional[torch.dtype] = None,
  174. quant_config: Optional[QuantizationConfig] = None,
  175. output_sizes: Optional[List[int]] = None,
  176. ):
  177. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  178. quant_config)
  179. self.gather_output = gather_output
  180. # Divide the weight matrix along the last dimension.
  181. tp_size = get_tensor_model_parallel_world_size()
  182. self.output_size_per_partition = divide(output_size, tp_size)
  183. if output_sizes is None:
  184. output_sizes = [output_size]
  185. self.quant_method.create_weights(self,
  186. self.input_size,
  187. [x // tp_size for x in output_sizes],
  188. self.input_size,
  189. self.output_size,
  190. self.params_dtype,
  191. weight_loader=self.weight_loader)
  192. if bias:
  193. self.bias = Parameter(
  194. torch.empty(self.output_size_per_partition,
  195. dtype=params_dtype))
  196. set_weight_attrs(self.bias, {
  197. "output_dim": 0,
  198. "weight_loader": self.weight_loader,
  199. })
  200. else:
  201. self.register_parameter("bias", None)
  202. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  203. tp_rank = get_tensor_model_parallel_rank()
  204. output_dim = getattr(param, "output_dim", None)
  205. param_data = param.data
  206. if output_dim is not None:
  207. shard_size = param_data.shape[output_dim]
  208. start_idx = tp_rank * shard_size
  209. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  210. shard_size)
  211. assert param_data.shape == loaded_weight.shape
  212. param_data.copy_(loaded_weight)
  213. def forward(self, input_):
  214. bias = self.bias if not self.skip_bias_add else None
  215. # Matrix multiply.
  216. output_parallel = self.quant_method.apply(self, input_, bias)
  217. if self.gather_output:
  218. # All-gather across the partitions.
  219. output = tensor_model_parallel_all_gather(output_parallel)
  220. else:
  221. output = output_parallel
  222. output_bias = self.bias if self.skip_bias_add else None
  223. return output, output_bias
  224. class MergedColumnParallelLinear(ColumnParallelLinear):
  225. """Packed linear layers with column parallelism.
  226. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  227. along the output dimension. When the weight matrix is loaded, the
  228. different partitions are sharded separately.
  229. Args:
  230. input_size: input dimension of the linear layer.
  231. output_sizes: list of output dimensions of the linear layer.
  232. bias: If true, add bias.
  233. gather_output: If true, call all-gather on output and make the output
  234. available to all GPUs, otherwise, every GPU will have
  235. its own output.
  236. skip_bias_add: This was added to enable performance optimizations where
  237. bias can be fused with other element-wise operations. we
  238. skip adding bias but instead return it.
  239. params_dtype: Data type for the parameters.
  240. quant_config: Quantization configure.
  241. """
  242. def __init__(
  243. self,
  244. input_size: int,
  245. output_sizes: List[int],
  246. bias: bool = True,
  247. gather_output: bool = False,
  248. skip_bias_add: bool = False,
  249. params_dtype: Optional[torch.dtype] = None,
  250. quant_config: Optional[QuantizationConfig] = None,
  251. ):
  252. self.output_sizes = output_sizes
  253. tp_size = get_tensor_model_parallel_world_size()
  254. assert all(output_size % tp_size == 0 for output_size in output_sizes)
  255. super().__init__(input_size, sum(output_sizes), bias, gather_output,
  256. skip_bias_add, params_dtype, quant_config,
  257. self.output_sizes)
  258. def weight_loader(self,
  259. param: Parameter,
  260. loaded_weight: torch.Tensor,
  261. loaded_shard_id: Optional[int] = None):
  262. param_data = param.data
  263. output_dim = getattr(param, "output_dim", None)
  264. is_metadata = getattr(param, "is_metadata", False)
  265. if loaded_shard_id is None:
  266. # Loaded weight is already packed.
  267. if output_dim is None:
  268. assert param_data.shape == loaded_weight.shape
  269. param_data.copy_(loaded_weight)
  270. return
  271. current_shard_offset = 0
  272. shard_offsets = []
  273. for i, output_size in enumerate(self.output_sizes):
  274. shard_offsets.append((i, current_shard_offset, output_size))
  275. current_shard_offset += output_size
  276. packed_dim = getattr(param, "packed_dim", None)
  277. for shard_id, shard_offset, shard_size in shard_offsets:
  278. # If quantized, we need to adjust the offset and size to account
  279. # for the packing.
  280. if packed_dim == output_dim:
  281. shard_size = shard_size // param.pack_factor
  282. shard_offset = shard_offset // param.pack_factor
  283. # If marlin, we need to adjust the offset and size to
  284. # account for the tiling.
  285. shard_size, shard_offset = adjust_marlin_shard(
  286. param, shard_size, shard_offset)
  287. loaded_weight_shard = loaded_weight.narrow(
  288. output_dim, shard_offset, shard_size)
  289. self.weight_loader(param, loaded_weight_shard, shard_id)
  290. return
  291. assert loaded_shard_id < len(self.output_sizes)
  292. tp_rank = get_tensor_model_parallel_rank()
  293. tp_size = get_tensor_model_parallel_world_size()
  294. if output_dim is not None:
  295. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  296. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  297. # If quantized, we need to adjust the offset and size to account
  298. # for the packing.
  299. packed_dim = getattr(param, "packed_dim", None)
  300. if packed_dim == output_dim:
  301. shard_size = shard_size // param.pack_factor
  302. shard_offset = shard_offset // param.pack_factor
  303. # If marlin, we need to adjust the offset and size to
  304. # account for the tiling.
  305. shard_size, shard_offset = adjust_marlin_shard(
  306. param, shard_size, shard_offset)
  307. param_data = param_data.narrow(output_dim, shard_offset,
  308. shard_size)
  309. start_idx = tp_rank * shard_size
  310. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  311. shard_size)
  312. elif is_metadata:
  313. # metadata indicates fixed size concatenated along dim 0
  314. shard_size = loaded_weight.shape[0]
  315. shard_offset = loaded_shard_id * shard_size
  316. param_data = param_data.narrow(0, shard_offset, shard_size)
  317. else:
  318. ignore_warning = getattr(param, "ignore_warning", False)
  319. if not ignore_warning:
  320. logger.warning(
  321. "Loading a weight without `output_dim` attribute in "
  322. "MergedColumnParallelLinear, assume the weight is "
  323. "the same for all partitions.")
  324. assert param_data.shape == loaded_weight.shape
  325. param_data.copy_(loaded_weight)
  326. class QKVParallelLinear(ColumnParallelLinear):
  327. """Linear layers for the attention's QKV transformation.
  328. Linear layers for the linear transformation of the query, key, and value
  329. vectors in the attention layer. The weight matrix is concatenated along
  330. the output dimension. The layer is parallelized along the head dimension.
  331. When the number of key/value heads is smaller than the number of query
  332. heads (e.g., multi-query/grouped-query attention), the key/value head may
  333. be replicated while the query heads are partitioned.
  334. Args:
  335. hidden_size: input hidden state size of the transformer.
  336. head_size: size of each attention head.
  337. total_num_heads: total number of attention query heads.
  338. total_num_kv_heads: total number of attention key/value heads. If
  339. None, assume total_num_kv_heads = total_num_heads.
  340. bias: If true, add bias.
  341. skip_bias_add: This was added to enable performance optimizations where
  342. bias can be fused with other element-wise operations. we
  343. skip adding bias but instead return it.
  344. params_dtype: Data type for the parameters.
  345. quant_config: Quantization configure.
  346. """
  347. def __init__(
  348. self,
  349. hidden_size: int,
  350. head_size: int,
  351. total_num_heads: int,
  352. total_num_kv_heads: Optional[int] = None,
  353. bias: bool = True,
  354. skip_bias_add: bool = False,
  355. params_dtype: Optional[torch.dtype] = None,
  356. quant_config: Optional[QuantizationConfig] = None,
  357. ):
  358. self.hidden_size = hidden_size
  359. self.head_size = head_size
  360. self.total_num_heads = total_num_heads
  361. if total_num_kv_heads is None:
  362. total_num_kv_heads = total_num_heads
  363. self.total_num_kv_heads = total_num_kv_heads
  364. # Divide the weight matrix along the last dimension.
  365. tp_size = get_tensor_model_parallel_world_size()
  366. self.num_heads = divide(self.total_num_heads, tp_size)
  367. if tp_size >= self.total_num_kv_heads:
  368. self.num_kv_heads = 1
  369. self.num_kv_head_replicas = divide(tp_size,
  370. self.total_num_kv_heads)
  371. else:
  372. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  373. self.num_kv_head_replicas = 1
  374. input_size = self.hidden_size
  375. output_size = (self.num_heads +
  376. 2 * self.num_kv_heads) * tp_size * self.head_size
  377. output_sizes = [
  378. self.num_heads * tp_size * self.head_size,
  379. self.num_kv_heads * tp_size * self.head_size,
  380. self.num_kv_heads * tp_size * self.head_size
  381. ]
  382. super().__init__(input_size, output_size, bias, False, skip_bias_add,
  383. params_dtype, quant_config, output_sizes)
  384. def weight_loader(self,
  385. param: Parameter,
  386. loaded_weight: torch.Tensor,
  387. loaded_shard_id: Optional[str] = None):
  388. param_data = param.data
  389. output_dim = getattr(param, "output_dim", None)
  390. is_metadata = getattr(param, "is_metadata", False)
  391. if loaded_shard_id is None:
  392. # Loaded weight is already packed.
  393. if output_dim is None:
  394. assert param_data.shape == loaded_weight.shape
  395. param_data.copy_(loaded_weight)
  396. return
  397. shard_offsets = [
  398. # (shard_id, shard_offset, shard_size)
  399. ("q", 0, self.total_num_heads * self.head_size),
  400. ("k", self.total_num_heads * self.head_size,
  401. self.total_num_kv_heads * self.head_size),
  402. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  403. self.head_size, self.total_num_kv_heads * self.head_size),
  404. ]
  405. packed_dim = getattr(param, "packed_dim", None)
  406. for shard_id, shard_offset, shard_size in shard_offsets:
  407. # If quantized, we need to adjust the offset and size to account
  408. # for the packing.
  409. if packed_dim == output_dim:
  410. shard_size = shard_size // param.pack_factor
  411. shard_offset = shard_offset // param.pack_factor
  412. # If marlin, we need to adjust the offset and size to
  413. # account for the tiling.
  414. shard_size, shard_offset = adjust_marlin_shard(
  415. param, shard_size, shard_offset)
  416. loaded_weight_shard = loaded_weight.narrow(
  417. output_dim, shard_offset, shard_size)
  418. self.weight_loader(param, loaded_weight_shard, shard_id)
  419. return
  420. tp_rank = get_tensor_model_parallel_rank()
  421. assert loaded_shard_id in ["q", "k", "v"]
  422. if output_dim is not None:
  423. if loaded_shard_id == "q":
  424. shard_offset = 0
  425. shard_size = self.num_heads * self.head_size
  426. elif loaded_shard_id == "k":
  427. shard_offset = self.num_heads * self.head_size
  428. shard_size = self.num_kv_heads * self.head_size
  429. elif loaded_shard_id == "v":
  430. shard_offset = (self.num_heads +
  431. self.num_kv_heads) * self.head_size
  432. shard_size = self.num_kv_heads * self.head_size
  433. # If quantized, we need to adjust the offset and size to account
  434. # for the packing.
  435. packed_dim = getattr(param, "packed_dim", None)
  436. if packed_dim == output_dim:
  437. shard_size = shard_size // param.pack_factor
  438. shard_offset = shard_offset // param.pack_factor
  439. # If marlin, we need to adjust the offset and size to
  440. # account for the tiling.
  441. shard_size, shard_offset = adjust_marlin_shard(
  442. param, shard_size, shard_offset)
  443. param_data = param_data.narrow(output_dim, shard_offset,
  444. shard_size)
  445. if loaded_shard_id == "q":
  446. shard_id = tp_rank
  447. else:
  448. shard_id = tp_rank // self.num_kv_head_replicas
  449. start_idx = shard_id * shard_size
  450. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  451. shard_size)
  452. elif is_metadata:
  453. # metadata indicates fixed size concatenated along dim 0
  454. shard_size = loaded_weight.shape[0]
  455. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  456. param_data = param_data.narrow(0, shard_index * shard_size,
  457. shard_size)
  458. else:
  459. ignore_warning = getattr(param, "ignore_warning", False)
  460. if not ignore_warning:
  461. logger.warning(
  462. "Loading a weight without `output_dim` attribute in "
  463. "QKVParallelLinear, assume the weight is the same "
  464. "for all partitions.")
  465. assert param_data.shape == loaded_weight.shape
  466. param_data.copy_(loaded_weight)
  467. class RowParallelLinear(LinearBase):
  468. """Linear layer with row parallelism.
  469. The linear layer is defined as Y = XA + b. A is parallelized along
  470. its first dimension and X along its second dimension as:
  471. - -
  472. | A_1 |
  473. | . |
  474. A = | . | X = [X_1, ..., X_p]
  475. | . |
  476. | A_p |
  477. - -
  478. Arguments:
  479. input_size: first dimension of matrix A.
  480. output_size: second dimension of matrix A.
  481. bias: If true, add bias. Note that bias is not parallelized.
  482. input_is_parallel: If true, we assume that the input is already
  483. split across the GPUs and we do not split
  484. again.
  485. skip_bias_add: This was added to enable performance optimization where
  486. bias can be fused with other element-wise operations.
  487. We skip adding bias but instead return it.
  488. params_dtype: Data type for the parameters.
  489. quant_config: Quantization configure.
  490. """
  491. def __init__(
  492. self,
  493. input_size: int,
  494. output_size: int,
  495. bias: bool = True,
  496. input_is_parallel: bool = True,
  497. skip_bias_add: bool = False,
  498. params_dtype: Optional[torch.dtype] = None,
  499. reduce_results: bool = True,
  500. quant_config: Optional[QuantizationConfig] = None,
  501. ):
  502. super().__init__(input_size, output_size, skip_bias_add, params_dtype,
  503. quant_config)
  504. self.input_is_parallel = input_is_parallel
  505. self.reduce_results = reduce_results
  506. # Divide the weight matrix along the last dimension.
  507. self.tp_size = get_tensor_model_parallel_world_size()
  508. self.input_size_per_partition = divide(input_size, self.tp_size)
  509. self.quant_method.create_weights(self,
  510. self.input_size_per_partition,
  511. [self.output_size],
  512. self.input_size,
  513. self.output_size,
  514. self.params_dtype,
  515. weight_loader=self.weight_loader)
  516. if not reduce_results and (bias and not skip_bias_add):
  517. raise ValueError("When not reduce the results, adding bias to the "
  518. "results can lead to incorrect results")
  519. if bias:
  520. self.bias = Parameter(
  521. torch.empty(self.output_size, dtype=params_dtype))
  522. set_weight_attrs(self.bias, {
  523. "output_dim": 0,
  524. "weight_loader": self.weight_loader,
  525. })
  526. else:
  527. self.register_parameter("bias", None)
  528. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  529. tp_rank = get_tensor_model_parallel_rank()
  530. input_dim = getattr(param, "input_dim", None)
  531. param_data = param.data
  532. if input_dim is not None:
  533. shard_size = param_data.shape[input_dim]
  534. start_idx = tp_rank * shard_size
  535. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  536. shard_size)
  537. assert param_data.shape == loaded_weight.shape
  538. param_data.copy_(loaded_weight)
  539. def forward(self, input_):
  540. # Set up backprop all-reduce.
  541. if self.input_is_parallel:
  542. input_parallel = input_
  543. else:
  544. tp_rank = get_tensor_model_parallel_rank()
  545. splitted_input = split_tensor_along_last_dim(
  546. input_, num_partitions=self.tp_size)
  547. input_parallel = splitted_input[tp_rank].contiguous()
  548. # Matrix multiply.
  549. output_parallel = self.quant_method.apply(self, input_parallel)
  550. if self.reduce_results and self.tp_size > 1:
  551. output_ = tensor_model_parallel_all_reduce(output_parallel)
  552. else:
  553. output_ = output_parallel
  554. if not self.skip_bias_add:
  555. output = output_ + self.bias if self.bias is not None else output_
  556. output_bias = None
  557. else:
  558. output = output_
  559. output_bias = self.bias
  560. return output, output_bias