123456789101112131415161718192021222324 |
- # @package _global_
- defaults:
- - /experiment/pile/gpt3s-flash.yaml
- - override /optimizer: adamw-zero
- model:
- config:
- n_embd: 1536
- n_head: 16
- n_layer: 24
- # mlp_checkpoint_lvl: 1 # To fit batch_size 8
- datamodule:
- batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"}
- train:
- optimizer:
- lr: 2.5e-4
- trainer:
- strategy:
- _target_: src.utils.ddp_zero1.DDPStrategyZero1
- find_unused_parameters: False
- gradient_as_bucket_view: True
|