123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- import os
- from pathlib import Path
- current_dir = Path(__file__).parent.absolute()
- import pytest
- import torch
- import dotenv
- from src.datamodules.language_modeling_hf import LMDataModule
- # load environment variables from `.env` file if it exists
- # recursively searches for `.env` in all folders starting from work dir
- dotenv.load_dotenv(override=True)
- def div_up(x: int, y: int) -> int:
- return (x + y - 1) // y
- # https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170
- def num_cpu_cores():
- try:
- import psutil
- return psutil.cpu_count(logical=False)
- except ImportError:
- return len(os.sched_getaffinity(0))
- class TestLMDataModule:
- def test_wikitext2(self):
- batch_size = 7
- dataset_name = 'wikitext'
- dataset_config_name = 'wikitext-2-raw-v1'
- data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
- cache_dir = data_dir / 'wikitext-2' / 'cache'
- max_length = 1024
- datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
- dataset_config_name=dataset_config_name,
- max_length=max_length, cache_dir=cache_dir,
- add_eos=False, batch_size=batch_size, num_workers=4)
- datamodule.prepare_data()
- datamodule.setup(stage='fit')
- train_loader = datamodule.train_dataloader()
- val_loader = datamodule.val_dataloader()
- datamodule.setup(stage='test')
- test_loader = datamodule.test_dataloader()
- train_len = 2391884
- val_len = 247289
- test_len = 283287
- assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
- assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
- assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
- for loader in [train_loader, val_loader, test_loader]:
- x, y = next(iter(loader))
- assert x.dim() == 2
- assert x.shape == (batch_size, max_length)
- assert x.dtype == torch.long
- assert torch.allclose(x[:, 1:], y[:, :-1])
- def test_wikitext103(self):
- batch_size = 7
- dataset_name = 'wikitext'
- dataset_config_name = 'wikitext-103-raw-v1'
- data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
- cache_dir = data_dir / 'wikitext-103' / 'cache'
- max_length = 1024
- datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
- dataset_config_name=dataset_config_name,
- max_length=max_length, cache_dir=cache_dir,
- add_eos=False, batch_size=batch_size, num_workers=4)
- datamodule.prepare_data()
- datamodule.setup(stage='fit')
- train_loader = datamodule.train_dataloader()
- val_loader = datamodule.val_dataloader()
- datamodule.setup(stage='test')
- test_loader = datamodule.test_dataloader()
- train_len = 117920140
- val_len = 247289
- test_len = 283287
- assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
- assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
- assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
- for loader in [train_loader, val_loader, test_loader]:
- x, y = next(iter(loader))
- assert x.dim() == 2
- assert x.shape == (batch_size, max_length)
- assert x.dtype == torch.long
- assert torch.allclose(x[:, 1:], y[:, :-1])
- def test_openwebtext(self):
- batch_size = 8
- dataset_name = 'openwebtext'
- dataset_config_name = None
- data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
- cache_dir = data_dir / 'openwebtext' / 'cache'
- max_length = 1024
- datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
- dataset_config_name=dataset_config_name,
- max_length=max_length, cache_dir=cache_dir,
- add_eos=True, batch_size=batch_size,
- num_workers=num_cpu_cores() // 2)
- datamodule.prepare_data()
- datamodule.setup(stage='fit')
- train_loader = datamodule.train_dataloader()
- val_loader = datamodule.val_dataloader()
- datamodule.setup(stage='test')
- test_loader = datamodule.test_dataloader()
- train_len = 9035582198
- val_len = 4434897
- test_len = 4434897
- assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
- assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
- assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
- for loader in [train_loader, val_loader, test_loader]:
- x, y = next(iter(loader))
- assert x.dim() == 2
- assert x.shape == (batch_size, max_length)
- assert x.dtype == torch.long
- assert torch.allclose(x[:, 1:], y[:, :-1])
- def test_lambada(self):
- batch_size = 8
- dataset_name = 'lambada'
- dataset_config_name = None
- data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
- cache_dir = data_dir / 'lambada' / 'cache'
- max_length = 1024
- datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
- dataset_config_name=dataset_config_name,
- max_length=max_length, cache_dir=cache_dir,
- add_eos=True, batch_size=batch_size,
- num_workers=64)
- datamodule.prepare_data()
- datamodule.setup(stage='fit')
- train_loader = datamodule.train_dataloader()
- val_loader = datamodule.val_dataloader()
- datamodule.setup(stage='test')
- test_loader = datamodule.test_dataloader()
- train_len = 9035582198
- val_len = 4434897
- test_len = 4434897
- assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
- assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
- assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
- for loader in [train_loader, val_loader, test_loader]:
- x, y = next(iter(loader))
- assert x.dim() == 2
- assert x.shape == (batch_size, max_length)
- assert x.dtype == torch.long
- assert torch.allclose(x[:, 1:], y[:, :-1])
- def test_the_pile(self):
- batch_size = 8
- dataset_name = 'the_pile'
- dataset_config_name = None
- data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
- cache_dir = data_dir / 'the_pile' / 'cache'
- max_length = 2048
- # Dataset is too large to fit into memory, need to use disk for concatenation
- datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
- dataset_config_name=dataset_config_name,
- max_length=max_length, cache_dir=cache_dir,
- add_eos=True, batch_size=batch_size,
- num_workers=num_cpu_cores() // 2, use_shmem=False)
- datamodule.prepare_data()
- datamodule.setup(stage='fit')
- train_loader = datamodule.train_dataloader()
- val_loader = datamodule.val_dataloader()
- datamodule.setup(stage='test')
- test_loader = datamodule.test_dataloader()
- train_len = 374337375694
- val_len = 383326395
- test_len = 373297018
- assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
- assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
- assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
- for loader in [train_loader, val_loader, test_loader]:
- x, y = next(iter(loader))
- assert x.dim() == 2
- assert x.shape == (batch_size, max_length)
- assert x.dtype == torch.long
- assert torch.allclose(x[:, 1:], y[:, :-1])
- def test_pg19(self):
- batch_size = 8
- dataset_name = 'pg19'
- dataset_config_name = None
- data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
- cache_dir = data_dir / 'pg19' / 'cache'
- max_length = 2048
- # Dataset is too large to fit into memory, need to use disk for concatenation
- datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
- dataset_config_name=dataset_config_name,
- max_length=max_length, cache_dir=cache_dir,
- add_eos=True, batch_size=batch_size,
- num_workers=num_cpu_cores() // 2)
- datamodule.prepare_data()
- datamodule.setup(stage='fit')
- train_loader = datamodule.train_dataloader()
- val_loader = datamodule.val_dataloader()
- datamodule.setup(stage='test')
- test_loader = datamodule.test_dataloader()
- train_len = 3066544128
- val_len = 4653056
- test_len = 10584064
- assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
- assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
- assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
- for loader in [train_loader, val_loader, test_loader]:
- x, y = next(iter(loader))
- assert x.dim() == 2
- assert x.shape == (batch_size, max_length)
- assert x.dtype == torch.long
- assert torch.allclose(x[:, 1:], y[:, :-1])
|