linear.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Dict, List, Optional
  3. import torch
  4. import torch.nn.functional as F
  5. from torch.nn.parameter import Parameter
  6. from aphrodite.modeling.megatron.parallel_state import (
  7. get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
  8. from aphrodite.modeling.megatron.communication_op import (
  9. tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
  10. from aphrodite.modeling.megatron.utils import (divide,
  11. split_tensor_along_last_dim)
  12. from aphrodite.modeling.utils import set_weight_attrs
  13. from aphrodite.common.logger import init_logger
  14. logger = init_logger(__name__)
  15. class LinearMethodBase(ABC):
  16. """Base class for different (maybe quantized) linear methods."""
  17. @abstractmethod
  18. def create_weights(self, input_size_per_partition: int,
  19. output_size_per_partition: int, input_size: int,
  20. output_size: int,
  21. params_dtype: torch.dtype) -> Dict[str, Any]:
  22. """Create weights for a linear layer."""
  23. raise NotImplementedError
  24. @abstractmethod
  25. def apply_weights(self,
  26. weights: Dict[str, torch.Tensor],
  27. x: torch.Tensor,
  28. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  29. """Apply the weights to the input tensor."""
  30. raise NotImplementedError
  31. class UnquantizedLinearMethod(LinearMethodBase):
  32. """Linear method without quantization.
  33. Args:
  34. separate_bias_add: If true, add bias separately after matrix
  35. multiplication.
  36. """
  37. def __init__(self, separate_bias_add: bool = False):
  38. self.separate_bias_add = separate_bias_add
  39. def create_weights(self, input_size_per_partition: int,
  40. output_size_per_partition: int, input_size: int,
  41. output_size: int,
  42. params_dtype: torch.dtype) -> Dict[str, Any]:
  43. weight = Parameter(torch.empty(output_size_per_partition,
  44. input_size_per_partition,
  45. dtype=params_dtype),
  46. requires_grad=False)
  47. set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
  48. return {"weight": weight}
  49. def apply_weights(self,
  50. weights: Dict[str, torch.Tensor],
  51. x: torch.Tensor,
  52. bias: Optional[torch.Tensor] = None) -> torch.Tensor:
  53. weight = weights["weight"]
  54. if self.separate_bias_add:
  55. if bias:
  56. return F.linear(x, weight) + bias
  57. return F.linear(x, weight)
  58. return F.linear(x, weight, bias)
  59. def apply_embedding(self, weights: Dict[str, torch.Tensor],
  60. x: torch.Tensor) -> torch.Tensor:
  61. weight = weights["weight"]
  62. return F.embedding(x, weight)
  63. class ReplicatedLinear(torch.nn.Module):
  64. """Replicated linear layer.
  65. Args:
  66. input_size: input dimension of the linear layer.
  67. output_size: output dimension of the linear layer.
  68. bias: If true, add bias.
  69. skip_bias_add: If true, skip adding bias but instead return it.
  70. params_dtype: Data type for the parameters.
  71. linear_method: (Maybe quantized) linear method.
  72. """
  73. def __init__(
  74. self,
  75. input_size: int,
  76. output_size: int,
  77. bias: bool = True,
  78. skip_bias_add: bool = False,
  79. params_dtype: Optional[torch.dtype] = None,
  80. linear_method: Optional[LinearMethodBase] = None,
  81. ):
  82. super().__init__()
  83. # Keep input parameters
  84. self.input_size = input_size
  85. self.output_size = output_size
  86. self.skip_bias_add = skip_bias_add
  87. if params_dtype is None:
  88. params_dtype = torch.get_default_dtype()
  89. self.params_dtype = params_dtype
  90. if linear_method is None:
  91. linear_method = UnquantizedLinearMethod()
  92. self.linear_method = linear_method
  93. self.linear_weights = self.linear_method.create_weights(
  94. self.input_size, self.output_size, self.input_size,
  95. self.output_size, self.params_dtype)
  96. for name, weight in self.linear_weights.items():
  97. if isinstance(weight, torch.nn.parameter.Parameter):
  98. self.register_parameter(name, weight)
  99. if bias:
  100. self.bias = Parameter(
  101. torch.empty(self.output_size, dtype=self.params_dtype))
  102. set_weight_attrs(self.bias, {"output_dim": 0})
  103. else:
  104. self.register_parameter("bias", None)
  105. def forward(self, x: torch.Tensor) -> torch.Tensor:
  106. bias = self.bias if not self.skip_bias_add else None
  107. output = self.linear_method.apply_weights(self.linear_weights, x, bias)
  108. output_bias = self.bias if self.skip_bias_add else None
  109. return output, output_bias
  110. class ColumnParallelLinear(torch.nn.Module):
  111. """Linear layer with column parallelism.
  112. The linear layer is defined as Y = XA + b. A is parallelized along
  113. its second dimension as A = [A_1, ..., A_p].
  114. Args:
  115. input_size: first dimension of matrix A.
  116. output_size: second dimension of matrix A.
  117. bias: If true, add bias.
  118. gather_output: If true, call all-gather on output and make Y available
  119. to all GPUs, otherwise, every GPU will have its output
  120. which is Y_i = XA_i
  121. skip_bias_add: This was added to enable performance optimizations where
  122. bias can be fused with other element-wise operations. we
  123. skip adding bias but instead return it.
  124. params_dtype: Data type for the parameters.
  125. linear_method: (Maybe quantized) linear method.
  126. """
  127. def __init__(
  128. self,
  129. input_size: int,
  130. output_size: int,
  131. bias: bool = True,
  132. gather_output: bool = False,
  133. skip_bias_add: bool = False,
  134. params_dtype: Optional[torch.dtype] = None,
  135. linear_method: Optional[LinearMethodBase] = None,
  136. ):
  137. super().__init__()
  138. # Keep input parameters
  139. self.input_size = input_size
  140. self.output_size = output_size
  141. self.gather_output = gather_output
  142. # Divide the weight matrix along the last dimension.
  143. tp_size = get_tensor_model_parallel_world_size()
  144. self.output_size_per_partition = divide(output_size, tp_size)
  145. self.skip_bias_add = skip_bias_add
  146. if params_dtype is None:
  147. params_dtype = torch.get_default_dtype()
  148. self.params_dtype = params_dtype
  149. if linear_method is None:
  150. linear_method = UnquantizedLinearMethod()
  151. self.linear_method = linear_method
  152. self.linear_weights = self.linear_method.create_weights(
  153. self.input_size, self.output_size_per_partition, self.input_size,
  154. self.output_size, self.params_dtype)
  155. for name, weight in self.linear_weights.items():
  156. if isinstance(weight, torch.nn.parameter.Parameter):
  157. self.register_parameter(name, weight)
  158. set_weight_attrs(weight, {"weight_loader": self.weight_loader})
  159. if bias:
  160. self.bias = Parameter(
  161. torch.empty(self.output_size_per_partition,
  162. dtype=params_dtype))
  163. set_weight_attrs(self.bias, {
  164. "output_dim": 0,
  165. "weight_loader": self.weight_loader,
  166. })
  167. else:
  168. self.register_parameter("bias", None)
  169. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  170. tp_rank = get_tensor_model_parallel_rank()
  171. tp_size = get_tensor_model_parallel_world_size()
  172. output_dim = getattr(param, "output_dim", None)
  173. param_data = param.data
  174. if output_dim is not None:
  175. if loaded_weight.shape[output_dim] % tp_size != 0:
  176. raise ValueError("Size is not aligned with the "
  177. "quantized weight shape")
  178. shard_size = loaded_weight.shape[output_dim] // tp_size
  179. start_idx = tp_rank * shard_size
  180. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  181. shard_size)
  182. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  183. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  184. param_data = param.data
  185. assert param_data.shape == loaded_weight.shape
  186. param_data.copy_(loaded_weight)
  187. def forward(self, input_):
  188. bias = self.bias if not self.skip_bias_add else None
  189. # Matrix multiply.
  190. output_parallel = self.linear_method.apply_weights(
  191. self.linear_weights, input_, bias)
  192. if self.gather_output:
  193. # All-gather across the partitions.
  194. output = tensor_model_parallel_all_gather(output_parallel)
  195. else:
  196. output = output_parallel
  197. output_bias = self.bias if self.skip_bias_add else None
  198. return output, output_bias
  199. class MergedColumnParallelLinear(ColumnParallelLinear):
  200. """Packed linear layers with column parallelism.
  201. Similar to ColumnParallelLinear, but the weight matrix is concatenated
  202. along the output dimension. When the weight matrix is loaded, the
  203. different partitions are sharded separately.
  204. Args:
  205. input_size: input dimension of the linear layer.
  206. output_sizes: list of output dimensions of the linear layer.
  207. bias: If true, add bias.
  208. gather_output: If true, call all-gather on output and make the output
  209. available to all GPUs, otherwise, every GPU will have
  210. its own output.
  211. skip_bias_add: This was added to enable performance optimizations where
  212. bias can be fused with other element-wise operations. we
  213. skip adding bias but instead return it.
  214. params_dtype: Data type for the parameters.
  215. linear_method: (Maybe quantized) linear method.
  216. """
  217. def __init__(
  218. self,
  219. input_size: int,
  220. output_sizes: List[int],
  221. bias: bool = True,
  222. gather_output: bool = False,
  223. skip_bias_add: bool = False,
  224. params_dtype: Optional[torch.dtype] = None,
  225. linear_method: Optional[LinearMethodBase] = None,
  226. ):
  227. self.output_sizes = output_sizes
  228. tp_size = get_tensor_model_parallel_world_size()
  229. assert all(output_size % tp_size == 0 for output_size in output_sizes)
  230. super().__init__(input_size, sum(output_sizes), bias, gather_output,
  231. skip_bias_add, params_dtype, linear_method)
  232. def weight_loader(self,
  233. param: Parameter,
  234. loaded_weight: torch.Tensor,
  235. loaded_shard_id: Optional[int] = None):
  236. param_data = param.data
  237. output_dim = getattr(param, "output_dim", None)
  238. if loaded_shard_id is None:
  239. # Loaded weight is already packed.
  240. if output_dim is None:
  241. assert param_data.shape == loaded_weight.shape
  242. param_data.copy_(loaded_weight)
  243. return
  244. current_shard_offset = 0
  245. shard_offsets = []
  246. for i, output_size in enumerate(self.output_sizes):
  247. shard_offsets.append((i, current_shard_offset, output_size))
  248. current_shard_offset += output_size
  249. packed_dim = getattr(param, "packed_dim", None)
  250. for shard_id, shard_offset, shard_size in shard_offsets:
  251. # If quantized, we need to adjust the offset and size to account
  252. # for the packing.
  253. if packed_dim == output_dim:
  254. shard_size = shard_size // param.pack_factor
  255. shard_offset = shard_offset // param.pack_factor
  256. loaded_weight_shard = loaded_weight.narrow(
  257. output_dim, shard_offset, shard_size)
  258. self.weight_loader(param, loaded_weight_shard, shard_id)
  259. return
  260. assert loaded_shard_id < len(self.output_sizes)
  261. tp_rank = get_tensor_model_parallel_rank()
  262. tp_size = get_tensor_model_parallel_world_size()
  263. if output_dim is not None:
  264. shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
  265. shard_size = self.output_sizes[loaded_shard_id] // tp_size
  266. # If quantized, we need to adjust the offset and size to account
  267. # for the packing.
  268. packed_dim = getattr(param, "packed_dim", None)
  269. if packed_dim == output_dim:
  270. shard_size = shard_size // param.pack_factor
  271. shard_offset = shard_offset // param.pack_factor
  272. param_data = param_data.narrow(output_dim, shard_offset,
  273. shard_size)
  274. start_idx = tp_rank * shard_size
  275. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  276. shard_size)
  277. else:
  278. ignore_warning = getattr(param, "ignore_warning", False)
  279. if not ignore_warning:
  280. logger.warning(
  281. "Loading a weight without `output_dim` attribute in "
  282. "MergedColumnParallelLinear, assume the weight is "
  283. "the same for all partitions.")
  284. assert param_data.shape == loaded_weight.shape
  285. param_data.copy_(loaded_weight)
  286. class QKVParallelLinear(ColumnParallelLinear):
  287. """Linear layers for the attention's QKV transformation.
  288. Linear layers for the linear transformation of the query, key, and value
  289. vectors in the attention layer. The weight matrix is concatenated along
  290. the output dimension. The layer is parallelized along the head dimension.
  291. When the number of key/value heads is smaller than the number of query
  292. heads (e.g., multi-query/grouped-query attention), the key/value head may
  293. be replicated while the query heads are partitioned.
  294. Args:
  295. hidden_size: input hidden state size of the transformer.
  296. head_size: size of each attention head.
  297. total_num_heads: total number of attention query heads.
  298. total_num_kv_heads: total number of attention key/value heads. If
  299. None, assume total_num_kv_heads = total_num_heads.
  300. bias: If true, add bias.
  301. skip_bias_add: This was added to enable performance optimizations where
  302. bias can be fused with other element-wise operations. we
  303. skip adding bias but instead return it.
  304. params_dtype: Data type for the parameters.
  305. linear_method: (Maybe quantized) linear method.
  306. """
  307. def __init__(
  308. self,
  309. hidden_size: int,
  310. head_size: int,
  311. total_num_heads: int,
  312. total_num_kv_heads: Optional[int] = None,
  313. bias: bool = True,
  314. skip_bias_add: bool = False,
  315. params_dtype: Optional[torch.dtype] = None,
  316. linear_method: Optional[LinearMethodBase] = None,
  317. ):
  318. self.hidden_size = hidden_size
  319. self.head_size = head_size
  320. self.total_num_heads = total_num_heads
  321. if total_num_kv_heads is None:
  322. total_num_kv_heads = total_num_heads
  323. self.total_num_kv_heads = total_num_kv_heads
  324. # Divide the weight matrix along the last dimension.
  325. tp_size = get_tensor_model_parallel_world_size()
  326. self.num_heads = divide(self.total_num_heads, tp_size)
  327. if tp_size >= self.total_num_kv_heads:
  328. self.num_kv_heads = 1
  329. self.num_kv_head_replicas = divide(tp_size,
  330. self.total_num_kv_heads)
  331. else:
  332. self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
  333. self.num_kv_head_replicas = 1
  334. input_size = self.hidden_size
  335. output_size = (self.num_heads +
  336. 2 * self.num_kv_heads) * tp_size * self.head_size
  337. super().__init__(input_size, output_size, bias, False, skip_bias_add,
  338. params_dtype, linear_method)
  339. def weight_loader(self,
  340. param: Parameter,
  341. loaded_weight: torch.Tensor,
  342. loaded_shard_id: Optional[str] = None):
  343. param_data = param.data
  344. output_dim = getattr(param, "output_dim", None)
  345. if loaded_shard_id is None:
  346. # Loaded weight is already packed.
  347. if output_dim is None:
  348. assert param_data.shape == loaded_weight.shape
  349. param_data.copy_(loaded_weight)
  350. return
  351. shard_offsets = [
  352. # (shard_id, shard_offset, shard_size)
  353. ("q", 0, self.total_num_heads * self.head_size),
  354. ("k", self.total_num_heads * self.head_size,
  355. self.total_num_kv_heads * self.head_size),
  356. ("v", (self.total_num_heads + self.total_num_kv_heads) *
  357. self.head_size, self.total_num_kv_heads * self.head_size),
  358. ]
  359. packed_dim = getattr(param, "packed_dim", None)
  360. for shard_id, shard_offset, shard_size in shard_offsets:
  361. # If quantized, we need to adjust the offset and size to account
  362. # for the packing.
  363. if packed_dim == output_dim:
  364. shard_size = shard_size // param.pack_factor
  365. shard_offset = shard_offset // param.pack_factor
  366. loaded_weight_shard = loaded_weight.narrow(
  367. output_dim, shard_offset, shard_size)
  368. self.weight_loader(param, loaded_weight_shard, shard_id)
  369. return
  370. tp_rank = get_tensor_model_parallel_rank()
  371. assert loaded_shard_id in ["q", "k", "v"]
  372. if output_dim is not None:
  373. if loaded_shard_id == "q":
  374. shard_offset = 0
  375. shard_size = self.num_heads * self.head_size
  376. elif loaded_shard_id == "k":
  377. shard_offset = self.num_heads * self.head_size
  378. shard_size = self.num_kv_heads * self.head_size
  379. elif loaded_shard_id == "v":
  380. shard_offset = (self.num_heads +
  381. self.num_kv_heads) * self.head_size
  382. shard_size = self.num_kv_heads * self.head_size
  383. # If quantized, we need to adjust the offset and size to account
  384. # for the packing.
  385. packed_dim = getattr(param, "packed_dim", None)
  386. if packed_dim == output_dim:
  387. shard_size = shard_size // param.pack_factor
  388. shard_offset = shard_offset // param.pack_factor
  389. param_data = param_data.narrow(output_dim, shard_offset,
  390. shard_size)
  391. if loaded_shard_id == "q":
  392. shard_id = tp_rank
  393. else:
  394. shard_id = tp_rank // self.num_kv_head_replicas
  395. start_idx = shard_id * shard_size
  396. loaded_weight = loaded_weight.narrow(output_dim, start_idx,
  397. shard_size)
  398. else:
  399. ignore_warning = getattr(param, "ignore_warning", False)
  400. if not ignore_warning:
  401. logger.warning(
  402. "Loading a weight without `output_dim` attribute in "
  403. "QKVParallelLinear, assume the weight is the same "
  404. "for all partitions.")
  405. assert param_data.shape == loaded_weight.shape
  406. param_data.copy_(loaded_weight)
  407. class RowParallelLinear(torch.nn.Module):
  408. """Linear layer with row parallelism.
  409. The linear layer is defined as Y = XA + b. A is parallelized along
  410. its first dimension and X along its second dimension as:
  411. - -
  412. | A_1 |
  413. | . |
  414. A = | . | X = [X_1, ..., X_p]
  415. | . |
  416. | A_p |
  417. - -
  418. Arguments:
  419. input_size: first dimension of matrix A.
  420. output_size: second dimension of matrix A.
  421. bias: If true, add bias. Note that bias is not parallelized.
  422. input_is_parallel: If true, we assume that the input is already
  423. split across the GPUs and we do not split
  424. again.
  425. skip_bias_add: This was added to enable performance optimization where
  426. bias can be fused with other element-wise operations.
  427. We skip adding bias but instead return it.
  428. params_dtype: Data type for the parameters.
  429. linear_method: (Maybe quantized) linear method.
  430. """
  431. def __init__(
  432. self,
  433. input_size: int,
  434. output_size: int,
  435. bias: bool = True,
  436. input_is_parallel: bool = True,
  437. skip_bias_add: bool = False,
  438. params_dtype: Optional[torch.dtype] = None,
  439. reduce_results: bool = True,
  440. linear_method: Optional[LinearMethodBase] = None,
  441. ):
  442. super().__init__()
  443. # Keep input parameters
  444. self.input_size = input_size
  445. self.output_size = output_size
  446. self.input_is_parallel = input_is_parallel
  447. self.reduce_results = reduce_results
  448. if params_dtype is None:
  449. params_dtype = torch.get_default_dtype()
  450. self.params_dtype = params_dtype
  451. # Divide the weight matrix along the last dimension.
  452. self.tp_size = get_tensor_model_parallel_world_size()
  453. self.input_size_per_partition = divide(input_size, self.tp_size)
  454. self.skip_bias_add = skip_bias_add
  455. if linear_method is None:
  456. linear_method = UnquantizedLinearMethod()
  457. self.linear_method = linear_method
  458. self.linear_weights = self.linear_method.create_weights(
  459. self.input_size_per_partition, self.output_size, self.input_size,
  460. self.output_size, self.params_dtype)
  461. for name, weight in self.linear_weights.items():
  462. if isinstance(weight, torch.nn.parameter.Parameter):
  463. self.register_parameter(name, weight)
  464. set_weight_attrs(weight, {"weight_loader": self.weight_loader})
  465. if not reduce_results and (bias and not skip_bias_add):
  466. raise ValueError("When not reduce the results, adding bias to the "
  467. "results can lead to incorrect results")
  468. if bias:
  469. self.bias = Parameter(
  470. torch.empty(self.output_size, dtype=params_dtype))
  471. set_weight_attrs(self.bias, {
  472. "output_dim": 0,
  473. "weight_loader": self.weight_loader,
  474. })
  475. else:
  476. self.register_parameter("bias", None)
  477. def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
  478. tp_rank = get_tensor_model_parallel_rank()
  479. tp_size = get_tensor_model_parallel_world_size()
  480. input_dim = getattr(param, "input_dim", None)
  481. param_data = param.data
  482. if input_dim is not None:
  483. if loaded_weight.shape[input_dim] % tp_size != 0:
  484. raise ValueError("Size is not aligned with the quantized "
  485. "weight shape")
  486. shard_size = loaded_weight.shape[input_dim] // tp_size
  487. start_idx = tp_rank * shard_size
  488. loaded_weight = loaded_weight.narrow(input_dim, start_idx,
  489. shard_size)
  490. if isinstance(param, torch.nn.parameter.UninitializedParameter):
  491. param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
  492. param_data = param.data
  493. assert param_data.shape == loaded_weight.shape
  494. param_data.copy_(loaded_weight)
  495. def forward(self, input_):
  496. # Set up backprop all-reduce.
  497. if self.input_is_parallel:
  498. input_parallel = input_
  499. else:
  500. tp_rank = get_tensor_model_parallel_rank()
  501. splitted_input = split_tensor_along_last_dim(
  502. input_, num_partitions=self.tp_size)
  503. input_parallel = splitted_input[tp_rank].contiguous()
  504. # Matrix multiply.
  505. output_parallel = self.linear_method.apply_weights(
  506. self.linear_weights, input_parallel)
  507. if self.reduce_results and self.tp_size > 1:
  508. output_ = tensor_model_parallel_all_reduce(output_parallel)
  509. else:
  510. output_ = output_parallel
  511. if not self.skip_bias_add:
  512. output = output_ + self.bias if self.bias is not None else output_
  513. output_bias = None
  514. else:
  515. output = output_
  516. output_bias = self.bias
  517. return output, output_bias