coord.py 904 B

12345678910111213141516171819202122232425262728293031
  1. import torch
  2. class CoordStage(object):
  3. def __init__(self, n_embed, down_factor):
  4. self.n_embed = n_embed
  5. self.down_factor = down_factor
  6. def eval(self):
  7. return self
  8. def encode(self, c):
  9. """fake vqmodel interface"""
  10. assert 0.0 <= c.min() and c.max() <= 1.0
  11. b,ch,h,w = c.shape
  12. assert ch == 1
  13. c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
  14. mode="area")
  15. c = c.clamp(0.0, 1.0)
  16. c = self.n_embed*c
  17. c_quant = c.round()
  18. c_ind = c_quant.to(dtype=torch.long)
  19. info = None, None, c_ind
  20. return c_quant, None, info
  21. def decode(self, c):
  22. c = c/self.n_embed
  23. c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
  24. mode="nearest")
  25. return c