test_activation.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from typing import Type
  2. import pytest
  3. import torch
  4. from aphrodite.modeling.layers.activation import (FastGELU, GeluAndMul,
  5. NewGELU, QuickGELU,
  6. SiluAndMul)
  7. from tests.kernels.utils import opcheck
  8. from .allclose_default import get_default_atol, get_default_rtol
  9. DTYPES = [torch.half, torch.bfloat16, torch.float]
  10. NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
  11. D = [512, 4096, 5120, 13824] # Arbitrary values for testing
  12. SEEDS = [0]
  13. CUDA_DEVICES = [
  14. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  15. ]
  16. @pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
  17. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  18. @pytest.mark.parametrize("d", D)
  19. @pytest.mark.parametrize("dtype", DTYPES)
  20. @pytest.mark.parametrize("seed", SEEDS)
  21. @pytest.mark.parametrize("device", CUDA_DEVICES)
  22. @torch.inference_mode()
  23. def test_act_and_mul(
  24. activation: str,
  25. num_tokens: int,
  26. d: int,
  27. dtype: torch.dtype,
  28. seed: int,
  29. device: str,
  30. ) -> None:
  31. torch.random.manual_seed(seed)
  32. if torch.cuda.is_available():
  33. torch.cuda.manual_seed(seed)
  34. torch.set_default_device(device)
  35. x = torch.randn(num_tokens, 2 * d, dtype=dtype)
  36. if activation == "silu":
  37. layer = SiluAndMul()
  38. fn = torch.ops._C.silu_and_mul
  39. elif activation == "gelu":
  40. layer = GeluAndMul(approximate="none")
  41. fn = torch.ops._C.gelu_and_mul
  42. elif activation == "gelu_tanh":
  43. layer = GeluAndMul(approximate="tanh")
  44. fn = torch.ops._C.gelu_tanh_and_mul
  45. out = layer(x)
  46. ref_out = layer.forward_native(x)
  47. # The SiLU and GELU implementations are equivalent to the native PyTorch
  48. # implementations, so we can do exact comparison.
  49. torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
  50. d = x.shape[-1] // 2
  51. output_shape = (x.shape[:-1] + (d, ))
  52. out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
  53. opcheck(fn, (out, x))
  54. @pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast),
  55. (NewGELU, torch.ops._C.gelu_new),
  56. (QuickGELU, torch.ops._C.gelu_quick)])
  57. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  58. @pytest.mark.parametrize("d", D)
  59. @pytest.mark.parametrize("dtype", DTYPES)
  60. @pytest.mark.parametrize("seed", SEEDS)
  61. @pytest.mark.parametrize("device", CUDA_DEVICES)
  62. @torch.inference_mode()
  63. def test_activation(
  64. activation: Type[torch.nn.Module],
  65. num_tokens: int,
  66. d: int,
  67. dtype: torch.dtype,
  68. seed: int,
  69. device: str,
  70. ) -> None:
  71. torch.random.manual_seed(seed)
  72. if torch.cuda.is_available():
  73. torch.cuda.manual_seed(seed)
  74. torch.set_default_device(device)
  75. x = torch.randn(num_tokens, d, dtype=dtype)
  76. layer = activation[0]()
  77. fn = activation[1]
  78. out = layer(x)
  79. ref_out = layer.forward_native(x)
  80. torch.testing.assert_close(out,
  81. ref_out,
  82. atol=get_default_atol(out),
  83. rtol=get_default_rtol(out))
  84. out = torch.empty_like(x)
  85. opcheck(fn, (out, x))