lora.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from typing import List, Optional
  2. import torch
  3. from aphrodite.common.utils import in_wsl
  4. class LoRALayerWeights:
  5. """LoRA weights for a layer composed of two low rank matrixes."""
  6. def __init__(
  7. self,
  8. module_name: str,
  9. rank: int,
  10. lora_alpha: int,
  11. lora_a: torch.Tensor,
  12. lora_b: torch.Tensor,
  13. embeddings_tensor: Optional[torch.Tensor] = None,
  14. scaling: Optional[float] = None,
  15. ) -> None:
  16. self.module_name = module_name
  17. self.rank = rank
  18. self.lora_alpha = lora_alpha
  19. self.lora_a = lora_a
  20. self.lora_b = lora_b
  21. self.embeddings_tensor = embeddings_tensor
  22. if scaling is None:
  23. self.scaling = self.lora_alpha / self.rank
  24. else:
  25. self.scaling = scaling
  26. def optimize(self) -> "LoRALayerWeights":
  27. """Optimize the LoRA by merging the scaling into lora_b."""
  28. if self.scaling == 1:
  29. return
  30. self.lora_b *= self.scaling
  31. self.scaling = 1
  32. return self
  33. @property
  34. def input_dim(self) -> int:
  35. return self.lora_a.shape[0]
  36. @property
  37. def output_dim(self) -> int:
  38. return self.lora_b.shape[1]
  39. @property
  40. def is_packed(self) -> bool:
  41. return False
  42. @property
  43. def extra_vocab_size(self) -> int:
  44. return self.embeddings_tensor.shape[
  45. 0] if self.embeddings_tensor is not None else 0
  46. @classmethod
  47. def create_dummy_lora_weights(
  48. cls,
  49. module_name: str,
  50. input_dim: int,
  51. output_dim: int,
  52. rank: int,
  53. dtype: torch.dtype,
  54. device: torch.device,
  55. embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
  56. pin_memory = str(device) == "cpu" and not in_wsl()
  57. lora_a = torch.zeros([input_dim, rank],
  58. dtype=dtype,
  59. device=device,
  60. pin_memory=pin_memory)
  61. lora_b = torch.zeros([rank, output_dim],
  62. dtype=dtype,
  63. device=device,
  64. pin_memory=pin_memory)
  65. embeddings_tensor = torch.rand(
  66. 10,
  67. embeddings_tensor_dim,
  68. dtype=dtype,
  69. device=device,
  70. pin_memory=pin_memory) if embeddings_tensor_dim else None
  71. return cls(
  72. module_name,
  73. rank=rank,
  74. lora_alpha=1,
  75. lora_a=lora_a,
  76. lora_b=lora_b,
  77. embeddings_tensor=embeddings_tensor,
  78. )
  79. class PackedLoRALayerWeights(LoRALayerWeights):
  80. """LoRA used for packed layers (eg. qkv_proj)."""
  81. def __init__(
  82. self,
  83. module_name: str,
  84. rank: int,
  85. lora_alphas: List[int],
  86. lora_a: List[torch.Tensor],
  87. lora_b: List[torch.Tensor],
  88. scaling: Optional[List[float]] = None,
  89. ) -> None:
  90. super().__init__(
  91. module_name=module_name,
  92. rank=rank,
  93. lora_alpha=0,
  94. lora_a=lora_a,
  95. lora_b=lora_b,
  96. scaling=scaling,
  97. embeddings_tensor=None,
  98. )
  99. self.lora_alphas = lora_alphas
  100. if scaling is None:
  101. self.scaling = [
  102. lora_alpha / self.rank for lora_alpha in self.lora_alphas
  103. ]
  104. @classmethod
  105. def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
  106. """Pack a list of LoRAs into a single LoRA.
  107. If LoRA is None, it signifies that the submodule does not have a LoRA.
  108. """
  109. first_lora = next(lora for lora in loras if lora is not None)
  110. for lora in loras:
  111. if lora is None:
  112. continue
  113. lora.optimize()
  114. rank = first_lora.rank
  115. module_name = first_lora.module_name
  116. obj = cls(
  117. module_name,
  118. rank,
  119. [lora.lora_alpha if lora is not None else None for lora in loras],
  120. [lora.lora_a if lora is not None else None for lora in loras],
  121. [lora.lora_b if lora is not None else None for lora in loras],
  122. scaling=[1 if lora is not None else None for lora in loras])
  123. return obj
  124. def optimize(self) -> "PackedLoRALayerWeights":
  125. """Optimize the LoRA by merging the scaling into lora_b."""
  126. for i in range(len(self.lora_b)):
  127. if self.scaling[i] == 1 or self.lora_b[i] is None:
  128. continue
  129. self.lora_b[i] *= self.scaling[i]
  130. self.scaling[i] = 1
  131. return self
  132. @property
  133. def input_dim(self) -> int:
  134. raise NotImplementedError()
  135. @property
  136. def output_dim(self) -> int:
  137. raise NotImplementedError()
  138. @property
  139. def is_packed(self) -> bool:
  140. return True