test_activation.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import time
  2. from typing import Type
  3. import pytest
  4. import torch
  5. from aphrodite.modeling.layers.activation import (FastGELU, GeluAndMul,
  6. NewGELU, QuickGELU,
  7. ReLUSquaredActivation,
  8. SiluAndMul)
  9. from tests.kernels.utils import opcheck
  10. from .allclose_default import get_default_atol, get_default_rtol
  11. DTYPES = [torch.half, torch.bfloat16, torch.float]
  12. NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
  13. D = [512, 4096, 5120, 13824] # Arbitrary values for testing
  14. SEEDS = [0]
  15. CUDA_DEVICES = [
  16. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  17. ]
  18. @pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
  19. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  20. @pytest.mark.parametrize("d", D)
  21. @pytest.mark.parametrize("dtype", DTYPES)
  22. @pytest.mark.parametrize("seed", SEEDS)
  23. @pytest.mark.parametrize("device", CUDA_DEVICES)
  24. @torch.inference_mode()
  25. def test_act_and_mul(
  26. activation: str,
  27. num_tokens: int,
  28. d: int,
  29. dtype: torch.dtype,
  30. seed: int,
  31. device: str,
  32. ) -> None:
  33. torch.random.manual_seed(seed)
  34. if torch.cuda.is_available():
  35. torch.cuda.manual_seed(seed)
  36. torch.set_default_device(device)
  37. x = torch.randn(num_tokens, 2 * d, dtype=dtype)
  38. if activation == "silu":
  39. layer = SiluAndMul()
  40. fn = torch.ops._C.silu_and_mul
  41. elif activation == "gelu":
  42. layer = GeluAndMul(approximate="none")
  43. fn = torch.ops._C.gelu_and_mul
  44. elif activation == "gelu_tanh":
  45. layer = GeluAndMul(approximate="tanh")
  46. fn = torch.ops._C.gelu_tanh_and_mul
  47. out = layer(x)
  48. ref_out = layer.forward_native(x)
  49. # The SiLU and GELU implementations are equivalent to the native PyTorch
  50. # implementations, so we can do exact comparison.
  51. torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
  52. d = x.shape[-1] // 2
  53. output_shape = (x.shape[:-1] + (d, ))
  54. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  55. opcheck(fn, (out, x))
  56. @pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
  57. (NewGELU, torch.ops._C.gelu_new),
  58. (QuickGELU, torch.ops._C.gelu_quick)])
  59. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  60. @pytest.mark.parametrize("d", D)
  61. @pytest.mark.parametrize("dtype", DTYPES)
  62. @pytest.mark.parametrize("seed", SEEDS)
  63. @pytest.mark.parametrize("device", CUDA_DEVICES)
  64. @torch.inference_mode()
  65. def test_activation(
  66. activation: Type[torch.nn.Module],
  67. num_tokens: int,
  68. d: int,
  69. dtype: torch.dtype,
  70. seed: int,
  71. device: str,
  72. ) -> None:
  73. torch.random.manual_seed(seed)
  74. if torch.cuda.is_available():
  75. torch.cuda.manual_seed(seed)
  76. torch.set_default_device(device)
  77. x = torch.randn(num_tokens, d, dtype=dtype)
  78. layer = activation[0]()
  79. fn = activation[1]
  80. out = layer(x)
  81. ref_out = layer.forward_native(x)
  82. torch.testing.assert_close(out,
  83. ref_out,
  84. atol=get_default_atol(out),
  85. rtol=get_default_rtol(out))
  86. out = torch.empty_like(x)
  87. opcheck(fn, (out, x))
  88. @pytest.mark.parametrize("activation_cls, kwargs", [
  89. (SiluAndMul, {}),
  90. (GeluAndMul, {"approximate": "none"}),
  91. (GeluAndMul, {"approximate": "tanh"}),
  92. (NewGELU, {}),
  93. (FastGELU, {}),
  94. (QuickGELU, {}),
  95. (ReLUSquaredActivation, {}),
  96. ])
  97. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  98. @pytest.mark.parametrize("d", D)
  99. @pytest.mark.parametrize("dtype", DTYPES)
  100. @pytest.mark.parametrize("seed", SEEDS)
  101. @pytest.mark.parametrize("device", CUDA_DEVICES)
  102. @torch.inference_mode()
  103. def test_activation_triton(
  104. activation_cls, kwargs, num_tokens, d, dtype, seed, device):
  105. torch.random.manual_seed(seed)
  106. if torch.cuda.is_available():
  107. torch.cuda.manual_seed(seed)
  108. torch.set_default_device(device)
  109. activation = activation_cls(**kwargs).to(device=device, dtype=dtype)
  110. # Input shape is (num_tokens, 2*d) for these activations.
  111. x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=device)
  112. native_out = activation.forward_native(x)
  113. triton_out = activation.forward_triton(x)
  114. torch.testing.assert_close(triton_out, native_out, atol=1e-2, rtol=1e-2)
  115. # TODO: enable this test after fixing the performance issue
  116. @pytest.mark.skip("skipping performance test")
  117. @pytest.mark.parametrize("activation_cls, kwargs", [
  118. (SiluAndMul, {}),
  119. (GeluAndMul, {"approximate": "none"}),
  120. (GeluAndMul, {"approximate": "tanh"}),
  121. (NewGELU, {}),
  122. (FastGELU, {}),
  123. (QuickGELU, {}),
  124. (ReLUSquaredActivation, {}),
  125. ])
  126. @pytest.mark.parametrize("batch_size, seq_len, hidden_size", [
  127. (1, 2048, 4096),
  128. (32, 512, 4096),
  129. ])
  130. @torch.inference_mode()
  131. def test_activation_performance(
  132. activation_cls, kwargs, batch_size: int, seq_len: int,
  133. hidden_size: int, device: str = "cuda"
  134. ) -> None:
  135. """Test that Triton implementation performance is close to CUDA.
  136. Note: Performance in isolation might not reflect real-world performance
  137. where activation is part of a larger pipeline."""
  138. if not torch.cuda.is_available():
  139. pytest.skip("CUDA not available")
  140. torch.set_default_device(device)
  141. activation = activation_cls(**kwargs).to(device=device, dtype=torch.float16)
  142. # For SiluAndMul and GeluAndMul, input shape needs 2*hidden_size
  143. if activation_cls in [SiluAndMul, GeluAndMul]:
  144. x = torch.randn(batch_size, seq_len, 2 * hidden_size,
  145. dtype=torch.float16, device=device)
  146. else:
  147. x = torch.randn(batch_size, seq_len, hidden_size,
  148. dtype=torch.float16, device=device)
  149. # Warmup
  150. for _ in range(10):
  151. activation.forward_cuda(x)
  152. activation.forward_triton(x)
  153. # Time CUDA implementation
  154. torch.cuda.synchronize()
  155. start = time.perf_counter()
  156. for _ in range(100):
  157. activation.forward_cuda(x)
  158. torch.cuda.synchronize()
  159. cuda_time = time.perf_counter() - start
  160. # Time Triton implementation
  161. torch.cuda.synchronize()
  162. start = time.perf_counter()
  163. for _ in range(100):
  164. activation.forward_triton(x)
  165. torch.cuda.synchronize()
  166. triton_time = time.perf_counter() - start
  167. # Must be within 1% for inference shapes (batch_size=1)
  168. # or within 20% for other shapes
  169. max_slowdown = 1.01 if batch_size == 1 else 1.2
  170. assert triton_time <= cuda_time * max_slowdown, (
  171. f"{activation_cls.__name__} Triton implementation is significantly "
  172. "slower than CUDA "
  173. f"(Triton: {triton_time:.3f}s, CUDA: {cuda_time:.3f}s) "
  174. f"for shape (batch={batch_size}, seq={seq_len}, hidden={hidden_size}) "
  175. f"slowdown : {(triton_time - cuda_time) / cuda_time * 100:.2f}%"
  176. )