bgmv_expand.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. import torch
  8. import triton
  9. import triton.language as tl
  10. from .utils import get_lora_op_configs
  11. @triton.jit
  12. def _bgmv_expand_kernel(
  13. input_ptr,
  14. lora_ptr,
  15. out_ptr,
  16. N,
  17. K,
  18. lora_indices,
  19. xm_stride,
  20. xk_stride,
  21. l0_stride,
  22. lora_k_stride,
  23. lora_n_stride,
  24. cm_stride,
  25. cn_stride,
  26. BLOCK_N: tl.constexpr,
  27. BLOCK_K: tl.constexpr,
  28. SPLIT_N: tl.constexpr,
  29. EVEN_K: tl.constexpr,
  30. ADD_INPUTS: tl.constexpr,
  31. CAST_TYPE: tl.constexpr,
  32. ):
  33. """
  34. GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
  35. performance
  36. """
  37. pid_sn = tl.program_id(axis=0)
  38. cur_batch = tl.program_id(axis=1)
  39. lora_index = tl.load(lora_indices + cur_batch)
  40. if lora_index == -1:
  41. return
  42. offset_k = tl.arange(0, BLOCK_K)
  43. offset_n = tl.arange(0, BLOCK_N)
  44. if EVEN_K:
  45. tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
  46. offset_k * xk_stride, ) # [BLOCK_K]
  47. else:
  48. tiled_a = tl.load(
  49. input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
  50. mask=offset_k < K,
  51. other=0,
  52. ) # [BLOCK_K]
  53. # N must be divisible by SPLIT_N
  54. split_n_length = tl.cdiv(N, SPLIT_N)
  55. if CAST_TYPE:
  56. tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
  57. # sliding to next row-block
  58. b_ptr = (lora_ptr + l0_stride * lora_index +
  59. pid_sn * split_n_length * lora_k_stride)
  60. c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
  61. for n in range(0, split_n_length, BLOCK_N):
  62. current_n = n + offset_n
  63. current_n_c = tl.max_contiguous(current_n, BLOCK_N)
  64. b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
  65. < K)
  66. c_mask = current_n < split_n_length
  67. tiled_b = tl.load(
  68. b_ptr + current_n_c[:, None] * lora_k_stride +
  69. offset_k[None, :] * lora_n_stride,
  70. mask=b_ptr_mask,
  71. other=0.0,
  72. ) # [BLOCK_N,BLOCK_K]
  73. if ADD_INPUTS:
  74. tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
  75. accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
  76. else:
  77. accumulator = tl.sum(tiled_a * tiled_b, 1)
  78. tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
  79. @torch.inference_mode()
  80. def _bgmv_expand(
  81. inputs: torch.Tensor,
  82. lora_b_weights: torch.Tensor,
  83. output_tensor: torch.Tensor,
  84. lora_indices_tensor: torch.Tensor,
  85. add_inputs: bool = True,
  86. ) -> None:
  87. """
  88. Args:
  89. inputs (torch.Tensor): input tensor
  90. lora_b_weights (torch.Tensor): lora'a weight
  91. output_tensor (torch.Tensor): output tensor
  92. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
  93. corresponding to each batch, An index of -1 means no lora should be
  94. applied.
  95. batches (int): batch size
  96. add_inputs (bool, optional): Defaults to False. adds the final lora
  97. results to the output.
  98. """
  99. assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
  100. assert lora_b_weights.dtype in [
  101. torch.float16,
  102. torch.bfloat16,
  103. ]
  104. assert inputs.size(1) == lora_b_weights.size(-1)
  105. assert inputs.is_contiguous()
  106. assert output_tensor.is_contiguous()
  107. if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
  108. assert lora_b_weights.size(1) == 1
  109. lora_b_weights = lora_b_weights.squeeze(dim=1)
  110. else:
  111. assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
  112. assert lora_b_weights.is_contiguous()
  113. # TODO tuning this config
  114. N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
  115. BLOCK_K = triton.next_power_of_2(K)
  116. EVEN_K = K % BLOCK_K == 0
  117. ADD_INPUTS = add_inputs
  118. CAST_TYPE = False
  119. if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
  120. torch.float16,
  121. torch.bfloat16,
  122. ]:
  123. CAST_TYPE = True
  124. batches = lora_indices_tensor.size(0)
  125. config = get_lora_op_configs("expand", batches, N)
  126. grid = lambda META: (
  127. META["SPLIT_N"],
  128. batches,
  129. )
  130. _bgmv_expand_kernel[grid](
  131. inputs,
  132. lora_b_weights,
  133. output_tensor,
  134. N,
  135. K,
  136. lora_indices_tensor,
  137. inputs.stride(0),
  138. inputs.stride(1),
  139. lora_b_weights.stride(0),
  140. lora_b_weights.stride(1),
  141. lora_b_weights.stride(2),
  142. output_tensor.stride(0),
  143. output_tensor.stride(1),
  144. BLOCK_K=BLOCK_K,
  145. EVEN_K=EVEN_K,
  146. ADD_INPUTS=ADD_INPUTS,
  147. CAST_TYPE=CAST_TYPE,
  148. **config,
  149. )
  150. return
  151. bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
  152. _bgmv_expand,
  153. mutates_args=["output_tensor"])