test_layernorm.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import pytest
  2. import torch
  3. from aphrodite.modeling.layers.layernorm import RMSNorm
  4. DTYPES = [torch.half, torch.bfloat16, torch.float]
  5. NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
  6. HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
  7. ADD_RESIDUAL = [False, True]
  8. SEEDS = [0]
  9. DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
  10. @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
  11. @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
  12. @pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
  13. @pytest.mark.parametrize("dtype", DTYPES)
  14. @pytest.mark.parametrize("seed", SEEDS)
  15. @pytest.mark.parametrize("device", DEVICES)
  16. @torch.inference_mode()
  17. def test_rms_norm(
  18. num_tokens: int,
  19. hidden_size: int,
  20. add_residual: bool,
  21. dtype: torch.dtype,
  22. seed: int,
  23. device: int,
  24. ) -> None:
  25. torch.random.manual_seed(seed)
  26. torch.cuda.manual_seed(seed)
  27. gpu_id = f"cuda:{device}"
  28. layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id)
  29. layer.weight.data.normal_(mean=1.0, std=0.1)
  30. scale = 1 / (2 * hidden_size)
  31. x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id)
  32. x *= scale
  33. residual = torch.randn_like(x) * scale if add_residual else None
  34. # NOTE: The reference implementation should be executed first
  35. # because the custom kernel is in-place.
  36. ref_out = layer._forward(x, residual)
  37. out = layer(x, residual)
  38. # NOTE: LayerNorm operators (including RMS) typically have larger
  39. # numerical errors than other operators because they involve reductions.
  40. # Therefore, we use a larger tolerance.
  41. if add_residual:
  42. assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
  43. assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
  44. else:
  45. assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)