punica.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Based on code from https://github.com/punica-ai/punica
  2. from typing import Optional
  3. import torch
  4. def _raise_import_error(e):
  5. if torch.cuda.get_device_capability() < (8, 0):
  6. raise ImportError(
  7. "punica LoRA kernels require compute capability >= 8.0") from e
  8. else:
  9. raise ImportError(
  10. "punica LoRA kernels could not be imported. If you built Aphrodite "
  11. "from source, make sure APHRODITE_INSTALL_PUNICA_KERNELS=1 env var "
  12. "was set.") from e
  13. def bgmv(
  14. y: torch.Tensor,
  15. x: torch.Tensor,
  16. w_t_all: torch.Tensor,
  17. indicies: torch.LongTensor,
  18. layer_idx: int,
  19. scale: float,
  20. ):
  21. """
  22. Semantics:
  23. y[i] += (
  24. x[i].unsqueeze(0)
  25. @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  26. * scale
  27. ).squeeze(0)
  28. Args:
  29. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  30. x: Shape: `[B, H1]`. Input vectors.
  31. w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
  32. matrices.
  33. indicies: Shape: `[B]`. Indices of the weight matrices.
  34. layer_idx: Layer index of the weight matrices.
  35. scale: Scaling factor.
  36. """
  37. try:
  38. import aphrodite._punica_C as punica_kernels
  39. except ImportError as e:
  40. _raise_import_error(e)
  41. punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
  42. def add_lora(y: torch.Tensor,
  43. x: torch.Tensor,
  44. wa_t_all: torch.Tensor,
  45. wb_t_all: torch.Tensor,
  46. indicies: torch.LongTensor,
  47. layer_idx: int,
  48. scale: float,
  49. *,
  50. buffer: Optional[torch.Tensor] = None):
  51. """
  52. Semantics:
  53. y[i] += (
  54. x[i].unsqueeze(0)
  55. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  56. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  57. * scale
  58. ).squeeze(0)
  59. Args:
  60. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  61. x: Shape: `[B, H1]`. Input vectors.
  62. wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
  63. LoRA A matrices.
  64. wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
  65. LoRA B matrices.
  66. indicies: Shape: `[B]`. Indices of the LoRA weights.
  67. layer_idx: Layer index of LoRA weights.
  68. scale: Scaling factor.
  69. buffer: Optional. Shape: `[B, R]`. Temporary buffer.
  70. """
  71. try:
  72. import aphrodite._punica_C as punica_kernels
  73. except ImportError as e:
  74. _raise_import_error(e)
  75. r = wb_t_all.size(-1)
  76. if buffer is None:
  77. # We set the buffer to be float32 by default to avoid
  78. # numerical inaccuracies that would otherwise happen
  79. # due to downcasting.
  80. buffer = torch.zeros((x.size(0), r),
  81. dtype=torch.float32,
  82. device=x.device)
  83. punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
  84. punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
  85. scale)
  86. def add_lora_slice(y: torch.Tensor,
  87. x: torch.Tensor,
  88. wa_t_all: torch.Tensor,
  89. wb_t_all: torch.Tensor,
  90. indicies: torch.LongTensor,
  91. layer_idx: int,
  92. scale: float,
  93. y_offset: int,
  94. y_slice_size: int,
  95. *,
  96. buffer: Optional[torch.Tensor] = None):
  97. """
  98. Same as `add_lora` but you can operate on slices of y.
  99. Pass whole y, define y_offset and y_slice_size.
  100. Semantics:
  101. y[i] += (
  102. x[i].unsqueeze(0)
  103. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  104. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  105. * scale
  106. ).squeeze(0)
  107. Args:
  108. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  109. x: Shape: `[B, H1]`. Input vectors.
  110. wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
  111. LoRA A matrices.
  112. wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
  113. LoRA B matrices.
  114. indicies: Shape: `[B]`. Indices of the LoRA weights.
  115. layer_idx: Layer index of LoRA weights.
  116. scale: Scaling factor.
  117. y_offset: Offset to apply to the starting column of y.
  118. y_slice_size: Size of the y column slice.
  119. """
  120. try:
  121. import aphrodite._punica_C as punica_kernels
  122. except ImportError as e:
  123. _raise_import_error(e)
  124. r = wb_t_all.size(-1)
  125. if buffer is None:
  126. # We set the buffer to be float32 by default to avoid
  127. # numerical inaccuracies that would otherwise happen
  128. # due to downcasting.
  129. buffer = torch.zeros((x.size(0), r),
  130. dtype=torch.float32,
  131. device=x.device)
  132. punica_kernels.dispatch_bgmv_low_level(
  133. buffer,
  134. x,
  135. wa_t_all,
  136. indicies,
  137. layer_idx,
  138. 1.0,
  139. x.size(1),
  140. buffer.size(1),
  141. 0,
  142. )
  143. punica_kernels.dispatch_bgmv_low_level(
  144. y,
  145. buffer,
  146. wb_t_all,
  147. indicies,
  148. layer_idx,
  149. scale,
  150. buffer.size(1),
  151. y_slice_size,
  152. y_offset,
  153. )