test_activation.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from typing import Type
  2. import pytest
  3. import torch
  4. from aphrodite.modeling.layers.activation import (FastGELU, GeluAndMul,
  5. NewGELU, SiluAndMul)
  6. from .allclose_default import get_default_atol, get_default_rtol
  7. DTYPES = [torch.half, torch.bfloat16, torch.float]
  8. NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
  9. D = [512, 4096, 5120, 13824] # Arbitrary values for testing
  10. SEEDS = [0]
  11. CUDA_DEVICES = [
  12. f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
  13. ]
  14. @pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"])
  15. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  16. @pytest.mark.parametrize("d", D)
  17. @pytest.mark.parametrize("dtype", DTYPES)
  18. @pytest.mark.parametrize("seed", SEEDS)
  19. @pytest.mark.parametrize("device", CUDA_DEVICES)
  20. @torch.inference_mode()
  21. def test_act_and_mul(
  22. activation: str,
  23. num_tokens: int,
  24. d: int,
  25. dtype: torch.dtype,
  26. seed: int,
  27. device: str,
  28. ) -> None:
  29. torch.random.manual_seed(seed)
  30. if torch.cuda.is_available():
  31. torch.cuda.manual_seed(seed)
  32. torch.set_default_device(device)
  33. x = torch.randn(num_tokens, 2 * d, dtype=dtype)
  34. if activation == "silu":
  35. layer = SiluAndMul()
  36. elif activation == "gelu":
  37. layer = GeluAndMul(approximate="none")
  38. elif activation == "gelu_tanh":
  39. layer = GeluAndMul(approximate="tanh")
  40. out = layer(x)
  41. ref_out = layer.forward_native(x)
  42. # The SiLU and GELU implementations are equivalent to the native PyTorch
  43. # implementations, so we can do exact comparison.
  44. torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
  45. @pytest.mark.parametrize("activation", [FastGELU, NewGELU])
  46. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  47. @pytest.mark.parametrize("d", D)
  48. @pytest.mark.parametrize("dtype", DTYPES)
  49. @pytest.mark.parametrize("seed", SEEDS)
  50. @pytest.mark.parametrize("device", CUDA_DEVICES)
  51. @torch.inference_mode()
  52. def test_activation(
  53. activation: Type[torch.nn.Module],
  54. num_tokens: int,
  55. d: int,
  56. dtype: torch.dtype,
  57. seed: int,
  58. device: str,
  59. ) -> None:
  60. torch.random.manual_seed(seed)
  61. if torch.cuda.is_available():
  62. torch.cuda.manual_seed(seed)
  63. torch.set_default_device(device)
  64. x = torch.randn(num_tokens, d, dtype=dtype)
  65. layer = activation()
  66. out = layer(x)
  67. ref_out = layer.forward_native(x)
  68. torch.testing.assert_close(out,
  69. ref_out,
  70. atol=get_default_atol(out),
  71. rtol=get_default_rtol(out))