activation.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # The following code is loosely based on:
  15. # https://github.com/unslothai/unsloth/blob/038e6d4c8d40207a87297ab3aaf787c19b1006d1/unsloth/kernels/swiglu.py
  16. # and
  17. # https://github.com/unslothai/unsloth/blob/038e6d4c8d40207a87297ab3aaf787c19b1006d1/unsloth/kernels/geglu.py
  18. import torch
  19. import triton
  20. import triton.language as tl
  21. from packaging.version import Version
  22. if Version(triton.__version__) >= Version("3.0.0"):
  23. from triton.language.extra import libdevice
  24. triton_tanh = libdevice.tanh
  25. triton_erf = libdevice.erf
  26. triton_sqrt = libdevice.sqrt
  27. else:
  28. triton_tanh = tl.math.tanh
  29. triton_erf = tl.math.erf
  30. triton_sqrt = tl.math.sqrt
  31. @triton.jit
  32. def _fg_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):
  33. """
  34. Compute SiLU activation and multiply with gate:
  35. h = silu(e) * g where silu(x) = x * sigmoid(x)
  36. Differences from unsloth:
  37. 1. Support for 2D inputs
  38. """
  39. pid = tl.program_id(axis=0)
  40. block_start = pid * BLOCK_SIZE
  41. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  42. mask = offsets < n_elements
  43. e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
  44. g_row = tl.load(g + offsets, mask=mask, other=0)
  45. f_row = e_row * tl.sigmoid(e_row)
  46. f_row = f_row.to(g_row.dtype)
  47. output = f_row * g_row
  48. tl.store(h + offsets, output, mask=mask)
  49. def swiglu_fg_kernel(e, g):
  50. # If e is 2D (num_tokens x d), add a dummy batch dimension
  51. squeeze = False
  52. if e.dim() == 2:
  53. e = e.unsqueeze(0)
  54. g = g.unsqueeze(0)
  55. squeeze = True
  56. batch, num_tokens, d = e.shape
  57. n_elements = batch * num_tokens * d
  58. h = torch.empty((batch, num_tokens, d), dtype=e.dtype, device=e.device)
  59. grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  60. with torch.cuda.device(e.device):
  61. _fg_kernel[grid](
  62. e.reshape(-1), g.reshape(-1), h.reshape(-1),
  63. n_elements, BLOCK_SIZE=1024
  64. )
  65. if squeeze:
  66. return h.squeeze(0)
  67. return h
  68. @triton.jit
  69. def _exact_gelu_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):
  70. """
  71. Compute exact GELU activation and multiply with gate:
  72. h = gelu(e) * g where gelu(x) = x * 0.5 * (1 + erf(x/sqrt(2)))
  73. Differences from unsloth:
  74. 1. Support for 2D inputs
  75. """
  76. pid = tl.program_id(axis=0)
  77. block_start = pid * BLOCK_SIZE
  78. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  79. mask = offsets < n_elements
  80. e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
  81. g_row = tl.load(g + offsets, mask=mask, other=0)
  82. f_row = 0.5 * e_row * (triton_erf(triton_sqrt(0.5) * e_row) + 1.0)
  83. f_row = f_row.to(g_row.dtype)
  84. output = f_row * g_row
  85. tl.store(h + offsets, output, mask=mask)
  86. @triton.jit
  87. def _approx_gelu_kernel(e, g, h, n_elements, BLOCK_SIZE: tl.constexpr):
  88. """
  89. Compute approximate GELU activation and multiply with gate:
  90. h = gelu(e) * g where
  91. gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
  92. Differences from unsloth:
  93. 1. Support for 2D inputs
  94. """
  95. pid = tl.program_id(axis=0)
  96. block_start = pid * BLOCK_SIZE
  97. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  98. mask = offsets < n_elements
  99. e_row = tl.load(e + offsets, mask=mask, other=0).to(tl.float32)
  100. g_row = tl.load(g + offsets, mask=mask, other=0)
  101. s = 0.7978845608028654 # sqrt(2/pi)
  102. f_row = 0.5 * e_row * (
  103. triton_tanh(s * e_row * (1.0 + 0.044715 * e_row * e_row)) + 1.0
  104. )
  105. f_row = f_row.to(g_row.dtype)
  106. output = f_row * g_row
  107. tl.store(h + offsets, output, mask=mask)
  108. def geglu_exact_forward_kernel(e, g):
  109. # If e is 2D (num_tokens x d), add a dummy batch dimension
  110. squeeze = False
  111. if e.dim() == 2:
  112. e = e.unsqueeze(0)
  113. g = g.unsqueeze(0)
  114. squeeze = True
  115. batch, num_tokens, d = e.shape
  116. n_elements = batch * num_tokens * d
  117. h = torch.empty((batch, num_tokens, d), dtype=e.dtype, device=e.device)
  118. grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  119. with torch.cuda.device(e.device):
  120. _exact_gelu_kernel[grid](
  121. e.reshape(-1), g.reshape(-1), h.reshape(-1),
  122. n_elements, BLOCK_SIZE=1024
  123. )
  124. if squeeze:
  125. return h.squeeze(0)
  126. return h
  127. def geglu_approx_forward_kernel(e, g):
  128. # If e is 2D (num_tokens x d), add a dummy batch dimension
  129. squeeze = False
  130. if e.dim() == 2:
  131. e = e.unsqueeze(0)
  132. g = g.unsqueeze(0)
  133. squeeze = True
  134. batch, num_tokens, d = e.shape
  135. n_elements = batch * num_tokens * d
  136. h = torch.empty((batch, num_tokens, d), dtype=e.dtype, device=e.device)
  137. grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  138. with torch.cuda.device(e.device):
  139. _approx_gelu_kernel[grid](
  140. e.reshape(-1), g.reshape(-1), h.reshape(-1),
  141. n_elements, BLOCK_SIZE=1024
  142. )
  143. if squeeze:
  144. return h.squeeze(0)
  145. return h
  146. @triton.jit
  147. def _gelu_new_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
  148. """
  149. Compute new GELU activation (same as approximate GELU):
  150. gelu_new(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
  151. """
  152. pid = tl.program_id(axis=0)
  153. block_start = pid * BLOCK_SIZE
  154. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  155. mask = offsets < n_elements
  156. x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
  157. x3 = x * x * x
  158. c = 0.79788456 # sqrt(2/pi)
  159. t = triton_tanh(c * (x + 0.044715 * x3))
  160. output = 0.5 * x * (1.0 + t)
  161. tl.store(output_ptr + offsets, output, mask=mask)
  162. def gelu_new_kernel(x: torch.Tensor) -> torch.Tensor:
  163. """Triton kernel wrapper for new GELU activation."""
  164. # If x is 2D (num_tokens x d), add a dummy batch dimension
  165. squeeze = False
  166. if x.dim() == 2:
  167. x = x.unsqueeze(0)
  168. squeeze = True
  169. batch, num_tokens, d = x.shape
  170. n_elements = batch * num_tokens * d
  171. output = torch.empty_like(x)
  172. grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  173. with torch.cuda.device(x.device):
  174. _gelu_new_kernel[grid](
  175. x.reshape(-1), output.reshape(-1),
  176. n_elements, BLOCK_SIZE=1024
  177. )
  178. if squeeze:
  179. return output.squeeze(0)
  180. return output
  181. @triton.jit
  182. def _fast_gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
  183. """
  184. Compute fast GELU activation:
  185. gelu_fast(x) = 0.5 * x * (1 + tanh(0.7978845608 * x * (1 + 0.044715 * x^2)))
  186. """
  187. pid = tl.program_id(axis=0)
  188. block_start = pid * BLOCK_SIZE
  189. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  190. mask = offsets < n_elements
  191. x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
  192. c = 0.79788456 # sqrt(2/pi)
  193. inner = x * (1.0 + 0.044715 * x * x)
  194. t = triton_tanh(c * inner)
  195. output = 0.5 * x * (1.0 + t)
  196. tl.store(output_ptr + offsets, output, mask=mask)
  197. @triton.jit
  198. def _quick_gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
  199. """
  200. Compute quick GELU activation:
  201. quick_gelu(x) = x * sigmoid(1.702 * x)
  202. """
  203. pid = tl.program_id(axis=0)
  204. block_start = pid * BLOCK_SIZE
  205. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  206. mask = offsets < n_elements
  207. x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
  208. # Compute x * sigmoid(1.702 * x)
  209. output = x * (1.0 / (1.0 + tl.exp(-1.702 * x)))
  210. tl.store(output_ptr + offsets, output, mask=mask)
  211. def fast_gelu_kernel(x: torch.Tensor) -> torch.Tensor:
  212. """Triton kernel wrapper for fast GELU activation."""
  213. # If x is 2D (num_tokens x d), add a dummy batch dimension
  214. squeeze = False
  215. if x.dim() == 2:
  216. x = x.unsqueeze(0)
  217. squeeze = True
  218. batch, num_tokens, d = x.shape
  219. n_elements = batch * num_tokens * d
  220. output = torch.empty_like(x)
  221. grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  222. with torch.cuda.device(x.device):
  223. _fast_gelu_kernel[grid](
  224. x.reshape(-1), output.reshape(-1),
  225. n_elements, BLOCK_SIZE=1024
  226. )
  227. if squeeze:
  228. return output.squeeze(0)
  229. return output
  230. def quick_gelu_kernel(x: torch.Tensor) -> torch.Tensor:
  231. """Triton kernel wrapper for quick GELU activation."""
  232. # If x is 2D (num_tokens x d), add a dummy batch dimension
  233. squeeze = False
  234. if x.dim() == 2:
  235. x = x.unsqueeze(0)
  236. squeeze = True
  237. batch, num_tokens, d = x.shape
  238. n_elements = batch * num_tokens * d
  239. output = torch.empty_like(x)
  240. grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  241. with torch.cuda.device(x.device):
  242. _quick_gelu_kernel[grid](
  243. x.reshape(-1), output.reshape(-1),
  244. n_elements, BLOCK_SIZE=1024
  245. )
  246. if squeeze:
  247. return output.squeeze(0)
  248. return output
  249. @triton.jit
  250. def _relu_squared_kernel(
  251. x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
  252. """
  253. Compute Squared ReLU:
  254. relu2(x) = x² if x > 0 else 0
  255. Optimization: Uses direct bit manipulation instead of relu->square
  256. For IEEE 754 floats, sign bit is the MSB, so we can:
  257. 1. Check sign bit directly
  258. 2. Square only if positive
  259. 3. Avoid branch prediction issues with masked operations
  260. """
  261. pid = tl.program_id(axis=0)
  262. block_start = pid * BLOCK_SIZE
  263. offsets = block_start + tl.arange(0, BLOCK_SIZE)
  264. mask = offsets < n_elements
  265. x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
  266. # Create mask for positive values (sign bit = 0)
  267. # IEEE 754: sign bit is MSB, so x >= 0 means top bit is 0
  268. is_positive = x >= 0
  269. # Square only positive values, others become 0
  270. # This is faster than separate relu and square
  271. output = tl.where(is_positive, x * x, 0.0)
  272. tl.store(output_ptr + offsets, output, mask=mask)
  273. def relu_squared_kernel(x: torch.Tensor) -> torch.Tensor:
  274. """Triton kernel wrapper for Squared ReLU activation."""
  275. # If x is 2D (num_tokens x d), add a dummy batch dimension
  276. squeeze = False
  277. if x.dim() == 2:
  278. x = x.unsqueeze(0)
  279. squeeze = True
  280. batch, num_tokens, d = x.shape
  281. n_elements = batch * num_tokens * d
  282. output = torch.empty_like(x)
  283. grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
  284. with torch.cuda.device(x.device):
  285. _relu_squared_kernel[grid](
  286. x.reshape(-1), output.reshape(-1),
  287. n_elements, BLOCK_SIZE=1024
  288. )
  289. if squeeze:
  290. return output.squeeze(0)
  291. return output