sgmv_expand_slice.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. import torch
  8. import triton
  9. import triton.language as tl
  10. from aphrodite.triton_utils import libentry
  11. @libentry()
  12. @triton.jit
  13. def _sgmv_expand_slice_kernel(
  14. input_ptr,
  15. lora_ptr,
  16. out_ptr,
  17. N,
  18. K,
  19. b_seq_start_loc,
  20. seq_lens,
  21. lora_indices,
  22. xm_stride,
  23. xk_stride, # 1
  24. l0_stride, # hidden_size*max_rank
  25. lora_k_stride,
  26. lora_n_stride,
  27. cm_stride,
  28. cn_stride,
  29. slice_offset,
  30. BLOCK_M: tl.constexpr,
  31. BLOCK_N: tl.constexpr,
  32. BLOCK_K: tl.constexpr,
  33. EVEN_K: tl.constexpr,
  34. ADD_INPUTS: tl.constexpr,
  35. CAST_TYPE: tl.constexpr,
  36. ):
  37. """
  38. Similar to the 'sgmv_expand' operator, but with an added parameter
  39. 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
  40. might be that in the future, we could implement a fusion operator to
  41. achieve the current functionality instead of having to call it multiple
  42. times.
  43. """
  44. pid = tl.program_id(axis=0)
  45. cur_batch = tl.program_id(axis=1)
  46. cta_n_num = tl.cdiv(N, BLOCK_N)
  47. pid_m = pid // cta_n_num
  48. pid_n = pid % cta_n_num
  49. M = tl.load(seq_lens + cur_batch)
  50. if pid_m * BLOCK_M > M:
  51. return
  52. lora_index = tl.load(lora_indices + cur_batch)
  53. if lora_index == -1:
  54. return
  55. cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
  56. offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
  57. offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
  58. offset_k = tl.arange(0, BLOCK_K)
  59. ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
  60. rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
  61. a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
  62. offset_k[None, :] * xk_stride, )
  63. b_ptr = (lora_ptr + l0_stride * lora_index +
  64. offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
  65. accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
  66. for k in range(tl.cdiv(K, BLOCK_K)):
  67. if EVEN_K:
  68. tiled_a = tl.load(a_ptr)
  69. tiled_b = tl.load(b_ptr)
  70. else:
  71. tiled_a = tl.load(a_ptr,
  72. mask=offset_k[None, :] < K - k * BLOCK_K,
  73. other=0)
  74. tiled_b = tl.load(b_ptr,
  75. mask=offset_k[:, None] < K - k * BLOCK_K,
  76. other=0)
  77. if CAST_TYPE:
  78. tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
  79. accumulator += tl.dot(
  80. tiled_a,
  81. tiled_b,
  82. )
  83. a_ptr += BLOCK_K * xk_stride
  84. b_ptr += BLOCK_K * lora_n_stride
  85. tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
  86. offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
  87. offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
  88. c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
  89. offset_cn[None, :] * cn_stride)
  90. M = tl.load(seq_lens + cur_batch)
  91. c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
  92. (slice_offset + N))
  93. if ADD_INPUTS:
  94. tiled_out = tl.load(c_ptr, mask=c_mask)
  95. tiled_c += tiled_out
  96. tl.store(c_ptr, tiled_c, mask=c_mask)
  97. @torch.inference_mode()
  98. def _sgmv_expand_slice(
  99. inputs: torch.Tensor,
  100. lora_b_weights: torch.Tensor,
  101. output_tensor: torch.Tensor,
  102. b_seq_start_loc: torch.Tensor,
  103. seq_len_tensor: torch.Tensor,
  104. lora_indices_tensor: torch.Tensor,
  105. batches: int,
  106. max_seq_length: int,
  107. token_nums: int,
  108. slice_offset: int,
  109. slice_size: int,
  110. add_inputs: bool = False,
  111. ) -> None:
  112. """_summary_
  113. Args:
  114. inputs (torch.Tensor): input tensor
  115. lora_b_weights (torch.Tensor): lora'a weight
  116. output_tensor (torch.Tensor): output tensor
  117. b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
  118. sequence lengths of the sequences in the batch, used to index
  119. into sequence. E.g., if the sequence length is [4, 6], it is
  120. [0, 4, 10].
  121. seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
  122. length of the sequences in the batch
  123. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
  124. corresponding to each batch. An index of -1 means no lora should be
  125. applied.
  126. batches (int): batch size
  127. max_seq_length (int): The max sequence lengths of the sequences
  128. in the batch
  129. token_nums (int): The token numbers in the batch. Used to verify if the
  130. token numbers in the inputs matches the one in the metadata.
  131. slice_offset (int): output_tensor's offset
  132. slice_size (int): current output_tensor's size
  133. add_inputs (bool, optional): Defaults to False, adds the final lora
  134. results to the output.
  135. """
  136. assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
  137. assert lora_b_weights.dtype in [
  138. torch.float16,
  139. torch.bfloat16,
  140. ]
  141. assert inputs.size(0) == token_nums
  142. assert inputs.size(1) == lora_b_weights.size(-1)
  143. assert b_seq_start_loc.size(0) == batches
  144. assert lora_indices_tensor.size(0) == batches
  145. assert slice_size == lora_b_weights.size(-2)
  146. assert inputs.is_contiguous()
  147. assert output_tensor.is_contiguous()
  148. if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
  149. assert lora_b_weights.size(1) == 1
  150. lora_b_weights = lora_b_weights.squeeze(dim=1)
  151. else:
  152. assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
  153. assert lora_b_weights.is_contiguous()
  154. # TODO tuning this config
  155. N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
  156. BLOCK_M = 32
  157. BLOCK_N = 32
  158. BLOCK_K = 16
  159. EVEN_K = K % BLOCK_K == 0
  160. ADD_INPUTS = add_inputs
  161. CAST_TYPE = False
  162. if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
  163. torch.float16,
  164. torch.bfloat16,
  165. ]:
  166. CAST_TYPE = True
  167. grid = (
  168. triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
  169. batches,
  170. )
  171. _sgmv_expand_slice_kernel[grid](
  172. inputs,
  173. lora_b_weights,
  174. output_tensor,
  175. N,
  176. K,
  177. b_seq_start_loc,
  178. seq_len_tensor,
  179. lora_indices_tensor,
  180. inputs.stride(0),
  181. inputs.stride(1),
  182. lora_b_weights.stride(0),
  183. lora_b_weights.stride(1),
  184. lora_b_weights.stride(2),
  185. output_tensor.stride(0),
  186. output_tensor.stride(1),
  187. slice_offset,
  188. BLOCK_M,
  189. BLOCK_N,
  190. BLOCK_K,
  191. EVEN_K,
  192. ADD_INPUTS,
  193. CAST_TYPE,
  194. )
  195. return
  196. try:
  197. sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
  198. _sgmv_expand_slice,
  199. mutates_args=["output_tensor"])
  200. except AttributeError:
  201. sgmv_expand_slice = _sgmv_expand_slice