data_module.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
  2. from pytorch_lightning import LightningDataModule
  3. from AR.data.bucket_sampler import DistributedBucketSampler
  4. from AR.data.dataset import Text2SemanticDataset
  5. from torch.utils.data import DataLoader
  6. class Text2SemanticDataModule(LightningDataModule):
  7. def __init__(
  8. self,
  9. config,
  10. train_semantic_path,
  11. train_phoneme_path,
  12. dev_semantic_path=None,
  13. dev_phoneme_path=None,
  14. ):
  15. super().__init__()
  16. self.config = config
  17. self.train_semantic_path = train_semantic_path
  18. self.train_phoneme_path = train_phoneme_path
  19. self.dev_semantic_path = dev_semantic_path
  20. self.dev_phoneme_path = dev_phoneme_path
  21. self.num_workers = self.config["data"]["num_workers"]
  22. def prepare_data(self):
  23. pass
  24. def setup(self, stage=None, output_logs=False):
  25. self._train_dataset = Text2SemanticDataset(
  26. phoneme_path=self.train_phoneme_path,
  27. semantic_path=self.train_semantic_path,
  28. max_sec=self.config["data"]["max_sec"],
  29. pad_val=self.config["data"]["pad_val"],
  30. )
  31. self._dev_dataset = self._train_dataset
  32. # self._dev_dataset = Text2SemanticDataset(
  33. # phoneme_path=self.dev_phoneme_path,
  34. # semantic_path=self.dev_semantic_path,
  35. # max_sample=self.config['data']['max_eval_sample'],
  36. # max_sec=self.config['data']['max_sec'],
  37. # pad_val=self.config['data']['pad_val'])
  38. def train_dataloader(self):
  39. batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
  40. batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
  41. sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
  42. return DataLoader(
  43. self._train_dataset,
  44. batch_size=batch_size,
  45. sampler=sampler,
  46. collate_fn=self._train_dataset.collate,
  47. num_workers=self.num_workers,
  48. persistent_workers=True,
  49. prefetch_factor=16,
  50. )
  51. def val_dataloader(self):
  52. return DataLoader(
  53. self._dev_dataset,
  54. batch_size=1,
  55. shuffle=False,
  56. collate_fn=self._train_dataset.collate,
  57. num_workers=max(self.num_workers, 12),
  58. persistent_workers=True,
  59. prefetch_factor=16,
  60. )
  61. # 这个会使用到嘛?
  62. def test_dataloader(self):
  63. return DataLoader(
  64. self._dev_dataset,
  65. batch_size=1,
  66. shuffle=False,
  67. collate_fn=self._train_dataset.collate,
  68. )