bgmv_expand.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. from typing import Dict, Optional
  8. import torch
  9. import triton
  10. import triton.language as tl
  11. from .utils import get_lora_op_configs
  12. @triton.jit
  13. def _bgmv_expand_kernel(
  14. input_ptr,
  15. lora_ptr,
  16. out_ptr,
  17. N,
  18. K,
  19. lora_indices,
  20. xm_stride,
  21. xk_stride,
  22. l0_stride,
  23. lora_k_stride,
  24. lora_n_stride,
  25. cm_stride,
  26. cn_stride,
  27. BLOCK_N: tl.constexpr,
  28. BLOCK_K: tl.constexpr,
  29. SPLIT_N: tl.constexpr,
  30. EVEN_K: tl.constexpr,
  31. ADD_INPUTS: tl.constexpr,
  32. CAST_TYPE: tl.constexpr,
  33. ):
  34. """
  35. GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
  36. performance
  37. """
  38. pid_sn = tl.program_id(axis=0)
  39. cur_batch = tl.program_id(axis=1)
  40. lora_index = tl.load(lora_indices + cur_batch)
  41. if lora_index == -1:
  42. return
  43. offset_k = tl.arange(0, BLOCK_K)
  44. offset_n = tl.arange(0, BLOCK_N)
  45. if EVEN_K:
  46. tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
  47. offset_k * xk_stride, ) # [BLOCK_K]
  48. else:
  49. tiled_a = tl.load(
  50. input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
  51. mask=offset_k < K,
  52. other=0,
  53. ) # [BLOCK_K]
  54. # N must be divisible by SPLIT_N
  55. split_n_length = tl.cdiv(N, SPLIT_N)
  56. if CAST_TYPE:
  57. tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
  58. # sliding to next row-block
  59. b_ptr = (lora_ptr + l0_stride * lora_index +
  60. pid_sn * split_n_length * lora_k_stride)
  61. c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
  62. for n in range(0, split_n_length, BLOCK_N):
  63. current_n = n + offset_n
  64. current_n_c = tl.max_contiguous(current_n, BLOCK_N)
  65. b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
  66. < K)
  67. c_mask = current_n < split_n_length
  68. tiled_b = tl.load(
  69. b_ptr + current_n_c[:, None] * lora_k_stride +
  70. offset_k[None, :] * lora_n_stride,
  71. mask=b_ptr_mask,
  72. other=0.0,
  73. ) # [BLOCK_N,BLOCK_K]
  74. if ADD_INPUTS:
  75. tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
  76. accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
  77. else:
  78. accumulator = tl.sum(tiled_a * tiled_b, 1)
  79. tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
  80. @torch.inference_mode()
  81. def bgmv_expand(
  82. inputs: torch.Tensor,
  83. lora_b_weights: torch.Tensor,
  84. output_tensor: torch.Tensor,
  85. lora_indices_tensor: torch.Tensor,
  86. add_inputs: bool = True,
  87. override_config: Optional[Dict[str, int]] = None,
  88. ):
  89. """
  90. Args:
  91. inputs (torch.Tensor): input tensor
  92. lora_b_weights (torch.Tensor): lora'a weight
  93. output_tensor (torch.Tensor): output tensor
  94. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
  95. corresponding to each batch, An index of -1 means no lora should be
  96. applied.
  97. batches (int): batch size
  98. add_inputs (bool, optional): Defaults to False. adds the final lora
  99. results to the output.
  100. override_config (Optional[Dict[str, int]], optional): Defaults to None.
  101. Triton grid config
  102. """
  103. assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
  104. assert lora_b_weights.dtype in [
  105. torch.float16,
  106. torch.bfloat16,
  107. ]
  108. assert inputs.size(1) == lora_b_weights.size(-1)
  109. assert inputs.is_contiguous()
  110. assert output_tensor.is_contiguous()
  111. if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
  112. assert lora_b_weights.size(1) == 1
  113. lora_b_weights = lora_b_weights.squeeze(dim=1)
  114. else:
  115. assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
  116. assert lora_b_weights.is_contiguous()
  117. # TODO tuning this config
  118. N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
  119. BLOCK_K = triton.next_power_of_2(K)
  120. EVEN_K = K % BLOCK_K == 0
  121. ADD_INPUTS = add_inputs
  122. CAST_TYPE = False
  123. if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
  124. torch.float16,
  125. torch.bfloat16,
  126. ]:
  127. CAST_TYPE = True
  128. batches = lora_indices_tensor.size(0)
  129. if override_config:
  130. config = override_config
  131. else:
  132. config = get_lora_op_configs("expand", batches, N)
  133. grid = lambda META: (
  134. META["SPLIT_N"],
  135. batches,
  136. )
  137. _bgmv_expand_kernel[grid](
  138. inputs,
  139. lora_b_weights,
  140. output_tensor,
  141. N,
  142. K,
  143. lora_indices_tensor,
  144. inputs.stride(0),
  145. inputs.stride(1),
  146. lora_b_weights.stride(0),
  147. lora_b_weights.stride(1),
  148. lora_b_weights.stride(2),
  149. output_tensor.stride(0),
  150. output_tensor.stride(1),
  151. BLOCK_K=BLOCK_K,
  152. EVEN_K=EVEN_K,
  153. ADD_INPUTS=ADD_INPUTS,
  154. CAST_TYPE=CAST_TYPE,
  155. **config,
  156. )
  157. return