lora.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. from typing import List, Optional
  2. import torch
  3. from aphrodite.common.utils import is_pin_memory_available
  4. class LoRALayerWeights:
  5. """LoRA weights for a layer composed of two low rank matrixes."""
  6. def __init__(
  7. self,
  8. module_name: str,
  9. rank: int,
  10. lora_alpha: int,
  11. lora_a: torch.Tensor,
  12. lora_b: torch.Tensor,
  13. embeddings_tensor: Optional[torch.Tensor] = None,
  14. scaling: Optional[float] = None,
  15. ) -> None:
  16. self.module_name = module_name
  17. self.rank = rank
  18. self.lora_alpha = lora_alpha
  19. self.lora_a = lora_a
  20. self.lora_b = lora_b
  21. self.embeddings_tensor = embeddings_tensor
  22. if scaling is None:
  23. self.scaling = self.lora_alpha / self.rank
  24. else:
  25. self.scaling = scaling
  26. def optimize(self) -> "LoRALayerWeights":
  27. """Optimize the LoRA by merging the scaling into lora_b."""
  28. if self.scaling == 1:
  29. return self
  30. self.lora_b *= self.scaling
  31. self.scaling = 1
  32. return self
  33. @property
  34. def input_dim(self) -> int:
  35. return self.lora_a.shape[0]
  36. @property
  37. def output_dim(self) -> int:
  38. return self.lora_b.shape[1]
  39. @property
  40. def is_packed(self) -> bool:
  41. return False
  42. @property
  43. def extra_vocab_size(self) -> int:
  44. return self.embeddings_tensor.shape[
  45. 0] if self.embeddings_tensor is not None else 0
  46. @classmethod
  47. def create_dummy_lora_weights(
  48. cls,
  49. module_name: str,
  50. input_dim: int,
  51. output_dim: int,
  52. rank: int,
  53. dtype: torch.dtype,
  54. device: torch.device,
  55. embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
  56. pin_memory = str(device) == "cpu" and is_pin_memory_available()
  57. lora_a = torch.zeros([input_dim, rank],
  58. dtype=dtype,
  59. device=device,
  60. pin_memory=pin_memory)
  61. lora_b = torch.zeros([rank, output_dim],
  62. dtype=dtype,
  63. device=device,
  64. pin_memory=pin_memory)
  65. embeddings_tensor = torch.rand(
  66. 10,
  67. embeddings_tensor_dim,
  68. dtype=dtype,
  69. device=device,
  70. pin_memory=pin_memory) if embeddings_tensor_dim else None
  71. return cls(
  72. module_name,
  73. rank=rank,
  74. lora_alpha=1,
  75. lora_a=lora_a,
  76. lora_b=lora_b,
  77. embeddings_tensor=embeddings_tensor,
  78. )
  79. class PackedLoRALayerWeights(LoRALayerWeights):
  80. """LoRA used for packed layers (eg. qkv_proj)."""
  81. def __init__(
  82. self,
  83. module_name: str,
  84. rank: int,
  85. lora_alphas: List[Optional[int]],
  86. lora_a: List[Optional[torch.Tensor]],
  87. lora_b: List[Optional[torch.Tensor]],
  88. scaling: Optional[List[float]] = None,
  89. ) -> None:
  90. super().__init__(
  91. module_name=module_name,
  92. rank=rank,
  93. lora_alpha=0,
  94. lora_a=lora_a,
  95. lora_b=lora_b,
  96. scaling=scaling, # type: ignore
  97. embeddings_tensor=None,
  98. )
  99. self.lora_alphas = lora_alphas
  100. if scaling is None:
  101. self.scaling = [ # type: ignore
  102. lora_alpha / self.rank # type: ignore # noqa
  103. for lora_alpha in self.lora_alphas
  104. ]
  105. @classmethod
  106. def pack(
  107. cls, loras: List[Optional["LoRALayerWeights"]]
  108. ) -> "PackedLoRALayerWeights":
  109. """Pack a list of LoRAs into a single LoRA.
  110. If LoRA is None, it signifies that the submodule does not have a LoRA.
  111. """
  112. first_lora = next(lora for lora in loras if lora is not None)
  113. for lora in loras:
  114. if lora is None:
  115. continue
  116. lora.optimize()
  117. rank = first_lora.rank
  118. module_name = first_lora.module_name
  119. obj = cls(
  120. module_name,
  121. rank,
  122. [lora.lora_alpha if lora is not None else None for lora in loras],
  123. [lora.lora_a if lora is not None else None for lora in loras],
  124. [lora.lora_b if lora is not None else None for lora in loras],
  125. scaling=[
  126. 1 if lora is not None else None # type: ignore
  127. for lora in loras
  128. ])
  129. return obj
  130. def optimize(self) -> "PackedLoRALayerWeights":
  131. """Optimize the LoRA by merging the scaling into lora_b."""
  132. for i in range(len(self.lora_b)):
  133. if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
  134. continue
  135. self.lora_b[i] *= self.scaling[i] # type: ignore
  136. self.scaling[i] = 1 # type: ignore
  137. return self
  138. @property
  139. def input_dim(self) -> int:
  140. raise NotImplementedError()
  141. @property
  142. def output_dim(self) -> int:
  143. raise NotImplementedError()
  144. @property
  145. def is_packed(self) -> bool:
  146. return True