linear.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, List, Optional
  3. import torch
  4. import torch.nn.functional as F
  5. from loguru import logger
  6. from torch.nn.parameter import Parameter
  7. from aphrodite.distributed import (
  8. divide,
  9. get_tensor_model_parallel_rank,
  10. get_tensor_model_parallel_world_size,
  11. split_tensor_along_last_dim,
  12. tensor_model_parallel_all_gather,
  13. tensor_model_parallel_all_reduce,
  14. )
  15. from aphrodite.modeling.layers.fused_moe import fused_moe
  16. from aphrodite.modeling.utils import set_weight_attrs
  17. def adjust_marlin_shard(param, shard_size, shard_offset):
  18. marlin_tile_size = getattr(param, "marlin_tile_size", None)
  19. if marlin_tile_size is None:
  20. return shard_size, shard_offset
  21. return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
  22. # Split the exl2 weight at group boundary
  23. def post_init_exl2(layer):
  24. q_groups = layer.q_groups
  25. num_groups = q_groups.shape[0] // 2
  26. input_size = layer.q_invperm.shape[0]
  27. tp_rank = get_tensor_model_parallel_rank()
  28. tp_size = get_tensor_model_parallel_world_size()
  29. rows = 0
  30. # row_number, qrow_number, group_number
  31. splits = [(0, 0, 0)]
  32. index = 1
  33. for i in range(num_groups - 1):
  34. bits = q_groups[i * 2].item()
  35. qrows = (q_groups[i * 2 + 3] - q_groups[i * 2 + 1]).item()
  36. rows += qrows * 32 // bits
  37. q_rows_end = q_groups[i * 2 + 3].item()
  38. if q_rows_end >= layer.q_weight.shape[0] // tp_size * index:
  39. splits.append((rows, q_rows_end, i + 1))
  40. index += 1
  41. splits.append((input_size, layer.q_weight.shape[0], num_groups))
  42. shard_qweight = torch.nn.Parameter(
  43. layer.q_weight[splits[tp_rank][1]:splits[tp_rank + 1][1]].clone(),
  44. requires_grad=False,
  45. )
  46. # the outer model.named_parameters() during loading still keeps a reference
  47. # so we have to force release memory
  48. layer.q_weight.data = torch.empty(0, device="cpu")
  49. del layer.q_weight
  50. layer.linear_weights["q_weight"] = shard_qweight.to(layer.device)
  51. shard_qgroups = torch.nn.Parameter(
  52. layer.q_groups[splits[tp_rank][2] * 2:splits[tp_rank + 1][2] *
  53. 2].clone(),
  54. requires_grad=False,
  55. )
  56. shard_qgroups[1::2] -= int(shard_qgroups[1])
  57. layer.q_groups.data = torch.empty(0, device="cpu")
  58. del layer.q_groups
  59. layer.linear_weights["q_groups"] = shard_qgroups.to(layer.device)
  60. shard_qscale = torch.nn.Parameter(
  61. layer.q_scale[splits[tp_rank][2]:splits[tp_rank + 1][2]].clone(),
  62. requires_grad=False,
  63. )
  64. layer.q_scale.data = torch.empty(0, device="cpu")
  65. del layer.q_scale
  66. layer.linear_weights["q_scale"] = shard_qscale.to(layer.device)
  67. shard_qscale_max = torch.nn.Parameter(
  68. layer.q_scale_max[splits[tp_rank][2]:splits[tp_rank + 1][2]].clone(),
  69. requires_grad=False,
  70. )
  71. layer.q_scale_max.data = torch.empty(0, device="cpu")
  72. del layer.q_scale_max
  73. layer.linear_weights["q_scale_max"] = shard_qscale_max.to(layer.device)
  74. q_perm = torch.argsort(layer.q_invperm).to(torch.short)
  75. # q_invperm is not used later, probably we should remove it too
  76. layer.linear_weights["q_perm"] = q_perm[
  77. splits[tp_rank][0]:splits[tp_rank + 1][0]].to(layer.device)
  78. class LinearMethodBase(ABC):
  79. """Base class for different (maybe quantized) linear methods."""
  80. @abstractmethod
  81. def create_weights(self, layer: torch.nn.Module,
  82. input_size_per_partition: int,
  83. output_partition_sizes: List[int], input_size: int,
  84. output_size: int, params_dtype: torch.dtype,
  85. **extra_weight_attrs):
  86. """Create weights for a linear layer."""
  87. raise NotImplementedError
  88. @abstractmethod
  89. def apply_weights(self,
  90. layer: torch.nn.Module,
  91. x: torch.Tensor,
  92. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  93. """Apply the weights in layer to the input tensor.
  94. Expects create_weights to have been called before on the layer."""
  95. raise NotImplementedError
  96. def create_moe_weights(self, layer: torch.nn.Module, num_experts: int,
  97. input_size_per_partition: int,
  98. output_partition_sizes: List[int], input_size: int,
  99. output_size: int, params_dtype: torch.dtype,
  100. **extra_weight_attrs):
  101. """Creating moe weights"""
  102. self.create_weights(layer, input_size_per_partition,
  103. output_partition_sizes, input_size, output_size,
  104. params_dtype, **extra_weight_attrs)
  105. if num_experts == 1:
  106. return
  107. raise AssertionError("Alpin fix this!!!")
  108. # for name, param in tuple(linear_weights.items()):
  109. # if isinstance(param, Parameter):
  110. # repeat_size = (num_experts, ) + (1, ) * param.dim()
  111. # new_param = Parameter(param.unsqueeze(0).repeat(*repeat_size),
  112. # requires_grad=False)
  113. # set_weight_attrs(new_param, param.__dict__)
  114. # linear_weights[name] = new_param
  115. # return linear_weights
  116. @abstractmethod
  117. def apply_moe_weights(self, w1: Dict[str,
  118. torch.Tensor], w2: Dict[str,
  119. torch.Tensor],
  120. x: torch.Tensor, gating_output: torch.Tensor,
  121. topk: int, renormalize: bool) -> torch.Tensor:
  122. """Apply the weights to the input tensor."""
  123. raise NotImplementedError
  124. class UnquantizedLinearMethod(LinearMethodBase):
  125. """Linear method without quantization.
  126. Args:
  127. separate_bias_add: If true, add bias separately after matrix
  128. multiplication.
  129. """
  130. def __init__(self, separate_bias_add: bool = False):
  131. self.separate_bias_add = separate_bias_add
  132. def create_weights(self, layer: torch.nn.Module,
  133. input_size_per_partition: int,
  134. output_partition_sizes: List[int], input_size: int,
  135. output_size: int, params_dtype: torch.dtype,
  136. **extra_weight_attrs):
  137. output_size_per_partition = sum(output_partition_sizes)
  138. weight = Parameter(torch.empty(output_size_per_partition,
  139. input_size_per_partition,
  140. dtype=params_dtype),
  141. requires_grad=False)
  142. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  143. layer.register_parameter("weight", weight)
  144. set_weight_attrs(weight, extra_weight_attrs)
  145. def apply_weights(self,
  146. layer: torch.nn.Module,
  147. x: torch.Tensor,
  148. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  149. weight = layer.weight
  150. if self.separate_bias_add:
  151. if bias is not None:
  152. return F.linear(x, weight) + bias
  153. return F.linear(x, weight)
  154. return F.linear(x, weight, bias)
  155. def apply_embedding(self, layer: torch.nn.Module,
  156. x: torch.Tensor) -> torch.Tensor:
  157. weight = layer.weight
  158. return F.embedding(x, weight)
  159. def apply_moe_weights(self, w1: Dict[str,
  160. torch.Tensor], w2: Dict[str,
  161. torch.Tensor],
  162. x: torch.Tensor, gating_output: torch.Tensor,
  163. topk: int, renormalize: bool) -> torch.Tensor:
  164. return fused_moe(x, w1["weight"], w2["weight"], gating_output, topk,
  165. renormalize)
  166. class ReplicatedLinear(torch.nn.Module):
  167. """Replicated linear layer.
  168. Args:
  169. input_size: input dimension of the linear layer.
  170. output_size: output dimension of the linear layer.
  171. bias: If true, add bias.
  172. skip_bias_add: If true, skip adding bias but instead return it.
  173. params_dtype: Data type for the parameters.
  174. linear_method: (Maybe quantized) linear method.
  175. """
  176. def __init__(
  177. self,
  178. input_size: int,
  179. output_size: int,
  180. bias: bool = True,
  181. skip_bias_add: bool = False,
  182. params_dtype: Optional[torch.dtype] = None,
  183. linear_method: Optional[LinearMethodBase] = None,
  184. ):
  185. super().__init__()
  186. # Keep input parameters
  187. self.input_size = input_size
  188. self.output_size = output_size
  189. self.skip_bias_add = skip_bias_add
  190. if params_dtype is None:
  191. params_dtype = torch.get_default_dtype()
  192. self.params_dtype = params_dtype
  193. if linear_method is None:
  194. linear_method = UnquantizedLinearMethod()
  195. self.linear_method.create_weights(self, self.input_size,
  196. [self.output_size], self.input_size,
  197. self.output_size, self.params_dtype)
  198. if bias:
  199. self.bias = Parameter(
  200. torch.empty(self.output_size, dtype=self.params_dtype))
  201. set_weight_attrs(self.bias, {"output_dim": 0})
  202. else:
  203. self.register_parameter("bias", None)
  204. def forward(self, x: torch.Tensor) -> torch.Tensor:
  205. bias = self.bias if not self.skip_bias_add else None
  206. output = self.linear_method.apply_weights(self, x, bias)
  207. output_bias = self.bias if self.skip_bias_add else None
  208. return output, output_bias
  209. class ColumnParallelLinear(torch.nn.Module):
  210. """Linear layer with column parallelism.
  211. The linear layer is defined as Y = XA + b. A is parallelized along
  212. its second dimension as A = [A_1, ..., A_p].
  213. Args:
  214. input_size: first dimension of matrix A.
  215. output_size: second dimension of matrix A.
  216. bias: If true, add bias.
  217. gather_output: If true, call all-gather on output and make Y available
  218. to all GPUs, otherwise, every GPU will have its output
  219. which is Y_i = XA_i
  220. skip_bias_add: This was added to enable performance optimizations where
  221. bias can be fused with other element-wise operations. we
  222. skip adding bias but instead return it.
  223. params_dtype: Data type for the parameters.
  224. linear_method: (Maybe quantized) linear method.
  225. output_sizes: list of output sizes packed into one output, like for
  226. QKV the list would be size 3.
  227. num_experts: number of experts for sparse moe models.
  228. """
  229. def __init__(
  230. self,
  231. input_size: int,
  232. output_size: int,
  233. bias: bool = True,
  234. gather_output: bool = False,
  235. skip_bias_add: bool = False,
  236. params_dtype: Optional[torch.dtype] = None,
  237. linear_method: Optional[LinearMethodBase] = None,
  238. output_sizes: Optional[List[int]] = None,
  239. num_experts: int = 1,
  240. ):
  241. super().__init__()
  242. # Keep input parameters
  243. self.input_size = input_size
  244. self.output_size = output_size
  245. self.gather_output = gather_output
  246. # Divide the weight matrix along the last dimension.
  247. tp_size = get_tensor_model_parallel_world_size()
  248. self.output_size_per_partition = divide(output_size, tp_size)
  249. self.skip_bias_add = skip_bias_add
  250. if params_dtype is None:
  251. params_dtype = torch.get_default_dtype()
  252. self.params_dtype = params_dtype
  253. if linear_method is None:
  254. linear_method = UnquantizedLinearMethod()
  255. if output_sizes is None:
  256. output_sizes = [output_size]
  257. self.linear_method = linear_method
  258. self.num_experts = num_experts
  259. self.linear_method.create_moe_weights(
  260. self,
  261. num_experts,
  262. self.input_size, [x // tp_size for x in output_sizes],
  263. self.input_size,
  264. self.output_size,
  265. self.params_dtype,
  266. weight_loader=self.weight_loader)
  267. if bias:
  268. self.bias = Parameter(
  269. torch.empty(self.output_size_per_partition,
  270. dtype=params_dtype))
  271. set_weight_attrs(self.bias, {
  272. "output_dim": 0,
  273. "weight_loader": self.weight_loader,
  274. })
  275. else:
  276. self.register_parameter("bias", None)
  277. def weight_loader(self,
  278. param: Parameter,
  279. loaded_weight: torch.Tensor,
  280. expert_id: int = -1):
  281. tp_rank = get_tensor_model_parallel_rank()
  282. tp_size = get_tensor_model_parallel_world_size()
  283. output_dim = getattr(param, "output_dim", None)
  284. param_data = param.data
  285. if self.num_experts > 1:
  286. if expert_id >= 0:
  287. param_data = param_data[expert_id]
  288. # Loaded weight is packed at expert dim
  289. else:
  290. output_dim = output_dim + 1
  291. if output_dim is not None:
  292. if loaded_weight.shape[output_dim] % tp_size != 0:
  293. raise ValueError(
  294. "Size is not aligned with the quantized weight shape")
  295. shard_size = loaded_weight.shape[output_dim] // tp_size
  296. start_idx = tp_rank * shard_size
  297. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  298. shard_size)
  299. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  300. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  301. param_data = param.data
  302. assert param_data.shape == loaded_weight.shape
  303. param_data.copy_(loaded_weight)
  304. def forward(self, input_):
  305. bias = self.bias if not self.skip_bias_add else None
  306. # Matrix multiply.
  307. output_parallel = self.linear_method.apply_weights(self, input_, bias)
  308. if self.gather_output:
  309. # All-gather across the partitions.
  310. output = tensor_model_parallel_all_gather(output_parallel)
  311. else:
  312. output = output_parallel
  313. output_bias = self.bias if self.skip_bias_add else None
  314. return output, output_bias
  315. class MergedColumnParallelLinear(ColumnParallelLinear):
  316. """Packed linear layers with column parallelism.
  317. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  318. along the output dimension. When the weight matrix is loaded, the
  319. different partitions are sharded separately.
  320. Args:
  321. input_size: input dimension of the linear layer.
  322. output_sizes: list of output dimensions of the linear layer.
  323. bias: If true, add bias.
  324. gather_output: If true, call all-gather on output and make the output
  325. available to all GPUs, otherwise, every GPU will have
  326. its own output.
  327. skip_bias_add: This was added to enable performance optimizations where
  328. bias can be fused with other element-wise operations. we
  329. skip adding bias but instead return it.
  330. params_dtype: Data type for the parameters.
  331. linear_method: (Maybe quantized) linear method.
  332. """
  333. def __init__(
  334. self,
  335. input_size: int,
  336. output_sizes: List[int],
  337. bias: bool = True,
  338. gather_output: bool = False,
  339. skip_bias_add: bool = False,
  340. params_dtype: Optional[torch.dtype] = None,
  341. linear_method: Optional[LinearMethodBase] = None,
  342. num_experts: int = 1,
  343. ):
  344. self.output_sizes = output_sizes
  345. tp_size = get_tensor_model_parallel_world_size()
  346. assert all(output_size % tp_size == 0 for output_size in output_sizes)
  347. super().__init__(input_size, sum(output_sizes), bias, gather_output,
  348. skip_bias_add, params_dtype, linear_method,
  349. self.output_sizes, num_experts)
  350. def weight_loader(self,
  351. param: Parameter,
  352. loaded_weight: torch.Tensor,
  353. loaded_shard_id: Optional[int] = None,
  354. expert_id: int = -1):
  355. param_data = param.data
  356. output_dim = getattr(param, "output_dim", None)
  357. is_metadata = getattr(param, "is_metadata", False)
  358. if self.num_experts > 1:
  359. if expert_id >= 0:
  360. param_data = param_data[expert_id]
  361. # Loaded weight is packed at expert dim
  362. elif output_dim is not None:
  363. output_dim = output_dim + 1
  364. if loaded_shard_id is None:
  365. # Loaded weight is already packed.
  366. if output_dim is None:
  367. assert param_data.shape == loaded_weight.shape
  368. param_data.copy_(loaded_weight)
  369. return
  370. current_shard_offset = 0
  371. shard_offsets = []
  372. for i, output_size in enumerate(self.output_sizes):
  373. shard_offsets.append((i, current_shard_offset, output_size))
  374. current_shard_offset += output_size
  375. packed_dim = getattr(param, "packed_dim", None)
  376. for shard_id, shard_offset, shard_size in shard_offsets:
  377. # If quantized, we need to adjust the offset and size to account
  378. # for the packing.
  379. if packed_dim == output_dim:
  380. shard_size = shard_size // param.pack_factor
  381. shard_offset = shard_offset // param.pack_factor
  382. # If marlin, we need to adjust the offset and size to
  383. # account for the tiling.
  384. shard_size, shard_offset = adjust_marlin_shard(
  385. param, shard_size, shard_offset)
  386. loaded_weight_shard = loaded_weight.narrow(
  387. output_dim, shard_offset, shard_size)
  388. self.weight_loader(param, loaded_weight_shard, shard_id)
  389. return
  390. assert loaded_shard_id < len(self.output_sizes)
  391. tp_rank = get_tensor_model_parallel_rank()
  392. tp_size = get_tensor_model_parallel_world_size()
  393. if output_dim is not None:
  394. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  395. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  396. # If quantized, we need to adjust the offset and size to account
  397. # for the packing.
  398. packed_dim = getattr(param, "packed_dim", None)
  399. if packed_dim == output_dim:
  400. shard_size = shard_size // param.pack_factor
  401. shard_offset = shard_offset // param.pack_factor
  402. # If marlin, we need to adjust the offset and size to
  403. # account for the tiling.
  404. shard_size, shard_offset = adjust_marlin_shard(
  405. param, shard_size, shard_offset)
  406. param_data = param_data.narrow(output_dim, shard_offset,
  407. shard_size)
  408. start_idx = tp_rank * shard_size
  409. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  410. shard_size)
  411. elif is_metadata:
  412. # metadata indicates fixed size concatenated along dim 0
  413. shard_size = loaded_weight.shape[0]
  414. shard_offset = loaded_shard_id * shard_size
  415. param_data = param_data.narrow(0, shard_offset, shard_size)
  416. else:
  417. ignore_warning = getattr(param, "ignore_warning", False)
  418. if not ignore_warning:
  419. logger.warning(
  420. "Loading a weight without `output_dim` attribute in "
  421. "MergedColumnParallelLinear, assume the weight is "
  422. "the same for all partitions.")
  423. assert param_data.shape == loaded_weight.shape
  424. param_data.copy_(loaded_weight)
  425. class QKVParallelLinear(ColumnParallelLinear):
  426. """Linear layers for the attention's QKV transformation.
  427. Linear layers for the linear transformation of the query, key, and value
  428. vectors in the attention layer. The weight matrix is concatenated along
  429. the output dimension. The layer is parallelized along the head dimension.
  430. When the number of key/value heads is smaller than the number of query
  431. heads (e.g., multi-query/grouped-query attention), the key/value head may
  432. be replicated while the query heads are partitioned.
  433. Args:
  434. hidden_size: input hidden state size of the transformer.
  435. head_size: size of each attention head.
  436. total_num_heads: total number of attention query heads.
  437. total_num_kv_heads: total number of attention key/value heads. If
  438. None, assume total_num_kv_heads = total_num_heads.
  439. bias: If true, add bias.
  440. skip_bias_add: This was added to enable performance optimizations where
  441. bias can be fused with other element-wise operations. we
  442. skip adding bias but instead return it.
  443. params_dtype: Data type for the parameters.
  444. linear_method: (Maybe quantized) linear method.
  445. """
  446. def __init__(
  447. self,
  448. hidden_size: int,
  449. head_size: int,
  450. total_num_heads: int,
  451. total_num_kv_heads: Optional[int] = None,
  452. bias: bool = True,
  453. skip_bias_add: bool = False,
  454. params_dtype: Optional[torch.dtype] = None,
  455. linear_method: Optional[LinearMethodBase] = None,
  456. ):
  457. self.hidden_size = hidden_size
  458. self.head_size = head_size
  459. self.total_num_heads = total_num_heads
  460. if total_num_kv_heads is None:
  461. total_num_kv_heads = total_num_heads
  462. self.total_num_kv_heads = total_num_kv_heads
  463. # Divide the weight matrix along the last dimension.
  464. tp_size = get_tensor_model_parallel_world_size()
  465. self.num_heads = divide(self.total_num_heads, tp_size)
  466. if tp_size >= self.total_num_kv_heads:
  467. self.num_kv_heads = 1
  468. self.num_kv_head_replicas = divide(tp_size,
  469. self.total_num_kv_heads)
  470. else:
  471. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  472. self.num_kv_head_replicas = 1
  473. input_size = self.hidden_size
  474. output_size = (self.num_heads +
  475. 2 * self.num_kv_heads) * tp_size * self.head_size
  476. super().__init__(input_size, output_size, bias, False, skip_bias_add,
  477. params_dtype, linear_method, [
  478. self.num_heads * tp_size * self.head_size,
  479. self.num_kv_heads * tp_size * self.head_size,
  480. self.num_kv_heads * tp_size * self.head_size
  481. ])
  482. def weight_loader(self,
  483. param: Parameter,
  484. loaded_weight: torch.Tensor,
  485. loaded_shard_id: Optional[str] = None):
  486. param_data = param.data
  487. output_dim = getattr(param, "output_dim", None)
  488. is_metadata = getattr(param, "is_metadata", False)
  489. if loaded_shard_id is None:
  490. # Loaded weight is already packed.
  491. if output_dim is None:
  492. assert param_data.shape == loaded_weight.shape
  493. param_data.copy_(loaded_weight)
  494. return
  495. shard_offsets = [
  496. # (shard_id, shard_offset, shard_size)
  497. ("q", 0, self.total_num_heads * self.head_size),
  498. ("k", self.total_num_heads * self.head_size,
  499. self.total_num_kv_heads * self.head_size),
  500. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  501. self.head_size, self.total_num_kv_heads * self.head_size),
  502. ]
  503. packed_dim = getattr(param, "packed_dim", None)
  504. for shard_id, shard_offset, shard_size in shard_offsets:
  505. # If quantized, we need to adjust the offset and size to account
  506. # for the packing.
  507. if packed_dim == output_dim:
  508. shard_size = shard_size // param.pack_factor
  509. shard_offset = shard_offset // param.pack_factor
  510. # If marlin, we need to adjust the offset and size to
  511. # account for the tiling.
  512. shard_size, shard_offset = adjust_marlin_shard(
  513. param, shard_size, shard_offset)
  514. loaded_weight_shard = loaded_weight.narrow(
  515. output_dim, shard_offset, shard_size)
  516. self.weight_loader(param, loaded_weight_shard, shard_id)
  517. return
  518. tp_rank = get_tensor_model_parallel_rank()
  519. assert loaded_shard_id in ["q", "k", "v"]
  520. if output_dim is not None:
  521. if loaded_shard_id == "q":
  522. shard_offset = 0
  523. shard_size = self.num_heads * self.head_size
  524. elif loaded_shard_id == "k":
  525. shard_offset = self.num_heads * self.head_size
  526. shard_size = self.num_kv_heads * self.head_size
  527. elif loaded_shard_id == "v":
  528. shard_offset = (self.num_heads +
  529. self.num_kv_heads) * self.head_size
  530. shard_size = self.num_kv_heads * self.head_size
  531. # If quantized, we need to adjust the offset and size to account
  532. # for the packing.
  533. packed_dim = getattr(param, "packed_dim", None)
  534. if packed_dim == output_dim:
  535. shard_size = shard_size // param.pack_factor
  536. shard_offset = shard_offset // param.pack_factor
  537. # If marlin, we need to adjust the offset and size to
  538. # account for the tiling.
  539. shard_size, shard_offset = adjust_marlin_shard(
  540. param, shard_size, shard_offset)
  541. param_data = param_data.narrow(output_dim, shard_offset,
  542. shard_size)
  543. if loaded_shard_id == "q":
  544. shard_id = tp_rank
  545. else:
  546. shard_id = tp_rank // self.num_kv_head_replicas
  547. start_idx = shard_id * shard_size
  548. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  549. shard_size)
  550. elif is_metadata:
  551. # metadata indicates fixed size concatenated along dim 0
  552. shard_size = loaded_weight.shape[0]
  553. shard_index = ["q", "k", "v"].index(loaded_shard_id)
  554. param_data = param_data.narrow(0, shard_index * shard_size,
  555. shard_size)
  556. else:
  557. ignore_warning = getattr(param, "ignore_warning", False)
  558. if not ignore_warning:
  559. logger.warning(
  560. "Loading a weight without `output_dim` attribute in "
  561. "QKVParallelLinear, assume the weight is the same "
  562. "for all partitions.")
  563. assert param_data.shape == loaded_weight.shape
  564. param_data.copy_(loaded_weight)
  565. class RowParallelLinear(torch.nn.Module):
  566. """Linear layer with row parallelism.
  567. The linear layer is defined as Y = XA + b. A is parallelized along
  568. its first dimension and X along its second dimension as:
  569. - -
  570. | A_1 |
  571. | . |
  572. A = | . | X = [X_1, ..., X_p]
  573. | . |
  574. | A_p |
  575. - -
  576. Arguments:
  577. input_size: first dimension of matrix A.
  578. output_size: second dimension of matrix A.
  579. bias: If true, add bias. Note that bias is not parallelized.
  580. input_is_parallel: If true, we assume that the input is already
  581. split across the GPUs and we do not split
  582. again.
  583. skip_bias_add: This was added to enable performance optimization where
  584. bias can be fused with other element-wise operations.
  585. We skip adding bias but instead return it.
  586. params_dtype: Data type for the parameters.
  587. linear_method: (Maybe quantized) linear method.
  588. """
  589. def __init__(
  590. self,
  591. input_size: int,
  592. output_size: int,
  593. bias: bool = True,
  594. input_is_parallel: bool = True,
  595. skip_bias_add: bool = False,
  596. params_dtype: Optional[torch.dtype] = None,
  597. reduce_results: bool = True,
  598. linear_method: Optional[LinearMethodBase] = None,
  599. num_experts: int = 1,
  600. ):
  601. super().__init__()
  602. # Keep input parameters
  603. self.input_size = input_size
  604. self.output_size = output_size
  605. self.input_is_parallel = input_is_parallel
  606. self.reduce_results = reduce_results
  607. if params_dtype is None:
  608. params_dtype = torch.get_default_dtype()
  609. self.params_dtype = params_dtype
  610. # Divide the weight matrix along the last dimension.
  611. self.tp_size = get_tensor_model_parallel_world_size()
  612. self.input_size_per_partition = divide(input_size, self.tp_size)
  613. self.skip_bias_add = skip_bias_add
  614. if linear_method is None:
  615. linear_method = UnquantizedLinearMethod()
  616. self.linear_method = linear_method
  617. self.num_experts = num_experts
  618. self.linear_method.create_moe_weights(self,
  619. num_experts,
  620. self.input_size_per_partition,
  621. [self.output_size],
  622. self.input_size,
  623. self.output_size,
  624. self.params_dtype,
  625. weight_loader=self.weight_loader)
  626. if not reduce_results and (bias and not skip_bias_add):
  627. raise ValueError("When not reduce the results, adding bias to the "
  628. "results can lead to incorrect results")
  629. if bias:
  630. self.bias = Parameter(
  631. torch.empty(self.output_size, dtype=params_dtype))
  632. set_weight_attrs(self.bias, {
  633. "output_dim": 0,
  634. "weight_loader": self.weight_loader,
  635. })
  636. else:
  637. self.register_parameter("bias", None)
  638. if self.is_exl2():
  639. self.uninitialized = [
  640. "q_invperm", "q_weight", "q_scale", "q_scale_max", "q_groups"
  641. ]
  642. def is_exl2(self) -> bool:
  643. return hasattr(
  644. self.linear_method, "quant_config"
  645. ) and self.linear_method.quant_config.get_name() == "exl2"
  646. def weight_loader(self,
  647. param: Parameter,
  648. loaded_weight: torch.Tensor,
  649. expert_id: int = -1):
  650. tp_rank = get_tensor_model_parallel_rank()
  651. tp_size = get_tensor_model_parallel_world_size()
  652. input_dim = getattr(param, "input_dim", None)
  653. param_data = param.data
  654. if self.num_experts > 1:
  655. if expert_id >= 0:
  656. param_data = param_data[expert_id]
  657. # Loaded weight is packed at expert dim
  658. elif input_dim is not None:
  659. input_dim = input_dim + 1
  660. if input_dim is not None:
  661. if loaded_weight.shape[input_dim] % tp_size != 0:
  662. raise ValueError(
  663. "Size is not aligned with the quantized weight shape")
  664. shard_size = loaded_weight.shape[input_dim] // tp_size
  665. start_idx = tp_rank * shard_size
  666. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  667. shard_size)
  668. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  669. if self.is_exl2():
  670. self.device = param.device
  671. param.materialize(loaded_weight.shape,
  672. dtype=loaded_weight.dtype,
  673. device="cpu")
  674. else:
  675. param.materialize(loaded_weight.shape,
  676. dtype=loaded_weight.dtype)
  677. param_data = param.data
  678. assert param_data.shape == loaded_weight.shape
  679. param_data.copy_(loaded_weight)
  680. if self.is_exl2():
  681. weights = self.linear_weights # TODO: update exl2
  682. self.uninitialized = [
  683. key for key in self.uninitialized if isinstance(
  684. weights[key], torch.nn.parameter.UninitializedParameter)
  685. ]
  686. if not self.uninitialized:
  687. # shard and release memory ASAP to reduce peak memory usage
  688. post_init_exl2(self)
  689. def forward(self, input_):
  690. is_exl2 = self.is_exl2()
  691. if self.input_is_parallel and is_exl2:
  692. input_parallel = tensor_model_parallel_all_gather(input_)
  693. elif self.input_is_parallel or is_exl2:
  694. input_parallel = input_
  695. else:
  696. tp_rank = get_tensor_model_parallel_rank()
  697. splitted_input = split_tensor_along_last_dim(
  698. input_, num_partitions=self.tp_size)
  699. input_parallel = splitted_input[tp_rank].contiguous()
  700. # Matrix multiply.
  701. output_parallel = self.linear_method.apply_weights(
  702. self, input_parallel)
  703. if self.reduce_results and self.tp_size > 1:
  704. output_ = tensor_model_parallel_all_reduce(output_parallel)
  705. else:
  706. output_ = output_parallel
  707. if not self.skip_bias_add:
  708. output = output_ + self.bias if self.bias is not None else output_
  709. output_bias = None
  710. else:
  711. output = output_
  712. output_bias = self.bias
  713. return output, output_bias