test_permute_cols.py 434 B

12345678910111213
  1. import pytest
  2. import torch
  3. from aphrodite._custom_ops import permute_cols
  4. @pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
  5. @pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
  6. def test_permute_cols(shape, dtype):
  7. x = torch.randn(shape, dtype=dtype).cuda()
  8. perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
  9. y = permute_cols(x, perm)
  10. torch.testing.assert_close(y, x[:, perm])