gpt3l-flash.yaml 538 B

123456789101112131415161718192021222324
  1. # @package _global_
  2. defaults:
  3. - /experiment/pile/gpt3s-flash.yaml
  4. - override /optimizer: adamw-zero
  5. model:
  6. config:
  7. n_embd: 1536
  8. n_head: 16
  9. n_layer: 24
  10. # mlp_checkpoint_lvl: 1 # To fit batch_size 8
  11. datamodule:
  12. batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"}
  13. train:
  14. optimizer:
  15. lr: 2.5e-4
  16. trainer:
  17. strategy:
  18. _target_: src.utils.ddp_zero1.DDPStrategyZero1
  19. find_unused_parameters: False
  20. gradient_as_bucket_view: True