1
0

punica.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # Based on code from https://github.com/punica-ai/punica
  2. from typing import Optional
  3. import torch
  4. from aphrodite import _custom_ops as ops
  5. from aphrodite.platforms import current_platform
  6. def _check_punica_support():
  7. if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
  8. return
  9. if current_platform.get_device_capability() < (8, 0):
  10. raise ImportError(
  11. "punica LoRA kernels require compute capability >= 8.0")
  12. else:
  13. raise ImportError(
  14. "punica LoRA kernels could not be imported. If you built vLLM "
  15. "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
  16. "was set.")
  17. def bgmv(
  18. y: torch.Tensor,
  19. x: torch.Tensor,
  20. w_t_all: torch.Tensor,
  21. indicies: torch.LongTensor,
  22. layer_idx: int,
  23. scale: float,
  24. ):
  25. """
  26. Semantics:
  27. y[i] += (
  28. x[i].unsqueeze(0)
  29. @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  30. * scale
  31. ).squeeze(0)
  32. Args:
  33. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  34. x: Shape: `[B, H1]`. Input vectors.
  35. w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
  36. matrices.
  37. indicies: Shape: `[B]`. Indices of the weight matrices.
  38. layer_idx: Layer index of the weight matrices.
  39. scale: Scaling factor.
  40. """
  41. _check_punica_support()
  42. ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
  43. def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
  44. w_t_all: torch.Tensor, indicies: torch.LongTensor,
  45. layer_idx: int, scale: float, y_offset: int,
  46. y_slice_size: int):
  47. """
  48. Same as `bgmv` but you can operate on slices of y.
  49. Pass whole y, define y_offset and y_slice_size.
  50. Semantics:
  51. y[i] += (
  52. x[i].unsqueeze(0)
  53. @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  54. * scale
  55. ).squeeze(0)
  56. Args:
  57. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  58. x: Shape: `[B, H1]`. Input vectors.
  59. w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
  60. all of the transposed LoRA matrices.
  61. indicies: Shape: `[B]`. Indices of the LoRA weights.
  62. layer_idx: Layer index of LoRA weights.
  63. scale: Scaling factor.
  64. y_offset: Offset to apply to the starting column of y.
  65. y_slice_size: Size of the y column slice.
  66. """
  67. _check_punica_support()
  68. ops.dispatch_bgmv_low_level(
  69. y,
  70. x,
  71. w_t_all,
  72. indicies,
  73. layer_idx,
  74. scale,
  75. x.size(1),
  76. y_slice_size,
  77. y_offset,
  78. )
  79. def add_lora(y: torch.Tensor,
  80. x: torch.Tensor,
  81. wa_t_all: torch.Tensor,
  82. wb_t_all: torch.Tensor,
  83. indicies: torch.LongTensor,
  84. layer_idx: int,
  85. scale: float,
  86. *,
  87. buffer: Optional[torch.Tensor] = None):
  88. """
  89. Semantics:
  90. y[i] += (
  91. x[i].unsqueeze(0)
  92. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  93. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  94. * scale
  95. ).squeeze(0)
  96. Args:
  97. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  98. x: Shape: `[B, H1]`. Input vectors.
  99. wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
  100. LoRA A matrices.
  101. wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
  102. LoRA B matrices.
  103. indicies: Shape: `[B]`. Indices of the LoRA weights.
  104. layer_idx: Layer index of LoRA weights.
  105. scale: Scaling factor.
  106. buffer: Optional. Shape: `[B, R]`. Temporary buffer.
  107. """
  108. _check_punica_support()
  109. r = wb_t_all.size(-1)
  110. if buffer is None:
  111. # We set the buffer to be float32 by default to avoid
  112. # numerical inaccuracies that would otherwise happen
  113. # due to downcasting.
  114. buffer = torch.zeros((x.size(0), r),
  115. dtype=torch.float32,
  116. device=x.device)
  117. ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
  118. ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale)
  119. def add_lora_slice(y: torch.Tensor,
  120. x: torch.Tensor,
  121. wa_t_all: torch.Tensor,
  122. wb_t_all: torch.Tensor,
  123. indicies: torch.LongTensor,
  124. layer_idx: int,
  125. scale: float,
  126. y_offset: int,
  127. y_slice_size: int,
  128. *,
  129. buffer: Optional[torch.Tensor] = None):
  130. """
  131. Same as `add_lora` but you can operate on slices of y.
  132. Pass whole y, define y_offset and y_slice_size.
  133. Semantics:
  134. y[i] += (
  135. x[i].unsqueeze(0)
  136. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  137. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  138. * scale
  139. ).squeeze(0)
  140. Args:
  141. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  142. x: Shape: `[B, H1]`. Input vectors.
  143. wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
  144. LoRA A matrices.
  145. wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
  146. LoRA B matrices.
  147. indicies: Shape: `[B]`. Indices of the LoRA weights.
  148. layer_idx: Layer index of LoRA weights.
  149. scale: Scaling factor.
  150. y_offset: Offset to apply to the starting column of y.
  151. y_slice_size: Size of the y column slice.
  152. """
  153. _check_punica_support()
  154. r = wb_t_all.size(-1)
  155. if buffer is None:
  156. # We set the buffer to be float32 by default to avoid
  157. # numerical inaccuracies that would otherwise happen
  158. # due to downcasting.
  159. buffer = torch.zeros((x.size(0), r),
  160. dtype=torch.float32,
  161. device=x.device)
  162. ops.dispatch_bgmv_low_level(
  163. buffer,
  164. x,
  165. wa_t_all,
  166. indicies,
  167. layer_idx,
  168. 1.0,
  169. x.size(1),
  170. buffer.size(1),
  171. 0,
  172. )
  173. ops.dispatch_bgmv_low_level(
  174. y,
  175. buffer,
  176. wb_t_all,
  177. indicies,
  178. layer_idx,
  179. scale,
  180. buffer.size(1),
  181. y_slice_size,
  182. y_offset,
  183. )