lora.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. from typing import List, Optional
  2. import torch
  3. import torch.types
  4. from aphrodite.common.utils import is_pin_memory_available
  5. class LoRALayerWeights:
  6. """LoRA weights for a layer composed of two low rank matrixes."""
  7. def __init__(
  8. self,
  9. module_name: str,
  10. rank: int,
  11. lora_alpha: int,
  12. lora_a: torch.Tensor,
  13. lora_b: torch.Tensor,
  14. embeddings_tensor: Optional[torch.Tensor] = None,
  15. scaling: Optional[float] = None,
  16. ) -> None:
  17. self.module_name = module_name
  18. self.rank = rank
  19. self.lora_alpha = lora_alpha
  20. self.lora_a = lora_a
  21. self.lora_b = lora_b
  22. self.embeddings_tensor = embeddings_tensor
  23. if scaling is None:
  24. self.scaling = self.lora_alpha / self.rank
  25. else:
  26. self.scaling = scaling
  27. def optimize(self) -> "LoRALayerWeights":
  28. """Optimize the LoRA by merging the scaling into lora_b."""
  29. if self.scaling == 1:
  30. return self
  31. self.lora_b *= self.scaling
  32. self.scaling = 1
  33. return self
  34. @property
  35. def input_dim(self) -> int:
  36. return self.lora_a.shape[0]
  37. @property
  38. def output_dim(self) -> int:
  39. return self.lora_b.shape[1]
  40. @property
  41. def is_packed(self) -> bool:
  42. return False
  43. @property
  44. def extra_vocab_size(self) -> int:
  45. return self.embeddings_tensor.shape[
  46. 0] if self.embeddings_tensor is not None else 0
  47. @classmethod
  48. def create_dummy_lora_weights(
  49. cls,
  50. module_name: str,
  51. input_dim: int,
  52. output_dim: int,
  53. rank: int,
  54. dtype: torch.types.Device,
  55. device: torch.device,
  56. embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
  57. pin_memory = str(device) == "cpu" and is_pin_memory_available()
  58. lora_a = torch.zeros([input_dim, rank],
  59. dtype=dtype,
  60. device=device,
  61. pin_memory=pin_memory)
  62. lora_b = torch.zeros([rank, output_dim],
  63. dtype=dtype,
  64. device=device,
  65. pin_memory=pin_memory)
  66. embeddings_tensor = torch.rand(
  67. 10,
  68. embeddings_tensor_dim,
  69. dtype=dtype,
  70. device=device,
  71. pin_memory=pin_memory) if embeddings_tensor_dim else None
  72. return cls(
  73. module_name,
  74. rank=rank,
  75. lora_alpha=1,
  76. lora_a=lora_a,
  77. lora_b=lora_b,
  78. embeddings_tensor=embeddings_tensor,
  79. )
  80. class PackedLoRALayerWeights(LoRALayerWeights):
  81. """LoRA used for packed layers (eg. qkv_proj)."""
  82. def __init__(
  83. self,
  84. module_name: str,
  85. rank: int,
  86. lora_alphas: List[Optional[int]],
  87. lora_a: List[Optional[torch.Tensor]],
  88. lora_b: List[Optional[torch.Tensor]],
  89. scaling: Optional[List[float]] = None,
  90. ) -> None:
  91. super().__init__(
  92. module_name=module_name,
  93. rank=rank,
  94. lora_alpha=0,
  95. lora_a=lora_a,
  96. lora_b=lora_b,
  97. scaling=scaling, # type: ignore
  98. embeddings_tensor=None,
  99. )
  100. self.lora_alphas = lora_alphas
  101. if scaling is None:
  102. self.scaling = [ # type: ignore
  103. lora_alpha / self.rank # type: ignore # noqa
  104. for lora_alpha in self.lora_alphas
  105. ]
  106. @classmethod
  107. def pack(
  108. cls, loras: List[Optional["LoRALayerWeights"]]
  109. ) -> "PackedLoRALayerWeights":
  110. """Pack a list of LoRAs into a single LoRA.
  111. If LoRA is None, it signifies that the submodule does not have a LoRA.
  112. """
  113. first_lora = next(lora for lora in loras if lora is not None)
  114. for lora in loras:
  115. if lora is None:
  116. continue
  117. lora.optimize()
  118. rank = first_lora.rank
  119. module_name = first_lora.module_name
  120. obj = cls(
  121. module_name,
  122. rank,
  123. [lora.lora_alpha if lora is not None else None for lora in loras],
  124. [lora.lora_a if lora is not None else None for lora in loras],
  125. [lora.lora_b if lora is not None else None for lora in loras],
  126. scaling=[
  127. 1 if lora is not None else None # type: ignore
  128. for lora in loras
  129. ])
  130. return obj
  131. def optimize(self) -> "PackedLoRALayerWeights":
  132. """Optimize the LoRA by merging the scaling into lora_b."""
  133. for i in range(len(self.lora_b)):
  134. if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
  135. continue
  136. self.lora_b[i] *= self.scaling[i] # type: ignore
  137. self.scaling[i] = 1 # type: ignore
  138. return self
  139. @property
  140. def input_dim(self) -> int:
  141. raise NotImplementedError()
  142. @property
  143. def output_dim(self) -> int:
  144. raise NotImplementedError()
  145. @property
  146. def is_packed(self) -> bool:
  147. return True