# @package _global_ defaults: - /experiment/pile/gpt3xl-flash.yaml model: config: n_embd: 2560 n_head: 20 # Headdim 128 is faster than headdim 80 n_layer: 32 initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"} mlp_checkpoint_lvl: 0 datamodule: batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"} train: optimizer: lr: 1.6e-4