fully_sharded_layers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. # pylint: disable=unused-argument
  2. from typing import TYPE_CHECKING, List, Optional, Union
  3. import torch
  4. import torch.nn as nn
  5. from transformers import PretrainedConfig
  6. from aphrodite.common.config import LoRAConfig
  7. from aphrodite.distributed.communication_op import (
  8. tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
  9. from aphrodite.distributed.parallel_state import get_tensor_model_parallel_rank
  10. from aphrodite.lora.layers import (ColumnParallelLinearWithLoRA,
  11. MergedColumnParallelLinearWithLoRA,
  12. MergedQKVParallelLinearWithLora,
  13. QKVParallelLinearWithLora,
  14. RowParallelLinearWithLoRA)
  15. if TYPE_CHECKING:
  16. pass
  17. def _fully_sharded_can_replace(can_replace):
  18. """
  19. decorator which adds the condition of fully sharded loras
  20. intended to wrap can_replace_layer()
  21. """
  22. def dec(*args, **kwargs):
  23. return (can_replace(*args, **kwargs)
  24. and kwargs["lora_config"].fully_sharded_loras)
  25. return dec
  26. # these layers are based on the tensor parallelism strategy given in
  27. # Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
  28. # https://arxiv.org/abs/2311.03285.
  29. class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
  30. """
  31. Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
  32. Based on S-LoRA, slicing happens along the rank dim.
  33. """
  34. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  35. tp_rank = get_tensor_model_parallel_rank()
  36. shard_size = self.lora_a_stacked.shape[2]
  37. start_idx = tp_rank * shard_size
  38. lora_a = lora_a[:, start_idx:start_idx + shard_size]
  39. return lora_a
  40. def apply(self, x: torch.Tensor,
  41. bias: Optional[torch.Tensor]) -> torch.Tensor:
  42. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  43. x = x.view(-1, x.shape[-1])
  44. output, out_orig_shape = output.view(-1,
  45. output.shape[-1]), output.shape
  46. buffer = torch.zeros(
  47. (x.shape[0], self.lora_a_stacked.shape[2]),
  48. dtype=torch.float32,
  49. device=x.device,
  50. )
  51. self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
  52. buffer = tensor_model_parallel_all_gather(buffer)
  53. self.punica_wrapper.add_expand(output,
  54. buffer,
  55. self.lora_b_stacked,
  56. add_input=True)
  57. # now have column partitioned output
  58. output = output.view(*out_orig_shape)
  59. return output
  60. @classmethod
  61. @_fully_sharded_can_replace
  62. def can_replace_layer(
  63. cls,
  64. source_layer: nn.Module,
  65. lora_config: LoRAConfig,
  66. packed_modules_list: List,
  67. model_config: Optional[PretrainedConfig],
  68. ) -> bool:
  69. # specifying kwargs so they can be easily accessed in decorator
  70. return super().can_replace_layer(
  71. source_layer=source_layer,
  72. lora_config=lora_config,
  73. packed_modules_list=packed_modules_list,
  74. model_config=model_config,
  75. decorate=False,
  76. )
  77. def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
  78. """
  79. MergedColumnParallelLinearWithShardedLoRA and
  80. MergedQKVParallelLinearWithShardedLora share the same
  81. LoRa weight application method.
  82. The main difference is the step by shard_size for lora_b which can
  83. vary for MergedQKVParallelLinearWithShardedLora but is constant for
  84. MergedColumnParallelLinearWithShardedLoRA.
  85. """
  86. # expecting 2 for column parallel and 3 for qkv
  87. n = len(layer.lora_a_stacked)
  88. output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
  89. x = x.view(-1, x.shape[-1])
  90. output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
  91. buffers = torch.zeros(
  92. (n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
  93. dtype=torch.float32,
  94. device=x.device,
  95. )
  96. for idx in range(n):
  97. layer.punica_wrapper.add_shrink(buffers[idx], x,
  98. layer.lora_a_stacked[idx], 1.0)
  99. buffers = tensor_model_parallel_all_gather(buffers)
  100. left_offset = 0
  101. for idx in range(n):
  102. shard_size = layer.lora_b_stacked[idx].shape[2]
  103. layer.punica_wrapper.add_expand_slice(
  104. output,
  105. buffers[idx],
  106. layer.lora_b_stacked[idx],
  107. left_offset,
  108. shard_size,
  109. add_input=True,
  110. )
  111. left_offset += shard_size
  112. output = output.view(*out_orig_shape)
  113. # now have column partitioned and packed output
  114. return output
  115. class MergedColumnParallelLinearWithShardedLoRA(
  116. MergedColumnParallelLinearWithLoRA):
  117. """
  118. Differs from MergedColumnParallelLinearWithLoRA by slicing the
  119. LoRA A's also.
  120. Based on S-LoRA, slicing happens along the rank dim.
  121. """
  122. def slice_lora_a(
  123. self, lora_a: List[Union[torch.Tensor, None]]
  124. ) -> List[Union[torch.Tensor, None]]:
  125. if lora_a[0] is None or lora_a[1] is None:
  126. return lora_a
  127. output_shard_size = self.lora_a_stacked[0].shape[2]
  128. output_start_idx = self.tp_rank * output_shard_size
  129. lora_a = [
  130. lora_a[0][:,
  131. output_start_idx:output_start_idx + output_shard_size],
  132. lora_a[1][:,
  133. output_start_idx:output_start_idx + output_shard_size],
  134. ]
  135. return lora_a
  136. def apply(self, x: torch.Tensor,
  137. bias: Optional[torch.Tensor]) -> torch.Tensor:
  138. return _mcp_apply(x, bias, self)
  139. @classmethod
  140. @_fully_sharded_can_replace
  141. def can_replace_layer(
  142. cls,
  143. source_layer: nn.Module,
  144. lora_config: LoRAConfig,
  145. packed_modules_list: List,
  146. model_config: Optional[PretrainedConfig],
  147. ) -> bool:
  148. # specifying kwargs so they can be easily accessed in decorator
  149. return super().can_replace_layer(
  150. source_layer=source_layer,
  151. lora_config=lora_config,
  152. packed_modules_list=packed_modules_list,
  153. model_config=model_config,
  154. decorate=False,
  155. )
  156. class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
  157. """
  158. Differs from QKVParallelLinearWithLora by slicing the
  159. LoRA A's also.
  160. Based on S-LoRA, slicing happens along the rank dim.
  161. """
  162. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
  163. tp_rank = get_tensor_model_parallel_rank()
  164. shard_size = self.lora_a_stacked.shape[2]
  165. start_idx = tp_rank * shard_size
  166. lora_a = lora_a[:, start_idx:start_idx + shard_size]
  167. return lora_a
  168. def apply(self, x: torch.Tensor,
  169. bias: Optional[torch.Tensor]) -> torch.Tensor:
  170. output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
  171. x = x.view(-1, x.shape[-1])
  172. output, out_orig_shape = output.view(-1,
  173. output.shape[-1]), output.shape
  174. buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
  175. dtype=torch.float32,
  176. device=x.device)
  177. self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
  178. buffer = tensor_model_parallel_all_gather(buffer)
  179. self.punica_wrapper.add_expand(output,
  180. buffer,
  181. self.lora_b_stacked,
  182. add_input=True)
  183. # now have column partitioned output
  184. output = output.view(*out_orig_shape)
  185. return output
  186. @classmethod
  187. @_fully_sharded_can_replace
  188. def can_replace_layer(cls, source_layer: nn.Module,
  189. lora_config: LoRAConfig, packed_modules_list: List,
  190. model_config: Optional[PretrainedConfig]) -> bool:
  191. # specifying kwargs so they can be easily accessed in decorator
  192. return super().can_replace_layer(
  193. source_layer=source_layer,
  194. lora_config=lora_config,
  195. packed_modules_list=packed_modules_list,
  196. model_config=model_config,
  197. decorate=False,
  198. )
  199. class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
  200. """
  201. Differs from MergedQKVParallelLinearWithLora by slicing the
  202. LoRA A's also.
  203. Based on S-LoRA, slicing happens along the rank dim.
  204. """
  205. def slice_lora_a(
  206. self, lora_a: List[Union[torch.Tensor, None]]
  207. ) -> List[Union[torch.Tensor, None]]:
  208. if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
  209. return lora_a
  210. shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
  211. start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
  212. lora_a = [
  213. lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
  214. lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
  215. lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]],
  216. ]
  217. return lora_a
  218. def apply(self, x: torch.Tensor,
  219. bias: Optional[torch.Tensor]) -> torch.Tensor:
  220. return _mcp_apply(x, bias, self)
  221. @classmethod
  222. @_fully_sharded_can_replace
  223. def can_replace_layer(
  224. cls,
  225. source_layer: nn.Module,
  226. lora_config: LoRAConfig,
  227. packed_modules_list: List,
  228. model_config: Optional[PretrainedConfig],
  229. ) -> bool:
  230. # specifying kwargs so they can be easily accessed in decorator
  231. return super().can_replace_layer(
  232. source_layer=source_layer,
  233. lora_config=lora_config,
  234. packed_modules_list=packed_modules_list,
  235. model_config=model_config,
  236. decorate=False,
  237. )
  238. class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
  239. """
  240. Differs from RowParallelLinearWithLoRA by slicing the
  241. LoRA B's also.
  242. Based on S-LoRA, slicing happens along the output dim.
  243. This yields a combined partial sum from the row parallel base
  244. layer and column partitioned output from the LoRA.
  245. """
  246. def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
  247. shard_size = self.lora_b_stacked.shape[2]
  248. start_idx = self.tp_rank * shard_size
  249. end_idx = (self.tp_rank + 1) * shard_size
  250. lora_b = lora_b[:, start_idx:end_idx]
  251. return lora_b
  252. def apply(self, x: torch.Tensor) -> torch.Tensor:
  253. output = self.base_layer.quant_method.apply(self.base_layer, x)
  254. x = x.view(-1, x.shape[-1])
  255. output, out_orig_shape = output.view(-1,
  256. output.shape[-1]), output.shape
  257. buffer = torch.zeros(
  258. (x.shape[0], self.lora_a_stacked.shape[2]),
  259. dtype=torch.float32,
  260. device=x.device,
  261. )
  262. self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
  263. buffer = tensor_model_parallel_all_reduce(buffer)
  264. # following S-LoRA, allows the fusing of all_gather and all_reduce
  265. # by adding the column partitioned lora output to a slice of output
  266. # tensor, which is a partial sum due to row parallel. All that
  267. # remains is a standard all_reduce. User should be aware though that
  268. # the output is not the same as a normal row_parallel, it should be
  269. # reduced before being used
  270. shard_size = self.lora_b_stacked.shape[2]
  271. start_idx = self.tp_rank * shard_size
  272. self.punica_wrapper.add_expand_slice(output, buffer,
  273. self.lora_b_stacked, start_idx,
  274. shard_size)
  275. output = output.view(*out_orig_shape)
  276. return output
  277. @classmethod
  278. @_fully_sharded_can_replace
  279. def can_replace_layer(
  280. cls,
  281. source_layer: nn.Module,
  282. lora_config: LoRAConfig,
  283. packed_modules_list: List,
  284. model_config: Optional[PretrainedConfig],
  285. ) -> bool:
  286. # specifying kwargs so they can be easily accessed in decorator
  287. return super().can_replace_layer(
  288. source_layer=source_layer,
  289. lora_config=lora_config,
  290. packed_modules_list=packed_modules_list,
  291. model_config=model_config,
  292. decorate=False,
  293. )