2
0

custom.py 998 B

1234567891011121314151617181920212223242526272829303132333435363738
  1. import os
  2. import numpy as np
  3. import albumentations
  4. from torch.utils.data import Dataset
  5. from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
  6. class CustomBase(Dataset):
  7. def __init__(self, *args, **kwargs):
  8. super().__init__()
  9. self.data = None
  10. def __len__(self):
  11. return len(self.data)
  12. def __getitem__(self, i):
  13. example = self.data[i]
  14. return example
  15. class CustomTrain(CustomBase):
  16. def __init__(self, size, training_images_list_file):
  17. super().__init__()
  18. with open(training_images_list_file, "r") as f:
  19. paths = f.read().splitlines()
  20. self.data = ImagePaths(paths=paths, size=size, random_crop=False)
  21. class CustomTest(CustomBase):
  22. def __init__(self, size, test_images_list_file):
  23. super().__init__()
  24. with open(test_images_list_file, "r") as f:
  25. paths = f.read().splitlines()
  26. self.data = ImagePaths(paths=paths, size=size, random_crop=False)