1
0

linear.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Dict, List, Optional
  3. from loguru import logger
  4. import torch
  5. import torch.nn.functional as F
  6. from torch.nn.parameter import Parameter
  7. from aphrodite.modeling.megatron.parallel_state import (
  8. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  9. from aphrodite.modeling.megatron.communication_op import (
  10. tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
  11. from aphrodite.modeling.megatron.utils import (divide,
  12. split_tensor_along_last_dim)
  13. from aphrodite.modeling.utils import set_weight_attrs
  14. def adjust_marlin_shard(param, shard_size, shard_offset):
  15. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  16. if marlin_tile_size is None:
  17. return shard_size, shard_offset
  18. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  19. class LinearMethodBase(ABC):
  20. """Base class for different (maybe quantized) linear methods."""
  21. @abstractmethod
  22. def create_weights(self, input_size_per_partition: int,
  23. output_partition_sizes: List[int], input_size: int,
  24. output_size: int,
  25. params_dtype: torch.dtype) -> Dict[str, Any]:
  26. """Create weights for a linear layer."""
  27. raise NotImplementedError
  28. @abstractmethod
  29. def apply_weights(self,
  30. weights: Dict[str, torch.Tensor],
  31. x: torch.Tensor,
  32. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  33. """Apply the weights to the input tensor."""
  34. raise NotImplementedError
  35. class UnquantizedLinearMethod(LinearMethodBase):
  36. """Linear method without quantization.
  37. Args:
  38. separate_bias_add: If true, add bias separately after matrix
  39. multiplication.
  40. """
  41. def __init__(self, separate_bias_add: bool = False):
  42. self.separate_bias_add = separate_bias_add
  43. def create_weights(self, input_size_per_partition: int,
  44. output_partition_sizes: List[int], input_size: int,
  45. output_size: int,
  46. params_dtype: torch.dtype) -> Dict[str, Any]:
  47. output_size_per_partition = sum(output_partition_sizes)
  48. weight = Parameter(torch.empty(output_size_per_partition,
  49. input_size_per_partition,
  50. dtype=params_dtype),
  51. requires_grad=False)
  52. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  53. return {"weight": weight}
  54. def apply_weights(self,
  55. weights: Dict[str, torch.Tensor],
  56. x: torch.Tensor,
  57. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  58. weight = weights["weight"]
  59. if self.separate_bias_add:
  60. if bias:
  61. return F.linear(x, weight) + bias
  62. return F.linear(x, weight)
  63. return F.linear(x, weight, bias)
  64. def apply_embedding(self, weights: Dict[str, torch.Tensor],
  65. x: torch.Tensor) -> torch.Tensor:
  66. weight = weights["weight"]
  67. return F.embedding(x, weight)
  68. class ReplicatedLinear(torch.nn.Module):
  69. """Replicated linear layer.
  70. Args:
  71. input_size: input dimension of the linear layer.
  72. output_size: output dimension of the linear layer.
  73. bias: If true, add bias.
  74. skip_bias_add: If true, skip adding bias but instead return it.
  75. params_dtype: Data type for the parameters.
  76. linear_method: (Maybe quantized) linear method.
  77. """
  78. def __init__(
  79. self,
  80. input_size: int,
  81. output_size: int,
  82. bias: bool = True,
  83. skip_bias_add: bool = False,
  84. params_dtype: Optional[torch.dtype] = None,
  85. linear_method: Optional[LinearMethodBase] = None,
  86. ):
  87. super().__init__()
  88. # Keep input parameters
  89. self.input_size = input_size
  90. self.output_size = output_size
  91. self.skip_bias_add = skip_bias_add
  92. if params_dtype is None:
  93. params_dtype = torch.get_default_dtype()
  94. self.params_dtype = params_dtype
  95. if linear_method is None:
  96. linear_method = UnquantizedLinearMethod()
  97. self.linear_method = linear_method
  98. self.linear_weights = self.linear_method.create_weights(
  99. self.input_size, [self.output_size], self.input_size,
  100. self.output_size, self.params_dtype)
  101. for name, weight in self.linear_weights.items():
  102. if isinstance(weight, torch.nn.parameter.Parameter):
  103. self.register_parameter(name, weight)
  104. if bias:
  105. self.bias = Parameter(
  106. torch.empty(self.output_size, dtype=self.params_dtype))
  107. set_weight_attrs(self.bias, {"output_dim": 0})
  108. else:
  109. self.register_parameter("bias", None)
  110. def forward(self, x: torch.Tensor) -> torch.Tensor:
  111. bias = self.bias if not self.skip_bias_add else None
  112. output = self.linear_method.apply_weights(self.linear_weights, x, bias)
  113. output_bias = self.bias if self.skip_bias_add else None
  114. return output, output_bias
  115. class ColumnParallelLinear(torch.nn.Module):
  116. """Linear layer with column parallelism.
  117. The linear layer is defined as Y = XA + b. A is parallelized along
  118. its second dimension as A = [A_1, ..., A_p].
  119. Args:
  120. input_size: first dimension of matrix A.
  121. output_size: second dimension of matrix A.
  122. bias: If true, add bias.
  123. gather_output: If true, call all-gather on output and make Y available
  124. to all GPUs, otherwise, every GPU will have its output
  125. which is Y_i = XA_i
  126. skip_bias_add: This was added to enable performance optimizations where
  127. bias can be fused with other element-wise operations. we
  128. skip adding bias but instead return it.
  129. params_dtype: Data type for the parameters.
  130. linear_method: (Maybe quantized) linear method.
  131. output_sizes: list of output sizes packed into one output, like for
  132. QKV the list would be size 3.
  133. """
  134. def __init__(
  135. self,
  136. input_size: int,
  137. output_size: int,
  138. bias: bool = True,
  139. gather_output: bool = False,
  140. skip_bias_add: bool = False,
  141. params_dtype: Optional[torch.dtype] = None,
  142. linear_method: Optional[LinearMethodBase] = None,
  143. output_sizes: Optional[List[int]] = None,
  144. ):
  145. super().__init__()
  146. # Keep input parameters
  147. self.input_size = input_size
  148. self.output_size = output_size
  149. self.gather_output = gather_output
  150. # Divide the weight matrix along the last dimension.
  151. tp_size = get_tensor_model_parallel_world_size()
  152. self.output_size_per_partition = divide(output_size, tp_size)
  153. self.skip_bias_add = skip_bias_add
  154. if params_dtype is None:
  155. params_dtype = torch.get_default_dtype()
  156. self.params_dtype = params_dtype
  157. if linear_method is None:
  158. linear_method = UnquantizedLinearMethod()
  159. if output_sizes is None:
  160. output_sizes = [output_size]
  161. self.linear_method = linear_method
  162. self.linear_weights = self.linear_method.create_weights(
  163. self.input_size, [x // tp_size for x in output_sizes],
  164. self.input_size, self.output_size, self.params_dtype)
  165. for name, weight in self.linear_weights.items():
  166. if isinstance(weight, torch.nn.parameter.Parameter):
  167. self.register_parameter(name, weight)
  168. set_weight_attrs(weight, {"weight_loader": self.weight_loader})
  169. if bias:
  170. self.bias = Parameter(
  171. torch.empty(self.output_size_per_partition,
  172. dtype=params_dtype))
  173. set_weight_attrs(self.bias, {
  174. "output_dim": 0,
  175. "weight_loader": self.weight_loader,
  176. })
  177. else:
  178. self.register_parameter("bias", None)
  179. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  180. tp_rank = get_tensor_model_parallel_rank()
  181. tp_size = get_tensor_model_parallel_world_size()
  182. output_dim = getattr(param, "output_dim", None)
  183. param_data = param.data
  184. if output_dim is not None:
  185. if loaded_weight.shape[output_dim] % tp_size != 0:
  186. raise ValueError(
  187. "Size is not aligned with the quantized weight shape")
  188. shard_size = loaded_weight.shape[output_dim] // tp_size
  189. start_idx = tp_rank * shard_size
  190. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  191. shard_size)
  192. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  193. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  194. param_data = param.data
  195. assert param_data.shape == loaded_weight.shape
  196. param_data.copy_(loaded_weight)
  197. def forward(self, input_):
  198. bias = self.bias if not self.skip_bias_add else None
  199. # Matrix multiply.
  200. output_parallel = self.linear_method.apply_weights(
  201. self.linear_weights, input_, bias)
  202. if self.gather_output:
  203. # All-gather across the partitions.
  204. output = tensor_model_parallel_all_gather(output_parallel)
  205. else:
  206. output = output_parallel
  207. output_bias = self.bias if self.skip_bias_add else None
  208. return output, output_bias
  209. class MergedColumnParallelLinear(ColumnParallelLinear):
  210. """Packed linear layers with column parallelism.
  211. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  212. along the output dimension. When the weight matrix is loaded, the
  213. different partitions are sharded separately.
  214. Args:
  215. input_size: input dimension of the linear layer.
  216. output_sizes: list of output dimensions of the linear layer.
  217. bias: If true, add bias.
  218. gather_output: If true, call all-gather on output and make the output
  219. available to all GPUs, otherwise, every GPU will have
  220. its own output.
  221. skip_bias_add: This was added to enable performance optimizations where
  222. bias can be fused with other element-wise operations. we
  223. skip adding bias but instead return it.
  224. params_dtype: Data type for the parameters.
  225. linear_method: (Maybe quantized) linear method.
  226. """
  227. def __init__(
  228. self,
  229. input_size: int,
  230. output_sizes: List[int],
  231. bias: bool = True,
  232. gather_output: bool = False,
  233. skip_bias_add: bool = False,
  234. params_dtype: Optional[torch.dtype] = None,
  235. linear_method: Optional[LinearMethodBase] = None,
  236. ):
  237. self.output_sizes = output_sizes
  238. tp_size = get_tensor_model_parallel_world_size()
  239. assert all(output_size % tp_size == 0 for output_size in output_sizes)
  240. super().__init__(input_size, sum(output_sizes), bias, gather_output,
  241. skip_bias_add, params_dtype, linear_method,
  242. self.output_sizes)
  243. def weight_loader(self,
  244. param: Parameter,
  245. loaded_weight: torch.Tensor,
  246. loaded_shard_id: Optional[int] = None):
  247. param_data = param.data
  248. output_dim = getattr(param, "output_dim", None)
  249. is_metadata = getattr(param, "is_metadata", False)
  250. if loaded_shard_id is None:
  251. # Loaded weight is already packed.
  252. if output_dim is None:
  253. assert param_data.shape == loaded_weight.shape
  254. param_data.copy_(loaded_weight)
  255. return
  256. current_shard_offset = 0
  257. shard_offsets = []
  258. for i, output_size in enumerate(self.output_sizes):
  259. shard_offsets.append((i, current_shard_offset, output_size))
  260. current_shard_offset += output_size
  261. packed_dim = getattr(param, "packed_dim", None)
  262. for shard_id, shard_offset, shard_size in shard_offsets:
  263. # If quantized, we need to adjust the offset and size to account
  264. # for the packing.
  265. if packed_dim == output_dim:
  266. shard_size = shard_size // param.pack_factor
  267. shard_offset = shard_offset // param.pack_factor
  268. # If marlin, we need to adjust the offset and size to
  269. # account for the tiling.
  270. shard_size, shard_offset = adjust_marlin_shard(
  271. param, shard_size, shard_offset)
  272. loaded_weight_shard = loaded_weight.narrow(
  273. output_dim, shard_offset, shard_size)
  274. self.weight_loader(param, loaded_weight_shard, shard_id)
  275. return
  276. assert loaded_shard_id < len(self.output_sizes)
  277. tp_rank = get_tensor_model_parallel_rank()
  278. tp_size = get_tensor_model_parallel_world_size()
  279. if output_dim is not None:
  280. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  281. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  282. # If quantized, we need to adjust the offset and size to account
  283. # for the packing.
  284. packed_dim = getattr(param, "packed_dim", None)
  285. if packed_dim == output_dim:
  286. shard_size = shard_size // param.pack_factor
  287. shard_offset = shard_offset // param.pack_factor
  288. # If marlin, we need to adjust the offset and size to account
  289. # for the tiling.
  290. shard_size, shard_offset = adjust_marlin_shard(
  291. param, shard_size, shard_offset)
  292. param_data = param_data.narrow(output_dim, shard_offset,
  293. shard_size)
  294. start_idx = tp_rank * shard_size
  295. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  296. shard_size)
  297. elif is_metadata:
  298. # metadata indicates fixed size concatenated along dim 0
  299. shard_size = loaded_weight.shape[0]
  300. shard_offset = loaded_shard_id * shard_size
  301. param_data = param_data.narrow(0, shard_offset, shard_size)
  302. else:
  303. ignore_warning = getattr(param, "ignore_warning", False)
  304. if not ignore_warning:
  305. logger.warning(
  306. "Loading a weight without `output_dim` attribute in "
  307. "MergedColumnParallelLinear, assume the weight is "
  308. "the same for all partitions.")
  309. assert param_data.shape == loaded_weight.shape
  310. param_data.copy_(loaded_weight)
  311. class QKVParallelLinear(ColumnParallelLinear):
  312. """Linear layers for the attention's QKV transformation.
  313. Linear layers for the linear transformation of the query, key, and value
  314. vectors in the attention layer. The weight matrix is concatenated along
  315. the output dimension. The layer is parallelized along the head dimension.
  316. When the number of key/value heads is smaller than the number of query
  317. heads (e.g., multi-query/grouped-query attention), the key/value head may
  318. be replicated while the query heads are partitioned.
  319. Args:
  320. hidden_size: input hidden state size of the transformer.
  321. head_size: size of each attention head.
  322. total_num_heads: total number of attention query heads.
  323. total_num_kv_heads: total number of attention key/value heads. If
  324. None, assume total_num_kv_heads = total_num_heads.
  325. bias: If true, add bias.
  326. skip_bias_add: This was added to enable performance optimizations where
  327. bias can be fused with other element-wise operations. we
  328. skip adding bias but instead return it.
  329. params_dtype: Data type for the parameters.
  330. linear_method: (Maybe quantized) linear method.
  331. """
  332. def __init__(
  333. self,
  334. hidden_size: int,
  335. head_size: int,
  336. total_num_heads: int,
  337. total_num_kv_heads: Optional[int] = None,
  338. bias: bool = True,
  339. skip_bias_add: bool = False,
  340. params_dtype: Optional[torch.dtype] = None,
  341. linear_method: Optional[LinearMethodBase] = None,
  342. ):
  343. self.hidden_size = hidden_size
  344. self.head_size = head_size
  345. self.total_num_heads = total_num_heads
  346. if total_num_kv_heads is None:
  347. total_num_kv_heads = total_num_heads
  348. self.total_num_kv_heads = total_num_kv_heads
  349. # Divide the weight matrix along the last dimension.
  350. tp_size = get_tensor_model_parallel_world_size()
  351. self.num_heads = divide(self.total_num_heads, tp_size)
  352. if tp_size >= self.total_num_kv_heads:
  353. self.num_kv_heads = 1
  354. self.num_kv_head_replicas = divide(tp_size,
  355. self.total_num_kv_heads)
  356. else:
  357. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  358. self.num_kv_head_replicas = 1
  359. input_size = self.hidden_size
  360. output_size = (self.num_heads +
  361. 2 * self.num_kv_heads) * tp_size * self.head_size
  362. super().__init__(input_size, output_size, bias, False, skip_bias_add,
  363. params_dtype, linear_method, [
  364. self.num_heads * tp_size * self.head_size,
  365. self.num_kv_heads * tp_size * self.head_size,
  366. self.num_kv_heads * tp_size * self.head_size
  367. ])
  368. def weight_loader(self,
  369. param: Parameter,
  370. loaded_weight: torch.Tensor,
  371. loaded_shard_id: Optional[str] = None):
  372. param_data = param.data
  373. output_dim = getattr(param, "output_dim", None)
  374. is_metadata = getattr(param, "is_metadata", False)
  375. if loaded_shard_id is None:
  376. # Loaded weight is already packed.
  377. if output_dim is None:
  378. assert param_data.shape == loaded_weight.shape
  379. param_data.copy_(loaded_weight)
  380. return
  381. shard_offsets = [
  382. # (shard_id, shard_offset, shard_size)
  383. ("q", 0, self.total_num_heads * self.head_size),
  384. ("k", self.total_num_heads * self.head_size,
  385. self.total_num_kv_heads * self.head_size),
  386. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  387. self.head_size, self.total_num_kv_heads * self.head_size),
  388. ]
  389. packed_dim = getattr(param, "packed_dim", None)
  390. for shard_id, shard_offset, shard_size in shard_offsets:
  391. # If quantized, we need to adjust the offset and size to account
  392. # for the packing.
  393. if packed_dim == output_dim:
  394. shard_size = shard_size // param.pack_factor
  395. shard_offset = shard_offset // param.pack_factor
  396. # If marlin, we need to adjust the offset and size to
  397. # account for the tiling.
  398. shard_size, shard_offset = adjust_marlin_shard(
  399. param, shard_size, shard_offset)
  400. loaded_weight_shard = loaded_weight.narrow(
  401. output_dim, shard_offset, shard_size)
  402. self.weight_loader(param, loaded_weight_shard, shard_id)
  403. return
  404. tp_rank = get_tensor_model_parallel_rank()
  405. assert loaded_shard_id in ["q", "k", "v"]
  406. if output_dim is not None:
  407. if loaded_shard_id == "q":
  408. shard_offset = 0
  409. shard_size = self.num_heads * self.head_size
  410. elif loaded_shard_id == "k":
  411. shard_offset = self.num_heads * self.head_size
  412. shard_size = self.num_kv_heads * self.head_size
  413. elif loaded_shard_id == "v":
  414. shard_offset = (self.num_heads +
  415. self.num_kv_heads) * self.head_size
  416. shard_size = self.num_kv_heads * self.head_size
  417. # If quantized, we need to adjust the offset and size to account
  418. # for the packing.
  419. packed_dim = getattr(param, "packed_dim", None)
  420. if packed_dim == output_dim:
  421. shard_size = shard_size // param.pack_factor
  422. shard_offset = shard_offset // param.pack_factor
  423. # If marlin, we need to adjust the offset and size to account
  424. # for the tiling.
  425. shard_size, shard_offset = adjust_marlin_shard(
  426. param, shard_size, shard_offset)
  427. param_data = param_data.narrow(output_dim, shard_offset,
  428. shard_size)
  429. if loaded_shard_id == "q":
  430. shard_id = tp_rank
  431. else:
  432. shard_id = tp_rank // self.num_kv_head_replicas
  433. start_idx = shard_id * shard_size
  434. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  435. shard_size)
  436. elif is_metadata:
  437. # metadata indicates fixed size concatenated along dim 0
  438. shard_size = loaded_weight.shape[0]
  439. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  440. param_data = param_data.narrow(0, shard_index * shard_size,
  441. shard_size)
  442. else:
  443. ignore_warning = getattr(param, "ignore_warning", False)
  444. if not ignore_warning:
  445. logger.warning(
  446. "Loading a weight without `output_dim` attribute in "
  447. "QKVParallelLinear, assume the weight is the same "
  448. "for all partitions.")
  449. assert param_data.shape == loaded_weight.shape
  450. param_data.copy_(loaded_weight)
  451. class RowParallelLinear(torch.nn.Module):
  452. """Linear layer with row parallelism.
  453. The linear layer is defined as Y = XA + b. A is parallelized along
  454. its first dimension and X along its second dimension as:
  455. - -
  456. | A_1 |
  457. | . |
  458. A = | . | X = [X_1, ..., X_p]
  459. | . |
  460. | A_p |
  461. - -
  462. Arguments:
  463. input_size: first dimension of matrix A.
  464. output_size: second dimension of matrix A.
  465. bias: If true, add bias. Note that bias is not parallelized.
  466. input_is_parallel: If true, we assume that the input is already
  467. split across the GPUs and we do not split
  468. again.
  469. skip_bias_add: This was added to enable performance optimization where
  470. bias can be fused with other element-wise operations.
  471. We skip adding bias but instead return it.
  472. params_dtype: Data type for the parameters.
  473. linear_method: (Maybe quantized) linear method.
  474. """
  475. def __init__(
  476. self,
  477. input_size: int,
  478. output_size: int,
  479. bias: bool = True,
  480. input_is_parallel: bool = True,
  481. skip_bias_add: bool = False,
  482. params_dtype: Optional[torch.dtype] = None,
  483. reduce_results: bool = True,
  484. linear_method: Optional[LinearMethodBase] = None,
  485. ):
  486. super().__init__()
  487. # Keep input parameters
  488. self.input_size = input_size
  489. self.output_size = output_size
  490. self.input_is_parallel = input_is_parallel
  491. self.reduce_results = reduce_results
  492. if params_dtype is None:
  493. params_dtype = torch.get_default_dtype()
  494. self.params_dtype = params_dtype
  495. # Divide the weight matrix along the last dimension.
  496. self.tp_size = get_tensor_model_parallel_world_size()
  497. self.input_size_per_partition = divide(input_size, self.tp_size)
  498. self.skip_bias_add = skip_bias_add
  499. if linear_method is None:
  500. linear_method = UnquantizedLinearMethod()
  501. self.linear_method = linear_method
  502. self.linear_weights = self.linear_method.create_weights(
  503. self.input_size_per_partition, [self.output_size], self.input_size,
  504. self.output_size, self.params_dtype)
  505. for name, weight in self.linear_weights.items():
  506. if isinstance(weight, torch.nn.parameter.Parameter):
  507. self.register_parameter(name, weight)
  508. set_weight_attrs(weight, {"weight_loader": self.weight_loader})
  509. if not reduce_results and (bias and not skip_bias_add):
  510. raise ValueError("When not reduce the results, adding bias to the "
  511. "results can lead to incorrect results")
  512. if bias:
  513. self.bias = Parameter(
  514. torch.empty(self.output_size, dtype=params_dtype))
  515. set_weight_attrs(self.bias, {
  516. "output_dim": 0,
  517. "weight_loader": self.weight_loader,
  518. })
  519. else:
  520. self.register_parameter("bias", None)
  521. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  522. tp_rank = get_tensor_model_parallel_rank()
  523. tp_size = get_tensor_model_parallel_world_size()
  524. input_dim = getattr(param, "input_dim", None)
  525. param_data = param.data
  526. if input_dim is not None:
  527. if loaded_weight.shape[input_dim] % tp_size != 0:
  528. raise ValueError(
  529. "Size is not aligned with the quantized weight shape")
  530. shard_size = loaded_weight.shape[input_dim] // tp_size
  531. start_idx = tp_rank * shard_size
  532. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  533. shard_size)
  534. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  535. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  536. param_data = param.data
  537. assert param_data.shape == loaded_weight.shape
  538. param_data.copy_(loaded_weight)
  539. def forward(self, input_):
  540. # Set up backprop all-reduce.
  541. if self.input_is_parallel:
  542. input_parallel = input_
  543. else:
  544. tp_rank = get_tensor_model_parallel_rank()
  545. splitted_input = split_tensor_along_last_dim(
  546. input_, num_partitions=self.tp_size)
  547. input_parallel = splitted_input[tp_rank].contiguous()
  548. # Matrix multiply.
  549. output_parallel = self.linear_method.apply_weights(
  550. self.linear_weights, input_parallel)
  551. if self.reduce_results and self.tp_size > 1:
  552. output_ = tensor_model_parallel_all_reduce(output_parallel)
  553. else:
  554. output_ = output_parallel
  555. if not self.skip_bias_add:
  556. output = output_ + self.bias if self.bias is not None else output_
  557. output_bias = None
  558. else:
  559. output = output_
  560. output_bias = self.bias
  561. return output, output_bias