bgmv_shrink.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """
  2. Based on:
  3. Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
  4. Punica: Multi-Tenant LoRA Serving.
  5. https://arxiv.org/abs/2310.18547
  6. """
  7. from typing import Dict, Optional
  8. import torch
  9. import triton
  10. import triton.language as tl
  11. from .utils import get_lora_op_configs
  12. @triton.jit
  13. def _bgmv_shrink_kernel(
  14. input_ptr,
  15. lora_ptr,
  16. out_ptr,
  17. N,
  18. K,
  19. lora_indices,
  20. scaling,
  21. xm_stride,
  22. xk_stride,
  23. l0_stride,
  24. lora_k_stride,
  25. lora_n_stride,
  26. cm_stride,
  27. cn_stride,
  28. BLOCK_N: tl.constexpr,
  29. BLOCK_K: tl.constexpr,
  30. SPLIT_K: tl.constexpr,
  31. ):
  32. """
  33. GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's
  34. performance
  35. """
  36. pid_sk = tl.program_id(axis=0)
  37. cur_batch = tl.program_id(axis=1)
  38. lora_index = tl.load(lora_indices + cur_batch)
  39. if lora_index == -1:
  40. return
  41. offset_n = tl.arange(0, BLOCK_N)
  42. offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
  43. a_ptr = input_ptr + cur_batch * xm_stride
  44. b_ptr = lora_ptr + l0_stride * lora_index
  45. accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)
  46. for k in range(0, K, BLOCK_K * SPLIT_K):
  47. current_k = k + offset_k
  48. current_k_c = tl.max_contiguous(current_k, BLOCK_K)
  49. tiled_a = tl.load(
  50. a_ptr + current_k_c,
  51. mask=current_k < K,
  52. other=0.0,
  53. ) # [BLOCK_K]
  54. b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)
  55. tiled_b = tl.load(
  56. b_ptr + offset_n[:, None] * lora_k_stride +
  57. current_k[None, :] * lora_n_stride,
  58. mask=b_ptr_mask,
  59. other=0.0,
  60. ) # [BLOCK_N,BLOCK_K]
  61. accumulator += tl.sum(tiled_a * tiled_b, 1)
  62. accumulator *= scaling
  63. offset_cn = tl.arange(0, BLOCK_N)
  64. c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride
  65. c_mask = offset_cn < N
  66. if SPLIT_K == 1:
  67. tl.store(c_ptr, accumulator, mask=c_mask)
  68. else:
  69. tl.atomic_add(c_ptr, accumulator, mask=c_mask)
  70. @torch.inference_mode()
  71. def bgmv_shrink(
  72. inputs: torch.Tensor,
  73. lora_a_weights: torch.Tensor,
  74. output_tensor: torch.Tensor,
  75. lora_indices_tensor: torch.Tensor,
  76. scaling: float = 1.0,
  77. override_config: Optional[Dict[str, int]] = None,
  78. ):
  79. """
  80. Args:
  81. inputs (torch.Tensor): input tensor
  82. lora_a_weights (torch.Tensor): lora'a weight
  83. output_tensor (torch.Tensor): output tensor
  84. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
  85. corresponding to each batch. An index of -1 means no lora should be
  86. applied.
  87. batches (int): batch size
  88. scaling (float): Scaling factor.
  89. override_config (Optional[Dict[str, int]], optional): Defaults to None.
  90. Triton grid config
  91. """
  92. assert inputs.dtype == lora_a_weights.dtype
  93. assert inputs.dtype in [torch.float16, torch.bfloat16]
  94. assert lora_a_weights.dtype in [
  95. torch.float16,
  96. torch.bfloat16,
  97. ]
  98. assert inputs.size(1) == lora_a_weights.size(-1)
  99. assert inputs.is_contiguous()
  100. if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
  101. assert lora_a_weights.size(1) == 1
  102. lora_a_weights = lora_a_weights.squeeze(dim=1)
  103. else:
  104. assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
  105. assert lora_a_weights.is_contiguous()
  106. assert output_tensor.is_contiguous()
  107. # TODO tuning this config
  108. batches = lora_indices_tensor.size(0)
  109. N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
  110. BLOCK_N = triton.next_power_of_2(N)
  111. if override_config:
  112. config = override_config
  113. else:
  114. # First try to load optimal config from the file
  115. config = get_lora_op_configs("bgmv_shrink", batches, K)
  116. grid = lambda META: (
  117. META["SPLIT_K"],
  118. batches,
  119. )
  120. _bgmv_shrink_kernel[grid](
  121. inputs,
  122. lora_a_weights,
  123. output_tensor,
  124. N,
  125. K,
  126. lora_indices_tensor,
  127. scaling,
  128. inputs.stride(0),
  129. inputs.stride(1),
  130. lora_a_weights.stride(0),
  131. lora_a_weights.stride(1),
  132. lora_a_weights.stride(2),
  133. output_tensor.stride(0),
  134. output_tensor.stride(1),
  135. BLOCK_N=BLOCK_N,
  136. **config,
  137. )
  138. return