bgmv_expand_slice.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. import torch
  8. import triton
  9. import triton.language as tl
  10. from .utils import get_lora_op_configs
  11. @triton.jit
  12. def _bgmv_expand_slice_kernel(
  13. input_ptr,
  14. lora_ptr,
  15. out_ptr,
  16. N,
  17. K,
  18. lora_indices,
  19. xm_stride,
  20. xk_stride,
  21. l0_stride,
  22. lora_k_stride,
  23. lora_n_stride,
  24. cm_stride,
  25. cn_stride,
  26. slice_offset,
  27. BLOCK_N: tl.constexpr,
  28. BLOCK_K: tl.constexpr,
  29. SPLIT_N: tl.constexpr,
  30. EVEN_K: tl.constexpr,
  31. ADD_INPUTS: tl.constexpr,
  32. CAST_TYPE: tl.constexpr,
  33. ):
  34. """
  35. GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
  36. performance
  37. """
  38. pid_sn = tl.program_id(axis=0)
  39. cur_batch = tl.program_id(axis=1)
  40. lora_index = tl.load(lora_indices + cur_batch)
  41. if lora_index == -1:
  42. return
  43. offset_k = tl.arange(0, BLOCK_K)
  44. offset_n = tl.arange(0, BLOCK_N)
  45. if EVEN_K:
  46. tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
  47. offset_k * xk_stride, ) # [BLOCK_K]
  48. else:
  49. tiled_a = tl.load(
  50. input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
  51. mask=offset_k < K,
  52. other=0,
  53. ) # [BLOCK_K]
  54. # N must be divisible by SPLIT_N
  55. split_n_length = tl.cdiv(N, SPLIT_N)
  56. if CAST_TYPE:
  57. tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
  58. # sliding to next row-block
  59. b_ptr = (lora_ptr + l0_stride * lora_index +
  60. pid_sn * split_n_length * lora_k_stride)
  61. c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
  62. slice_offset * cn_stride)
  63. for n in range(0, split_n_length, BLOCK_N):
  64. current_n = n + offset_n
  65. b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
  66. < K)
  67. c_mask = current_n < split_n_length
  68. tiled_b = tl.load(
  69. b_ptr + current_n[:, None] * lora_k_stride +
  70. offset_k[None, :] * lora_n_stride,
  71. mask=b_ptr_mask,
  72. other=0.0,
  73. ) # [BLOCK_N,BLOCK_K]
  74. if ADD_INPUTS:
  75. tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
  76. accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
  77. else:
  78. accumulator = tl.sum(tiled_a * tiled_b, 1)
  79. tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
  80. @torch.inference_mode()
  81. def _bgmv_expand_slice(
  82. inputs: torch.Tensor,
  83. lora_b_weights: torch.Tensor,
  84. output_tensor: torch.Tensor,
  85. lora_indices_tensor: torch.Tensor,
  86. slice_offset: int,
  87. slice_size: int,
  88. add_inputs: bool = True,
  89. ) -> None:
  90. """
  91. Args:
  92. inputs (torch.Tensor): input tensor
  93. lora_b_weights (torch.Tensor): lora'b weight
  94. output_tensor (torch.Tensor): output tensor
  95. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
  96. corresponding to each batch, An index of -1 means no lora should be
  97. applied.
  98. slice_offst (int): output_tensor's offst
  99. slice_size (int): current output_tensor's size
  100. batches (int): batch size
  101. add_inputs (bool, optional): Defaults to False.
  102. """
  103. assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
  104. assert lora_b_weights.dtype in [
  105. torch.float16,
  106. torch.bfloat16,
  107. ]
  108. assert inputs.size(1) == lora_b_weights.size(-1)
  109. assert slice_size == lora_b_weights.size(-2)
  110. assert inputs.is_contiguous()
  111. assert output_tensor.is_contiguous()
  112. if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
  113. assert lora_b_weights.size(1) == 1
  114. lora_b_weights = lora_b_weights.squeeze(dim=1)
  115. else:
  116. assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
  117. assert lora_b_weights.is_contiguous()
  118. # TODO tuning this config
  119. N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
  120. BLOCK_K = triton.next_power_of_2(K)
  121. EVEN_K = K % BLOCK_K == 0
  122. ADD_INPUTS = add_inputs
  123. CAST_TYPE = False
  124. if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
  125. torch.float16,
  126. torch.bfloat16,
  127. ]:
  128. CAST_TYPE = True
  129. batches = lora_indices_tensor.size(0)
  130. config = get_lora_op_configs("expand", batches, N)
  131. grid = lambda META: (
  132. META["SPLIT_N"],
  133. batches,
  134. )
  135. _bgmv_expand_slice_kernel[grid](
  136. inputs,
  137. lora_b_weights,
  138. output_tensor,
  139. N,
  140. K,
  141. lora_indices_tensor,
  142. inputs.stride(0),
  143. inputs.stride(1),
  144. lora_b_weights.stride(0),
  145. lora_b_weights.stride(1),
  146. lora_b_weights.stride(2),
  147. output_tensor.stride(0),
  148. output_tensor.stride(1),
  149. slice_offset,
  150. BLOCK_K=BLOCK_K,
  151. EVEN_K=EVEN_K,
  152. ADD_INPUTS=ADD_INPUTS,
  153. CAST_TYPE=CAST_TYPE,
  154. **config,
  155. )
  156. return
  157. try:
  158. bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
  159. _bgmv_expand_slice,
  160. mutates_args=["output_tensor"])
  161. except AttributeError:
  162. bgmv_expand_slice = _bgmv_expand_slice