# @package _global_ defaults: - override /trainer: default # choose trainer from 'configs/trainer/' - override /model: null - override /datamodule: thepile - override /optimizer: adamw-apex # slight speedup (1-2%) over Pytorch AdamW - override /scheduler: cosine-warmup-timm - override /callbacks: [default, norm-monitor] - override /metrics: [perplexity, num-tokens] - override /logger: wandb # all parameters below will be merged with parameters from default configurations set above # this allows you to overwrite only specified parameters task: _target_: src.tasks.seq.SequenceLMModel seed: 1111 trainer: accelerator: gpu devices: 8 num_nodes: 1 accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}} max_steps: 800000 val_check_interval: ${eval:2000 * ${.accumulate_grad_batches}} check_val_every_n_epoch: null # We don't care about epoch boundary precision: bf16 gradient_clip_val: 1.0 strategy: null datamodule: batch_size: 16 # Per GPU batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k max_length: 2048 fault_tolerant: True ddp: ${eval:"${trainer.devices} > 1"} train: gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"} global_batch_size: 256 optimizer: lr: 6e-4 weight_decay: 0.1 optimizer_param_grouping: bias_weight_decay: False normalization_weight_decay: False scheduler: t_in_epochs: False t_initial: 600000 warmup_lr_init: 1e-6 warmup_t: ${eval:0.01 * ${trainer.max_steps}} lr_min: ${eval:0.1 * ${train.optimizer.lr}} loss_fn: # This is faster and uses less memory than torch.nn.CrossEntropyLoss. # It's also more numerically stable if we're using DeepSpeed 16 bits. _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss inplace_backward: True # to save memory eval: log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step callbacks: model_checkpoint: monitor: val/loss mode: min save_top_k: 3 save_last: True every_n_train_steps: 1000 dirpath: ${work_dir}/checkpoints/${oc.select:name,''} filename: step_{step} auto_insert_metric_name: False model_checkpoint_progress: _target_: src.callbacks.model_checkpoint.ModelCheckpointMine # fault_tolerant: True # The .pl_auto_save.ckpt doesn't get saved by all workers every_n_train_steps: 50000 save_last: False save_top_k: -1 # Save all the checkpoints dirpath: ${..model_checkpoint.dirpath} filename: progress_step_{step} auto_insert_metric_name: False early_stopping: null