1
0

punica.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Based on code from https://github.com/punica-ai/punica
  2. from typing import Optional
  3. import torch
  4. import_exc = None
  5. try:
  6. import aphrodite._punica_C as punica_kernels
  7. except ImportError as e:
  8. import_exc = e
  9. if import_exc is None:
  10. def bgmv(
  11. y: torch.Tensor,
  12. x: torch.Tensor,
  13. w_t_all: torch.Tensor,
  14. indicies: torch.LongTensor,
  15. layer_idx: int,
  16. scale: float,
  17. ):
  18. """
  19. Semantics:
  20. y[i] += (
  21. x[i].unsqueeze(0)
  22. @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  23. * scale
  24. ).squeeze(0)
  25. Args:
  26. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  27. x: Shape: `[B, H1]`. Input vectors.
  28. w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
  29. matrices.
  30. indicies: Shape: `[B]`. Indices of the weight matrices.
  31. layer_idx: Layer index of the weight matrices.
  32. scale: Scaling factor.
  33. """
  34. punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
  35. def add_lora(y: torch.Tensor,
  36. x: torch.Tensor,
  37. wa_t_all: torch.Tensor,
  38. wb_t_all: torch.Tensor,
  39. indicies: torch.LongTensor,
  40. layer_idx: int,
  41. scale: float,
  42. *,
  43. buffer: Optional[torch.Tensor] = None):
  44. """
  45. Semantics:
  46. y[i] += (
  47. x[i].unsqueeze(0)
  48. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  49. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  50. * scale
  51. ).squeeze(0)
  52. Args:
  53. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  54. x: Shape: `[B, H1]`. Input vectors.
  55. wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
  56. LoRA A matrices.
  57. wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
  58. LoRA B matrices.
  59. indicies: Shape: `[B]`. Indices of the LoRA weights.
  60. layer_idx: Layer index of LoRA weights.
  61. scale: Scaling factor.
  62. buffer: Optional. Shape: `[B, R]`. Temporary buffer.
  63. """
  64. r = wb_t_all.size(-1)
  65. if buffer is None:
  66. # We set the buffer to be float32 by default to avoid
  67. # numerical innacuracies that would otherwise happen
  68. # due to downcasting.
  69. buffer = torch.zeros((x.size(0), r),
  70. dtype=torch.float32,
  71. device=x.device)
  72. punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
  73. 1.0)
  74. punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
  75. scale)
  76. def add_lora_slice(y: torch.Tensor,
  77. x: torch.Tensor,
  78. wa_t_all: torch.Tensor,
  79. wb_t_all: torch.Tensor,
  80. indicies: torch.LongTensor,
  81. layer_idx: int,
  82. scale: float,
  83. y_offset: int,
  84. y_slice_size: int,
  85. *,
  86. buffer: Optional[torch.Tensor] = None):
  87. """
  88. Same as `add_lora` but you can operate on slices of y.
  89. Pass whole y, define y_offset and y_slice_size.
  90. Semantics:
  91. y[i] += (
  92. x[i].unsqueeze(0)
  93. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  94. @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
  95. * scale
  96. ).squeeze(0)
  97. Args:
  98. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
  99. x: Shape: `[B, H1]`. Input vectors.
  100. wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
  101. LoRA A matrices.
  102. wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
  103. LoRA B matrices.
  104. indicies: Shape: `[B]`. Indices of the LoRA weights.
  105. layer_idx: Layer index of LoRA weights.
  106. scale: Scaling factor.
  107. y_offset: Offset to apply to the starting column of y.
  108. y_slice_size: Size of the y column slice.
  109. """
  110. r = wb_t_all.size(-1)
  111. if buffer is None:
  112. # We set the buffer to be float32 by default to avoid
  113. # numerical inaccuracies that would otherwise happen
  114. # due to downcasting.
  115. buffer = torch.zeros((x.size(0), r),
  116. dtype=torch.float32,
  117. device=x.device)
  118. punica_kernels.dispatch_bgmv_low_level(
  119. buffer,
  120. x,
  121. wa_t_all,
  122. indicies,
  123. layer_idx,
  124. 1.0,
  125. x.size(1),
  126. buffer.size(1),
  127. 0,
  128. )
  129. punica_kernels.dispatch_bgmv_low_level(
  130. y,
  131. buffer,
  132. wb_t_all,
  133. indicies,
  134. layer_idx,
  135. scale,
  136. buffer.size(1),
  137. y_slice_size,
  138. y_offset,
  139. )
  140. else:
  141. def _raise_exc(
  142. *args, # pylint: disable=unused-argument
  143. **kwargs # pylint: disable=unused-argument
  144. ):
  145. if torch.cuda.get_device_capability() < (8, 0):
  146. raise ImportError("punica LoRA kernels require compute "
  147. "capability>=8.0") from import_exc
  148. else:
  149. raise ImportError(
  150. "punica LoRA kernels could not be imported. If you built Aphrodite "
  151. "from source, make sure APHRODITE_INSTALL_PUNICA_KERNELS=1 env var "
  152. "was set.") from import_exc
  153. bgmv = _raise_exc
  154. add_lora = _raise_exc
  155. add_lora_slice = _raise_exc
  156. __all__ = [
  157. "bgmv",
  158. "add_lora",
  159. "add_lora_slice",
  160. ]