layers.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # Copyright 2023 The PygmalionAI team.
  2. # Copyright 2023 The vLLM team.
  3. # Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
  4. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
  5. # Parts of the code here are adapted from PyTorch
  6. # repo: https://github.com/pytorch/pytorch
  7. import torch
  8. import torch.nn.functional as F
  9. import torch.nn.init as init
  10. from torch.nn.parameter import Parameter
  11. from aphrodite.modeling.megatron.parallel_state import (
  12. get_tensor_model_parallel_rank,
  13. get_tensor_model_parallel_world_size,
  14. )
  15. from .mappings import (
  16. gather_from_tensor_model_parallel_region,
  17. reduce_from_tensor_model_parallel_region,
  18. scatter_to_tensor_model_parallel_region,
  19. )
  20. from .utils import (
  21. divide,
  22. VocabUtility,
  23. )
  24. _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
  25. 'partition_dim': -1,
  26. 'partition_stride': 1}
  27. def param_is_not_tensor_parallel_duplicate(param):
  28. return (hasattr(param, 'tensor_model_parallel') and
  29. param.tensor_model_parallel) or (
  30. get_tensor_model_parallel_rank() == 0)
  31. def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
  32. # Make sure the attributes are not set.
  33. for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
  34. assert not hasattr(tensor, attribute)
  35. # Set the attributes.
  36. setattr(tensor, 'tensor_model_parallel', is_parallel)
  37. setattr(tensor, 'partition_dim', dim)
  38. setattr(tensor, 'partition_stride', stride)
  39. def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
  40. def maybe_set(attribute, value):
  41. if not hasattr(tensor, attribute):
  42. setattr(tensor, attribute, value)
  43. for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
  44. maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
  45. def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
  46. def maybe_copy(attribute):
  47. if hasattr(source_tensor, attribute):
  48. setattr(destination_tensor, attribute,
  49. getattr(source_tensor, attribute))
  50. for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
  51. maybe_copy(attribute)
  52. class VocabParallelEmbedding(torch.nn.Module):
  53. """Embedding parallelized in the vocabulary dimension.
  54. This is mainly adapted from torch.nn.Embedding and all the default
  55. values are kept.
  56. Arguments:
  57. num_embeddings: vocabulary size.
  58. embedding_dim: size of hidden state.
  59. Keyword Arguments:
  60. init_method: method to initialize weights.
  61. params_dtype
  62. use_cpu_initialization
  63. perform_initialization
  64. """
  65. def __init__(self, num_embeddings: int, embedding_dim: int, *,
  66. init_method=init.xavier_normal_,
  67. params_dtype: torch.dtype=None,
  68. use_cpu_initialization: bool=False,
  69. perform_initialization: bool=True):
  70. super(VocabParallelEmbedding, self).__init__()
  71. assert not perform_initialization
  72. assert not use_cpu_initialization
  73. # Keep the input dimensions.
  74. self.num_embeddings = num_embeddings
  75. self.embedding_dim = embedding_dim
  76. if params_dtype is None:
  77. params_dtype = torch.get_default_dtype()
  78. # Set the defaults for compatibility.
  79. self.padding_idx = None
  80. self.max_norm = None
  81. self.norm_type = 2.
  82. self.scale_grad_by_freq = False
  83. self.sparse = False
  84. self._weight = None
  85. self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
  86. # Divide the weight matrix along the vocaburaly dimension.
  87. self.vocab_start_index, self.vocab_end_index = \
  88. VocabUtility.vocab_range_from_global_vocab_size(
  89. self.num_embeddings, get_tensor_model_parallel_rank(),
  90. self.tensor_model_parallel_size)
  91. self.num_embeddings_per_partition = self.vocab_end_index - \
  92. self.vocab_start_index
  93. self.weight = Parameter(torch.empty(
  94. self.num_embeddings_per_partition, self.embedding_dim,
  95. device=torch.cuda.current_device(), dtype=params_dtype))
  96. def forward(self, input_):
  97. if self.tensor_model_parallel_size > 1:
  98. # Build the mask.
  99. input_mask = (input_ < self.vocab_start_index) | \
  100. (input_ >= self.vocab_end_index)
  101. # Mask the input.
  102. masked_input = input_.clone() - self.vocab_start_index
  103. masked_input[input_mask] = 0
  104. else:
  105. masked_input = input_
  106. # Get the embeddings.
  107. output_parallel = F.embedding(masked_input, self.weight,
  108. self.padding_idx, self.max_norm,
  109. self.norm_type, self.scale_grad_by_freq,
  110. self.sparse)
  111. # Mask the output embedding.
  112. if self.tensor_model_parallel_size > 1:
  113. output_parallel[input_mask, :] = 0.0
  114. # Reduce across all the model parallel GPUs.
  115. output = reduce_from_tensor_model_parallel_region(output_parallel)
  116. return output
  117. class ColumnParallelLinear(torch.nn.Module):
  118. """Linear layer with column parallelism.
  119. The linear layer is defined as Y = XA + b. A is parallelized along
  120. its second dimension as A = [A_1, ..., A_p].
  121. Arguments:
  122. input_size: first dimension of matrix A.
  123. output_size: second dimension of matrix A.
  124. Keyword Arguments
  125. bias: If true, add bias
  126. gather_output: If true, call all-gather on output and make Y available
  127. to all GPUs, otherwise, every GPU will have its output
  128. which is Y_i = XA_i
  129. init_method: method to initialize weights. Note that bias is always set
  130. to zero.
  131. stride: For the strided linear layers.
  132. keep_master_weight_for_test: This was added for testing and should be
  133. set to False. It returns the master weights
  134. used for initialization.
  135. skip_bias_add: This was added to enable performance optimations where bias
  136. can be fused with other elementwise operations. we skip
  137. adding bias but instead return it.
  138. params_dtype:
  139. use_cpu_initialization:
  140. """
  141. def __init__(self, input_size, output_size, *,
  142. bias=True, gather_output=True,
  143. init_method=init.xavier_normal_, stride=1,
  144. keep_master_weight_for_test=False,
  145. skip_bias_add=False,
  146. params_dtype=None,
  147. use_cpu_initialization=False,
  148. perform_initialization=True,
  149. quant_config=None,
  150. ):
  151. super(ColumnParallelLinear, self).__init__()
  152. assert not perform_initialization
  153. assert not use_cpu_initialization
  154. # Keep input parameters
  155. self.input_size = input_size
  156. self.output_size = output_size
  157. self.gather_output = gather_output
  158. # Divide the weight matrix along the last dimension.
  159. self.world_size = get_tensor_model_parallel_world_size()
  160. self.output_size_per_partition = divide(output_size, self.world_size)
  161. self.skip_bias_add = skip_bias_add
  162. self.quant_config = quant_config
  163. if params_dtype is None:
  164. params_dtype = torch.get_default_dtype()
  165. # Parameters.
  166. # Note: torch.nn.functional.linear performs XA^T + b and as a result
  167. # we allocate the transpose.
  168. self.create_weights(params_dtype)
  169. if bias:
  170. self.bias = Parameter(torch.empty(
  171. self.output_size_per_partition,
  172. device=torch.cuda.current_device(),
  173. dtype=params_dtype))
  174. set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
  175. # Always initialize bias to zero.
  176. with torch.no_grad():
  177. self.bias.zero_()
  178. else:
  179. self.register_parameter('bias', None)
  180. def create_weights(self, dtype: torch.dtype) -> None:
  181. self.weight = Parameter(torch.empty(
  182. self.output_size_per_partition, self.input_size,
  183. device=torch.cuda.current_device(), dtype=dtype))
  184. def apply_weights(
  185. self,
  186. x: torch.Tensor,
  187. bias: Optional[torch.Tensor],
  188. ) -> torch.Tensor:
  189. return F.linear(x, self.weight, bias)
  190. def forward(self, input_):
  191. """Forward of ColumnParallelLinear
  192. Args:
  193. input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
  194. Returns:
  195. - output
  196. - bias
  197. """
  198. bias = self.bias if not self.skip_bias_add else None
  199. input_parallel = input_
  200. # Matrix multiply.
  201. output_parallel = self.apply_weights(input_parallel, bias)
  202. if self.gather_output:
  203. # All-gather across the partitions.
  204. output = gather_from_tensor_model_parallel_region(output_parallel)
  205. else:
  206. output = output_parallel
  207. output_bias = self.bias if self.skip_bias_add else None
  208. return output, output_bias
  209. class RowParallelLinear(torch.nn.Module):
  210. """Linear layer with row parallelism.
  211. The linear layer is defined as Y = XA + b. A is parallelized along
  212. its first dimension and X along its second dimension as:
  213. - -
  214. | A_1 |
  215. | . |
  216. A = | . | X = [X_1, ..., X_p]
  217. | . |
  218. | A_p |
  219. - -
  220. Arguments:
  221. input_size: first dimension of matrix A.
  222. output_size: second dimension of matrix A.
  223. Keyword Arguments:
  224. bias: If true, add bias. Note that bias is not parallelized.
  225. input_is_parallel: If true, we assume that the input is already
  226. split across the GPUs and we do not split
  227. again.
  228. init_method: method to initialize weights. Note that bias is always set
  229. to zero.
  230. stride: For the strided linear layers.
  231. keep_master_weight_for_test: This was added for testing and should be
  232. set to False. It returns the master weights
  233. used for initialization.
  234. skip_bias_add: This was added to enable performance optimization where bias
  235. can be fused with other elementwise operations. We skip
  236. adding bias but instead return it.
  237. params_dtype:
  238. use_cpu_initialization:
  239. perform_initialization:
  240. reduce_results:
  241. """
  242. def __init__(self, input_size, output_size, *,
  243. bias=True, input_is_parallel=False,
  244. init_method=init.xavier_normal_, stride=1,
  245. keep_master_weight_for_test=False,
  246. skip_bias_add=False,
  247. params_dtype=None,
  248. use_cpu_initialization=False,
  249. perform_initialization=True,
  250. reduce_results=True,
  251. quant_config=None,
  252. ):
  253. super(RowParallelLinear, self).__init__()
  254. assert not perform_initialization
  255. assert not use_cpu_initialization
  256. # Keep input parameters
  257. self.input_size = input_size
  258. self.output_size = output_size
  259. self.input_is_parallel = input_is_parallel
  260. self.reduce_results = reduce_results
  261. if params_dtype is None:
  262. params_dtype = torch.get_default_dtype()
  263. # Divide the weight matrix along the last dimension.
  264. self.world_size = get_tensor_model_parallel_world_size()
  265. self.input_size_per_partition = divide(input_size, self.world_size)
  266. self.skip_bias_add = skip_bias_add
  267. self.quant_config = quant_config
  268. self.create_weights(params_dtype)
  269. if not reduce_results and (bias and not skip_bias_add):
  270. raise ValueError("When not reduce the results, adding bias to the "
  271. "results can lead to incorrect results")
  272. if bias:
  273. self.bias = Parameter(torch.empty(
  274. self.output_size, device=torch.cuda.current_device(),
  275. dtype=params_dtype))
  276. # Always initialize bias to zero.
  277. with torch.no_grad():
  278. self.bias.zero_()
  279. else:
  280. self.register_parameter('bias', None)
  281. def create_weights(self, dtype: torch.dtype) -> None:
  282. self.weight = Parameter(torch.empty(
  283. self.output_size, self.input_size_per_partition,
  284. device=torch.cuda.current_device(), dtype=dtype))
  285. def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
  286. return F.linear(x, self.weight)
  287. def forward(self, input_):
  288. """Forward of RowParallelLinear
  289. Args:
  290. input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
  291. Returns:
  292. - output
  293. - bias
  294. """
  295. # Set up backprop all-reduce.
  296. if self.input_is_parallel:
  297. input_parallel = input_
  298. else:
  299. input_parallel = scatter_to_tensor_model_parallel_region(input_)
  300. # Matrix multiply.
  301. output_parallel = self.apply_weights(input_parallel)
  302. if self.reduce_results and self.world_size > 1:
  303. output_ = reduce_from_tensor_model_parallel_region(output_parallel)
  304. else:
  305. output_ = output_parallel
  306. if not self.skip_bias_add:
  307. output = output_ + self.bias if self.bias is not None else output_
  308. output_bias = None
  309. else:
  310. output = output_
  311. output_bias = self.bias
  312. return output, output_bias