sgmv_expand.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. import torch
  8. import triton
  9. import triton.language as tl
  10. from aphrodite.triton_utils import libentry
  11. @libentry()
  12. @triton.jit
  13. def _sgmv_expand_kernel(
  14. input_ptr,
  15. lora_ptr,
  16. out_ptr,
  17. N,
  18. K,
  19. b_seq_start_loc,
  20. seq_lens,
  21. lora_indices,
  22. xm_stride,
  23. xk_stride, # 1
  24. l0_stride, # hidden_size*max_rank
  25. lora_k_stride,
  26. lora_n_stride,
  27. cm_stride,
  28. cn_stride,
  29. BLOCK_M: tl.constexpr,
  30. BLOCK_N: tl.constexpr,
  31. BLOCK_K: tl.constexpr,
  32. EVEN_K: tl.constexpr,
  33. ADD_INPUTS: tl.constexpr,
  34. CAST_TYPE: tl.constexpr,
  35. ):
  36. """
  37. The sgmv's expand triton kernel is based on GroupGEMM.
  38. """
  39. pid = tl.program_id(axis=0)
  40. cur_batch = tl.program_id(axis=1)
  41. cta_n_num = tl.cdiv(N, BLOCK_N)
  42. pid_m = pid // cta_n_num
  43. pid_n = pid % cta_n_num
  44. M = tl.load(seq_lens + cur_batch)
  45. if pid_m * BLOCK_M > M:
  46. return
  47. lora_index = tl.load(lora_indices + cur_batch)
  48. if lora_index == -1:
  49. return
  50. cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
  51. offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
  52. offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
  53. offset_k = tl.arange(0, BLOCK_K)
  54. ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
  55. rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
  56. a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
  57. offset_k[None, :] * xk_stride, )
  58. b_ptr = (lora_ptr + l0_stride * lora_index +
  59. offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
  60. accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
  61. for k in range(tl.cdiv(K, BLOCK_K)):
  62. if EVEN_K:
  63. tiled_a = tl.load(a_ptr)
  64. tiled_b = tl.load(b_ptr)
  65. else:
  66. tiled_a = tl.load(a_ptr,
  67. mask=offset_k[None, :] < K - k * BLOCK_K,
  68. other=0)
  69. tiled_b = tl.load(b_ptr,
  70. mask=offset_k[:, None] < K - k * BLOCK_K,
  71. other=0)
  72. if CAST_TYPE:
  73. tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
  74. accumulator += tl.dot(
  75. tiled_a,
  76. tiled_b,
  77. )
  78. a_ptr += BLOCK_K * xk_stride
  79. b_ptr += BLOCK_K * lora_n_stride
  80. tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
  81. offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
  82. offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
  83. c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
  84. offset_cn[None, :] * cn_stride)
  85. M = tl.load(seq_lens + cur_batch)
  86. c_mask = (offset_cm[:, None] <
  87. (cur_seq_start + M)) & (offset_cn[None, :] < N)
  88. if ADD_INPUTS:
  89. tiled_out = tl.load(c_ptr, mask=c_mask)
  90. tiled_c += tiled_out
  91. tl.store(c_ptr, tiled_c, mask=c_mask)
  92. @torch.inference_mode()
  93. def _sgmv_expand(
  94. inputs: torch.Tensor,
  95. lora_b_weights: torch.Tensor,
  96. output_tensor: torch.Tensor,
  97. b_seq_start_loc: torch.Tensor,
  98. seq_len_tensor: torch.Tensor,
  99. lora_indices_tensor: torch.Tensor,
  100. batches: int,
  101. max_seq_length: int,
  102. token_nums: int,
  103. add_inputs: bool = False,
  104. ) -> None:
  105. """
  106. Args:
  107. inputs (torch.Tensor): input tensor
  108. lora_b_weights (torch.Tensor): lora'a weight
  109. output_tensor (torch.Tensor): output tensor
  110. b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
  111. sequence lengths of the sequences in the batch, used to index
  112. into sequence. E.g.,if the sequence length is [4, 6], it is
  113. [0, 4, 10].
  114. seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
  115. length of the sequences in the batch
  116. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
  117. corresponding to each batch. An index of -1 means no lora should be
  118. applied.
  119. batches (int): batch size
  120. max_seq_length (int): The max sequence lengths of the sequences in the
  121. batch.
  122. token_nums (int): The token numbers in the batch. Used to verify if the
  123. token numbers in the inputs matches the one in the metadata.
  124. add_inputs (bool, optional): Defaults to False, adds the final lora
  125. results to the output.
  126. """
  127. assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
  128. assert lora_b_weights.dtype in [
  129. torch.float16,
  130. torch.bfloat16,
  131. ]
  132. assert inputs.size(0) == token_nums
  133. assert inputs.size(1) == lora_b_weights.size(-1)
  134. assert b_seq_start_loc.size(0) == batches
  135. assert lora_indices_tensor.size(0) == batches
  136. assert inputs.is_contiguous()
  137. assert output_tensor.is_contiguous()
  138. if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
  139. assert lora_b_weights.size(1) == 1
  140. lora_b_weights = lora_b_weights.squeeze(dim=1)
  141. else:
  142. assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
  143. assert lora_b_weights.is_contiguous()
  144. # TODO tuning this config
  145. N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
  146. BLOCK_M = 32
  147. BLOCK_N = 32
  148. BLOCK_K = 16
  149. EVEN_K = K % BLOCK_K == 0
  150. ADD_INPUTS = add_inputs
  151. CAST_TYPE = False
  152. if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
  153. torch.float16,
  154. torch.bfloat16,
  155. ]:
  156. CAST_TYPE = True
  157. grid = (
  158. triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
  159. batches,
  160. )
  161. _sgmv_expand_kernel[grid](
  162. inputs,
  163. lora_b_weights,
  164. output_tensor,
  165. N,
  166. K,
  167. b_seq_start_loc,
  168. seq_len_tensor,
  169. lora_indices_tensor,
  170. inputs.stride(0),
  171. inputs.stride(1),
  172. lora_b_weights.stride(0),
  173. lora_b_weights.stride(1),
  174. lora_b_weights.stride(2),
  175. output_tensor.stride(0),
  176. output_tensor.stride(1),
  177. BLOCK_M,
  178. BLOCK_N,
  179. BLOCK_K,
  180. EVEN_K,
  181. ADD_INPUTS,
  182. CAST_TYPE,
  183. )
  184. return
  185. try:
  186. sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
  187. _sgmv_expand,
  188. mutates_args=["output_tensor"])
  189. except AttributeError:
  190. sgmv_expand = _sgmv_expand