# @package _global_ defaults: - override /trainer: default # choose trainer from 'configs/trainer/' - override /model: null - override /datamodule: openwebtext # FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time # per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms. # For GPT2-medium time per global goes from 997ms to 972ms. - override /optimizer: adamw-apex - override /scheduler: linear-warmup - override /callbacks: [default, norm-monitor] - override /metrics: [perplexity, num-tokens] - override /logger: wandb # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters task: _target_: src.tasks.seq.SequenceLMModel seed: 1111 trainer: accelerator: gpu devices: 8 num_nodes: 1 accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}} max_steps: 400000 val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}} check_val_every_n_epoch: null # We don't care about epoch boundary precision: 16 gradient_clip_val: 1.0 strategy: null datamodule: batch_size: 16 # Per GPU batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k max_length: 1024 fault_tolerant: True ddp: ${eval:"${trainer.devices} > 1"} train: gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} global_batch_size: 512 optimizer: lr: 6e-4 weight_decay: 0.1 optimizer_param_grouping: bias_weight_decay: False normalization_weight_decay: False scheduler: num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}} num_training_steps: ${trainer.max_steps} loss_fn: # This is faster and uses less memory than torch.nn.CrossEntropyLoss. # It's also more numerically stable if we're using DeepSpeed 16 bits. _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss inplace_backward: True # to save memory eval: log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step callbacks: model_checkpoint: monitor: val/loss mode: min save_top_k: 3 save_last: True every_n_train_steps: 1000 dirpath: ${work_dir}/checkpoints/${oc.select:name,''} filename: step_{step} auto_insert_metric_name: False model_checkpoint_progress: _target_: src.callbacks.model_checkpoint.ModelCheckpointMine fault_tolerant: True every_n_train_steps: 50000 save_last: False save_top_k: -1 # Save all the checkpoints dirpath: ${..model_checkpoint.dirpath} filename: progress_step_{step} auto_insert_metric_name: False early_stopping: null