punica.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Based on code from https://github.com/punica-ai/punica
  2. from typing import Optional
  3. import torch
  4. def _raise_import_error(e):
  5. if torch.cuda.get_device_capability() < (8, 0):
  6. raise ImportError(
  7. "punica LoRA kernels require compute capability >= 8.0") from e
  8. else:
  9. raise ImportError(
  10. "punica LoRA kernels could not be imported. If you built Aphrodite "
  11. "from source, make sure APHRODITE_INSTALL_PUNICA_KERNELS=1 env var "
  12. "was set.") from e
  13. def bgmv(
  14. y: torch.Tensor,
  15. x: torch.Tensor,
  16. w_t_all: torch.Tensor,
  17. indicies: torch.LongTensor,
  18. layer_idx: int,
  19. scale: float,
  20. ):
  21. """
  22. Semantics:
  23. y[i] += (
  24. x[i].unsqueeze(0)
  25. @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  26. * scale
  27. ).squeeze(0)
  28. Args:
  29. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  30. x: Shape: `[B, H1]`. Input vectors.
  31. w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
  32. matrices.
  33. indicies: Shape: `[B]`. Indices of the weight matrices.
  34. layer_idx: Layer index of the weight matrices.
  35. scale: Scaling factor.
  36. """
  37. try:
  38. import aphrodite._punica_C as punica_kernels
  39. except ImportError as e:
  40. _raise_import_error(e)
  41. punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
  42. def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
  43. w_t_all: torch.Tensor, indicies: torch.LongTensor,
  44. layer_idx: int, scale: float, y_offset: int,
  45. y_slice_size: int):
  46. """
  47. Same as `bgmv` but you can operate on slices of y.
  48. Pass whole y, define y_offset and y_slice_size.
  49. Semantics:
  50. y[i] += (
  51. x[i].unsqueeze(0)
  52. @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  53. * scale
  54. ).squeeze(0)
  55. Args:
  56. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  57. x: Shape: `[B, H1]`. Input vectors.
  58. w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
  59. all of the transposed LoRA matrices.
  60. indicies: Shape: `[B]`. Indices of the LoRA weights.
  61. layer_idx: Layer index of LoRA weights.
  62. scale: Scaling factor.
  63. y_offset: Offset to apply to the starting column of y.
  64. y_slice_size: Size of the y column slice.
  65. """
  66. try:
  67. import aphrodite._punica_C as punica_kernels
  68. except ImportError as e:
  69. _raise_import_error(e)
  70. punica_kernels.dispatch_bgmv_low_level(
  71. y,
  72. x,
  73. w_t_all,
  74. indicies,
  75. layer_idx,
  76. scale,
  77. x.size(1),
  78. y_slice_size,
  79. y_offset,
  80. )
  81. def add_lora(y: torch.Tensor,
  82. x: torch.Tensor,
  83. wa_t_all: torch.Tensor,
  84. wb_t_all: torch.Tensor,
  85. indicies: torch.LongTensor,
  86. layer_idx: int,
  87. scale: float,
  88. *,
  89. buffer: Optional[torch.Tensor] = None):
  90. """
  91. Semantics:
  92. y[i] += (
  93. x[i].unsqueeze(0)
  94. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  95. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  96. * scale
  97. ).squeeze(0)
  98. Args:
  99. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  100. x: Shape: `[B, H1]`. Input vectors.
  101. wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
  102. LoRA A matrices.
  103. wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
  104. LoRA B matrices.
  105. indicies: Shape: `[B]`. Indices of the LoRA weights.
  106. layer_idx: Layer index of LoRA weights.
  107. scale: Scaling factor.
  108. buffer: Optional. Shape: `[B, R]`. Temporary buffer.
  109. """
  110. try:
  111. import aphrodite._punica_C as punica_kernels
  112. except ImportError as e:
  113. _raise_import_error(e)
  114. r = wb_t_all.size(-1)
  115. if buffer is None:
  116. # We set the buffer to be float32 by default to avoid
  117. # numerical inaccuracies that would otherwise happen
  118. # due to downcasting.
  119. buffer = torch.zeros((x.size(0), r),
  120. dtype=torch.float32,
  121. device=x.device)
  122. punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
  123. punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
  124. scale)
  125. def add_lora_slice(y: torch.Tensor,
  126. x: torch.Tensor,
  127. wa_t_all: torch.Tensor,
  128. wb_t_all: torch.Tensor,
  129. indicies: torch.LongTensor,
  130. layer_idx: int,
  131. scale: float,
  132. y_offset: int,
  133. y_slice_size: int,
  134. *,
  135. buffer: Optional[torch.Tensor] = None):
  136. """
  137. Same as `add_lora` but you can operate on slices of y.
  138. Pass whole y, define y_offset and y_slice_size.
  139. Semantics:
  140. y[i] += (
  141. x[i].unsqueeze(0)
  142. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  143. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  144. * scale
  145. ).squeeze(0)
  146. Args:
  147. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  148. x: Shape: `[B, H1]`. Input vectors.
  149. wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
  150. LoRA A matrices.
  151. wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
  152. LoRA B matrices.
  153. indicies: Shape: `[B]`. Indices of the LoRA weights.
  154. layer_idx: Layer index of LoRA weights.
  155. scale: Scaling factor.
  156. y_offset: Offset to apply to the starting column of y.
  157. y_slice_size: Size of the y column slice.
  158. """
  159. try:
  160. import aphrodite._punica_C as punica_kernels
  161. except ImportError as e:
  162. _raise_import_error(e)
  163. r = wb_t_all.size(-1)
  164. if buffer is None:
  165. # We set the buffer to be float32 by default to avoid
  166. # numerical inaccuracies that would otherwise happen
  167. # due to downcasting.
  168. buffer = torch.zeros((x.size(0), r),
  169. dtype=torch.float32,
  170. device=x.device)
  171. punica_kernels.dispatch_bgmv_low_level(
  172. buffer,
  173. x,
  174. wa_t_all,
  175. indicies,
  176. layer_idx,
  177. 1.0,
  178. x.size(1),
  179. buffer.size(1),
  180. 0,
  181. )
  182. punica_kernels.dispatch_bgmv_low_level(
  183. y,
  184. buffer,
  185. wb_t_all,
  186. indicies,
  187. layer_idx,
  188. scale,
  189. buffer.size(1),
  190. y_slice_size,
  191. y_offset,
  192. )