123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/data_module.py
- from pytorch_lightning import LightningDataModule
- from AR.data.bucket_sampler import DistributedBucketSampler
- from AR.data.dataset import Text2SemanticDataset
- from torch.utils.data import DataLoader
- class Text2SemanticDataModule(LightningDataModule):
- def __init__(
- self,
- config,
- train_semantic_path,
- train_phoneme_path,
- dev_semantic_path=None,
- dev_phoneme_path=None,
- ):
- super().__init__()
- self.config = config
- self.train_semantic_path = train_semantic_path
- self.train_phoneme_path = train_phoneme_path
- self.dev_semantic_path = dev_semantic_path
- self.dev_phoneme_path = dev_phoneme_path
- self.num_workers = self.config["data"]["num_workers"]
- def prepare_data(self):
- pass
- def setup(self, stage=None, output_logs=False):
- self._train_dataset = Text2SemanticDataset(
- phoneme_path=self.train_phoneme_path,
- semantic_path=self.train_semantic_path,
- max_sec=self.config["data"]["max_sec"],
- pad_val=self.config["data"]["pad_val"],
- )
- self._dev_dataset = self._train_dataset
- # self._dev_dataset = Text2SemanticDataset(
- # phoneme_path=self.dev_phoneme_path,
- # semantic_path=self.dev_semantic_path,
- # max_sample=self.config['data']['max_eval_sample'],
- # max_sec=self.config['data']['max_sec'],
- # pad_val=self.config['data']['pad_val'])
- def train_dataloader(self):
- batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
- batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
- sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
- return DataLoader(
- self._train_dataset,
- batch_size=batch_size,
- sampler=sampler,
- collate_fn=self._train_dataset.collate,
- num_workers=self.num_workers,
- persistent_workers=True,
- prefetch_factor=16,
- )
- def val_dataloader(self):
- return DataLoader(
- self._dev_dataset,
- batch_size=1,
- shuffle=False,
- collate_fn=self._train_dataset.collate,
- num_workers=max(self.num_workers, 12),
- persistent_workers=True,
- prefetch_factor=16,
- )
- # 这个会使用到嘛?
- def test_dataloader(self):
- return DataLoader(
- self._dev_dataset,
- batch_size=1,
- shuffle=False,
- collate_fn=self._train_dataset.collate,
- )
|