test_language_modeling_hf.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import os
  2. from pathlib import Path
  3. current_dir = Path(__file__).parent.absolute()
  4. import pytest
  5. import torch
  6. import dotenv
  7. from src.datamodules.language_modeling_hf import LMDataModule
  8. # load environment variables from `.env` file if it exists
  9. # recursively searches for `.env` in all folders starting from work dir
  10. dotenv.load_dotenv(override=True)
  11. def div_up(x: int, y: int) -> int:
  12. return (x + y - 1) // y
  13. # https://stackoverflow.com/questions/1006289/how-to-find-out-the-number-of-cpus-using-python/55423170#55423170
  14. def num_cpu_cores():
  15. try:
  16. import psutil
  17. return psutil.cpu_count(logical=False)
  18. except ImportError:
  19. return len(os.sched_getaffinity(0))
  20. class TestLMDataModule:
  21. def test_wikitext2(self):
  22. batch_size = 7
  23. dataset_name = 'wikitext'
  24. dataset_config_name = 'wikitext-2-raw-v1'
  25. data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
  26. cache_dir = data_dir / 'wikitext-2' / 'cache'
  27. max_length = 1024
  28. datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
  29. dataset_config_name=dataset_config_name,
  30. max_length=max_length, cache_dir=cache_dir,
  31. add_eos=False, batch_size=batch_size, num_workers=4)
  32. datamodule.prepare_data()
  33. datamodule.setup(stage='fit')
  34. train_loader = datamodule.train_dataloader()
  35. val_loader = datamodule.val_dataloader()
  36. datamodule.setup(stage='test')
  37. test_loader = datamodule.test_dataloader()
  38. train_len = 2391884
  39. val_len = 247289
  40. test_len = 283287
  41. assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
  42. assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
  43. assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
  44. for loader in [train_loader, val_loader, test_loader]:
  45. x, y = next(iter(loader))
  46. assert x.dim() == 2
  47. assert x.shape == (batch_size, max_length)
  48. assert x.dtype == torch.long
  49. assert torch.allclose(x[:, 1:], y[:, :-1])
  50. def test_wikitext103(self):
  51. batch_size = 7
  52. dataset_name = 'wikitext'
  53. dataset_config_name = 'wikitext-103-raw-v1'
  54. data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
  55. cache_dir = data_dir / 'wikitext-103' / 'cache'
  56. max_length = 1024
  57. datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
  58. dataset_config_name=dataset_config_name,
  59. max_length=max_length, cache_dir=cache_dir,
  60. add_eos=False, batch_size=batch_size, num_workers=4)
  61. datamodule.prepare_data()
  62. datamodule.setup(stage='fit')
  63. train_loader = datamodule.train_dataloader()
  64. val_loader = datamodule.val_dataloader()
  65. datamodule.setup(stage='test')
  66. test_loader = datamodule.test_dataloader()
  67. train_len = 117920140
  68. val_len = 247289
  69. test_len = 283287
  70. assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
  71. assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
  72. assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
  73. for loader in [train_loader, val_loader, test_loader]:
  74. x, y = next(iter(loader))
  75. assert x.dim() == 2
  76. assert x.shape == (batch_size, max_length)
  77. assert x.dtype == torch.long
  78. assert torch.allclose(x[:, 1:], y[:, :-1])
  79. def test_openwebtext(self):
  80. batch_size = 8
  81. dataset_name = 'openwebtext'
  82. dataset_config_name = None
  83. data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
  84. cache_dir = data_dir / 'openwebtext' / 'cache'
  85. max_length = 1024
  86. datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
  87. dataset_config_name=dataset_config_name,
  88. max_length=max_length, cache_dir=cache_dir,
  89. add_eos=True, batch_size=batch_size,
  90. num_workers=num_cpu_cores() // 2)
  91. datamodule.prepare_data()
  92. datamodule.setup(stage='fit')
  93. train_loader = datamodule.train_dataloader()
  94. val_loader = datamodule.val_dataloader()
  95. datamodule.setup(stage='test')
  96. test_loader = datamodule.test_dataloader()
  97. train_len = 9035582198
  98. val_len = 4434897
  99. test_len = 4434897
  100. assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
  101. assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
  102. assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
  103. for loader in [train_loader, val_loader, test_loader]:
  104. x, y = next(iter(loader))
  105. assert x.dim() == 2
  106. assert x.shape == (batch_size, max_length)
  107. assert x.dtype == torch.long
  108. assert torch.allclose(x[:, 1:], y[:, :-1])
  109. def test_lambada(self):
  110. batch_size = 8
  111. dataset_name = 'lambada'
  112. dataset_config_name = None
  113. data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
  114. cache_dir = data_dir / 'lambada' / 'cache'
  115. max_length = 1024
  116. datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
  117. dataset_config_name=dataset_config_name,
  118. max_length=max_length, cache_dir=cache_dir,
  119. add_eos=True, batch_size=batch_size,
  120. num_workers=64)
  121. datamodule.prepare_data()
  122. datamodule.setup(stage='fit')
  123. train_loader = datamodule.train_dataloader()
  124. val_loader = datamodule.val_dataloader()
  125. datamodule.setup(stage='test')
  126. test_loader = datamodule.test_dataloader()
  127. train_len = 9035582198
  128. val_len = 4434897
  129. test_len = 4434897
  130. assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
  131. assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
  132. assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
  133. for loader in [train_loader, val_loader, test_loader]:
  134. x, y = next(iter(loader))
  135. assert x.dim() == 2
  136. assert x.shape == (batch_size, max_length)
  137. assert x.dtype == torch.long
  138. assert torch.allclose(x[:, 1:], y[:, :-1])
  139. def test_the_pile(self):
  140. batch_size = 8
  141. dataset_name = 'the_pile'
  142. dataset_config_name = None
  143. data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
  144. cache_dir = data_dir / 'the_pile' / 'cache'
  145. max_length = 2048
  146. # Dataset is too large to fit into memory, need to use disk for concatenation
  147. datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
  148. dataset_config_name=dataset_config_name,
  149. max_length=max_length, cache_dir=cache_dir,
  150. add_eos=True, batch_size=batch_size,
  151. num_workers=num_cpu_cores() // 2, use_shmem=False)
  152. datamodule.prepare_data()
  153. datamodule.setup(stage='fit')
  154. train_loader = datamodule.train_dataloader()
  155. val_loader = datamodule.val_dataloader()
  156. datamodule.setup(stage='test')
  157. test_loader = datamodule.test_dataloader()
  158. train_len = 374337375694
  159. val_len = 383326395
  160. test_len = 373297018
  161. assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
  162. assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
  163. assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
  164. for loader in [train_loader, val_loader, test_loader]:
  165. x, y = next(iter(loader))
  166. assert x.dim() == 2
  167. assert x.shape == (batch_size, max_length)
  168. assert x.dtype == torch.long
  169. assert torch.allclose(x[:, 1:], y[:, :-1])
  170. def test_pg19(self):
  171. batch_size = 8
  172. dataset_name = 'pg19'
  173. dataset_config_name = None
  174. data_dir = Path(os.getenv('DATA_DIR', current_dir.parent.parent / 'data'))
  175. cache_dir = data_dir / 'pg19' / 'cache'
  176. max_length = 2048
  177. # Dataset is too large to fit into memory, need to use disk for concatenation
  178. datamodule = LMDataModule(dataset_name, tokenizer_name='gpt2',
  179. dataset_config_name=dataset_config_name,
  180. max_length=max_length, cache_dir=cache_dir,
  181. add_eos=True, batch_size=batch_size,
  182. num_workers=num_cpu_cores() // 2)
  183. datamodule.prepare_data()
  184. datamodule.setup(stage='fit')
  185. train_loader = datamodule.train_dataloader()
  186. val_loader = datamodule.val_dataloader()
  187. datamodule.setup(stage='test')
  188. test_loader = datamodule.test_dataloader()
  189. train_len = 3066544128
  190. val_len = 4653056
  191. test_len = 10584064
  192. assert len(train_loader) == div_up((train_len - 1) // max_length, batch_size)
  193. assert len(val_loader) == div_up((val_len - 1) // max_length, batch_size)
  194. assert len(test_loader) == div_up((test_len - 1) // max_length, batch_size)
  195. for loader in [train_loader, val_loader, test_loader]:
  196. x, y = next(iter(loader))
  197. assert x.dim() == 2
  198. assert x.shape == (batch_size, max_length)
  199. assert x.dtype == torch.long
  200. assert torch.allclose(x[:, 1:], y[:, :-1])