flow_util.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import torch
  2. def convert_flow_to_deformation(flow):
  3. r"""convert flow fields to deformations.
  4. Args:
  5. flow (tensor): Flow field obtained by the model
  6. Returns:
  7. deformation (tensor): The deformation used for warpping
  8. """
  9. b,c,h,w = flow.shape
  10. flow_norm = 2 * torch.cat([flow[:,:1,...]/(w-1),flow[:,1:,...]/(h-1)], 1)
  11. grid = make_coordinate_grid(flow)
  12. deformation = grid + flow_norm.permute(0,2,3,1)
  13. return deformation
  14. def make_coordinate_grid(flow):
  15. r"""obtain coordinate grid with the same size as the flow filed.
  16. Args:
  17. flow (tensor): Flow field obtained by the model
  18. Returns:
  19. grid (tensor): The grid with the same size as the input flow
  20. """
  21. b,c,h,w = flow.shape
  22. x = torch.arange(w).to(flow)
  23. y = torch.arange(h).to(flow)
  24. x = (2 * (x / (w - 1)) - 1)
  25. y = (2 * (y / (h - 1)) - 1)
  26. yy = y.view(-1, 1).repeat(1, w)
  27. xx = x.view(1, -1).repeat(h, 1)
  28. meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
  29. meshed = meshed.expand(b, -1, -1, -1)
  30. return meshed
  31. def warp_image(source_image, deformation):
  32. r"""warp the input image according to the deformation
  33. Args:
  34. source_image (tensor): source images to be warpped
  35. deformation (tensor): deformations used to warp the images; value in range (-1, 1)
  36. Returns:
  37. output (tensor): the warpped images
  38. """
  39. _, h_old, w_old, _ = deformation.shape
  40. _, _, h, w = source_image.shape
  41. if h_old != h or w_old != w:
  42. deformation = deformation.permute(0, 3, 1, 2)
  43. deformation = torch.nn.functional.interpolate(deformation, size=(h, w), mode='bilinear')
  44. deformation = deformation.permute(0, 2, 3, 1)
  45. return torch.nn.functional.grid_sample(source_image, deformation)