linear.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter
  7. from aphrodite.distributed import (
  8. divide,
  9. get_tensor_model_parallel_rank,
  10. get_tensor_model_parallel_world_size,
  11. split_tensor_along_last_dim,
  12. tensor_model_parallel_all_gather,
  13. tensor_model_parallel_all_reduce,
  14. )
  15. from aphrodite.modeling.layers.fused_moe import fused_moe
  16. from aphrodite.modeling.utils import set_weight_attrs
  17. def adjust_marlin_shard(param, shard_size, shard_offset):
  18. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  19. if marlin_tile_size is None:
  20. return shard_size, shard_offset
  21. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  22. class LinearMethodBase(ABC):
  23. """Base class for different (maybe quantized) linear methods."""
  24. @abstractmethod
  25. def create_weights(self, input_size_per_partition: int,
  26. output_partition_sizes: List[int], input_size: int,
  27. output_size: int,
  28. params_dtype: torch.dtype) -> Dict[str, Any]:
  29. """Create weights for a linear layer."""
  30. raise NotImplementedError
  31. @abstractmethod
  32. def apply_weights(self,
  33. weights: Dict[str, torch.Tensor],
  34. x: torch.Tensor,
  35. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  36. """Apply the weights to the input tensor."""
  37. raise NotImplementedError
  38. def create_moe_weights(self, num_experts: int,
  39. input_size_per_partition: int,
  40. output_partition_sizes: List[int], input_size: int,
  41. output_size: int,
  42. params_dtype: torch.dtype) -> Dict[str, Any]:
  43. """Creating moe weights"""
  44. linear_weights = self.create_weights(input_size_per_partition,
  45. output_partition_sizes,
  46. input_size, output_size,
  47. params_dtype)
  48. if num_experts == 1:
  49. return linear_weights
  50. for name, param in tuple(linear_weights.items()):
  51. if isinstance(param, Parameter):
  52. repeat_size = (num_experts, ) + (1, ) * param.dim()
  53. new_param = Parameter(param.unsqueeze(0).repeat(*repeat_size),
  54. requires_grad=False)
  55. set_weight_attrs(new_param, param.__dict__)
  56. linear_weights[name] = new_param
  57. return linear_weights
  58. @abstractmethod
  59. def apply_moe_weights(self, w1: Dict[str,
  60. torch.Tensor], w2: Dict[str,
  61. torch.Tensor],
  62. x: torch.Tensor, gating_output: torch.Tensor,
  63. topk: int, renormalize: bool) -> torch.Tensor:
  64. """Apply the weights to the input tensor."""
  65. raise NotImplementedError
  66. class UnquantizedLinearMethod(LinearMethodBase):
  67. """Linear method without quantization.
  68. Args:
  69. separate_bias_add: If true, add bias separately after matrix
  70. multiplication.
  71. """
  72. def __init__(self, separate_bias_add: bool = False):
  73. self.separate_bias_add = separate_bias_add
  74. def create_weights(self, input_size_per_partition: int,
  75. output_partition_sizes: List[int], input_size: int,
  76. output_size: int,
  77. params_dtype: torch.dtype) -> Dict[str, Any]:
  78. output_size_per_partition = sum(output_partition_sizes)
  79. weight = Parameter(torch.empty(output_size_per_partition,
  80. input_size_per_partition,
  81. dtype=params_dtype),
  82. requires_grad=False)
  83. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  84. return {"weight": weight}
  85. def apply_weights(self,
  86. weights: Dict[str, torch.Tensor],
  87. x: torch.Tensor,
  88. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  89. weight = weights["weight"]
  90. if self.separate_bias_add:
  91. if bias is not None:
  92. return F.linear(x, weight) + bias
  93. return F.linear(x, weight)
  94. return F.linear(x, weight, bias)
  95. def apply_embedding(self, weights: Dict[str, torch.Tensor],
  96. x: torch.Tensor) -> torch.Tensor:
  97. weight = weights["weight"]
  98. return F.embedding(x, weight)
  99. def apply_moe_weights(self, w1: Dict[str,
  100. torch.Tensor], w2: Dict[str,
  101. torch.Tensor],
  102. x: torch.Tensor, gating_output: torch.Tensor,
  103. topk: int, renormalize: bool) -> torch.Tensor:
  104. return fused_moe(x, w1["weight"], w2["weight"], gating_output, topk,
  105. renormalize)
  106. class ReplicatedLinear(torch.nn.Module):
  107. """Replicated linear layer.
  108. Args:
  109. input_size: input dimension of the linear layer.
  110. output_size: output dimension of the linear layer.
  111. bias: If true, add bias.
  112. skip_bias_add: If true, skip adding bias but instead return it.
  113. params_dtype: Data type for the parameters.
  114. linear_method: (Maybe quantized) linear method.
  115. """
  116. def __init__(
  117. self,
  118. input_size: int,
  119. output_size: int,
  120. bias: bool = True,
  121. skip_bias_add: bool = False,
  122. params_dtype: Optional[torch.dtype] = None,
  123. linear_method: Optional[LinearMethodBase] = None,
  124. ):
  125. super().__init__()
  126. # Keep input parameters
  127. self.input_size = input_size
  128. self.output_size = output_size
  129. self.skip_bias_add = skip_bias_add
  130. if params_dtype is None:
  131. params_dtype = torch.get_default_dtype()
  132. self.params_dtype = params_dtype
  133. if linear_method is None:
  134. linear_method = UnquantizedLinearMethod()
  135. self.linear_method = linear_method
  136. self.linear_weights = self.linear_method.create_weights(
  137. self.input_size, [self.output_size], self.input_size,
  138. self.output_size, self.params_dtype)
  139. for name, weight in self.linear_weights.items():
  140. if isinstance(weight, torch.nn.parameter.Parameter):
  141. self.register_parameter(name, weight)
  142. if bias:
  143. self.bias = Parameter(
  144. torch.empty(self.output_size, dtype=self.params_dtype))
  145. set_weight_attrs(self.bias, {"output_dim": 0})
  146. else:
  147. self.register_parameter("bias", None)
  148. def forward(self, x: torch.Tensor) -> torch.Tensor:
  149. bias = self.bias if not self.skip_bias_add else None
  150. output = self.linear_method.apply_weights(self.linear_weights, x, bias)
  151. output_bias = self.bias if self.skip_bias_add else None
  152. return output, output_bias
  153. class ColumnParallelLinear(torch.nn.Module):
  154. """Linear layer with column parallelism.
  155. The linear layer is defined as Y = XA + b. A is parallelized along
  156. its second dimension as A = [A_1, ..., A_p].
  157. Args:
  158. input_size: first dimension of matrix A.
  159. output_size: second dimension of matrix A.
  160. bias: If true, add bias.
  161. gather_output: If true, call all-gather on output and make Y available
  162. to all GPUs, otherwise, every GPU will have its output
  163. which is Y_i = XA_i
  164. skip_bias_add: This was added to enable performance optimizations where
  165. bias can be fused with other element-wise operations. we
  166. skip adding bias but instead return it.
  167. params_dtype: Data type for the parameters.
  168. linear_method: (Maybe quantized) linear method.
  169. output_sizes: list of output sizes packed into one output, like for
  170. QKV the list would be size 3.
  171. """
  172. def __init__(
  173. self,
  174. input_size: int,
  175. output_size: int,
  176. bias: bool = True,
  177. gather_output: bool = False,
  178. skip_bias_add: bool = False,
  179. params_dtype: Optional[torch.dtype] = None,
  180. linear_method: Optional[LinearMethodBase] = None,
  181. output_sizes: Optional[List[int]] = None,
  182. num_experts: int = 1,
  183. ):
  184. super().__init__()
  185. # Keep input parameters
  186. self.input_size = input_size
  187. self.output_size = output_size
  188. self.gather_output = gather_output
  189. # Divide the weight matrix along the last dimension.
  190. tp_size = get_tensor_model_parallel_world_size()
  191. self.output_size_per_partition = divide(output_size, tp_size)
  192. self.skip_bias_add = skip_bias_add
  193. if params_dtype is None:
  194. params_dtype = torch.get_default_dtype()
  195. self.params_dtype = params_dtype
  196. if linear_method is None:
  197. linear_method = UnquantizedLinearMethod()
  198. if output_sizes is None:
  199. output_sizes = [output_size]
  200. self.linear_method = linear_method
  201. self.num_experts = num_experts
  202. self.linear_weights = self.linear_method.create_moe_weights(
  203. num_experts, self.input_size, [x // tp_size for x in output_sizes],
  204. self.input_size, self.output_size, self.params_dtype)
  205. for name, weight in self.linear_weights.items():
  206. if isinstance(weight, torch.nn.parameter.Parameter):
  207. self.register_parameter(name, weight)
  208. set_weight_attrs(weight, {"weight_loader": self.weight_loader})
  209. if bias:
  210. self.bias = Parameter(
  211. torch.empty(self.output_size_per_partition,
  212. dtype=params_dtype))
  213. set_weight_attrs(self.bias, {
  214. "output_dim": 0,
  215. "weight_loader": self.weight_loader,
  216. })
  217. else:
  218. self.register_parameter("bias", None)
  219. def weight_loader(self,
  220. param: Parameter,
  221. loaded_weight: torch.Tensor,
  222. expert_id: int = -1):
  223. tp_rank = get_tensor_model_parallel_rank()
  224. tp_size = get_tensor_model_parallel_world_size()
  225. output_dim = getattr(param, "output_dim", None)
  226. param_data = param.data
  227. if self.num_experts > 1:
  228. if expert_id >= 0:
  229. param_data = param_data[expert_id]
  230. # Loaded weight is packed at expert dim
  231. else:
  232. output_dim = output_dim + 1
  233. if output_dim is not None:
  234. if loaded_weight.shape[output_dim] % tp_size != 0:
  235. raise ValueError(
  236. "Size is not aligned with the quantized weight shape")
  237. shard_size = loaded_weight.shape[output_dim] // tp_size
  238. start_idx = tp_rank * shard_size
  239. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  240. shard_size)
  241. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  242. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  243. param_data = param.data
  244. assert param_data.shape == loaded_weight.shape
  245. param_data.copy_(loaded_weight)
  246. def forward(self, input_):
  247. bias = self.bias if not self.skip_bias_add else None
  248. # Matrix multiply.
  249. output_parallel = self.linear_method.apply_weights(
  250. self.linear_weights, input_, bias)
  251. if self.gather_output:
  252. # All-gather across the partitions.
  253. output = tensor_model_parallel_all_gather(output_parallel)
  254. else:
  255. output = output_parallel
  256. output_bias = self.bias if self.skip_bias_add else None
  257. return output, output_bias
  258. class MergedColumnParallelLinear(ColumnParallelLinear):
  259. """Packed linear layers with column parallelism.
  260. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  261. along the output dimension. When the weight matrix is loaded, the
  262. different partitions are sharded separately.
  263. Args:
  264. input_size: input dimension of the linear layer.
  265. output_sizes: list of output dimensions of the linear layer.
  266. bias: If true, add bias.
  267. gather_output: If true, call all-gather on output and make the output
  268. available to all GPUs, otherwise, every GPU will have
  269. its own output.
  270. skip_bias_add: This was added to enable performance optimizations where
  271. bias can be fused with other element-wise operations. we
  272. skip adding bias but instead return it.
  273. params_dtype: Data type for the parameters.
  274. linear_method: (Maybe quantized) linear method.
  275. """
  276. def __init__(
  277. self,
  278. input_size: int,
  279. output_sizes: List[int],
  280. bias: bool = True,
  281. gather_output: bool = False,
  282. skip_bias_add: bool = False,
  283. params_dtype: Optional[torch.dtype] = None,
  284. linear_method: Optional[LinearMethodBase] = None,
  285. num_experts: int = 1,
  286. ):
  287. self.output_sizes = output_sizes
  288. tp_size = get_tensor_model_parallel_world_size()
  289. assert all(output_size % tp_size == 0 for output_size in output_sizes)
  290. super().__init__(input_size, sum(output_sizes), bias, gather_output,
  291. skip_bias_add, params_dtype, linear_method,
  292. self.output_sizes, num_experts)
  293. def weight_loader(self,
  294. param: Parameter,
  295. loaded_weight: torch.Tensor,
  296. loaded_shard_id: Optional[int] = None,
  297. expert_id: int = -1):
  298. param_data = param.data
  299. output_dim = getattr(param, "output_dim", None)
  300. is_metadata = getattr(param, "is_metadata", False)
  301. if self.num_experts > 1:
  302. if expert_id >= 0:
  303. param_data = param_data[expert_id]
  304. # Loaded weight is packed at expert dim
  305. elif output_dim is not None:
  306. output_dim = output_dim + 1
  307. if loaded_shard_id is None:
  308. # Loaded weight is already packed.
  309. if output_dim is None:
  310. assert param_data.shape == loaded_weight.shape
  311. param_data.copy_(loaded_weight)
  312. return
  313. current_shard_offset = 0
  314. shard_offsets = []
  315. for i, output_size in enumerate(self.output_sizes):
  316. shard_offsets.append((i, current_shard_offset, output_size))
  317. current_shard_offset += output_size
  318. packed_dim = getattr(param, "packed_dim", None)
  319. for shard_id, shard_offset, shard_size in shard_offsets:
  320. # If quantized, we need to adjust the offset and size to account
  321. # for the packing.
  322. if packed_dim == output_dim:
  323. shard_size = shard_size // param.pack_factor
  324. shard_offset = shard_offset // param.pack_factor
  325. # If marlin, we need to adjust the offset and size to
  326. # account for the tiling.
  327. shard_size, shard_offset = adjust_marlin_shard(
  328. param, shard_size, shard_offset)
  329. loaded_weight_shard = loaded_weight.narrow(
  330. output_dim, shard_offset, shard_size)
  331. self.weight_loader(param, loaded_weight_shard, shard_id)
  332. return
  333. assert loaded_shard_id < len(self.output_sizes)
  334. tp_rank = get_tensor_model_parallel_rank()
  335. tp_size = get_tensor_model_parallel_world_size()
  336. if output_dim is not None:
  337. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  338. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  339. # If quantized, we need to adjust the offset and size to account
  340. # for the packing.
  341. packed_dim = getattr(param, "packed_dim", None)
  342. if packed_dim == output_dim:
  343. shard_size = shard_size // param.pack_factor
  344. shard_offset = shard_offset // param.pack_factor
  345. # If marlin, we need to adjust the offset and size to
  346. # account for the tiling.
  347. shard_size, shard_offset = adjust_marlin_shard(
  348. param, shard_size, shard_offset)
  349. param_data = param_data.narrow(output_dim, shard_offset,
  350. shard_size)
  351. start_idx = tp_rank * shard_size
  352. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  353. shard_size)
  354. elif is_metadata:
  355. # metadata indicates fixed size concatenated along dim 0
  356. shard_size = loaded_weight.shape[0]
  357. shard_offset = loaded_shard_id * shard_size
  358. param_data = param_data.narrow(0, shard_offset, shard_size)
  359. else:
  360. ignore_warning = getattr(param, "ignore_warning", False)
  361. if not ignore_warning:
  362. logger.warning(
  363. "Loading a weight without `output_dim` attribute in "
  364. "MergedColumnParallelLinear, assume the weight is "
  365. "the same for all partitions.")
  366. assert param_data.shape == loaded_weight.shape
  367. param_data.copy_(loaded_weight)
  368. class QKVParallelLinear(ColumnParallelLinear):
  369. """Linear layers for the attention's QKV transformation.
  370. Linear layers for the linear transformation of the query, key, and value
  371. vectors in the attention layer. The weight matrix is concatenated along
  372. the output dimension. The layer is parallelized along the head dimension.
  373. When the number of key/value heads is smaller than the number of query
  374. heads (e.g., multi-query/grouped-query attention), the key/value head may
  375. be replicated while the query heads are partitioned.
  376. Args:
  377. hidden_size: input hidden state size of the transformer.
  378. head_size: size of each attention head.
  379. total_num_heads: total number of attention query heads.
  380. total_num_kv_heads: total number of attention key/value heads. If
  381. None, assume total_num_kv_heads = total_num_heads.
  382. bias: If true, add bias.
  383. skip_bias_add: This was added to enable performance optimizations where
  384. bias can be fused with other element-wise operations. we
  385. skip adding bias but instead return it.
  386. params_dtype: Data type for the parameters.
  387. linear_method: (Maybe quantized) linear method.
  388. """
  389. def __init__(
  390. self,
  391. hidden_size: int,
  392. head_size: int,
  393. total_num_heads: int,
  394. total_num_kv_heads: Optional[int] = None,
  395. bias: bool = True,
  396. skip_bias_add: bool = False,
  397. params_dtype: Optional[torch.dtype] = None,
  398. linear_method: Optional[LinearMethodBase] = None,
  399. ):
  400. self.hidden_size = hidden_size
  401. self.head_size = head_size
  402. self.total_num_heads = total_num_heads
  403. if total_num_kv_heads is None:
  404. total_num_kv_heads = total_num_heads
  405. self.total_num_kv_heads = total_num_kv_heads
  406. # Divide the weight matrix along the last dimension.
  407. tp_size = get_tensor_model_parallel_world_size()
  408. self.num_heads = divide(self.total_num_heads, tp_size)
  409. if tp_size >= self.total_num_kv_heads:
  410. self.num_kv_heads = 1
  411. self.num_kv_head_replicas = divide(tp_size,
  412. self.total_num_kv_heads)
  413. else:
  414. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  415. self.num_kv_head_replicas = 1
  416. input_size = self.hidden_size
  417. output_size = (self.num_heads +
  418. 2 * self.num_kv_heads) * tp_size * self.head_size
  419. super().__init__(input_size, output_size, bias, False, skip_bias_add,
  420. params_dtype, linear_method, [
  421. self.num_heads * tp_size * self.head_size,
  422. self.num_kv_heads * tp_size * self.head_size,
  423. self.num_kv_heads * tp_size * self.head_size
  424. ])
  425. def weight_loader(self,
  426. param: Parameter,
  427. loaded_weight: torch.Tensor,
  428. loaded_shard_id: Optional[str] = None):
  429. param_data = param.data
  430. output_dim = getattr(param, "output_dim", None)
  431. is_metadata = getattr(param, "is_metadata", False)
  432. if loaded_shard_id is None:
  433. # Loaded weight is already packed.
  434. if output_dim is None:
  435. assert param_data.shape == loaded_weight.shape
  436. param_data.copy_(loaded_weight)
  437. return
  438. shard_offsets = [
  439. # (shard_id, shard_offset, shard_size)
  440. ("q", 0, self.total_num_heads * self.head_size),
  441. ("k", self.total_num_heads * self.head_size,
  442. self.total_num_kv_heads * self.head_size),
  443. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  444. self.head_size, self.total_num_kv_heads * self.head_size),
  445. ]
  446. packed_dim = getattr(param, "packed_dim", None)
  447. for shard_id, shard_offset, shard_size in shard_offsets:
  448. # If quantized, we need to adjust the offset and size to account
  449. # for the packing.
  450. if packed_dim == output_dim:
  451. shard_size = shard_size // param.pack_factor
  452. shard_offset = shard_offset // param.pack_factor
  453. # If marlin, we need to adjust the offset and size to
  454. # account for the tiling.
  455. shard_size, shard_offset = adjust_marlin_shard(
  456. param, shard_size, shard_offset)
  457. loaded_weight_shard = loaded_weight.narrow(
  458. output_dim, shard_offset, shard_size)
  459. self.weight_loader(param, loaded_weight_shard, shard_id)
  460. return
  461. tp_rank = get_tensor_model_parallel_rank()
  462. assert loaded_shard_id in ["q", "k", "v"]
  463. if output_dim is not None:
  464. if loaded_shard_id == "q":
  465. shard_offset = 0
  466. shard_size = self.num_heads * self.head_size
  467. elif loaded_shard_id == "k":
  468. shard_offset = self.num_heads * self.head_size
  469. shard_size = self.num_kv_heads * self.head_size
  470. elif loaded_shard_id == "v":
  471. shard_offset = (self.num_heads +
  472. self.num_kv_heads) * self.head_size
  473. shard_size = self.num_kv_heads * self.head_size
  474. # If quantized, we need to adjust the offset and size to account
  475. # for the packing.
  476. packed_dim = getattr(param, "packed_dim", None)
  477. if packed_dim == output_dim:
  478. shard_size = shard_size // param.pack_factor
  479. shard_offset = shard_offset // param.pack_factor
  480. # If marlin, we need to adjust the offset and size to
  481. # account for the tiling.
  482. shard_size, shard_offset = adjust_marlin_shard(
  483. param, shard_size, shard_offset)
  484. param_data = param_data.narrow(output_dim, shard_offset,
  485. shard_size)
  486. if loaded_shard_id == "q":
  487. shard_id = tp_rank
  488. else:
  489. shard_id = tp_rank // self.num_kv_head_replicas
  490. start_idx = shard_id * shard_size
  491. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  492. shard_size)
  493. elif is_metadata:
  494. # metadata indicates fixed size concatenated along dim 0
  495. shard_size = loaded_weight.shape[0]
  496. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  497. param_data = param_data.narrow(0, shard_index * shard_size,
  498. shard_size)
  499. else:
  500. ignore_warning = getattr(param, "ignore_warning", False)
  501. if not ignore_warning:
  502. logger.warning(
  503. "Loading a weight without `output_dim` attribute in "
  504. "QKVParallelLinear, assume the weight is the same "
  505. "for all partitions.")
  506. assert param_data.shape == loaded_weight.shape
  507. param_data.copy_(loaded_weight)
  508. class RowParallelLinear(torch.nn.Module):
  509. """Linear layer with row parallelism.
  510. The linear layer is defined as Y = XA + b. A is parallelized along
  511. its first dimension and X along its second dimension as:
  512. - -
  513. | A_1 |
  514. | . |
  515. A = | . | X = [X_1, ..., X_p]
  516. | . |
  517. | A_p |
  518. - -
  519. Arguments:
  520. input_size: first dimension of matrix A.
  521. output_size: second dimension of matrix A.
  522. bias: If true, add bias. Note that bias is not parallelized.
  523. input_is_parallel: If true, we assume that the input is already
  524. split across the GPUs and we do not split
  525. again.
  526. skip_bias_add: This was added to enable performance optimization where
  527. bias can be fused with other element-wise operations.
  528. We skip adding bias but instead return it.
  529. params_dtype: Data type for the parameters.
  530. linear_method: (Maybe quantized) linear method.
  531. """
  532. def __init__(
  533. self,
  534. input_size: int,
  535. output_size: int,
  536. bias: bool = True,
  537. input_is_parallel: bool = True,
  538. skip_bias_add: bool = False,
  539. params_dtype: Optional[torch.dtype] = None,
  540. reduce_results: bool = True,
  541. linear_method: Optional[LinearMethodBase] = None,
  542. num_experts: int = 1,
  543. ):
  544. super().__init__()
  545. # Keep input parameters
  546. self.input_size = input_size
  547. self.output_size = output_size
  548. self.input_is_parallel = input_is_parallel
  549. self.reduce_results = reduce_results
  550. if params_dtype is None:
  551. params_dtype = torch.get_default_dtype()
  552. self.params_dtype = params_dtype
  553. # Divide the weight matrix along the last dimension.
  554. self.tp_size = get_tensor_model_parallel_world_size()
  555. self.input_size_per_partition = divide(input_size, self.tp_size)
  556. self.skip_bias_add = skip_bias_add
  557. if linear_method is None:
  558. linear_method = UnquantizedLinearMethod()
  559. self.linear_method = linear_method
  560. self.num_experts = num_experts
  561. self.linear_weights = self.linear_method.create_moe_weights(
  562. num_experts, self.input_size_per_partition, [self.output_size],
  563. self.input_size, self.output_size, self.params_dtype)
  564. for name, weight in self.linear_weights.items():
  565. if isinstance(weight, torch.nn.parameter.Parameter):
  566. self.register_parameter(name, weight)
  567. set_weight_attrs(weight, {"weight_loader": self.weight_loader})
  568. if not reduce_results and (bias and not skip_bias_add):
  569. raise ValueError("When not reduce the results, adding bias to the "
  570. "results can lead to incorrect results")
  571. if bias:
  572. self.bias = Parameter(
  573. torch.empty(self.output_size, dtype=params_dtype))
  574. set_weight_attrs(self.bias, {
  575. "output_dim": 0,
  576. "weight_loader": self.weight_loader,
  577. })
  578. else:
  579. self.register_parameter("bias", None)
  580. def weight_loader(self,
  581. param: Parameter,
  582. loaded_weight: torch.Tensor,
  583. expert_id: int = -1):
  584. tp_rank = get_tensor_model_parallel_rank()
  585. tp_size = get_tensor_model_parallel_world_size()
  586. input_dim = getattr(param, "input_dim", None)
  587. param_data = param.data
  588. if self.num_experts > 1:
  589. if expert_id >= 0:
  590. param_data = param_data[expert_id]
  591. # Loaded weight is packed at expert dim
  592. elif input_dim is not None:
  593. input_dim = input_dim + 1
  594. if input_dim is not None:
  595. if loaded_weight.shape[input_dim] % tp_size != 0:
  596. raise ValueError(
  597. "Size is not aligned with the quantized weight shape")
  598. shard_size = loaded_weight.shape[input_dim] // tp_size
  599. start_idx = tp_rank * shard_size
  600. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  601. shard_size)
  602. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  603. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  604. param_data = param.data
  605. assert param_data.shape == loaded_weight.shape
  606. param_data.copy_(loaded_weight)
  607. def forward(self, input_):
  608. is_exl2 = hasattr(
  609. self.linear_method, "quant_config"
  610. ) and self.linear_method.quant_config.get_name() == "exl2"
  611. if self.input_is_parallel and is_exl2:
  612. input_parallel = tensor_model_parallel_all_gather(input_)
  613. elif self.input_is_parallel or is_exl2:
  614. input_parallel = input_
  615. else:
  616. tp_rank = get_tensor_model_parallel_rank()
  617. splitted_input = split_tensor_along_last_dim(
  618. input_, num_partitions=self.tp_size)
  619. input_parallel = splitted_input[tp_rank].contiguous()
  620. # Matrix multiply.
  621. output_parallel = self.linear_method.apply_weights(
  622. self.linear_weights, input_parallel)
  623. if self.reduce_results and self.tp_size > 1:
  624. output_ = tensor_model_parallel_all_reduce(output_parallel)
  625. else:
  626. output_ = output_parallel
  627. if not self.skip_bias_add:
  628. output = output_ + self.bias if self.bias is not None else output_
  629. output_bias = None
  630. else:
  631. output = output_
  632. output_bias = self.bias
  633. return output, output_bias