gpt2l-flash.yaml 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # @package _global_
  2. defaults:
  3. - /experiment/owt/gpt2m-flash.yaml
  4. - override /model/gpt2model: gpt2-large
  5. # TD [2022-08-03] Surprisingly it's faster to use the ZeRO optimizer than just AdamW.
  6. # Still, fairscale is even faster and uses less memory.
  7. # I think it's because Pytorch is using ZeRO stage 1 and fairscale is using ZeRO stage 2?
  8. # However, fairscale has issues with saving checkpoint (either OOM or very
  9. # slow since it goes through the CPU?). Fairscale says Pytorch ZeRO is the
  10. # upstream version of OSS
  11. # https://github.com/facebookresearch/fairscale/issues/937
  12. # Pytorch ZeRO as also very slow for saving checkpoints due to
  13. # consolidate_state_dict(), but I've fixed it to save separate checkpoint per GPU.
  14. - override /optimizer: adamw-zero
  15. # FusedAdam doesn't seem to speed things up here, time per global step
  16. # (i.e. batch size 512) on 8 A100s is around 2056ms for both AdamW and FusedAdam.
  17. # This could be because each GPU is only doing the optimizer step for 1 /
  18. # world_size of the parameters.
  19. # Maybe the bottleneck here is the NCCL call to exchange parameters (ZeRO).
  20. # - override /optimizer: adamw-apex-zero
  21. # Can enable mlp_chekcpoint_lvl to fit batch_size 16 on A100 40GB
  22. # model:
  23. # config:
  24. # # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"}
  25. # mlp_checkpoint_lvl: 1
  26. datamodule:
  27. # batch_size: 16
  28. batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"}
  29. trainer:
  30. # strategy: null
  31. # strategy: ${eval:"None if ${trainer.devices} == 1 else 'ddp_sharded'"}
  32. strategy:
  33. _target_: src.utils.ddp_zero1.DDPStrategyZero1
  34. find_unused_parameters: False
  35. gradient_as_bucket_view: True
  36. # TD [2022-08-03] Deepspeed makes the ppl curve go wild
  37. # strategy: deepspeed_stage_1