base.yaml 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # @package _global_
  2. defaults:
  3. - override /trainer: default # choose trainer from 'configs/trainer/'
  4. - override /model: null
  5. - override /datamodule: openwebtext
  6. # FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time
  7. # per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms.
  8. # For GPT2-medium time per global goes from 997ms to 972ms.
  9. - override /optimizer: adamw-apex
  10. - override /scheduler: linear-warmup
  11. - override /callbacks: [default, norm-monitor]
  12. - override /metrics: [perplexity, num-tokens]
  13. - override /logger: wandb
  14. # all parameters below will be merged with parameters from default configurations set above
  15. # this allows you to overwrite only specified parameters
  16. task:
  17. _target_: src.tasks.seq.SequenceLMModel
  18. seed: 1111
  19. trainer:
  20. accelerator: gpu
  21. devices: 8
  22. num_nodes: 1
  23. accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
  24. max_steps: 400000
  25. val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
  26. check_val_every_n_epoch: null # We don't care about epoch boundary
  27. precision: 16
  28. gradient_clip_val: 1.0
  29. strategy: null
  30. datamodule:
  31. batch_size: 16 # Per GPU
  32. batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k
  33. max_length: 1024
  34. fault_tolerant: True
  35. ddp: ${eval:"${trainer.devices} > 1"}
  36. train:
  37. gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
  38. global_batch_size: 512
  39. optimizer:
  40. lr: 6e-4
  41. weight_decay: 0.1
  42. optimizer_param_grouping:
  43. bias_weight_decay: False
  44. normalization_weight_decay: False
  45. scheduler:
  46. num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}}
  47. num_training_steps: ${trainer.max_steps}
  48. loss_fn:
  49. # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
  50. # It's also more numerically stable if we're using DeepSpeed 16 bits.
  51. _target_: flash_attn.losses.cross_entropy.CrossEntropyLoss
  52. inplace_backward: True # to save memory
  53. eval:
  54. log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step
  55. callbacks:
  56. model_checkpoint:
  57. monitor: val/loss
  58. mode: min
  59. save_top_k: 3
  60. save_last: True
  61. every_n_train_steps: 1000
  62. dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
  63. filename: step_{step}
  64. auto_insert_metric_name: False
  65. model_checkpoint_progress:
  66. _target_: src.callbacks.model_checkpoint.ModelCheckpointMine
  67. fault_tolerant: True
  68. every_n_train_steps: 50000
  69. save_last: False
  70. save_top_k: -1 # Save all the checkpoints
  71. dirpath: ${..model_checkpoint.dirpath}
  72. filename: progress_step_{step}
  73. auto_insert_metric_name: False
  74. early_stopping: null