permuter.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. class AbstractPermuter(nn.Module):
  5. def __init__(self, *args, **kwargs):
  6. super().__init__()
  7. def forward(self, x, reverse=False):
  8. raise NotImplementedError
  9. class Identity(AbstractPermuter):
  10. def __init__(self):
  11. super().__init__()
  12. def forward(self, x, reverse=False):
  13. return x
  14. class Subsample(AbstractPermuter):
  15. def __init__(self, H, W):
  16. super().__init__()
  17. C = 1
  18. indices = np.arange(H*W).reshape(C,H,W)
  19. while min(H, W) > 1:
  20. indices = indices.reshape(C,H//2,2,W//2,2)
  21. indices = indices.transpose(0,2,4,1,3)
  22. indices = indices.reshape(C*4,H//2, W//2)
  23. H = H//2
  24. W = W//2
  25. C = C*4
  26. assert H == W == 1
  27. idx = torch.tensor(indices.ravel())
  28. self.register_buffer('forward_shuffle_idx',
  29. nn.Parameter(idx, requires_grad=False))
  30. self.register_buffer('backward_shuffle_idx',
  31. nn.Parameter(torch.argsort(idx), requires_grad=False))
  32. def forward(self, x, reverse=False):
  33. if not reverse:
  34. return x[:, self.forward_shuffle_idx]
  35. else:
  36. return x[:, self.backward_shuffle_idx]
  37. def mortonify(i, j):
  38. """(i,j) index to linear morton code"""
  39. i = np.uint64(i)
  40. j = np.uint64(j)
  41. z = np.uint(0)
  42. for pos in range(32):
  43. z = (z |
  44. ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
  45. ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
  46. )
  47. return z
  48. class ZCurve(AbstractPermuter):
  49. def __init__(self, H, W):
  50. super().__init__()
  51. reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
  52. idx = np.argsort(reverseidx)
  53. idx = torch.tensor(idx)
  54. reverseidx = torch.tensor(reverseidx)
  55. self.register_buffer('forward_shuffle_idx',
  56. idx)
  57. self.register_buffer('backward_shuffle_idx',
  58. reverseidx)
  59. def forward(self, x, reverse=False):
  60. if not reverse:
  61. return x[:, self.forward_shuffle_idx]
  62. else:
  63. return x[:, self.backward_shuffle_idx]
  64. class SpiralOut(AbstractPermuter):
  65. def __init__(self, H, W):
  66. super().__init__()
  67. assert H == W
  68. size = W
  69. indices = np.arange(size*size).reshape(size,size)
  70. i0 = size//2
  71. j0 = size//2-1
  72. i = i0
  73. j = j0
  74. idx = [indices[i0, j0]]
  75. step_mult = 0
  76. for c in range(1, size//2+1):
  77. step_mult += 1
  78. # steps left
  79. for k in range(step_mult):
  80. i = i - 1
  81. j = j
  82. idx.append(indices[i, j])
  83. # step down
  84. for k in range(step_mult):
  85. i = i
  86. j = j + 1
  87. idx.append(indices[i, j])
  88. step_mult += 1
  89. if c < size//2:
  90. # step right
  91. for k in range(step_mult):
  92. i = i + 1
  93. j = j
  94. idx.append(indices[i, j])
  95. # step up
  96. for k in range(step_mult):
  97. i = i
  98. j = j - 1
  99. idx.append(indices[i, j])
  100. else:
  101. # end reached
  102. for k in range(step_mult-1):
  103. i = i + 1
  104. idx.append(indices[i, j])
  105. assert len(idx) == size*size
  106. idx = torch.tensor(idx)
  107. self.register_buffer('forward_shuffle_idx', idx)
  108. self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
  109. def forward(self, x, reverse=False):
  110. if not reverse:
  111. return x[:, self.forward_shuffle_idx]
  112. else:
  113. return x[:, self.backward_shuffle_idx]
  114. class SpiralIn(AbstractPermuter):
  115. def __init__(self, H, W):
  116. super().__init__()
  117. assert H == W
  118. size = W
  119. indices = np.arange(size*size).reshape(size,size)
  120. i0 = size//2
  121. j0 = size//2-1
  122. i = i0
  123. j = j0
  124. idx = [indices[i0, j0]]
  125. step_mult = 0
  126. for c in range(1, size//2+1):
  127. step_mult += 1
  128. # steps left
  129. for k in range(step_mult):
  130. i = i - 1
  131. j = j
  132. idx.append(indices[i, j])
  133. # step down
  134. for k in range(step_mult):
  135. i = i
  136. j = j + 1
  137. idx.append(indices[i, j])
  138. step_mult += 1
  139. if c < size//2:
  140. # step right
  141. for k in range(step_mult):
  142. i = i + 1
  143. j = j
  144. idx.append(indices[i, j])
  145. # step up
  146. for k in range(step_mult):
  147. i = i
  148. j = j - 1
  149. idx.append(indices[i, j])
  150. else:
  151. # end reached
  152. for k in range(step_mult-1):
  153. i = i + 1
  154. idx.append(indices[i, j])
  155. assert len(idx) == size*size
  156. idx = idx[::-1]
  157. idx = torch.tensor(idx)
  158. self.register_buffer('forward_shuffle_idx', idx)
  159. self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
  160. def forward(self, x, reverse=False):
  161. if not reverse:
  162. return x[:, self.forward_shuffle_idx]
  163. else:
  164. return x[:, self.backward_shuffle_idx]
  165. class Random(nn.Module):
  166. def __init__(self, H, W):
  167. super().__init__()
  168. indices = np.random.RandomState(1).permutation(H*W)
  169. idx = torch.tensor(indices.ravel())
  170. self.register_buffer('forward_shuffle_idx', idx)
  171. self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
  172. def forward(self, x, reverse=False):
  173. if not reverse:
  174. return x[:, self.forward_shuffle_idx]
  175. else:
  176. return x[:, self.backward_shuffle_idx]
  177. class AlternateParsing(AbstractPermuter):
  178. def __init__(self, H, W):
  179. super().__init__()
  180. indices = np.arange(W*H).reshape(H,W)
  181. for i in range(1, H, 2):
  182. indices[i, :] = indices[i, ::-1]
  183. idx = indices.flatten()
  184. assert len(idx) == H*W
  185. idx = torch.tensor(idx)
  186. self.register_buffer('forward_shuffle_idx', idx)
  187. self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
  188. def forward(self, x, reverse=False):
  189. if not reverse:
  190. return x[:, self.forward_shuffle_idx]
  191. else:
  192. return x[:, self.backward_shuffle_idx]
  193. if __name__ == "__main__":
  194. p0 = AlternateParsing(16, 16)
  195. print(p0.forward_shuffle_idx)
  196. print(p0.backward_shuffle_idx)
  197. x = torch.randint(0, 768, size=(11, 256))
  198. y = p0(x)
  199. xre = p0(y, reverse=True)
  200. assert torch.equal(x, xre)
  201. p1 = SpiralOut(2, 2)
  202. print(p1.forward_shuffle_idx)
  203. print(p1.backward_shuffle_idx)