fully_sharded_layers.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # pylint: disable=unused-argument
  2. from typing import TYPE_CHECKING, List, Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.common.config import LoRAConfig
  7. from aphrodite.distributed.communication_op import (
  8. tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
  9. from aphrodite.distributed.parallel_state import get_tensor_model_parallel_rank
  10. from aphrodite.lora.layers import (ColumnParallelLinearWithLoRA,
  11. MergedColumnParallelLinearWithLoRA,
  12. MergedQKVParallelLinearWithLora,
  13. RowParallelLinearWithLoRA)
  14. from aphrodite.lora.punica import bgmv, dispatch_bgmv_low_level
  15. if TYPE_CHECKING:
  16. pass
  17. def _fully_sharded_can_replace(can_replace):
  18. """
  19. decorator which adds the condition of fully sharded loras
  20. intended to wrap can_replace_layer()
  21. """
  22. def dec(*args, **kwargs):
  23. return (can_replace(*args, **kwargs)
  24. and kwargs['lora_config'].fully_sharded_loras)
  25. return dec
  26. # these layers are based on the tensor parallelism strategy given in
  27. # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
  28. # https://arxiv.org/abs/2311.03285.
  29. class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
  30. """
  31. Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
  32. Based on S-LoRA, slicing happens along the rank dim.
  33. """
  34. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  35. tp_rank = get_tensor_model_parallel_rank()
  36. shard_size = self.lora_a_stacked.shape[2]
  37. start_idx = tp_rank * shard_size
  38. lora_a = lora_a[:, start_idx:start_idx + shard_size]
  39. return lora_a
  40. def apply(self, x: torch.Tensor,
  41. bias: Optional[torch.Tensor]) -> torch.Tensor:
  42. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  43. x = x.view(-1, x.shape[-1])
  44. output, out_orig_shape = output.view(-1,
  45. output.shape[-1]), output.shape
  46. buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
  47. dtype=torch.float32,
  48. device=x.device)
  49. bgmv(buffer, x, self.lora_a_stacked,
  50. self.indices[:self.indices_len[0]], 0, 1.0)
  51. buffer = tensor_model_parallel_all_gather(buffer)
  52. bgmv(output, buffer, self.lora_b_stacked,
  53. self.indices[:self.indices_len[0]], 0, 1.0)
  54. # now have column partitioned output
  55. output = output.view(*out_orig_shape)
  56. return output
  57. @classmethod
  58. @_fully_sharded_can_replace
  59. def can_replace_layer(cls, source_layer: nn.Module,
  60. lora_config: LoRAConfig, packed_modules_list: List,
  61. model_config: Optional[PretrainedConfig]) -> bool:
  62. # specifying kwargs so they can be easily accessed in decorator
  63. return super().can_replace_layer(
  64. source_layer=source_layer,
  65. lora_config=lora_config,
  66. packed_modules_list=packed_modules_list,
  67. model_config=model_config,
  68. decorate=False,
  69. )
  70. def _mcp_apply(x, bias, layer):
  71. """
  72. MergedColumnParallelLinearWithShardedLoRA and
  73. QKVParallelLinearWithShardedLora share the same
  74. LoRa weight application method.
  75. The main difference is the step by shard_size for lora_b which can
  76. vary for QKVParallelLinearWithShardedLora but is constant for
  77. MergedColumnParallelLinearWithShardedLoRA.
  78. """
  79. # expecting 2 for column parallel and 3 for qkv
  80. n = len(layer.lora_a_stacked)
  81. output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
  82. x = x.view(-1, x.shape[-1])
  83. output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
  84. buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
  85. dtype=torch.float32,
  86. device=x.device)
  87. for idx in range(n):
  88. bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
  89. layer.indices[:layer.indices_len[0]], 0, 1.0)
  90. buffers = tensor_model_parallel_all_gather(buffers)
  91. left_offset = 0
  92. for idx in range(n):
  93. shard_size = layer.lora_b_stacked[idx].shape[2]
  94. dispatch_bgmv_low_level(output, buffers[idx],
  95. layer.lora_b_stacked[idx],
  96. layer.indices[:layer.indices_len[0]], 0, 1.0,
  97. left_offset, shard_size)
  98. left_offset += shard_size
  99. output = output.view(*out_orig_shape)
  100. # now have column partitioned and packed output
  101. return output
  102. class MergedColumnParallelLinearWithShardedLoRA(
  103. MergedColumnParallelLinearWithLoRA):
  104. """
  105. Differs from MergedColumnParallelLinearWithLoRA by slicing the
  106. LoRA A's also.
  107. Based on S-LoRA, slicing happens along the rank dim.
  108. """
  109. def slice_lora_a(
  110. self, lora_a: List[Union[torch.Tensor, None]]
  111. ) -> List[Union[torch.Tensor, None]]:
  112. if lora_a[0] is None or lora_a[1] is None:
  113. return lora_a
  114. output_shard_size = self.lora_a_stacked[0].shape[2]
  115. output_start_idx = self.tp_rank * output_shard_size
  116. lora_a = [
  117. lora_a[0][:,
  118. output_start_idx:output_start_idx + output_shard_size],
  119. lora_a[1][:, output_start_idx:output_start_idx + output_shard_size]
  120. ]
  121. return lora_a
  122. def apply(self, x: torch.Tensor,
  123. bias: Optional[torch.Tensor]) -> torch.Tensor:
  124. return _mcp_apply(x, bias, self)
  125. @classmethod
  126. @_fully_sharded_can_replace
  127. def can_replace_layer(cls, source_layer: nn.Module,
  128. lora_config: LoRAConfig, packed_modules_list: List,
  129. model_config: Optional[PretrainedConfig]) -> bool:
  130. # specifying kwargs so they can be easily accessed in decorator
  131. return super().can_replace_layer(
  132. source_layer=source_layer,
  133. lora_config=lora_config,
  134. packed_modules_list=packed_modules_list,
  135. model_config=model_config,
  136. decorate=False,
  137. )
  138. class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
  139. """
  140. Differs from QKVParallelLinearWithLora by slicing the
  141. LoRA A's also.
  142. Based on S-LoRA, slicing happens along the rank dim.
  143. """
  144. def slice_lora_a(
  145. self, lora_a: List[Union[torch.Tensor, None]]
  146. ) -> List[Union[torch.Tensor, None]]:
  147. if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
  148. return lora_a
  149. shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
  150. start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
  151. lora_a = [
  152. lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
  153. lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
  154. lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]]
  155. ]
  156. return lora_a
  157. def apply(self, x: torch.Tensor,
  158. bias: Optional[torch.Tensor]) -> torch.Tensor:
  159. return _mcp_apply(x, bias, self)
  160. @classmethod
  161. @_fully_sharded_can_replace
  162. def can_replace_layer(cls, source_layer: nn.Module,
  163. lora_config: LoRAConfig, packed_modules_list: List,
  164. model_config: Optional[PretrainedConfig]) -> bool:
  165. # specifying kwargs so they can be easily accessed in decorator
  166. return super().can_replace_layer(
  167. source_layer=source_layer,
  168. lora_config=lora_config,
  169. packed_modules_list=packed_modules_list,
  170. model_config=model_config,
  171. decorate=False,
  172. )
  173. class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
  174. """
  175. Differs from RowParallelLinearWithLoRA by slicing the
  176. LoRA B's also.
  177. Based on S-LoRA, slicing happens along the output dim.
  178. This yields a combined partial sum from the row parallel base
  179. layer and column partitioned output from the LoRA.
  180. """
  181. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  182. shard_size = self.lora_b_stacked.shape[2]
  183. start_idx = self.tp_rank * shard_size
  184. end_idx = (self.tp_rank + 1) * shard_size
  185. lora_b = lora_b[:, start_idx:end_idx]
  186. return lora_b
  187. def apply(self, x: torch.Tensor) -> torch.Tensor:
  188. output = self.base_layer.quant_method.apply(self.base_layer, x)
  189. x = x.view(-1, x.shape[-1])
  190. output, out_orig_shape = output.view(-1,
  191. output.shape[-1]), output.shape
  192. buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
  193. dtype=torch.float32,
  194. device=x.device)
  195. bgmv(buffer, x, self.lora_a_stacked,
  196. self.indices[:self.indices_len[0]], 0, 1.0)
  197. buffer = tensor_model_parallel_all_reduce(buffer)
  198. # following S-LoRA, allows the fusing of all_gather and all_reduce
  199. # by adding the column partitioned lora output to a slice of output
  200. # tensor, which is a partial sum due to row parallel. All that
  201. # remains is a standard all_reduce. User should be aware though that
  202. # the output is not the same as a normal row_parallel, it should be
  203. # reduced before being used
  204. shard_size = self.lora_b_stacked.shape[2]
  205. start_idx = self.tp_rank * shard_size
  206. dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
  207. self.indices[:self.indices_len[0]], 0, 1.0,
  208. start_idx, shard_size)
  209. output = output.view(*out_orig_shape)
  210. return output
  211. @classmethod
  212. @_fully_sharded_can_replace
  213. def can_replace_layer(cls, source_layer: nn.Module,
  214. lora_config: LoRAConfig, packed_modules_list: List,
  215. model_config: Optional[PretrainedConfig]) -> bool:
  216. # specifying kwargs so they can be easily accessed in decorator
  217. return super().can_replace_layer(
  218. source_layer=source_layer,
  219. lora_config=lora_config,
  220. packed_modules_list=packed_modules_list,
  221. model_config=model_config,
  222. decorate=False,
  223. )