fully_sharded_layers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # pylint: disable=unused-argument
  2. from typing import TYPE_CHECKING, List, Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.common.config import LoRAConfig
  7. from aphrodite.distributed.communication_op import (
  8. tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
  9. from aphrodite.distributed.parallel_state import get_tensor_model_parallel_rank
  10. from aphrodite.lora.layers import (ColumnParallelLinearWithLoRA,
  11. MergedColumnParallelLinearWithLoRA,
  12. MergedQKVParallelLinearWithLora,
  13. QKVParallelLinearWithLora,
  14. RowParallelLinearWithLoRA)
  15. from aphrodite.lora.punica import bgmv, dispatch_bgmv_low_level
  16. if TYPE_CHECKING:
  17. pass
  18. def _fully_sharded_can_replace(can_replace):
  19. """
  20. decorator which adds the condition of fully sharded loras
  21. intended to wrap can_replace_layer()
  22. """
  23. def dec(*args, **kwargs):
  24. return (can_replace(*args, **kwargs)
  25. and kwargs['lora_config'].fully_sharded_loras)
  26. return dec
  27. # these layers are based on the tensor parallelism strategy given in
  28. # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
  29. # https://arxiv.org/abs/2311.03285.
  30. class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
  31. """
  32. Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
  33. Based on S-LoRA, slicing happens along the rank dim.
  34. """
  35. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  36. tp_rank = get_tensor_model_parallel_rank()
  37. shard_size = self.lora_a_stacked.shape[2]
  38. start_idx = tp_rank * shard_size
  39. lora_a = lora_a[:, start_idx:start_idx + shard_size]
  40. return lora_a
  41. def apply(self, x: torch.Tensor,
  42. bias: Optional[torch.Tensor]) -> torch.Tensor:
  43. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  44. x = x.view(-1, x.shape[-1])
  45. output, out_orig_shape = output.view(-1,
  46. output.shape[-1]), output.shape
  47. buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
  48. dtype=torch.float32,
  49. device=x.device)
  50. bgmv(buffer, x, self.lora_a_stacked,
  51. self.indices[:self.indices_len[0]], 0, 1.0)
  52. buffer = tensor_model_parallel_all_gather(buffer)
  53. bgmv(output, buffer, self.lora_b_stacked,
  54. self.indices[:self.indices_len[0]], 0, 1.0)
  55. # now have column partitioned output
  56. output = output.view(*out_orig_shape)
  57. return output
  58. @classmethod
  59. @_fully_sharded_can_replace
  60. def can_replace_layer(cls, source_layer: nn.Module,
  61. lora_config: LoRAConfig, packed_modules_list: List,
  62. model_config: Optional[PretrainedConfig]) -> bool:
  63. # specifying kwargs so they can be easily accessed in decorator
  64. return super().can_replace_layer(
  65. source_layer=source_layer,
  66. lora_config=lora_config,
  67. packed_modules_list=packed_modules_list,
  68. model_config=model_config,
  69. decorate=False,
  70. )
  71. def _mcp_apply(x, bias, layer):
  72. """
  73. MergedColumnParallelLinearWithShardedLoRA and
  74. MergedQKVParallelLinearWithShardedLora share the same
  75. LoRa weight application method.
  76. The main difference is the step by shard_size for lora_b which can
  77. vary for MergedQKVParallelLinearWithShardedLora but is constant for
  78. MergedColumnParallelLinearWithShardedLoRA.
  79. """
  80. # expecting 2 for column parallel and 3 for qkv
  81. n = len(layer.lora_a_stacked)
  82. output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
  83. x = x.view(-1, x.shape[-1])
  84. output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
  85. buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
  86. dtype=torch.float32,
  87. device=x.device)
  88. for idx in range(n):
  89. bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
  90. layer.indices[:layer.indices_len[0]], 0, 1.0)
  91. buffers = tensor_model_parallel_all_gather(buffers)
  92. left_offset = 0
  93. for idx in range(n):
  94. shard_size = layer.lora_b_stacked[idx].shape[2]
  95. dispatch_bgmv_low_level(output, buffers[idx],
  96. layer.lora_b_stacked[idx],
  97. layer.indices[:layer.indices_len[0]], 0, 1.0,
  98. left_offset, shard_size)
  99. left_offset += shard_size
  100. output = output.view(*out_orig_shape)
  101. # now have column partitioned and packed output
  102. return output
  103. class MergedColumnParallelLinearWithShardedLoRA(
  104. MergedColumnParallelLinearWithLoRA):
  105. """
  106. Differs from MergedColumnParallelLinearWithLoRA by slicing the
  107. LoRA A's also.
  108. Based on S-LoRA, slicing happens along the rank dim.
  109. """
  110. def slice_lora_a(
  111. self, lora_a: List[Union[torch.Tensor, None]]
  112. ) -> List[Union[torch.Tensor, None]]:
  113. if lora_a[0] is None or lora_a[1] is None:
  114. return lora_a
  115. output_shard_size = self.lora_a_stacked[0].shape[2]
  116. output_start_idx = self.tp_rank * output_shard_size
  117. lora_a = [
  118. lora_a[0][:,
  119. output_start_idx:output_start_idx + output_shard_size],
  120. lora_a[1][:, output_start_idx:output_start_idx + output_shard_size]
  121. ]
  122. return lora_a
  123. def apply(self, x: torch.Tensor,
  124. bias: Optional[torch.Tensor]) -> torch.Tensor:
  125. return _mcp_apply(x, bias, self)
  126. @classmethod
  127. @_fully_sharded_can_replace
  128. def can_replace_layer(cls, source_layer: nn.Module,
  129. lora_config: LoRAConfig, packed_modules_list: List,
  130. model_config: Optional[PretrainedConfig]) -> bool:
  131. # specifying kwargs so they can be easily accessed in decorator
  132. return super().can_replace_layer(
  133. source_layer=source_layer,
  134. lora_config=lora_config,
  135. packed_modules_list=packed_modules_list,
  136. model_config=model_config,
  137. decorate=False,
  138. )
  139. class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
  140. """
  141. Differs from QKVParallelLinearWithLora by slicing the
  142. LoRA A's also.
  143. Based on S-LoRA, slicing happens along the rank dim.
  144. """
  145. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  146. tp_rank = get_tensor_model_parallel_rank()
  147. shard_size = self.lora_a_stacked.shape[2]
  148. start_idx = tp_rank * shard_size
  149. lora_a = lora_a[:, start_idx:start_idx + shard_size]
  150. return lora_a
  151. def apply(self, x: torch.Tensor,
  152. bias: Optional[torch.Tensor]) -> torch.Tensor:
  153. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  154. x = x.view(-1, x.shape[-1])
  155. output, out_orig_shape = output.view(-1,
  156. output.shape[-1]), output.shape
  157. buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
  158. dtype=torch.float32,
  159. device=x.device)
  160. bgmv(buffer, x, self.lora_a_stacked,
  161. self.indices[:self.indices_len[0]], 0, 1.0)
  162. buffer = tensor_model_parallel_all_gather(buffer)
  163. bgmv(output, buffer, self.lora_b_stacked,
  164. self.indices[:self.indices_len[0]], 0, 1.0)
  165. # now have column partitioned output
  166. output = output.view(*out_orig_shape)
  167. return output
  168. @classmethod
  169. @_fully_sharded_can_replace
  170. def can_replace_layer(cls, source_layer: nn.Module,
  171. lora_config: LoRAConfig, packed_modules_list: List,
  172. model_config: Optional[PretrainedConfig]) -> bool:
  173. # specifying kwargs so they can be easily accessed in decorator
  174. return super().can_replace_layer(
  175. source_layer=source_layer,
  176. lora_config=lora_config,
  177. packed_modules_list=packed_modules_list,
  178. model_config=model_config,
  179. decorate=False,
  180. )
  181. class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
  182. """
  183. Differs from MergedQKVParallelLinearWithLora by slicing the
  184. LoRA A's also.
  185. Based on S-LoRA, slicing happens along the rank dim.
  186. """
  187. def slice_lora_a(
  188. self, lora_a: List[Union[torch.Tensor, None]]
  189. ) -> List[Union[torch.Tensor, None]]:
  190. if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
  191. return lora_a
  192. shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
  193. start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
  194. lora_a = [
  195. lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
  196. lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
  197. lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]]
  198. ]
  199. return lora_a
  200. def apply(self, x: torch.Tensor,
  201. bias: Optional[torch.Tensor]) -> torch.Tensor:
  202. return _mcp_apply(x, bias, self)
  203. @classmethod
  204. @_fully_sharded_can_replace
  205. def can_replace_layer(cls, source_layer: nn.Module,
  206. lora_config: LoRAConfig, packed_modules_list: List,
  207. model_config: Optional[PretrainedConfig]) -> bool:
  208. # specifying kwargs so they can be easily accessed in decorator
  209. return super().can_replace_layer(
  210. source_layer=source_layer,
  211. lora_config=lora_config,
  212. packed_modules_list=packed_modules_list,
  213. model_config=model_config,
  214. decorate=False,
  215. )
  216. class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
  217. """
  218. Differs from RowParallelLinearWithLoRA by slicing the
  219. LoRA B's also.
  220. Based on S-LoRA, slicing happens along the output dim.
  221. This yields a combined partial sum from the row parallel base
  222. layer and column partitioned output from the LoRA.
  223. """
  224. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  225. shard_size = self.lora_b_stacked.shape[2]
  226. start_idx = self.tp_rank * shard_size
  227. end_idx = (self.tp_rank + 1) * shard_size
  228. lora_b = lora_b[:, start_idx:end_idx]
  229. return lora_b
  230. def apply(self, x: torch.Tensor) -> torch.Tensor:
  231. output = self.base_layer.quant_method.apply(self.base_layer, x)
  232. x = x.view(-1, x.shape[-1])
  233. output, out_orig_shape = output.view(-1,
  234. output.shape[-1]), output.shape
  235. buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
  236. dtype=torch.float32,
  237. device=x.device)
  238. bgmv(buffer, x, self.lora_a_stacked,
  239. self.indices[:self.indices_len[0]], 0, 1.0)
  240. buffer = tensor_model_parallel_all_reduce(buffer)
  241. # following S-LoRA, allows the fusing of all_gather and all_reduce
  242. # by adding the column partitioned lora output to a slice of output
  243. # tensor, which is a partial sum due to row parallel. All that
  244. # remains is a standard all_reduce. User should be aware though that
  245. # the output is not the same as a normal row_parallel, it should be
  246. # reduced before being used
  247. shard_size = self.lora_b_stacked.shape[2]
  248. start_idx = self.tp_rank * shard_size
  249. dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
  250. self.indices[:self.indices_len[0]], 0, 1.0,
  251. start_idx, shard_size)
  252. output = output.view(*out_orig_shape)
  253. return output
  254. @classmethod
  255. @_fully_sharded_can_replace
  256. def can_replace_layer(cls, source_layer: nn.Module,
  257. lora_config: LoRAConfig, packed_modules_list: List,
  258. model_config: Optional[PretrainedConfig]) -> bool:
  259. # specifying kwargs so they can be easily accessed in decorator
  260. return super().can_replace_layer(
  261. source_layer=source_layer,
  262. lora_config=lora_config,
  263. packed_modules_list=packed_modules_list,
  264. model_config=model_config,
  265. decorate=False,
  266. )