sgmv_shrink.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. import torch
  8. import triton
  9. import triton.language as tl
  10. from aphrodite.triton_utils import libentry
  11. @libentry()
  12. @triton.jit
  13. def _sgmv_shrink_kernel(
  14. input_ptr,
  15. lora_ptr,
  16. out_ptr,
  17. N,
  18. K,
  19. b_seq_start_loc,
  20. seq_lens,
  21. lora_indices,
  22. scaling,
  23. xm_stride, # hidden_size
  24. xk_stride, # 1
  25. l0_stride, # hidden_size*max_rank
  26. lora_k_stride,
  27. lora_n_stride,
  28. cm_stride,
  29. cn_stride,
  30. BLOCK_M: tl.constexpr,
  31. BLOCK_N: tl.constexpr,
  32. BLOCK_K: tl.constexpr,
  33. EVEN_K: tl.constexpr,
  34. SPLIT_K: tl.constexpr,
  35. ):
  36. """
  37. The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
  38. The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
  39. introducing SPLIT-K can improve performance
  40. """
  41. pid = tl.program_id(axis=0)
  42. pid_sk = tl.program_id(axis=1)
  43. cur_batch = tl.program_id(axis=2)
  44. cta_n_num = tl.cdiv(N, BLOCK_N)
  45. pid_m = pid // cta_n_num
  46. pid_n = pid % cta_n_num
  47. M = tl.load(seq_lens + cur_batch)
  48. if pid_m * BLOCK_M > M:
  49. return
  50. lora_index = tl.load(lora_indices + cur_batch)
  51. if lora_index == -1:
  52. return
  53. cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
  54. offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
  55. offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
  56. offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
  57. ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
  58. rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
  59. a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
  60. offset_k[None, :] * xk_stride)
  61. b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride +
  62. offset_k[:, None] * lora_n_stride)
  63. accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
  64. for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
  65. if EVEN_K:
  66. tiled_a = tl.load(a_ptr)
  67. tiled_b = tl.load(b_ptr)
  68. else:
  69. k_remaining = K - k * (BLOCK_K * SPLIT_K)
  70. tiled_a = tl.load(a_ptr,
  71. mask=offset_k[None, :] < k_remaining,
  72. other=0.0)
  73. tiled_b = tl.load(b_ptr,
  74. mask=offset_k[:, None] < k_remaining,
  75. other=0.0)
  76. accumulator += tl.dot(tiled_a, tiled_b)
  77. a_ptr += BLOCK_K * SPLIT_K * xk_stride
  78. b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
  79. offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
  80. offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
  81. c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
  82. offset_cn[None, :] * cn_stride)
  83. c_mask = (offset_cm[:, None] <
  84. (cur_seq_start + M)) & (offset_cn[None, :] < N)
  85. accumulator *= scaling
  86. # handles write-back with reduction-splitting
  87. if SPLIT_K == 1:
  88. tl.store(c_ptr, accumulator, mask=c_mask)
  89. else:
  90. tl.atomic_add(c_ptr, accumulator, mask=c_mask)
  91. @torch.inference_mode()
  92. def _sgmv_shrink(
  93. inputs: torch.Tensor,
  94. lora_a_weights: torch.Tensor,
  95. output_tensor: torch.Tensor,
  96. b_seq_start_loc: torch.Tensor,
  97. seq_len_tensor: torch.Tensor,
  98. lora_indices_tensor: torch.Tensor,
  99. batches: int,
  100. max_seq_length: int,
  101. scaling: float,
  102. ) -> None:
  103. """
  104. Args:
  105. inputs (torch.Tensor): input tensor
  106. lora_a_weights (torch.Tensor): lora'a weight
  107. output_tensor (torch.Tensor): output tensor
  108. b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
  109. sequence lengths of the sequences in the batch, used to index
  110. into sequence. E.g.,if the sequence length is [4, 6], it is
  111. [0, 4].
  112. seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
  113. length of the sequences in the batch
  114. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
  115. corresponding to each batch. An index of -1 means no lora should be
  116. applied.
  117. batches (int): batch size
  118. max_seq_length (int): The max sequence lengths of the sequences
  119. in the batch
  120. scaling (float): Scaling factor.
  121. """
  122. assert inputs.dtype == lora_a_weights.dtype
  123. assert inputs.dtype in [torch.float16, torch.bfloat16]
  124. assert lora_a_weights.dtype in [
  125. torch.float16,
  126. torch.bfloat16,
  127. ]
  128. assert inputs.size(1) == lora_a_weights.size(-1)
  129. assert b_seq_start_loc.size(0) == batches
  130. assert lora_indices_tensor.size(0) == batches
  131. assert inputs.is_contiguous()
  132. if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
  133. assert lora_a_weights.size(1) == 1
  134. lora_a_weights = lora_a_weights.squeeze(dim=1)
  135. else:
  136. assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
  137. assert lora_a_weights.is_contiguous()
  138. assert output_tensor.is_contiguous()
  139. # TODO tuning this config
  140. N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
  141. BLOCK_M = 32
  142. BLOCK_N = 16
  143. BLOCK_K = 32
  144. SPLIT_K = 8
  145. EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
  146. grid = (
  147. triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
  148. SPLIT_K,
  149. batches,
  150. )
  151. _sgmv_shrink_kernel[grid](
  152. inputs,
  153. lora_a_weights,
  154. output_tensor,
  155. N,
  156. K,
  157. b_seq_start_loc,
  158. seq_len_tensor,
  159. lora_indices_tensor,
  160. scaling,
  161. inputs.stride(0),
  162. inputs.stride(1),
  163. lora_a_weights.stride(0),
  164. lora_a_weights.stride(1),
  165. lora_a_weights.stride(2),
  166. output_tensor.stride(0),
  167. output_tensor.stride(1),
  168. BLOCK_M,
  169. BLOCK_N,
  170. BLOCK_K,
  171. EVEN_K,
  172. SPLIT_K,
  173. )
  174. return
  175. sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
  176. _sgmv_shrink,
  177. mutates_args=["output_tensor"])