123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248 |
- import torch
- import torch.nn as nn
- import numpy as np
- class AbstractPermuter(nn.Module):
- def __init__(self, *args, **kwargs):
- super().__init__()
- def forward(self, x, reverse=False):
- raise NotImplementedError
- class Identity(AbstractPermuter):
- def __init__(self):
- super().__init__()
- def forward(self, x, reverse=False):
- return x
- class Subsample(AbstractPermuter):
- def __init__(self, H, W):
- super().__init__()
- C = 1
- indices = np.arange(H*W).reshape(C,H,W)
- while min(H, W) > 1:
- indices = indices.reshape(C,H//2,2,W//2,2)
- indices = indices.transpose(0,2,4,1,3)
- indices = indices.reshape(C*4,H//2, W//2)
- H = H//2
- W = W//2
- C = C*4
- assert H == W == 1
- idx = torch.tensor(indices.ravel())
- self.register_buffer('forward_shuffle_idx',
- nn.Parameter(idx, requires_grad=False))
- self.register_buffer('backward_shuffle_idx',
- nn.Parameter(torch.argsort(idx), requires_grad=False))
- def forward(self, x, reverse=False):
- if not reverse:
- return x[:, self.forward_shuffle_idx]
- else:
- return x[:, self.backward_shuffle_idx]
- def mortonify(i, j):
- """(i,j) index to linear morton code"""
- i = np.uint64(i)
- j = np.uint64(j)
- z = np.uint(0)
- for pos in range(32):
- z = (z |
- ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
- ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
- )
- return z
- class ZCurve(AbstractPermuter):
- def __init__(self, H, W):
- super().__init__()
- reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
- idx = np.argsort(reverseidx)
- idx = torch.tensor(idx)
- reverseidx = torch.tensor(reverseidx)
- self.register_buffer('forward_shuffle_idx',
- idx)
- self.register_buffer('backward_shuffle_idx',
- reverseidx)
- def forward(self, x, reverse=False):
- if not reverse:
- return x[:, self.forward_shuffle_idx]
- else:
- return x[:, self.backward_shuffle_idx]
- class SpiralOut(AbstractPermuter):
- def __init__(self, H, W):
- super().__init__()
- assert H == W
- size = W
- indices = np.arange(size*size).reshape(size,size)
- i0 = size//2
- j0 = size//2-1
- i = i0
- j = j0
- idx = [indices[i0, j0]]
- step_mult = 0
- for c in range(1, size//2+1):
- step_mult += 1
- # steps left
- for k in range(step_mult):
- i = i - 1
- j = j
- idx.append(indices[i, j])
- # step down
- for k in range(step_mult):
- i = i
- j = j + 1
- idx.append(indices[i, j])
- step_mult += 1
- if c < size//2:
- # step right
- for k in range(step_mult):
- i = i + 1
- j = j
- idx.append(indices[i, j])
- # step up
- for k in range(step_mult):
- i = i
- j = j - 1
- idx.append(indices[i, j])
- else:
- # end reached
- for k in range(step_mult-1):
- i = i + 1
- idx.append(indices[i, j])
- assert len(idx) == size*size
- idx = torch.tensor(idx)
- self.register_buffer('forward_shuffle_idx', idx)
- self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
- def forward(self, x, reverse=False):
- if not reverse:
- return x[:, self.forward_shuffle_idx]
- else:
- return x[:, self.backward_shuffle_idx]
- class SpiralIn(AbstractPermuter):
- def __init__(self, H, W):
- super().__init__()
- assert H == W
- size = W
- indices = np.arange(size*size).reshape(size,size)
- i0 = size//2
- j0 = size//2-1
- i = i0
- j = j0
- idx = [indices[i0, j0]]
- step_mult = 0
- for c in range(1, size//2+1):
- step_mult += 1
- # steps left
- for k in range(step_mult):
- i = i - 1
- j = j
- idx.append(indices[i, j])
- # step down
- for k in range(step_mult):
- i = i
- j = j + 1
- idx.append(indices[i, j])
- step_mult += 1
- if c < size//2:
- # step right
- for k in range(step_mult):
- i = i + 1
- j = j
- idx.append(indices[i, j])
- # step up
- for k in range(step_mult):
- i = i
- j = j - 1
- idx.append(indices[i, j])
- else:
- # end reached
- for k in range(step_mult-1):
- i = i + 1
- idx.append(indices[i, j])
- assert len(idx) == size*size
- idx = idx[::-1]
- idx = torch.tensor(idx)
- self.register_buffer('forward_shuffle_idx', idx)
- self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
- def forward(self, x, reverse=False):
- if not reverse:
- return x[:, self.forward_shuffle_idx]
- else:
- return x[:, self.backward_shuffle_idx]
- class Random(nn.Module):
- def __init__(self, H, W):
- super().__init__()
- indices = np.random.RandomState(1).permutation(H*W)
- idx = torch.tensor(indices.ravel())
- self.register_buffer('forward_shuffle_idx', idx)
- self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
- def forward(self, x, reverse=False):
- if not reverse:
- return x[:, self.forward_shuffle_idx]
- else:
- return x[:, self.backward_shuffle_idx]
- class AlternateParsing(AbstractPermuter):
- def __init__(self, H, W):
- super().__init__()
- indices = np.arange(W*H).reshape(H,W)
- for i in range(1, H, 2):
- indices[i, :] = indices[i, ::-1]
- idx = indices.flatten()
- assert len(idx) == H*W
- idx = torch.tensor(idx)
- self.register_buffer('forward_shuffle_idx', idx)
- self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
- def forward(self, x, reverse=False):
- if not reverse:
- return x[:, self.forward_shuffle_idx]
- else:
- return x[:, self.backward_shuffle_idx]
- if __name__ == "__main__":
- p0 = AlternateParsing(16, 16)
- print(p0.forward_shuffle_idx)
- print(p0.backward_shuffle_idx)
- x = torch.randint(0, 768, size=(11, 256))
- y = p0(x)
- xre = p0(y, reverse=True)
- assert torch.equal(x, xre)
- p1 = SpiralOut(2, 2)
- print(p1.forward_shuffle_idx)
- print(p1.backward_shuffle_idx)
|