linear.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. from abc import ABC, abstractmethod
  2. from typing import List, Optional
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch import nn
  7. from torch.nn.parameter import Parameter
  8. from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
  9. get_tensor_model_parallel_world_size,
  10. split_tensor_along_last_dim,
  11. tensor_model_parallel_all_gather,
  12. tensor_model_parallel_all_reduce)
  13. from aphrodite.modeling.utils import set_weight_attrs
  14. def adjust_marlin_shard(param, shard_size, shard_offset):
  15. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  16. if marlin_tile_size is None:
  17. return shard_size, shard_offset
  18. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  19. class LinearMethodBase(ABC):
  20. """Base class for different (maybe quantized) linear methods."""
  21. @abstractmethod
  22. def create_weights(self, layer: torch.nn.Module,
  23. input_size_per_partition: int,
  24. output_partition_sizes: List[int], input_size: int,
  25. output_size: int, params_dtype: torch.dtype,
  26. **extra_weight_attrs):
  27. """Create weights for a linear layer.
  28. The weights will be set as attributes of the layer.
  29. Args:
  30. layer: The layer that is using the LinearMethodBase factory.
  31. input_size_per_partition: Size of the weight input dim on rank X.
  32. output_partition_sizes: Sizes of the output dim of each logical
  33. weight on rank X. E.g., output_partition_sizes for QKVLinear
  34. is a list contains the width of Wq, Wk, Wv on rank X.
  35. input_size: Size of the input dim of the weight across all ranks.
  36. output_size: Size of the output dim of the weight across all ranks.
  37. params_dtype: Datatype of the parameters.
  38. """
  39. raise NotImplementedError
  40. @abstractmethod
  41. def apply_weights(self,
  42. layer: torch.nn.Module,
  43. x: torch.Tensor,
  44. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  45. """Apply the weights in layer to the input tensor.
  46. Expects create_weights to have been called before on the layer."""
  47. raise NotImplementedError
  48. def process_weights_after_loading(self, layer: nn.Module) -> None:
  49. """Process the weight after loading.
  50. This can be used for example, to transpose weights for computation.
  51. """
  52. return
  53. class UnquantizedLinearMethod(LinearMethodBase):
  54. """Linear method without quantization.
  55. Args:
  56. separate_bias_add: If true, add bias separately after matrix
  57. multiplication.
  58. """
  59. def __init__(self, separate_bias_add: bool = False):
  60. self.separate_bias_add = separate_bias_add
  61. def create_weights(self, layer: torch.nn.Module,
  62. input_size_per_partition: int,
  63. output_partition_sizes: List[int], input_size: int,
  64. output_size: int, params_dtype: torch.dtype,
  65. **extra_weight_attrs):
  66. output_size_per_partition = sum(output_partition_sizes)
  67. weight = Parameter(torch.empty(output_size_per_partition,
  68. input_size_per_partition,
  69. dtype=params_dtype),
  70. requires_grad=False)
  71. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  72. layer.register_parameter("weight", weight)
  73. set_weight_attrs(weight, extra_weight_attrs)
  74. def apply_weights(self,
  75. layer: torch.nn.Module,
  76. x: torch.Tensor,
  77. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  78. weight = layer.weight
  79. if self.separate_bias_add:
  80. if bias is not None:
  81. return F.linear(x, weight) + bias
  82. return F.linear(x, weight)
  83. return F.linear(x, weight, bias)
  84. class ReplicatedLinear(torch.nn.Module):
  85. """Replicated linear layer.
  86. Args:
  87. input_size: input dimension of the linear layer.
  88. output_size: output dimension of the linear layer.
  89. bias: If true, add bias.
  90. skip_bias_add: If true, skip adding bias but instead return it.
  91. params_dtype: Data type for the parameters.
  92. linear_method: (Maybe quantized) linear method.
  93. """
  94. def __init__(
  95. self,
  96. input_size: int,
  97. output_size: int,
  98. bias: bool = True,
  99. skip_bias_add: bool = False,
  100. params_dtype: Optional[torch.dtype] = None,
  101. linear_method: Optional[LinearMethodBase] = None,
  102. ):
  103. super().__init__()
  104. # Keep input parameters
  105. self.input_size = input_size
  106. self.output_size = output_size
  107. self.skip_bias_add = skip_bias_add
  108. if params_dtype is None:
  109. params_dtype = torch.get_default_dtype()
  110. self.params_dtype = params_dtype
  111. if linear_method is None:
  112. linear_method = UnquantizedLinearMethod()
  113. self.linear_method = linear_method
  114. self.linear_method.create_weights(self, self.input_size,
  115. [self.output_size], self.input_size,
  116. self.output_size, self.params_dtype)
  117. if bias:
  118. self.bias = Parameter(
  119. torch.empty(self.output_size, dtype=self.params_dtype))
  120. set_weight_attrs(self.bias, {"output_dim": 0})
  121. else:
  122. self.register_parameter("bias", None)
  123. def forward(self, x: torch.Tensor) -> torch.Tensor:
  124. bias = self.bias if not self.skip_bias_add else None
  125. output = self.linear_method.apply_weights(self, x, bias)
  126. output_bias = self.bias if self.skip_bias_add else None
  127. return output, output_bias
  128. class ColumnParallelLinear(torch.nn.Module):
  129. """Linear layer with column parallelism.
  130. The linear layer is defined as Y = XA + b. A is parallelized along
  131. its second dimension as A = [A_1, ..., A_p].
  132. Args:
  133. input_size: first dimension of matrix A.
  134. output_size: second dimension of matrix A.
  135. bias: If true, add bias.
  136. gather_output: If true, call all-gather on output and make Y available
  137. to all GPUs, otherwise, every GPU will have its output
  138. which is Y_i = XA_i
  139. skip_bias_add: This was added to enable performance optimizations where
  140. bias can be fused with other element-wise operations. we
  141. skip adding bias but instead return it.
  142. params_dtype: Data type for the parameters.
  143. linear_method: (Maybe quantized) linear method.
  144. output_sizes: list of output sizes packed into one output, like for QKV
  145. the list would be size 3.
  146. """
  147. def __init__(
  148. self,
  149. input_size: int,
  150. output_size: int,
  151. bias: bool = True,
  152. gather_output: bool = False,
  153. skip_bias_add: bool = False,
  154. params_dtype: Optional[torch.dtype] = None,
  155. linear_method: Optional[LinearMethodBase] = None,
  156. output_sizes: Optional[List[int]] = None,
  157. ):
  158. super().__init__()
  159. # Keep input parameters
  160. self.input_size = input_size
  161. self.output_size = output_size
  162. self.gather_output = gather_output
  163. # Divide the weight matrix along the last dimension.
  164. tp_size = get_tensor_model_parallel_world_size()
  165. self.output_size_per_partition = divide(output_size, tp_size)
  166. self.skip_bias_add = skip_bias_add
  167. if params_dtype is None:
  168. params_dtype = torch.get_default_dtype()
  169. self.params_dtype = params_dtype
  170. if linear_method is None:
  171. linear_method = UnquantizedLinearMethod()
  172. if output_sizes is None:
  173. output_sizes = [output_size]
  174. self.linear_method = linear_method
  175. self.linear_method.create_weights(self,
  176. self.input_size,
  177. [x // tp_size for x in output_sizes],
  178. self.input_size,
  179. self.output_size,
  180. self.params_dtype,
  181. weight_loader=self.weight_loader)
  182. if bias:
  183. self.bias = Parameter(
  184. torch.empty(self.output_size_per_partition,
  185. dtype=params_dtype))
  186. set_weight_attrs(self.bias, {
  187. "output_dim": 0,
  188. "weight_loader": self.weight_loader,
  189. })
  190. else:
  191. self.register_parameter("bias", None)
  192. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  193. tp_rank = get_tensor_model_parallel_rank()
  194. output_dim = getattr(param, "output_dim", None)
  195. param_data = param.data
  196. if output_dim is not None:
  197. shard_size = param_data.shape[output_dim]
  198. start_idx = tp_rank * shard_size
  199. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  200. shard_size)
  201. assert param_data.shape == loaded_weight.shape
  202. param_data.copy_(loaded_weight)
  203. def forward(self, input_):
  204. bias = self.bias if not self.skip_bias_add else None
  205. # Matrix multiply.
  206. output_parallel = self.linear_method.apply_weights(self, input_, bias)
  207. if self.gather_output:
  208. # All-gather across the partitions.
  209. output = tensor_model_parallel_all_gather(output_parallel)
  210. else:
  211. output = output_parallel
  212. output_bias = self.bias if self.skip_bias_add else None
  213. return output, output_bias
  214. class MergedColumnParallelLinear(ColumnParallelLinear):
  215. """Packed linear layers with column parallelism.
  216. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  217. along the output dimension. When the weight matrix is loaded, the
  218. different partitions are sharded separately.
  219. Args:
  220. input_size: input dimension of the linear layer.
  221. output_sizes: list of output dimensions of the linear layer.
  222. bias: If true, add bias.
  223. gather_output: If true, call all-gather on output and make the output
  224. available to all GPUs, otherwise, every GPU will have
  225. its own output.
  226. skip_bias_add: This was added to enable performance optimizations where
  227. bias can be fused with other element-wise operations. we
  228. skip adding bias but instead return it.
  229. params_dtype: Data type for the parameters.
  230. linear_method: (Maybe quantized) linear method.
  231. """
  232. def __init__(
  233. self,
  234. input_size: int,
  235. output_sizes: List[int],
  236. bias: bool = True,
  237. gather_output: bool = False,
  238. skip_bias_add: bool = False,
  239. params_dtype: Optional[torch.dtype] = None,
  240. linear_method: Optional[LinearMethodBase] = None,
  241. ):
  242. self.output_sizes = output_sizes
  243. tp_size = get_tensor_model_parallel_world_size()
  244. assert all(output_size % tp_size == 0 for output_size in output_sizes)
  245. super().__init__(input_size, sum(output_sizes), bias, gather_output,
  246. skip_bias_add, params_dtype, linear_method,
  247. self.output_sizes)
  248. def weight_loader(self,
  249. param: Parameter,
  250. loaded_weight: torch.Tensor,
  251. loaded_shard_id: Optional[int] = None):
  252. param_data = param.data
  253. output_dim = getattr(param, "output_dim", None)
  254. is_metadata = getattr(param, "is_metadata", False)
  255. if loaded_shard_id is None:
  256. # Loaded weight is already packed.
  257. if output_dim is None:
  258. assert param_data.shape == loaded_weight.shape
  259. param_data.copy_(loaded_weight)
  260. return
  261. current_shard_offset = 0
  262. shard_offsets = []
  263. for i, output_size in enumerate(self.output_sizes):
  264. shard_offsets.append((i, current_shard_offset, output_size))
  265. current_shard_offset += output_size
  266. packed_dim = getattr(param, "packed_dim", None)
  267. for shard_id, shard_offset, shard_size in shard_offsets:
  268. # If quantized, we need to adjust the offset and size to account
  269. # for the packing.
  270. if packed_dim == output_dim:
  271. shard_size = shard_size // param.pack_factor
  272. shard_offset = shard_offset // param.pack_factor
  273. # If marlin, we need to adjust the offset and size to
  274. # account for the tiling.
  275. shard_size, shard_offset = adjust_marlin_shard(
  276. param, shard_size, shard_offset)
  277. loaded_weight_shard = loaded_weight.narrow(
  278. output_dim, shard_offset, shard_size)
  279. self.weight_loader(param, loaded_weight_shard, shard_id)
  280. return
  281. assert loaded_shard_id < len(self.output_sizes)
  282. tp_rank = get_tensor_model_parallel_rank()
  283. tp_size = get_tensor_model_parallel_world_size()
  284. if output_dim is not None:
  285. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  286. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  287. # If quantized, we need to adjust the offset and size to account
  288. # for the packing.
  289. packed_dim = getattr(param, "packed_dim", None)
  290. if packed_dim == output_dim:
  291. shard_size = shard_size // param.pack_factor
  292. shard_offset = shard_offset // param.pack_factor
  293. # If marlin, we need to adjust the offset and size to
  294. # account for the tiling.
  295. shard_size, shard_offset = adjust_marlin_shard(
  296. param, shard_size, shard_offset)
  297. param_data = param_data.narrow(output_dim, shard_offset,
  298. shard_size)
  299. start_idx = tp_rank * shard_size
  300. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  301. shard_size)
  302. elif is_metadata:
  303. # metadata indicates fixed size concatenated along dim 0
  304. shard_size = loaded_weight.shape[0]
  305. shard_offset = loaded_shard_id * shard_size
  306. param_data = param_data.narrow(0, shard_offset, shard_size)
  307. else:
  308. ignore_warning = getattr(param, "ignore_warning", False)
  309. if not ignore_warning:
  310. logger.warning(
  311. "Loading a weight without `output_dim` attribute in "
  312. "MergedColumnParallelLinear, assume the weight is "
  313. "the same for all partitions.")
  314. assert param_data.shape == loaded_weight.shape
  315. param_data.copy_(loaded_weight)
  316. class QKVParallelLinear(ColumnParallelLinear):
  317. """Linear layers for the attention's QKV transformation.
  318. Linear layers for the linear transformation of the query, key, and value
  319. vectors in the attention layer. The weight matrix is concatenated along
  320. the output dimension. The layer is parallelized along the head dimension.
  321. When the number of key/value heads is smaller than the number of query
  322. heads (e.g., multi-query/grouped-query attention), the key/value head may
  323. be replicated while the query heads are partitioned.
  324. Args:
  325. hidden_size: input hidden state size of the transformer.
  326. head_size: size of each attention head.
  327. total_num_heads: total number of attention query heads.
  328. total_num_kv_heads: total number of attention key/value heads. If
  329. None, assume total_num_kv_heads = total_num_heads.
  330. bias: If true, add bias.
  331. skip_bias_add: This was added to enable performance optimizations where
  332. bias can be fused with other element-wise operations. we
  333. skip adding bias but instead return it.
  334. params_dtype: Data type for the parameters.
  335. linear_method: (Maybe quantized) linear method.
  336. """
  337. def __init__(
  338. self,
  339. hidden_size: int,
  340. head_size: int,
  341. total_num_heads: int,
  342. total_num_kv_heads: Optional[int] = None,
  343. bias: bool = True,
  344. skip_bias_add: bool = False,
  345. params_dtype: Optional[torch.dtype] = None,
  346. linear_method: Optional[LinearMethodBase] = None,
  347. ):
  348. self.hidden_size = hidden_size
  349. self.head_size = head_size
  350. self.total_num_heads = total_num_heads
  351. if total_num_kv_heads is None:
  352. total_num_kv_heads = total_num_heads
  353. self.total_num_kv_heads = total_num_kv_heads
  354. # Divide the weight matrix along the last dimension.
  355. tp_size = get_tensor_model_parallel_world_size()
  356. self.num_heads = divide(self.total_num_heads, tp_size)
  357. if tp_size >= self.total_num_kv_heads:
  358. self.num_kv_heads = 1
  359. self.num_kv_head_replicas = divide(tp_size,
  360. self.total_num_kv_heads)
  361. else:
  362. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  363. self.num_kv_head_replicas = 1
  364. input_size = self.hidden_size
  365. output_size = (self.num_heads +
  366. 2 * self.num_kv_heads) * tp_size * self.head_size
  367. output_sizes = [
  368. self.num_heads * tp_size * self.head_size,
  369. self.num_kv_heads * tp_size * self.head_size,
  370. self.num_kv_heads * tp_size * self.head_size
  371. ]
  372. super().__init__(input_size, output_size, bias, False, skip_bias_add,
  373. params_dtype, linear_method, output_sizes)
  374. def weight_loader(self,
  375. param: Parameter,
  376. loaded_weight: torch.Tensor,
  377. loaded_shard_id: Optional[str] = None):
  378. param_data = param.data
  379. output_dim = getattr(param, "output_dim", None)
  380. is_metadata = getattr(param, "is_metadata", False)
  381. if loaded_shard_id is None:
  382. # Loaded weight is already packed.
  383. if output_dim is None:
  384. assert param_data.shape == loaded_weight.shape
  385. param_data.copy_(loaded_weight)
  386. return
  387. shard_offsets = [
  388. # (shard_id, shard_offset, shard_size)
  389. ("q", 0, self.total_num_heads * self.head_size),
  390. ("k", self.total_num_heads * self.head_size,
  391. self.total_num_kv_heads * self.head_size),
  392. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  393. self.head_size, self.total_num_kv_heads * self.head_size),
  394. ]
  395. packed_dim = getattr(param, "packed_dim", None)
  396. for shard_id, shard_offset, shard_size in shard_offsets:
  397. # If quantized, we need to adjust the offset and size to account
  398. # for the packing.
  399. if packed_dim == output_dim:
  400. shard_size = shard_size // param.pack_factor
  401. shard_offset = shard_offset // param.pack_factor
  402. # If marlin, we need to adjust the offset and size to
  403. # account for the tiling.
  404. shard_size, shard_offset = adjust_marlin_shard(
  405. param, shard_size, shard_offset)
  406. loaded_weight_shard = loaded_weight.narrow(
  407. output_dim, shard_offset, shard_size)
  408. self.weight_loader(param, loaded_weight_shard, shard_id)
  409. return
  410. tp_rank = get_tensor_model_parallel_rank()
  411. assert loaded_shard_id in ["q", "k", "v"]
  412. if output_dim is not None:
  413. if loaded_shard_id == "q":
  414. shard_offset = 0
  415. shard_size = self.num_heads * self.head_size
  416. elif loaded_shard_id == "k":
  417. shard_offset = self.num_heads * self.head_size
  418. shard_size = self.num_kv_heads * self.head_size
  419. elif loaded_shard_id == "v":
  420. shard_offset = (self.num_heads +
  421. self.num_kv_heads) * self.head_size
  422. shard_size = self.num_kv_heads * self.head_size
  423. # If quantized, we need to adjust the offset and size to account
  424. # for the packing.
  425. packed_dim = getattr(param, "packed_dim", None)
  426. if packed_dim == output_dim:
  427. shard_size = shard_size // param.pack_factor
  428. shard_offset = shard_offset // param.pack_factor
  429. # If marlin, we need to adjust the offset and size to
  430. # account for the tiling.
  431. shard_size, shard_offset = adjust_marlin_shard(
  432. param, shard_size, shard_offset)
  433. param_data = param_data.narrow(output_dim, shard_offset,
  434. shard_size)
  435. if loaded_shard_id == "q":
  436. shard_id = tp_rank
  437. else:
  438. shard_id = tp_rank // self.num_kv_head_replicas
  439. start_idx = shard_id * shard_size
  440. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  441. shard_size)
  442. elif is_metadata:
  443. # metadata indicates fixed size concatenated along dim 0
  444. shard_size = loaded_weight.shape[0]
  445. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  446. param_data = param_data.narrow(0, shard_index * shard_size,
  447. shard_size)
  448. else:
  449. ignore_warning = getattr(param, "ignore_warning", False)
  450. if not ignore_warning:
  451. logger.warning(
  452. "Loading a weight without `output_dim` attribute in "
  453. "QKVParallelLinear, assume the weight is the same "
  454. "for all partitions.")
  455. assert param_data.shape == loaded_weight.shape
  456. param_data.copy_(loaded_weight)
  457. class RowParallelLinear(torch.nn.Module):
  458. """Linear layer with row parallelism.
  459. The linear layer is defined as Y = XA + b. A is parallelized along
  460. its first dimension and X along its second dimension as:
  461. - -
  462. | A_1 |
  463. | . |
  464. A = | . | X = [X_1, ..., X_p]
  465. | . |
  466. | A_p |
  467. - -
  468. Arguments:
  469. input_size: first dimension of matrix A.
  470. output_size: second dimension of matrix A.
  471. bias: If true, add bias. Note that bias is not parallelized.
  472. input_is_parallel: If true, we assume that the input is already
  473. split across the GPUs and we do not split
  474. again.
  475. skip_bias_add: This was added to enable performance optimization where
  476. bias can be fused with other element-wise operations.
  477. We skip adding bias but instead return it.
  478. params_dtype: Data type for the parameters.
  479. linear_method: (Maybe quantized) linear method.
  480. """
  481. def __init__(
  482. self,
  483. input_size: int,
  484. output_size: int,
  485. bias: bool = True,
  486. input_is_parallel: bool = True,
  487. skip_bias_add: bool = False,
  488. params_dtype: Optional[torch.dtype] = None,
  489. reduce_results: bool = True,
  490. linear_method: Optional[LinearMethodBase] = None,
  491. ):
  492. super().__init__()
  493. # Keep input parameters
  494. self.input_size = input_size
  495. self.output_size = output_size
  496. self.input_is_parallel = input_is_parallel
  497. self.reduce_results = reduce_results
  498. if params_dtype is None:
  499. params_dtype = torch.get_default_dtype()
  500. self.params_dtype = params_dtype
  501. # Divide the weight matrix along the last dimension.
  502. self.tp_size = get_tensor_model_parallel_world_size()
  503. self.input_size_per_partition = divide(input_size, self.tp_size)
  504. self.skip_bias_add = skip_bias_add
  505. if linear_method is None:
  506. linear_method = UnquantizedLinearMethod()
  507. self.linear_method = linear_method
  508. self.linear_method.create_weights(self,
  509. self.input_size_per_partition,
  510. [self.output_size],
  511. self.input_size,
  512. self.output_size,
  513. self.params_dtype,
  514. weight_loader=self.weight_loader)
  515. if not reduce_results and (bias and not skip_bias_add):
  516. raise ValueError("When not reduce the results, adding bias to the "
  517. "results can lead to incorrect results")
  518. if bias:
  519. self.bias = Parameter(
  520. torch.empty(self.output_size, dtype=params_dtype))
  521. set_weight_attrs(self.bias, {
  522. "output_dim": 0,
  523. "weight_loader": self.weight_loader,
  524. })
  525. else:
  526. self.register_parameter("bias", None)
  527. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  528. tp_rank = get_tensor_model_parallel_rank()
  529. input_dim = getattr(param, "input_dim", None)
  530. param_data = param.data
  531. if input_dim is not None:
  532. shard_size = param_data.shape[input_dim]
  533. start_idx = tp_rank * shard_size
  534. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  535. shard_size)
  536. assert param_data.shape == loaded_weight.shape
  537. param_data.copy_(loaded_weight)
  538. def forward(self, input_):
  539. # Set up backprop all-reduce.
  540. if self.input_is_parallel:
  541. input_parallel = input_
  542. else:
  543. tp_rank = get_tensor_model_parallel_rank()
  544. splitted_input = split_tensor_along_last_dim(
  545. input_, num_partitions=self.tp_size)
  546. input_parallel = splitted_input[tp_rank].contiguous()
  547. # Matrix multiply.
  548. output_parallel = self.linear_method.apply_weights(
  549. self, input_parallel)
  550. if self.reduce_results and self.tp_size > 1:
  551. output_ = tensor_model_parallel_all_reduce(output_parallel)
  552. else:
  553. output_ = output_parallel
  554. if not self.skip_bias_add:
  555. output = output_ + self.bias if self.bias is not None else output_
  556. output_bias = None
  557. else:
  558. output = output_
  559. output_bias = self.bias
  560. return output, output_bias