lora.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from typing import List, Optional
  2. import torch
  3. import torch.types
  4. from aphrodite.common.utils import is_pin_memory_available
  5. class LoRALayerWeights:
  6. """LoRA weights for a layer composed of two low rank matrixes."""
  7. def __init__(
  8. self,
  9. module_name: str,
  10. rank: Optional[int],
  11. lora_alpha: int,
  12. lora_a: Optional[torch.Tensor],
  13. lora_b: torch.Tensor,
  14. embeddings_tensor: Optional[torch.Tensor] = None,
  15. scaling: Optional[float] = None,
  16. ) -> None:
  17. """
  18. rank == None means that we have full rank tensors (ModulesToSave)
  19. in this case:
  20. lora_a=None
  21. lora_b=full rank tensor
  22. """
  23. self.module_name = module_name
  24. self.rank = rank
  25. self.lora_alpha = lora_alpha
  26. self.lora_a = lora_a
  27. self.lora_b = lora_b
  28. self.embeddings_tensor = embeddings_tensor
  29. self.scaling: Optional[float]
  30. if (scaling is None) and (self.rank is not None):
  31. self.scaling = self.lora_alpha / self.rank
  32. else:
  33. self.scaling = scaling
  34. def optimize(self) -> "LoRALayerWeights":
  35. """Optimize the LoRA by merging the scaling into lora_b."""
  36. if self.scaling == 1:
  37. return self
  38. self.lora_b *= self.scaling
  39. self.scaling = 1
  40. return self
  41. @property
  42. def input_dim(self) -> int:
  43. if self.lora_a is not None:
  44. return self.lora_a.shape[0]
  45. return self.lora_b.shape[0]
  46. @property
  47. def output_dim(self) -> int:
  48. return self.lora_b.shape[1]
  49. @property
  50. def is_packed(self) -> bool:
  51. return False
  52. @property
  53. def extra_vocab_size(self) -> int:
  54. return self.embeddings_tensor.shape[
  55. 0] if self.embeddings_tensor is not None else 0
  56. @classmethod
  57. def create_dummy_lora_weights(
  58. cls,
  59. module_name: str,
  60. input_dim: int,
  61. output_dim: int,
  62. rank: Optional[int],
  63. dtype: torch.types.Device,
  64. device: torch.device,
  65. embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
  66. pin_memory = str(device) == "cpu" and is_pin_memory_available()
  67. if rank is None:
  68. lora_a = None
  69. lora_b = torch.zeros([input_dim, output_dim],
  70. dtype=dtype,
  71. device=device,
  72. pin_memory=pin_memory)
  73. embeddings_tensor = None
  74. scaling = 1
  75. else:
  76. lora_a = torch.zeros([input_dim, rank],
  77. dtype=dtype,
  78. device=device,
  79. pin_memory=pin_memory)
  80. lora_b = torch.zeros([rank, output_dim],
  81. dtype=dtype,
  82. device=device,
  83. pin_memory=pin_memory)
  84. scaling = None
  85. embeddings_tensor = torch.rand(
  86. 10,
  87. embeddings_tensor_dim,
  88. dtype=dtype,
  89. device=device,
  90. pin_memory=pin_memory) if embeddings_tensor_dim else None
  91. return cls(
  92. module_name,
  93. rank=rank,
  94. lora_alpha=1,
  95. lora_a=lora_a,
  96. lora_b=lora_b,
  97. scaling=scaling,
  98. embeddings_tensor=embeddings_tensor,
  99. )
  100. def lora_a_pin_memory(self):
  101. if self.lora_a is not None:
  102. self.lora_a = self.lora_a.pin_memory()
  103. def lora_b_pin_memory(self):
  104. self.lora_b = self.lora_b.pin_memory()
  105. class PackedLoRALayerWeights(LoRALayerWeights):
  106. """LoRA used for packed layers (eg. qkv_proj)."""
  107. def __init__(
  108. self,
  109. module_name: str,
  110. rank: Optional[int],
  111. lora_alphas: List[Optional[int]],
  112. lora_a: List[Optional[torch.Tensor]],
  113. lora_b: List[Optional[torch.Tensor]],
  114. scaling: Optional[List[float]] = None,
  115. ) -> None:
  116. super().__init__(
  117. module_name=module_name,
  118. rank=rank,
  119. lora_alpha=0,
  120. lora_a=lora_a,
  121. lora_b=lora_b,
  122. scaling=scaling, # type: ignore
  123. embeddings_tensor=None,
  124. )
  125. self.lora_alphas = lora_alphas
  126. if (scaling is None) and (self.rank is not None):
  127. self.scaling = [ # type: ignore
  128. lora_alpha / self.rank # type: ignore # noqa
  129. for lora_alpha in self.lora_alphas
  130. ]
  131. @classmethod
  132. def pack(
  133. cls, loras: List[Optional["LoRALayerWeights"]]
  134. ) -> "PackedLoRALayerWeights":
  135. """Pack a list of LoRAs into a single LoRA.
  136. If LoRA is None, it signifies that the submodule does not have a LoRA.
  137. """
  138. first_lora = next(lora for lora in loras if lora is not None)
  139. for lora in loras:
  140. if lora is None:
  141. continue
  142. lora.optimize()
  143. rank = first_lora.rank
  144. module_name = first_lora.module_name
  145. obj = cls(
  146. module_name,
  147. rank,
  148. [lora.lora_alpha if lora is not None else None for lora in loras],
  149. [lora.lora_a if lora is not None else None for lora in loras],
  150. [lora.lora_b if lora is not None else None for lora in loras],
  151. scaling=[
  152. 1 if lora is not None else None # type: ignore
  153. for lora in loras
  154. ])
  155. return obj
  156. def optimize(self) -> "PackedLoRALayerWeights":
  157. """Optimize the LoRA by merging the scaling into lora_b."""
  158. for i in range(len(self.lora_b)):
  159. if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
  160. continue
  161. self.lora_b[i] *= self.scaling[i] # type: ignore
  162. self.scaling[i] = 1 # type: ignore
  163. return self
  164. @property
  165. def input_dim(self) -> int:
  166. raise NotImplementedError()
  167. @property
  168. def output_dim(self) -> int:
  169. raise NotImplementedError()
  170. @property
  171. def is_packed(self) -> bool:
  172. return True