fully_sharded_layers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # pylint: disable=unused-argument
  2. from typing import TYPE_CHECKING, List, Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.common.config import LoRAConfig
  7. from aphrodite.distributed.communication_op import (
  8. tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
  9. from aphrodite.distributed.parallel_state import get_tensor_model_parallel_rank
  10. from aphrodite.lora.layers import (ColumnParallelLinearWithLoRA,
  11. MergedColumnParallelLinearWithLoRA,
  12. MergedQKVParallelLinearWithLora,
  13. QKVParallelLinearWithLora,
  14. RowParallelLinearWithLoRA)
  15. if TYPE_CHECKING:
  16. pass
  17. def _fully_sharded_can_replace(can_replace):
  18. """
  19. decorator which adds the condition of fully sharded loras
  20. intended to wrap can_replace_layer()
  21. """
  22. def dec(*args, **kwargs):
  23. return (can_replace(*args, **kwargs)
  24. and kwargs["lora_config"].fully_sharded_loras)
  25. return dec
  26. # these layers are based on the tensor parallelism strategy given in
  27. # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
  28. # https://arxiv.org/abs/2311.03285.
  29. class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
  30. """
  31. Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
  32. Based on S-LoRA, slicing happens along the rank dim.
  33. """
  34. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  35. tp_rank = get_tensor_model_parallel_rank()
  36. shard_size = self.lora_a_stacked.shape[2]
  37. start_idx = tp_rank * shard_size
  38. lora_a = lora_a[:, start_idx:start_idx + shard_size]
  39. return lora_a
  40. def apply(self, x: torch.Tensor,
  41. bias: Optional[torch.Tensor]) -> torch.Tensor:
  42. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  43. x = x.view(-1, x.shape[-1])
  44. output, out_orig_shape = output.view(-1,
  45. output.shape[-1]), output.shape
  46. buffer = torch.zeros(
  47. (x.shape[0], self.lora_a_stacked.shape[2]),
  48. dtype=torch.float32,
  49. device=x.device,
  50. )
  51. self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
  52. buffer = tensor_model_parallel_all_gather(buffer)
  53. self.punica_wrapper.add_expand(output,
  54. buffer,
  55. self.lora_b_stacked,
  56. add_input=True)
  57. # now have column partitioned output
  58. output = output.view(*out_orig_shape)
  59. return output
  60. @classmethod
  61. @_fully_sharded_can_replace
  62. def can_replace_layer(
  63. cls,
  64. source_layer: nn.Module,
  65. lora_config: LoRAConfig,
  66. packed_modules_list: List,
  67. model_config: Optional[PretrainedConfig],
  68. ) -> bool:
  69. # specifying kwargs so they can be easily accessed in decorator
  70. return super().can_replace_layer(
  71. source_layer=source_layer,
  72. lora_config=lora_config,
  73. packed_modules_list=packed_modules_list,
  74. model_config=model_config,
  75. decorate=False,
  76. )
  77. def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
  78. """
  79. MergedColumnParallelLinearWithShardedLoRA and
  80. MergedQKVParallelLinearWithShardedLora share the same
  81. LoRa weight application method.
  82. The main difference is the step by shard_size for lora_b which can
  83. vary for MergedQKVParallelLinearWithShardedLora but is constant for
  84. MergedColumnParallelLinearWithShardedLoRA.
  85. """
  86. # expecting 2 for column parallel and 3 for qkv
  87. n = len(layer.lora_a_stacked)
  88. output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
  89. x = x.view(-1, x.shape[-1])
  90. output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
  91. buffers = torch.zeros(
  92. (n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
  93. dtype=torch.float32,
  94. device=x.device,
  95. )
  96. for idx in range(n):
  97. layer.punica_wrapper.add_shrink(buffers[idx], x,
  98. layer.lora_a_stacked[idx], 1.0)
  99. buffers = tensor_model_parallel_all_gather(buffers)
  100. left_offset = 0
  101. for idx in range(n):
  102. shard_size = layer.lora_b_stacked[idx].shape[2]
  103. layer.punica_wrapper.add_expand_slice(
  104. output,
  105. buffers[idx],
  106. layer.lora_b_stacked[idx],
  107. left_offset,
  108. shard_size,
  109. add_input=True,
  110. )
  111. left_offset += shard_size
  112. output = output.view(*out_orig_shape)
  113. # now have column partitioned and packed output
  114. return output
  115. class MergedColumnParallelLinearWithShardedLoRA(
  116. MergedColumnParallelLinearWithLoRA):
  117. """
  118. Differs from MergedColumnParallelLinearWithLoRA by slicing the
  119. LoRA A's also.
  120. Based on S-LoRA, slicing happens along the rank dim.
  121. """
  122. def slice_lora_a(
  123. self, lora_a: List[Union[torch.Tensor, None]]
  124. ) -> List[Union[torch.Tensor, None]]:
  125. # NOTE: lora_a contains 2 subloras, and each sublora could be None.
  126. output_shard_size = self.lora_a_stacked[0].shape[2]
  127. output_start_idx = self.tp_rank * output_shard_size
  128. lora_a = [
  129. lora_a[0][:, output_start_idx:output_start_idx +
  130. output_shard_size] if lora_a[0] is not None else None,
  131. lora_a[1][:, output_start_idx:output_start_idx +
  132. output_shard_size] if lora_a[1] is not None else None,
  133. ]
  134. return lora_a
  135. def apply(self, x: torch.Tensor,
  136. bias: Optional[torch.Tensor]) -> torch.Tensor:
  137. return _mcp_apply(x, bias, self)
  138. @classmethod
  139. @_fully_sharded_can_replace
  140. def can_replace_layer(
  141. cls,
  142. source_layer: nn.Module,
  143. lora_config: LoRAConfig,
  144. packed_modules_list: List,
  145. model_config: Optional[PretrainedConfig],
  146. ) -> bool:
  147. # specifying kwargs so they can be easily accessed in decorator
  148. return super().can_replace_layer(
  149. source_layer=source_layer,
  150. lora_config=lora_config,
  151. packed_modules_list=packed_modules_list,
  152. model_config=model_config,
  153. decorate=False,
  154. )
  155. class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
  156. """
  157. Differs from QKVParallelLinearWithLora by slicing the
  158. LoRA A's also.
  159. Based on S-LoRA, slicing happens along the rank dim.
  160. """
  161. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  162. tp_rank = get_tensor_model_parallel_rank()
  163. shard_size = self.lora_a_stacked.shape[2]
  164. start_idx = tp_rank * shard_size
  165. lora_a = lora_a[:, start_idx:start_idx + shard_size]
  166. return lora_a
  167. def apply(self, x: torch.Tensor,
  168. bias: Optional[torch.Tensor]) -> torch.Tensor:
  169. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  170. x = x.view(-1, x.shape[-1])
  171. output, out_orig_shape = output.view(-1,
  172. output.shape[-1]), output.shape
  173. buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
  174. dtype=torch.float32,
  175. device=x.device)
  176. self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
  177. buffer = tensor_model_parallel_all_gather(buffer)
  178. self.punica_wrapper.add_expand(output,
  179. buffer,
  180. self.lora_b_stacked,
  181. add_input=True)
  182. # now have column partitioned output
  183. output = output.view(*out_orig_shape)
  184. return output
  185. @classmethod
  186. @_fully_sharded_can_replace
  187. def can_replace_layer(cls, source_layer: nn.Module,
  188. lora_config: LoRAConfig, packed_modules_list: List,
  189. model_config: Optional[PretrainedConfig]) -> bool:
  190. # specifying kwargs so they can be easily accessed in decorator
  191. return super().can_replace_layer(
  192. source_layer=source_layer,
  193. lora_config=lora_config,
  194. packed_modules_list=packed_modules_list,
  195. model_config=model_config,
  196. decorate=False,
  197. )
  198. class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
  199. """
  200. Differs from MergedQKVParallelLinearWithLora by slicing the
  201. LoRA A's also.
  202. Based on S-LoRA, slicing happens along the rank dim.
  203. """
  204. def slice_lora_a(
  205. self, lora_a: List[Union[torch.Tensor, None]]
  206. ) -> List[Union[torch.Tensor, None]]:
  207. # NOTE: lora_a contains 3 subloras, and each sublora could be None.
  208. shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
  209. start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
  210. lora_a = [
  211. lora_a[0][:, start_idx[0]:start_idx[0] +
  212. shard_size[0]] if lora_a[0] is not None else None,
  213. lora_a[1][:, start_idx[1]:start_idx[1] +
  214. shard_size[1]] if lora_a[1] is not None else None,
  215. lora_a[2][:, start_idx[2]:start_idx[2] +
  216. shard_size[2]] if lora_a[2] is not None else None,
  217. ]
  218. return lora_a
  219. def apply(self, x: torch.Tensor,
  220. bias: Optional[torch.Tensor]) -> torch.Tensor:
  221. return _mcp_apply(x, bias, self)
  222. @classmethod
  223. @_fully_sharded_can_replace
  224. def can_replace_layer(
  225. cls,
  226. source_layer: nn.Module,
  227. lora_config: LoRAConfig,
  228. packed_modules_list: List,
  229. model_config: Optional[PretrainedConfig],
  230. ) -> bool:
  231. # specifying kwargs so they can be easily accessed in decorator
  232. return super().can_replace_layer(
  233. source_layer=source_layer,
  234. lora_config=lora_config,
  235. packed_modules_list=packed_modules_list,
  236. model_config=model_config,
  237. decorate=False,
  238. )
  239. class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
  240. """
  241. Differs from RowParallelLinearWithLoRA by slicing the
  242. LoRA B's also.
  243. Based on S-LoRA, slicing happens along the output dim.
  244. This yields a combined partial sum from the row parallel base
  245. layer and column partitioned output from the LoRA.
  246. """
  247. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  248. shard_size = self.lora_b_stacked.shape[2]
  249. start_idx = self.tp_rank * shard_size
  250. end_idx = (self.tp_rank + 1) * shard_size
  251. lora_b = lora_b[:, start_idx:end_idx]
  252. return lora_b
  253. def apply(self, x: torch.Tensor) -> torch.Tensor:
  254. output = self.base_layer.quant_method.apply(self.base_layer, x)
  255. x = x.view(-1, x.shape[-1])
  256. output, out_orig_shape = output.view(-1,
  257. output.shape[-1]), output.shape
  258. buffer = torch.zeros(
  259. (x.shape[0], self.lora_a_stacked.shape[2]),
  260. dtype=torch.float32,
  261. device=x.device,
  262. )
  263. self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
  264. buffer = tensor_model_parallel_all_reduce(buffer)
  265. # following S-LoRA, allows the fusing of all_gather and all_reduce
  266. # by adding the column partitioned lora output to a slice of output
  267. # tensor, which is a partial sum due to row parallel. All that
  268. # remains is a standard all_reduce. User should be aware though that
  269. # the output is not the same as a normal row_parallel, it should be
  270. # reduced before being used
  271. shard_size = self.lora_b_stacked.shape[2]
  272. start_idx = self.tp_rank * shard_size
  273. self.punica_wrapper.add_expand_slice(output, buffer,
  274. self.lora_b_stacked, start_idx,
  275. shard_size)
  276. output = output.view(*out_orig_shape)
  277. return output
  278. @classmethod
  279. @_fully_sharded_can_replace
  280. def can_replace_layer(
  281. cls,
  282. source_layer: nn.Module,
  283. lora_config: LoRAConfig,
  284. packed_modules_list: List,
  285. model_config: Optional[PretrainedConfig],
  286. ) -> bool:
  287. # specifying kwargs so they can be easily accessed in decorator
  288. return super().can_replace_layer(
  289. source_layer=source_layer,
  290. lora_config=lora_config,
  291. packed_modules_list=packed_modules_list,
  292. model_config=model_config,
  293. decorate=False,
  294. )