libentry.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copied From https://github.com/FlagOpen/FlagGems
  2. import inspect
  3. import triton
  4. class LibEntry(triton.KernelInterface):
  5. def __init__(
  6. self,
  7. fn,
  8. ):
  9. self.fn = fn
  10. self.arg_names = fn.arg_names
  11. self.divisibility = 16
  12. self.kernel_cache = dict()
  13. fn = self.fn
  14. while not isinstance(fn, triton.runtime.JITFunction):
  15. fn = fn.fn
  16. self.jit_function: triton.runtime.JITFunction = fn
  17. self.specialize_indices = [
  18. p.num for p in self.jit_function.params
  19. if not p.is_constexpr and not p.do_not_specialize
  20. ]
  21. self.do_not_specialize_indices = [
  22. p.num for p in self.jit_function.params
  23. if not p.is_constexpr and p.do_not_specialize
  24. ]
  25. def key(self, spec_args, dns_args, const_args):
  26. spec_key = [(arg.dtype, arg.data_ptr() %
  27. self.divisibility == 0) if hasattr(arg, "data_ptr") else
  28. (type(arg), arg) for arg in spec_args]
  29. dns_key = [
  30. arg.dtype if hasattr(
  31. arg, "data_ptr") else type(arg) if not isinstance(arg, int)
  32. else "i32" if -(2**31) <= arg and arg <= 2**31 -
  33. 1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64"
  34. for arg in dns_args
  35. ]
  36. # const args passed by position
  37. return tuple(spec_key + dns_key + const_args)
  38. def run(self, *args, **kwargs):
  39. grid = kwargs["grid"]
  40. # collect all the arguments
  41. spec_args = [] # specialize arguments
  42. dns_args = [] # do not specialize arguments
  43. const_args = [] # constexpr arguments
  44. k_args = [] # kernel arguments
  45. for i, arg in enumerate(args):
  46. if i in self.specialize_indices:
  47. k_args.append(arg)
  48. spec_args.append(arg)
  49. elif i in self.do_not_specialize_indices:
  50. k_args.append(arg)
  51. dns_args.append(arg)
  52. else:
  53. const_args.append(arg)
  54. for p in self.jit_function.params[len(args):]:
  55. if p.name in kwargs:
  56. val = kwargs[p.name]
  57. elif p.default is inspect._empty:
  58. continue
  59. else:
  60. val = p.default
  61. if p.is_constexpr:
  62. const_args.append(val)
  63. elif p.do_not_specialize:
  64. dns_args.append(val)
  65. k_args.append(val)
  66. else:
  67. spec_args.append(val)
  68. k_args.append(val)
  69. entry_key = self.key(spec_args, dns_args, const_args)
  70. if entry_key not in self.kernel_cache:
  71. # compile the kernel also completes the related computations
  72. kernel = self.fn.run(*args, **kwargs)
  73. fn = self.fn
  74. # collect constexpr arguments for grid computation
  75. constexprs = {}
  76. while not isinstance(fn, triton.runtime.JITFunction):
  77. if isinstance(fn, triton.runtime.Autotuner):
  78. config = fn.best_config
  79. constexprs["num_warps"] = config.num_warps
  80. constexprs["num_stages"] = config.num_stages
  81. constexprs["num_ctas"] = config.num_ctas
  82. constexprs = {**constexprs, **config.kwargs}
  83. elif isinstance(fn, triton.runtime.Heuristics):
  84. for v, heur in fn.values.items():
  85. constexprs[v] = heur({
  86. **dict(zip(fn.arg_names, args)),
  87. **kwargs,
  88. **constexprs,
  89. })
  90. else:
  91. raise RuntimeError("Invalid Runtime Function")
  92. fn = fn.fn
  93. # In vLLM, certain kernels like fused_moe_kernel get the
  94. # best_config(as kwargs) from a configuration json file, rather
  95. # than using Autotuner & Heuristics. Therefore, all their constexprs
  96. # (tl.constexpr) are assigned values through the following loop.
  97. for p in self.jit_function.params:
  98. if p.is_constexpr and p.name not in constexprs:
  99. constexprs[p.name] = p.default #default=inspect._empty
  100. self.kernel_cache[entry_key] = (kernel, constexprs)
  101. else:
  102. # load kernel from cache directly
  103. kernel, constexprs = self.kernel_cache[entry_key]
  104. if callable(grid):
  105. # collect all arguments to the grid fn,ie:
  106. # 1. args,
  107. # 2. kwargs,
  108. # 3. all all other captured arguments in CompiledKernel from
  109. # Autotunner & Heuristics when kwargs & captured args conflict,
  110. # captured args have higher priority
  111. # 4. We must filter out captured args with default value firstly
  112. constexprs = {
  113. k: v
  114. for k, v in constexprs.items() if v is not inspect._empty
  115. }
  116. meta = {
  117. **dict(zip(self.arg_names, args)),
  118. **kwargs,
  119. **constexprs,
  120. }
  121. grid = grid(meta)
  122. if isinstance(grid, tuple):
  123. grid = grid + (1, 1)
  124. elif isinstance(grid, list):
  125. grid = grid + [1, 1]
  126. kernel[grid[0:3]](*k_args)
  127. # maintaining the same return type as the JITFunction.run
  128. return kernel
  129. def libentry():
  130. """
  131. Decorator for triton library entries.
  132. Motivation:
  133. The runtime overhead of Triton kernels is the reason for the lower
  134. performance of small kernels, particularly evident with smaller models.
  135. Using this decorator can reduce Triton runtime overhead.
  136. How:
  137. The `run` function of JITFunction needs to accomplish:
  138. - Parameter binding using inspect
  139. - KernelArg type wrapping
  140. - Cache key calculation
  141. When dealing with small size, these steps can become bottlenecks in
  142. Triton runtime. Libentry simplifies these steps to reduce runtime
  143. overhead, thereby improving the runtime expenses of small kernels.
  144. NOTE:
  145. When Triton is upgraded to version 3.0.0, libentry can be removed.
  146. """
  147. def decorator(fn):
  148. return LibEntry(fn)
  149. return decorator