浏览代码

Release training code

Tri Dao 2 年之前
父节点
当前提交
0bf5e50038
共有 100 个文件被更改,包括 1495 次插入0 次删除
  1. 2 0
      csrc/layer_norm/README.md
  2. 107 0
      training/Dockerfile
  3. 133 0
      training/README.md
  4. 2 0
      training/configs/callbacks/causality-monitor.yaml
  5. 45 0
      training/configs/callbacks/default.yaml
  6. 4 0
      training/configs/callbacks/ema.yaml
  7. 5 0
      training/configs/callbacks/flop-count.yaml
  8. 11 0
      training/configs/callbacks/gpu-monitor.yaml
  9. 2 0
      training/configs/callbacks/model-summary.yaml
  10. 0 0
      training/configs/callbacks/none.yaml
  11. 2 0
      training/configs/callbacks/norm-monitor.yaml
  12. 5 0
      training/configs/callbacks/params-log.yaml
  13. 26 0
      training/configs/callbacks/wandb.yaml
  14. 50 0
      training/configs/config.yaml
  15. 15 0
      training/configs/datamodule/openwebtext.yaml
  16. 14 0
      training/configs/datamodule/thepile.yaml
  17. 82 0
      training/configs/experiment/owt/base.yaml
  18. 41 0
      training/configs/experiment/owt/gpt2l-flash.yaml
  19. 14 0
      training/configs/experiment/owt/gpt2l-hf.yaml
  20. 14 0
      training/configs/experiment/owt/gpt2l.yaml
  21. 17 0
      training/configs/experiment/owt/gpt2m-flash.yaml
  22. 11 0
      training/configs/experiment/owt/gpt2m-hf.yaml
  23. 11 0
      training/configs/experiment/owt/gpt2m.yaml
  24. 18 0
      training/configs/experiment/owt/gpt2s-flash.yaml
  25. 23 0
      training/configs/experiment/owt/gpt2s-hf.yaml
  26. 8 0
      training/configs/experiment/owt/gpt2s.yaml
  27. 21 0
      training/configs/experiment/owt/gpt2xl-flash.yaml
  28. 14 0
      training/configs/experiment/owt/gpt2xl.yaml
  29. 83 0
      training/configs/experiment/pile/base.yaml
  30. 18 0
      training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml
  31. 18 0
      training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml
  32. 18 0
      training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml
  33. 18 0
      training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml
  34. 18 0
      training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml
  35. 10 0
      training/configs/experiment/pile/gpt3l-flash-8k.yaml
  36. 10 0
      training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml
  37. 8 0
      training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml
  38. 8 0
      training/configs/experiment/pile/gpt3l-flash-rotary.yaml
  39. 24 0
      training/configs/experiment/pile/gpt3l-flash.yaml
  40. 10 0
      training/configs/experiment/pile/gpt3m-flash-8k.yaml
  41. 10 0
      training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml
  42. 8 0
      training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml
  43. 8 0
      training/configs/experiment/pile/gpt3m-flash-rotary.yaml
  44. 16 0
      training/configs/experiment/pile/gpt3m-flash.yaml
  45. 10 0
      training/configs/experiment/pile/gpt3s-flash-8k.yaml
  46. 10 0
      training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml
  47. 8 0
      training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml
  48. 8 0
      training/configs/experiment/pile/gpt3s-flash-rotary.yaml
  49. 17 0
      training/configs/experiment/pile/gpt3s-flash.yaml
  50. 10 0
      training/configs/experiment/pile/gpt3xl-flash-8k.yaml
  51. 10 0
      training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml
  52. 8 0
      training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml
  53. 8 0
      training/configs/experiment/pile/gpt3xl-flash-rotary.yaml
  54. 35 0
      training/configs/experiment/pile/gpt3xl-flash.yaml
  55. 7 0
      training/configs/logger/comet.yaml
  56. 8 0
      training/configs/logger/csv.yaml
  57. 9 0
      training/configs/logger/many_loggers.yaml
  58. 10 0
      training/configs/logger/mlflow.yaml
  59. 11 0
      training/configs/logger/neptune.yaml
  60. 10 0
      training/configs/logger/tensorboard.yaml
  61. 15 0
      training/configs/logger/wandb.yaml
  62. 3 0
      training/configs/metrics/acc.yaml
  63. 4 0
      training/configs/metrics/acc_ignore_index.yaml
  64. 4 0
      training/configs/metrics/acctop5.yaml
  65. 3 0
      training/configs/metrics/mse.yaml
  66. 3 0
      training/configs/metrics/num-tokens.yaml
  67. 3 0
      training/configs/metrics/perplexity.yaml
  68. 27 0
      training/configs/mode/debug.yaml
  69. 13 0
      training/configs/mode/default.yaml
  70. 17 0
      training/configs/mode/exp.yaml
  71. 31 0
      training/configs/mode/profile.yaml
  72. 22 0
      training/configs/mode/smoke.yaml
  73. 13 0
      training/configs/model/gpt2-hf.yaml
  74. 13 0
      training/configs/model/gpt2.yaml
  75. 6 0
      training/configs/model/gpt2model/gpt2-large.yaml
  76. 6 0
      training/configs/model/gpt2model/gpt2-medium.yaml
  77. 6 0
      training/configs/model/gpt2model/gpt2-small.yaml
  78. 6 0
      training/configs/model/gpt2model/gpt2-xlarge.yaml
  79. 2 0
      training/configs/optimizer/adam.yaml
  80. 3 0
      training/configs/optimizer/adamw-apex-distributed.yaml
  81. 7 0
      training/configs/optimizer/adamw-apex-zero.yaml
  82. 3 0
      training/configs/optimizer/adamw-apex.yaml
  83. 7 0
      training/configs/optimizer/adamw-zero.yaml
  84. 2 0
      training/configs/optimizer/adamw.yaml
  85. 2 0
      training/configs/optimizer/fusedlamb-ds.yaml
  86. 2 0
      training/configs/optimizer/fusedlamb.yaml
  87. 2 0
      training/configs/optimizer/sgd.yaml
  88. 2 0
      training/configs/scheduler/cosine-warmup-timm.yaml
  89. 2 0
      training/configs/scheduler/cosine-warmup.yaml
  90. 3 0
      training/configs/scheduler/invsqrt.yaml
  91. 2 0
      training/configs/scheduler/linear-warmup.yaml
  92. 2 0
      training/configs/scheduler/multi-step.yaml
  93. 9 0
      training/configs/scheduler/plateau.yaml
  94. 2 0
      training/configs/scheduler/poly-warmup.yaml
  95. 3 0
      training/configs/scheduler/step.yaml
  96. 1 0
      training/configs/task/sequence-model.yaml
  97. 49 0
      training/configs/trainer/all_params.yaml
  98. 6 0
      training/configs/trainer/ddp.yaml
  99. 21 0
      training/configs/trainer/debug.yaml
  100. 7 0
      training/configs/trainer/default.yaml

+ 2 - 0
csrc/layer_norm/README.md

@@ -2,6 +2,8 @@ This CUDA extension implements fused dropout + residual + LayerNorm, based on
 Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
 We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
 
+This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`.
+
 It has only been tested on A100s.
 
 ```sh

+ 107 - 0
training/Dockerfile

@@ -0,0 +1,107 @@
+# Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile
+# ARG COMPAT=0
+ARG PERSONAL=0
+# FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0
+FROM nvcr.io/nvidia/pytorch:22.11-py3 as base
+
+ENV HOST docker
+ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
+# https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes
+ENV TZ America/Los_Angeles
+RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
+
+# git for installing dependencies
+# tzdata to set time zone
+# wget and unzip to download data
+# [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment.
+# [2021-12-07] TD: openmpi-bin for MPI (multi-node training)
+RUN apt-get update && apt-get install -y --no-install-recommends \
+    build-essential \
+    cmake \
+    curl \
+    ca-certificates \
+    sudo \
+    less \
+    htop \
+    git \
+    tzdata \
+    wget \
+    tmux \
+    zip \
+    unzip \
+    zsh stow subversion fasd \
+    && rm -rf /var/lib/apt/lists/*
+    # openmpi-bin \
+
+# Allow running runmpi as root
+# ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
+
+# # Create a non-root user and switch to it
+# RUN adduser --disabled-password --gecos '' --shell /bin/bash user \
+#     && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user
+# USER user
+
+# All users can use /home/user as their home directory
+ENV HOME=/home/user
+RUN mkdir -p /home/user && chmod 777 /home/user
+WORKDIR /home/user
+
+# Set up personal environment
+# FROM base-${COMPAT} as env-0
+FROM base as env-0
+FROM env-0 as env-1
+# Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image
+# https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile
+ONBUILD COPY dotfiles ./dotfiles
+ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami)
+# nvcr pytorch image sets SHELL=/bin/bash
+ONBUILD ENV SHELL=/bin/zsh
+
+FROM env-${PERSONAL} as packages
+
+# Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for
+ENV PIP_NO_CACHE_DIR=1
+
+# # apex and pytorch-fast-transformers take a while to compile so we install them first
+# TD [2022-04-28] apex is already installed. In case we need a newer commit:
+# RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
+# TD [2021-10-28] pytorch-fast-transformers doesn't have a wheel compatible with CUDA 11.3 and Pytorch 1.10
+# So we install from source, and change compiler flag -arch=compute_60 -> -arch=compute_70 for V100
+# RUN pip install pytorch-fast-transformers==0.4.0
+# RUN pip install git+git://github.com/idiap/fast-transformers.git@v0.4.0  # doesn't work on V100
+RUN git clone https://github.com/idiap/fast-transformers \
+    && sed -i 's/\["-arch=compute_60"\]/\["-arch=compute_70"\]/' fast-transformers/setup.py \
+    && pip install fast-transformers/ \
+    && rm -rf fast-transformers
+
+# xgboost conflicts with deepspeed
+RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.5
+
+# General packages that we don't care about the version
+# zstandard to extract the_pile dataset
+# psutil to get the number of cpu physical cores
+# twine to upload package to PyPI
+# ninja is broken for some reason, it returns error code 245
+RUN pip uninstall -y ninja && pip install ninja
+RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine \
+    && python -m spacy download en_core_web_sm
+# hydra
+RUN pip install hydra-core==1.2.0 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich
+# Core packages
+RUN pip install transformers==4.24.0 datasets==2.7.1 pytorch-lightning==1.7.7 triton==2.0.0.dev20221120 wandb==0.13.5 timm==0.6.12 torchmetrics==0.10.3
+
+# For MLPerf
+RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
+
+# Install FlashAttention
+RUN pip install flash-attn==0.2.2
+
+# Install CUDA extensions for cross-entropy, fused dense, layer norm
+RUN git clone https://github.com/HazyResearch/flash-attention \
+    && cd flash-attention && git checkout v0.2.2 \
+    && cd csrc/fused_softmax && pip install . && cd ../../ \
+    && cd csrc/rotary && pip install . && cd ../../ \
+    && cd csrc/xentropy && pip install . && cd ../../ \
+    && cd csrc/layer_norm && pip install . && cd ../../ \
+    && cd csrc/fused_dense_lib && pip install . && cd ../../ \
+    && cd .. && rm -rf flash-attention

+ 133 - 0
training/README.md

@@ -0,0 +1,133 @@
+Examples of how FlashAttention can be integrated into a model (e.g., GPT, ViT)
+and trained end-to-end.
+We also added optimized implementations of other layers (e.g., MLP, LayerNorm,
+cross-entropy loss, rotary embedding).
+
+Goals:
+- Performance: we optimize for model speed and memory, especially on 1-node
+  (e.g., with 8 A100s).
+- Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm),
+  and the model code illustrates how these components can be put together.
+  The training code also aims to be model- & task-agnostic.
+
+Non-goals (and other resources):
+- Support as many models as possible: Huggingface's
+  [transformers](https://github.com/huggingface/transformers) and
+  [timm](https://github.com/rwightman/pytorch-image-models/) are great for this.
+- Large-scale distributed training: our codebase has been used for multi-GPU and multi-node
+  training for models up to 2.7B parameters. However, if you're looking for large-scale distributed
+  training techniques (e.g., pipeline parallelism, tensor parallelism),
+  check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and
+  [DeepSpeed](https://github.com/microsoft/deepspeed).
+- Inference: we currently focus on training (this might change in the future).
+  If you want fast inference, take a look at
+  [FasterTransformer](https://github.com/NVIDIA/FasterTransformer).
+- Production: this codebase was written during several research projects to validate ideas
+  on speeding up ML models.
+
+## Model Components
+
+The GPT model is implemented
+[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
+
+We provide the following optimized components:
+
+- FlashAttention: fast and memory-efficient exact attention. This makes
+attention much faster and saves a lot of activation memory. As a result we don't need
+to use any activation checkpointing.
+```sh
+pip install flash-attn
+```
+
+- Fused matmul + bias (forward and backward), and fused matmul + bias + gelu
+(forward and backward), adapted from Apex's
+[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We
+make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before
+this doesn't have the best matmul + bias + gelu performance for bfloat16.
+```sh
+cd ../csrc/fused_dense_lib && pip install .
+```
+- Optimized cross-entropy loss, adapted from Apex's
+[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory.
+```sh
+cd ../csrc/xentropy && pip install .
+```
+- Fused rotary embedding:
+```sh
+cd ../csrc/rotary && pip install .
+```
+- Fused dropout + residual + LayerNorm, adapted from Apex's
+[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
+This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`.
+```sh
+cd ../csrc/layer_norm && pip install .
+```
+
+## Training
+
+Feel free to use the model in your training setup. We also provide here training
+scripts to train GPT2 on Openwebtext and GPT3 on The Pile as examples.
+
+We use [Hydra](https://hydra.cc/) for configuration,
+[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and
+[Wandb](https://wandb.ai/) for logging.
+
+We use the template from `https://github.com/ashleve/lightning-hydra-template`.
+Please read the instructions there to understand the repo structure.
+
+### Dataset preparation
+
+Running the training command would automatically download the datasets
+(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the
+tokens, then save this cache to disk. Alternatively, you can also prepare the
+datasets as a separate steps.
+
+The cached datasets are saved to `${DATA_DIR}/openwebtext` and
+`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to
+`./data/{openwebtext,the_pile}`. 
+
+- Openwebtext:
+```sh
+export PYTHONPATH=$PWD:$PYTHONPATH
+pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "openwebtext"
+```
+This takes around 1h on a 64-core CPU. The processed dataset has size 17GB.
+
+- The Pile:
+```sh
+export PYTHONPATH=$PWD:$PYTHONPATH
+pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile"
+```
+This takes around 20h on a 96-core CPU. The processed dataset has size 699GB.
+
+### GPT2 training on Openwebtext
+To train GPT2 on Openwebtext with 8 GPUs:
+```sh
+python run.py experiment=owt/gpt2s-flash trainer.devices=8
+python run.py experiment=owt/gpt2m-flash trainer.devices=8
+python run.py experiment=owt/gpt2l-flash trainer.devices=8
+python run.py experiment=owt/gpt2xl-flash trainer.devices=8
+```
+The default parameters are set for 8 x A100 80GB.
+
+To train with bf16 instead of fp16, add `trainer.precision=bf16`.
+To adjust device batch size to fit GPU memory (the global batch size stays the
+same, and gradient accumulation is calculated automatically), set `datamodule.batch_size=blah`.
+
+### GPT3 training on The Pile
+To train GPT3 on The Pile with 8 GPUs:
+```sh
+python run.py experiment=pile/gpt3s-flash trainer.devices=8
+python run.py experiment=pile/gpt3m-flash trainer.devices=8
+python run.py experiment=pile/gpt3l-flash trainer.devices=8
+python run.py experiment=pile/gpt3xl-flash trainer.devices=8
+```
+The default parameters are set for 8 x A100 80GB.
+
+## Requirements
+
+Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,
+hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
+We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
+
+We provide a Dockerfile that lists all the required packages.

+ 2 - 0
training/configs/callbacks/causality-monitor.yaml

@@ -0,0 +1,2 @@
+causality-monitor:
+  _target_: src.callbacks.causality_monitor.CausalityMonitor

+ 45 - 0
training/configs/callbacks/default.yaml

@@ -0,0 +1,45 @@
+# rich_progress_bar:
+#   _target_: pytorch_lightning.callbacks.RichProgressBar
+
+rich_model_summary:
+  _target_: pytorch_lightning.callbacks.RichModelSummary
+
+model_checkpoint:
+  _target_: pytorch_lightning.callbacks.ModelCheckpoint
+  monitor: "val/acc" # name of the logged metric which determines when model is improving
+  mode: "max" # can be "max" or "min"
+  save_top_k: 1 # save k best models (determined by above metric)
+  save_last: True # additionaly always save model from last epoch
+  verbose: False
+  dirpath: ${oc.env:CHECKPOINT_DIR,checkpoints}/${oc.select:name,''}
+  filename: "epoch_{epoch:03d}"
+  auto_insert_metric_name: False
+
+early_stopping:
+  _target_: pytorch_lightning.callbacks.EarlyStopping
+  monitor: "val/acc" # name of the logged metric which determines when model is improving
+  mode: "max" # can be "max" or "min"
+  patience: 100 # how many epochs of not improving until training stops
+  min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
+
+learning_rate_monitor:
+  _target_: pytorch_lightning.callbacks.LearningRateMonitor
+  logging_interval: step
+
+speed_monitor:
+  _target_: src.callbacks.speed_monitor.SpeedMonitor
+  intra_step_time: True
+  inter_step_time: True
+  epoch_time: True
+
+loss_scale_monitor:
+  _target_: src.callbacks.loss_scale_monitor.LossScaleMonitor
+
+params_log:
+  _target_: src.callbacks.params_log.ParamsLog
+  total_params_log: True
+  trainable_params_log: True
+  non_trainable_params_log: True
+
+gpu_affinity:
+  _target_: src.callbacks.gpu_affinity.GpuAffinity

+ 4 - 0
training/configs/callbacks/ema.yaml

@@ -0,0 +1,4 @@
+ema:
+  _target_: src.callbacks.ema.EMACallback
+  decay: ???
+  use_num_updates: False

+ 5 - 0
training/configs/callbacks/flop-count.yaml

@@ -0,0 +1,5 @@
+flop_count:
+  _target_: src.callbacks.flop_count.FlopCount
+  profilers: ['fvcore']
+  input_size: [3, 224, 224]
+  device: null

+ 11 - 0
training/configs/callbacks/gpu-monitor.yaml

@@ -0,0 +1,11 @@
+defaults:
+  - default.yaml
+
+gpu_stats_monitor:
+  _target_: pytorch_lightning.callbacks.GPUStatsMonitor
+  # [2021-08-13] TD: I just want the intra_step_size but it'll error if I
+  # don't have memory_utilization and gpu_utilization.
+  # Maybe I should write a callback with just the intra_step_size.
+  memory_utilization: True
+  gpu_utilization: True
+  intra_step_time: True

+ 2 - 0
training/configs/callbacks/model-summary.yaml

@@ -0,0 +1,2 @@
+model_summary:
+  _target_: pytorch_lightning.callbacks.RichModelSummary

+ 0 - 0
training/configs/callbacks/none.yaml


+ 2 - 0
training/configs/callbacks/norm-monitor.yaml

@@ -0,0 +1,2 @@
+norm_monitor:
+  _target_: src.callbacks.norm_monitor.NormMonitor

+ 5 - 0
training/configs/callbacks/params-log.yaml

@@ -0,0 +1,5 @@
+params_log:
+  _target_: src.callbacks.params_log.ParamsLog
+  total_params_log: True
+  trainable_params_log: True
+  non_trainable_params_log: True

+ 26 - 0
training/configs/callbacks/wandb.yaml

@@ -0,0 +1,26 @@
+defaults:
+  - default.yaml
+
+watch_model:
+  _target_: src.callbacks.wandb_callbacks.WatchModel
+  log: "all"
+  log_freq: 100
+
+upload_code_as_artifact:
+  _target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact
+  code_dir: ${work_dir}/src
+
+upload_ckpts_as_artifact:
+  _target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
+  ckpt_dir: "checkpoints/"
+  upload_best_only: True
+
+log_f1_precision_recall_heatmap:
+  _target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap
+
+log_confusion_matrix:
+  _target_: src.callbacks.wandb_callbacks.LogConfusionMatrix
+
+log_image_predictions:
+  _target_: src.callbacks.wandb_callbacks.LogImagePredictions
+  num_samples: 8

+ 50 - 0
training/configs/config.yaml

@@ -0,0 +1,50 @@
+# @package _global_
+
+# specify here default training configuration
+defaults:
+  - _self_
+  - trainer: default
+  - optimizer: adamw
+  - scheduler: null
+  - task: sequence-model
+  - model: null
+  - datamodule: null
+  - callbacks: default # set this to null if you don't want to use callbacks
+  - metrics: null
+  - logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)
+
+  - mode: default
+
+  - experiment: null
+  - hparams_search: null
+
+  # enable color logging
+  - override hydra/hydra_logging: colorlog
+  - override hydra/job_logging: colorlog
+
+# path to original working directory
+# hydra hijacks working directory by changing it to the current log directory,
+# so it's useful to have this path as a special variable
+# https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
+work_dir: ${hydra:runtime.cwd}
+
+# path to folder with data
+data_dir: ${work_dir}/data/
+
+# pretty print config at the start of the run using Rich library
+print_config: True
+
+# disable python warnings if they annoy you
+ignore_warnings: True
+
+# check performance on test set, using the best model achieved during training
+# lightning chooses best model based on metric specified in checkpoint callback
+test_after_training: True
+
+resume: False
+
+# seed for random number generators in pytorch, numpy and python.random
+seed: null
+
+# name of the run, accessed by loggers
+name: null

+ 15 - 0
training/configs/datamodule/openwebtext.yaml

@@ -0,0 +1,15 @@
+_target_: src.datamodules.language_modeling_hf.LMDataModule
+dataset_name: openwebtext
+dataset_config_name: null
+tokenizer_name: gpt2
+cache_dir: ${oc.env:DATA_DIR,${data_dir}}/openwebtext/cache
+max_length: 1024
+val_ratio: 0.0005
+val_split_seed: 2357
+add_eos: True
+batch_size: 8  # per GPU
+batch_size_eval: ${eval:${.batch_size} * 2}
+num_workers: 32  # For preprocessing only
+shuffle: True
+pin_memory: True
+__train_len: ${div_up:9035582198, ${.max_length}}

+ 14 - 0
training/configs/datamodule/thepile.yaml

@@ -0,0 +1,14 @@
+_target_: src.datamodules.language_modeling_hf.LMDataModule
+dataset_name: the_pile
+dataset_config_name: null
+tokenizer_name: gpt2
+cache_dir: ${oc.env:DATA_DIR,${data_dir}}/the_pile/cache
+max_length: 2048
+add_eos: True
+batch_size: 4  # per GPU
+batch_size_eval: ${eval:${.batch_size} * 2}
+num_workers: 64  # For preprocessing only
+use_shmem: False
+shuffle: True
+pin_memory: True
+__train_len: ${div_up:374337375694, ${.max_length}}

+ 82 - 0
training/configs/experiment/owt/base.yaml

@@ -0,0 +1,82 @@
+# @package _global_
+defaults:
+  - override /trainer: default # choose trainer from 'configs/trainer/'
+  - override /model: null
+  - override /datamodule: openwebtext
+  # FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time
+  # per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms.
+  # For GPT2-medium time per global goes from 997ms to 972ms.
+  - override /optimizer: adamw-apex
+  - override /scheduler: linear-warmup
+  - override /callbacks: [default, norm-monitor]
+  - override /metrics: [perplexity, num-tokens]
+  - override /logger: wandb
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+task:
+  _target_: src.tasks.seq.SequenceLMModel
+
+seed: 1111
+
+trainer:
+  accelerator: gpu
+  devices: 8
+  num_nodes: 1
+  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
+  max_steps: 400000
+  val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
+  check_val_every_n_epoch: null  # We don't care about epoch boundary
+  precision: 16
+  gradient_clip_val: 1.0
+  strategy: null
+
+datamodule:
+  batch_size: 16  # Per GPU
+  batch_size_eval: ${.batch_size}  # Fused dense only support batch size at most 64k
+  max_length: 1024
+  fault_tolerant: True
+  ddp: ${eval:"${trainer.devices} > 1"}
+
+train:
+  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
+  global_batch_size: 512
+  optimizer:
+    lr: 6e-4
+    weight_decay: 0.1
+  optimizer_param_grouping:
+    bias_weight_decay: False
+    normalization_weight_decay: False
+  scheduler:
+    num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}}
+    num_training_steps: ${trainer.max_steps}
+  loss_fn:
+    # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
+    # It's also more numerically stable if we're using DeepSpeed 16 bits.
+    _target_: src.losses.cross_entropy_apex.CrossEntropyLossApex
+    inplace_backward: True  # to save memory
+
+eval:
+  log_on_step: True  # 1 training epoch takes too long, we want to see metrics per train step
+
+callbacks:
+  model_checkpoint:
+    monitor: val/loss
+    mode: min
+    save_top_k: 3
+    save_last: True
+    every_n_train_steps: 1000
+    dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
+    filename: step_{step}
+    auto_insert_metric_name: False
+  model_checkpoint_progress:
+    _target_: src.callbacks.model_checkpoint.ModelCheckpointMine
+    fault_tolerant: True
+    every_n_train_steps: 50000
+    save_last: False
+    save_top_k: -1  # Save all the checkpoints
+    dirpath: ${..model_checkpoint.dirpath}
+    filename: progress_step_{step}
+    auto_insert_metric_name: False
+  early_stopping: null

+ 41 - 0
training/configs/experiment/owt/gpt2l-flash.yaml

@@ -0,0 +1,41 @@
+# @package _global_
+defaults:
+  - /experiment/owt/gpt2m-flash.yaml
+  - override /model/gpt2model: gpt2-large
+  # TD [2022-08-03] Surprisingly it's faster to use the ZeRO optimizer than just AdamW.
+  # Still, fairscale is even faster and uses less memory.
+  # I think it's because Pytorch is using ZeRO stage 1 and fairscale is using ZeRO stage 2?
+  # However, fairscale has issues with saving checkpoint (either OOM or very
+  # slow since it goes through the CPU?). Fairscale says Pytorch ZeRO is the
+  # upstream version of OSS
+  # https://github.com/facebookresearch/fairscale/issues/937
+  # Pytorch ZeRO as also very slow for saving checkpoints due to
+  # consolidate_state_dict(), but I've fixed it to save separate checkpoint per GPU.
+  - override /optimizer: adamw-zero
+
+  # FusedAdam doesn't seem to speed things up here, time per global step
+  # (i.e. batch size 512) on 8 A100s is around 2056ms for both AdamW and FusedAdam.
+  # This could be because each GPU is only doing the optimizer step for 1 /
+  # world_size of the parameters.
+  # Maybe the bottleneck here is the NCCL call to exchange parameters (ZeRO).
+  # - override /optimizer: adamw-apex-zero
+
+# Can enable mlp_chekcpoint_lvl to fit batch_size 16 on A100 40GB
+# model:
+#   config:
+#     # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"}
+#     mlp_checkpoint_lvl: 1
+
+datamodule:
+  # batch_size: 16
+  batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"}
+
+trainer:
+  # strategy: null
+  # strategy: ${eval:"None if ${trainer.devices} == 1 else 'ddp_sharded'"}
+  strategy:
+    _target_: src.utils.ddp_zero1.DDPStrategyZero1
+    find_unused_parameters: False
+    gradient_as_bucket_view: True
+  # TD [2022-08-03] Deepspeed makes the ppl curve go wild
+  # strategy: deepspeed_stage_1

+ 14 - 0
training/configs/experiment/owt/gpt2l-hf.yaml

@@ -0,0 +1,14 @@
+# @package _global_
+defaults:
+  - /experiment/owt/gpt2m-hf.yaml
+  - override /model/gpt2model: gpt2-large
+  - override /optimizer: adamw-zero
+
+datamodule:
+  batch_size: 2
+
+trainer:
+  strategy:
+    _target_: src.utils.ddp_zero1.DDPStrategyZero1
+    find_unused_parameters: False
+    gradient_as_bucket_view: True

+ 14 - 0
training/configs/experiment/owt/gpt2l.yaml

@@ -0,0 +1,14 @@
+# @package _global_
+defaults:
+  - /experiment/owt/gpt2m.yaml
+  - override /model/gpt2model: gpt2-large
+  - override /optimizer: adamw-zero
+
+datamodule:
+  batch_size: 4  # Per GPU
+
+trainer:
+  strategy:
+    _target_: src.utils.ddp_zero1.DDPStrategyZero1
+    find_unused_parameters: False
+    gradient_as_bucket_view: True

+ 17 - 0
training/configs/experiment/owt/gpt2m-flash.yaml

@@ -0,0 +1,17 @@
+# @package _global_
+defaults:
+  - /experiment/owt/gpt2s-flash.yaml
+  - override /model/gpt2model: gpt2-medium
+
+# Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB
+model:
+  config:
+    mlp_checkpoint_lvl: 1
+
+datamodule:
+  # batch_size: 32
+  batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"}
+
+train:
+  optimizer:
+    lr: 1.5e-4

+ 11 - 0
training/configs/experiment/owt/gpt2m-hf.yaml

@@ -0,0 +1,11 @@
+# @package _global_
+defaults:
+  - /experiment/owt/gpt2s-hf.yaml
+  - override /model/gpt2model: gpt2-medium
+
+datamodule:
+  batch_size: 4
+
+train:
+  optimizer:
+    lr: 1.5e-4

+ 11 - 0
training/configs/experiment/owt/gpt2m.yaml

@@ -0,0 +1,11 @@
+# @package _global_
+defaults:
+  - /experiment/owt/gpt2s.yaml
+  - override /model/gpt2model: gpt2-medium
+
+datamodule:
+  batch_size: 8  # Per GPU
+
+train:
+  optimizer:
+    lr: 1.5e-4

+ 18 - 0
training/configs/experiment/owt/gpt2s-flash.yaml

@@ -0,0 +1,18 @@
+# @package _global_
+defaults:
+  - /experiment/owt/base.yaml
+  - override /model: gpt2
+  - override /model/gpt2model: gpt2-small
+
+model:
+  config:
+    # n_positions is already set to ${datamodule.max_length}
+    use_flash_attn: True
+    fused_bias_fc: True
+    fused_dense_gelu_dense: True
+    fused_dropout_add_ln: True
+    pad_vocab_size_multiple: 8
+
+datamodule:
+  # batch_size: 64
+  batch_size: ${eval:"16 if ${train.gpu_mem} < 24 else (32 if ${train.gpu_mem} < 40 else 64)"}

+ 23 - 0
training/configs/experiment/owt/gpt2s-hf.yaml

@@ -0,0 +1,23 @@
+# @package _global_
+defaults:
+  - /experiment/owt/base.yaml
+  - override /model: gpt2-hf
+  - override /model/gpt2model: gpt2-small
+  - override /callbacks: [default, norm-monitor, flop-count]
+
+datamodule:
+  batch_size: 8
+
+train:
+  # Use the standard torch.nn.CrossEntropyLoss
+  loss_fn: null
+
+callbacks:
+  flop_count:
+    input_size:
+      - ${datamodule.max_length}
+    input_dtype:
+      # It's surprisingly hard to get hydra to return torch.long since it's not a callable
+      _target_: torch.__getattribute__
+      _args_:
+        - long

+ 8 - 0
training/configs/experiment/owt/gpt2s.yaml

@@ -0,0 +1,8 @@
+# @package _global_
+defaults:
+  - /experiment/owt/base.yaml
+  - override /model: gpt2
+  - override /model/gpt2model: gpt2-small
+
+datamodule:
+  batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"}

+ 21 - 0
training/configs/experiment/owt/gpt2xl-flash.yaml

@@ -0,0 +1,21 @@
+# @package _global_
+defaults:
+  - /experiment/owt/gpt2l-flash.yaml
+  - override /model/gpt2model: gpt2-xlarge
+
+# Can enable mlp_checkpoint_lvl to fit to A100 40GB
+# model:
+#   config:
+#     # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"}
+#     mlp_checkpoint_lvl: 1
+
+datamodule:
+  batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"}
+  # With adamw-zero optimizer:
+  # checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1)
+  # checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1)
+  # checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1)
+  # With adamw-apex-distributed optimizer:
+  # checkpoint_lvl=1, batch size = 8: mem 41.5GB, 4500ms / batch of 512 (550ms * 7 + 650ms * 1)
+  # checkpoint_lvl=1 for 24 layers and checkpoint_lvl=2 for 24 layers,
+  # batch size = 8: mem 39GB, 4640ms / batch of 512 (565ms * 7 + 675ms * 1)

+ 14 - 0
training/configs/experiment/owt/gpt2xl.yaml

@@ -0,0 +1,14 @@
+# @package _global_
+defaults:
+  - /experiment/owt/gpt2m.yaml
+  - override /model/gpt2model: gpt2-xlarge
+  - override /optimizer: adamw-zero
+
+datamodule:
+  batch_size: 2  # Per GPU
+
+trainer:
+  strategy:
+    _target_: src.utils.ddp_zero1.DDPStrategyZero1
+    find_unused_parameters: False
+    gradient_as_bucket_view: True

+ 83 - 0
training/configs/experiment/pile/base.yaml

@@ -0,0 +1,83 @@
+# @package _global_
+defaults:
+  - override /trainer: default # choose trainer from 'configs/trainer/'
+  - override /model: null
+  - override /datamodule: thepile
+  - override /optimizer: adamw-apex  # slight speedup (1-2%) over Pytorch AdamW
+  - override /scheduler: cosine-warmup-timm
+  - override /callbacks: [default, norm-monitor]
+  - override /metrics: [perplexity, num-tokens]
+  - override /logger: wandb
+
+# all parameters below will be merged with parameters from default configurations set above
+# this allows you to overwrite only specified parameters
+
+task:
+  _target_: src.tasks.seq.SequenceLMModel
+
+seed: 1111
+
+trainer:
+  accelerator: gpu
+  devices: 8
+  num_nodes: 1
+  accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
+  max_steps: 800000
+  val_check_interval: ${eval:2000 * ${.accumulate_grad_batches}}
+  check_val_every_n_epoch: null  # We don't care about epoch boundary
+  precision: bf16
+  gradient_clip_val: 1.0
+  strategy: null
+
+datamodule:
+  batch_size: 16  # Per GPU
+  batch_size_eval: ${.batch_size}  # Fused dense only support batch size at most 64k
+  max_length: 2048
+  fault_tolerant: True
+  ddp: ${eval:"${trainer.devices} > 1"}
+
+train:
+  gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
+  global_batch_size: 256
+  optimizer:
+    lr: 6e-4
+    weight_decay: 0.1
+  optimizer_param_grouping:
+    bias_weight_decay: False
+    normalization_weight_decay: False
+  scheduler:
+    t_in_epochs: False
+    t_initial: 600000
+    warmup_lr_init: 1e-6
+    warmup_t: ${eval:0.01 * ${trainer.max_steps}}
+    lr_min: ${eval:0.1 * ${train.optimizer.lr}}
+  loss_fn:
+    # This is faster and uses less memory than torch.nn.CrossEntropyLoss.
+    # It's also more numerically stable if we're using DeepSpeed 16 bits.
+    _target_: src.losses.cross_entropy_apex.CrossEntropyLossApex
+    inplace_backward: True  # to save memory
+
+eval:
+  log_on_step: True  # 1 training epoch takes too long, we want to see metrics per train step
+
+callbacks:
+  model_checkpoint:
+    monitor: val/loss
+    mode: min
+    save_top_k: 3
+    save_last: True
+    every_n_train_steps: 1000
+    dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
+    filename: step_{step}
+    auto_insert_metric_name: False
+  model_checkpoint_progress:
+    _target_: src.callbacks.model_checkpoint.ModelCheckpointMine
+    # fault_tolerant: True  # The .pl_auto_save.ckpt doesn't get saved by all workers
+    every_n_train_steps: 50000
+    save_last: False
+    save_top_k: -1  # Save all the checkpoints
+    dirpath: ${..model_checkpoint.dirpath}
+    filename: progress_step_{step}
+    auto_insert_metric_name: False
+  early_stopping: null
+

+ 18 - 0
training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml

@@ -0,0 +1,18 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt2xl-flash-8k.yaml
+
+model:
+  config:
+    n_embd: 2560
+    n_head: 32
+    n_layer: 32
+    initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
+    mlp_checkpoint_lvl: 0
+
+datamodule:
+  batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
+
+train:
+  optimizer:
+    lr: 1.6e-4

+ 18 - 0
training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml

@@ -0,0 +1,18 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt2xl-flash-rotary-8k.yaml
+
+model:
+  config:
+    n_embd: 2560
+    n_head: 20
+    n_layer: 32
+    initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
+    mlp_checkpoint_lvl: 0
+
+datamodule:
+  batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"}
+
+train:
+  optimizer:
+    lr: 1.6e-4

+ 18 - 0
training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml

@@ -0,0 +1,18 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt2xl-flash-rotary.yaml
+
+model:
+  config:
+    n_embd: 2560
+    n_head: 20
+    n_layer: 32
+    initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
+    mlp_checkpoint_lvl: 0
+
+datamodule:
+  batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu} < 80 else 32))"}
+
+train:
+  optimizer:
+    lr: 1.6e-4

+ 18 - 0
training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml

@@ -0,0 +1,18 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt2xl-flash-rotary-8k.yaml
+
+model:
+  config:
+    n_embd: 2560
+    n_head: 32
+    n_layer: 32
+    initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
+    mlp_checkpoint_lvl: 0
+
+datamodule:
+  batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"}
+
+train:
+  optimizer:
+    lr: 1.6e-4

+ 18 - 0
training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml

@@ -0,0 +1,18 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt2xl-flash-rotary.yaml
+
+model:
+  config:
+    n_embd: 2560
+    n_head: 32
+    n_layer: 32
+    initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
+    mlp_checkpoint_lvl: 0
+
+datamodule:
+  batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu} < 80 else 32))"}
+
+train:
+  optimizer:
+    lr: 1.6e-4

+ 10 - 0
training/configs/experiment/pile/gpt3l-flash-8k.yaml

@@ -0,0 +1,10 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3l-flash.yaml
+
+datamodule:
+  max_length: 8192
+  batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
+
+train:
+  global_batch_size: 64

+ 10 - 0
training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml

@@ -0,0 +1,10 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3l-flash-rotary.yaml
+
+trainer:
+  max_steps: 60000
+
+train:
+  scheduler:
+    t_initial: ${trainer.max_steps}

+ 8 - 0
training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml

@@ -0,0 +1,8 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3l-flash-8k.yaml
+
+model:
+  config:
+    max_position_embeddings: 0  # Disable absolute position embedding
+    rotary_emb_fraction: 0.5

+ 8 - 0
training/configs/experiment/pile/gpt3l-flash-rotary.yaml

@@ -0,0 +1,8 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3l-flash.yaml
+
+model:
+  config:
+    max_position_embeddings: 0  # Disable absolute position embedding
+    rotary_emb_fraction: 0.5

+ 24 - 0
training/configs/experiment/pile/gpt3l-flash.yaml

@@ -0,0 +1,24 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3s-flash.yaml
+  - override /optimizer: adamw-zero
+
+model:
+  config:
+    n_embd: 1536
+    n_head: 16
+    n_layer: 24
+    # mlp_checkpoint_lvl: 1  # To fit batch_size 8
+
+datamodule:
+  batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"}
+
+train:
+  optimizer:
+    lr: 2.5e-4
+
+trainer:
+  strategy:
+    _target_: src.utils.ddp_zero1.DDPStrategyZero1
+    find_unused_parameters: False
+    gradient_as_bucket_view: True

+ 10 - 0
training/configs/experiment/pile/gpt3m-flash-8k.yaml

@@ -0,0 +1,10 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3m-flash.yaml
+
+datamodule:
+  max_length: 8192
+  batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"}
+
+train:
+  global_batch_size: 64

+ 10 - 0
training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml

@@ -0,0 +1,10 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3m-flash-rotary.yaml
+
+trainer:
+  max_steps: 60000
+
+train:
+  scheduler:
+    t_initial: ${trainer.max_steps}

+ 8 - 0
training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml

@@ -0,0 +1,8 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3m-flash-8k.yaml
+
+model:
+  config:
+    max_position_embeddings: 0  # Disable absolute position embedding
+    rotary_emb_fraction: 0.5

+ 8 - 0
training/configs/experiment/pile/gpt3m-flash-rotary.yaml

@@ -0,0 +1,8 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3m-flash.yaml
+
+model:
+  config:
+    max_position_embeddings: 0  # Disable absolute position embedding
+    rotary_emb_fraction: 0.5

+ 16 - 0
training/configs/experiment/pile/gpt3m-flash.yaml

@@ -0,0 +1,16 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3s-flash.yaml
+  - override /model/gpt2model: gpt2-medium
+
+# Can enable mlp_checkpoint_lvl to fit batch_size 16 to A100 40GB
+# model:
+#   config:
+#     mlp_checkpoint_lvl: 1
+
+datamodule:
+  batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"}
+
+train:
+  optimizer:
+    lr: 3.0e-4

+ 10 - 0
training/configs/experiment/pile/gpt3s-flash-8k.yaml

@@ -0,0 +1,10 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3s-flash.yaml
+
+datamodule:
+  max_length: 8192
+  batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"}
+
+train:
+  global_batch_size: 64

+ 10 - 0
training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml

@@ -0,0 +1,10 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3s-flash-rotary.yaml
+
+trainer:
+  max_steps: 60000
+
+train:
+  scheduler:
+    t_initial: ${trainer.max_steps}

+ 8 - 0
training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml

@@ -0,0 +1,8 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3s-flash-8k.yaml
+
+model:
+  config:
+    max_position_embeddings: 0  # Disable absolute position embedding
+    rotary_emb_fraction: 0.5

+ 8 - 0
training/configs/experiment/pile/gpt3s-flash-rotary.yaml

@@ -0,0 +1,8 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3s-flash.yaml
+
+model:
+  config:
+    max_position_embeddings: 0  # Disable absolute position embedding
+    rotary_emb_fraction: 0.5

+ 17 - 0
training/configs/experiment/pile/gpt3s-flash.yaml

@@ -0,0 +1,17 @@
+# @package _global_
+defaults:
+  - /experiment/pile/base.yaml
+  - override /model: gpt2
+  - override /model/gpt2model: gpt2-small
+
+model:
+  config:
+    # n_positions is already set to ${datamodule.max_length}
+    use_flash_attn: True
+    fused_dropout_add_ln: True
+    fused_dense_gelu_dense: True
+    fused_bias_fc: True
+    pad_vocab_size_multiple: 8
+
+datamodule:
+  batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"}

+ 10 - 0
training/configs/experiment/pile/gpt3xl-flash-8k.yaml

@@ -0,0 +1,10 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt2xl-flash.yaml
+
+datamodule:
+  max_length: 8192
+  batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
+
+train:
+  global_batch_size: 128

+ 10 - 0
training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml

@@ -0,0 +1,10 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt2xl-flash-rotary.yaml
+
+trainer:
+  max_steps: 60000
+
+train:
+  scheduler:
+    t_initial: ${trainer.max_steps}

+ 8 - 0
training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml

@@ -0,0 +1,8 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt2xl-flash-8k.yaml
+
+model:
+  config:
+    max_position_embeddings: 0  # Disable absolute position embedding
+    rotary_emb_fraction: 0.5

+ 8 - 0
training/configs/experiment/pile/gpt3xl-flash-rotary.yaml

@@ -0,0 +1,8 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt2xl-flash.yaml
+
+model:
+  config:
+    max_position_embeddings: 0  # Disable absolute position embedding
+    rotary_emb_fraction: 0.5

+ 35 - 0
training/configs/experiment/pile/gpt3xl-flash.yaml

@@ -0,0 +1,35 @@
+# @package _global_
+defaults:
+  - /experiment/pile/gpt3s-flash.yaml
+  - override /optimizer: adamw-zero
+
+model:
+  config:
+    n_embd: 2048
+    n_head: 16
+    n_layer: 24
+
+datamodule:
+  batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"}
+
+train:
+  global_batch_size: 512
+  optimizer:
+    lr: 2.0e-4
+  scheduler:
+    t_initial: 300000
+
+trainer:
+  strategy:
+    _target_: src.utils.ddp_zero1.DDPStrategyZero1
+    find_unused_parameters: False
+    gradient_as_bucket_view: True
+  max_steps: 400000
+  val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
+
+callbacks:
+  model_checkpoint:
+    every_n_train_steps: 1000
+  model_checkpoint_progress:
+    every_n_train_steps: 12500
+    fault_tolerant: False  # Saving takes too long

+ 7 - 0
training/configs/logger/comet.yaml

@@ -0,0 +1,7 @@
+# https://www.comet.ml
+
+comet:
+  _target_: pytorch_lightning.loggers.comet.CometLogger
+  api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
+  project_name: "template-tests"
+  experiment_name: ${name}

+ 8 - 0
training/configs/logger/csv.yaml

@@ -0,0 +1,8 @@
+# csv logger built in lightning
+
+csv:
+  _target_: pytorch_lightning.loggers.csv_logs.CSVLogger
+  save_dir: "."
+  name: "csv/"
+  version: ${name}
+  prefix: ""

+ 9 - 0
training/configs/logger/many_loggers.yaml

@@ -0,0 +1,9 @@
+# train with many loggers at once
+
+defaults:
+  # - comet.yaml
+  - csv.yaml
+  # - mlflow.yaml
+  # - neptune.yaml
+  # - tensorboard.yaml
+  - wandb.yaml

+ 10 - 0
training/configs/logger/mlflow.yaml

@@ -0,0 +1,10 @@
+# https://mlflow.org
+
+mlflow:
+  _target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
+  experiment_name: ${name}
+  tracking_uri: null
+  tags: null
+  save_dir: ./mlruns
+  prefix: ""
+  artifact_location: null

+ 11 - 0
training/configs/logger/neptune.yaml

@@ -0,0 +1,11 @@
+# https://neptune.ai
+
+neptune:
+  _target_: pytorch_lightning.loggers.neptune.NeptuneLogger
+  api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
+  project_name: your_name/template-tests
+  close_after_fit: True
+  offline_mode: False
+  experiment_name: ${name}
+  experiment_id: null
+  prefix: ""

+ 10 - 0
training/configs/logger/tensorboard.yaml

@@ -0,0 +1,10 @@
+# https://www.tensorflow.org/tensorboard/
+
+tensorboard:
+  _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
+  save_dir: "tensorboard/"
+  name: "default"
+  version: ${name}
+  log_graph: False
+  default_hp_metric: True
+  prefix: ""

+ 15 - 0
training/configs/logger/wandb.yaml

@@ -0,0 +1,15 @@
+# https://wandb.ai
+
+wandb:
+  _target_: pytorch_lightning.loggers.wandb.WandbLogger
+  project: attention
+  name: ${name}
+  save_dir: "."
+  mode: online # set offline to store all logs only locally
+  id: ${oc.select:name} # pass correct id to resume experiment!
+  # entity: ""  # set to name of your wandb team or just remove it
+  log_model: False
+  prefix: ""
+  job_type: "train"
+  group: ""
+  tags: []

+ 3 - 0
training/configs/metrics/acc.yaml

@@ -0,0 +1,3 @@
+# @package eval.metrics
+acc:
+  _target_: src.metrics.accuracy.AccuracyMine

+ 4 - 0
training/configs/metrics/acc_ignore_index.yaml

@@ -0,0 +1,4 @@
+# @package eval.metrics
+acc:
+  _target_: torchmetrics.Accuracy
+  ignore_index: -100

+ 4 - 0
training/configs/metrics/acctop5.yaml

@@ -0,0 +1,4 @@
+# @package eval.metrics
+acctop5:
+  _target_: src.metrics.accuracy.AccuracyMine
+  top_k: 5

+ 3 - 0
training/configs/metrics/mse.yaml

@@ -0,0 +1,3 @@
+# @package eval.metrics
+mse:
+  _target_: torchmetrics.MeanSquaredError

+ 3 - 0
training/configs/metrics/num-tokens.yaml

@@ -0,0 +1,3 @@
+# @package eval.metrics
+num-tokens:
+  _target_: src.metrics.num_tokens.NumTokens

+ 3 - 0
training/configs/metrics/perplexity.yaml

@@ -0,0 +1,3 @@
+# @package eval.metrics
+ppl:
+  _target_: src.metrics.perplexity.Perplexity

+ 27 - 0
training/configs/mode/debug.yaml

@@ -0,0 +1,27 @@
+# @package _global_
+
+# run in debug mode with:
+# `python run.py mode=debug`
+
+defaults:
+  - override /trainer: debug.yaml
+
+debug_mode: True
+
+hydra:
+  # sets level of all command line loggers to 'DEBUG'
+  verbose: True
+
+  # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
+  # sets level of only chosen command line loggers to 'DEBUG'
+  # verbose: [src.train, src.utils.utils]
+
+  # sets output paths for all file logs to 'logs/debug/'
+  run:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S}
+  sweep:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S}
+    subdir: ${hydra.job.num}
+
+# disable rich config printing, since it will be already printed by hydra when `verbose: True`
+print_config: False

+ 13 - 0
training/configs/mode/default.yaml

@@ -0,0 +1,13 @@
+# @package _global_
+
+# default running mode
+
+default_mode: True
+
+hydra:
+  # default output paths for all file logs
+  run:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+  sweep:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/multiruns/${now:%Y-%m-%d_%H-%M-%S}
+    subdir: ${hydra.job.num}

+ 17 - 0
training/configs/mode/exp.yaml

@@ -0,0 +1,17 @@
+# @package _global_
+
+# run in experiment mode with:
+# `python run.py mode=exp name=experiment_name`
+
+experiment_mode: True
+
+# allows for custom naming of the experiment
+name: ???
+
+hydra:
+  # sets output paths for all file logs to `logs/experiment/name'
+  run:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name}
+  sweep:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/experiments/${name}
+    subdir: ${hydra.job.num}

+ 31 - 0
training/configs/mode/profile.yaml

@@ -0,0 +1,31 @@
+# @package _global_
+# Run the Pytorch profiler
+
+trainer:
+  profiler:
+    _target_: pytorch_lightning.profilers.PyTorchProfiler
+    dirpath: ${hydra.run.dir}
+    schedule:
+      _target_: torch.profiler.schedule
+      wait: 5
+      warmup: 5
+      active: 5
+    use_cuda: True
+  max_steps: 20
+
+logger:
+  wandb:
+    mode: disabled
+
+callbacks:
+  model_checkpoint: null
+  model_checkpoint_progress: null
+  early_stopping: null
+
+hydra:
+  # sets output paths for all file logs to 'logs/profile/'
+  run:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/${now:%Y-%m-%d}/${now:%H-%M-%S}
+  sweep:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/profile/multirun_${now:%Y-%m-%d_%H-%M-%S}
+    subdir: ${hydra.job.num}

+ 22 - 0
training/configs/mode/smoke.yaml

@@ -0,0 +1,22 @@
+# @package _global_
+# Smoke test: disable logging and model checkpointing
+
+logger:
+  wandb:
+    mode: disabled
+
+callbacks:
+  model_checkpoint: null
+  model_checkpoint_progress: null
+
+hydra:
+  # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
+  # sets level of only chosen command line loggers to 'DEBUG'
+  # verbose: [src.train, src.utils.utils]
+
+  # sets output paths for all file logs to 'logs/debug/'
+  run:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/${now:%Y-%m-%d}/${now:%H-%M-%S}
+  sweep:
+    dir: ${oc.env:RESULT_DIR,${work_dir}/logs}/debug/multirun_${now:%Y-%m-%d_%H-%M-%S}
+    subdir: ${hydra.job.num}

+ 13 - 0
training/configs/model/gpt2-hf.yaml

@@ -0,0 +1,13 @@
+defaults:
+  - _self_
+  - gpt2model: gpt2-small
+
+_target_: transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel
+_recursive_: True
+config:
+  _target_: transformers.GPT2Config
+  # Mistral's config: https://github.com/stanford-crfm/mistral/blob/main/conf/models/gpt2-small.yaml
+  # However, reorder_and_upcast_attn slows things down
+  reorder_and_upcast_attn: false
+  scale_attn_by_inverse_layer_idx: true
+  n_positions: ${datamodule.max_length}

+ 13 - 0
training/configs/model/gpt2.yaml

@@ -0,0 +1,13 @@
+defaults:
+  - _self_
+  - gpt2model: gpt2-small
+
+_target_: flash_attn.models.gpt.GPTLMHeadModel
+_recursive_: True
+config:
+  _target_: transformers.GPT2Config
+  # Mistral's config: # https://github.com/stanford-crfm/mistral/blob/main/conf/models/mistral-small.yaml
+  # However, reorder_and_upcast_attn slows things down
+  reorder_and_upcast_attn: false
+  scale_attn_by_inverse_layer_idx: true
+  n_positions: ${datamodule.max_length}

+ 6 - 0
training/configs/model/gpt2model/gpt2-large.yaml

@@ -0,0 +1,6 @@
+# @package _global_
+model:
+  config:
+    n_embd: 1280
+    n_head: 20
+    n_layer: 36

+ 6 - 0
training/configs/model/gpt2model/gpt2-medium.yaml

@@ -0,0 +1,6 @@
+# @package _global_
+model:
+  config:
+    n_embd: 1024
+    n_head: 16
+    n_layer: 24

+ 6 - 0
training/configs/model/gpt2model/gpt2-small.yaml

@@ -0,0 +1,6 @@
+# @package _global_
+model:
+  config:
+    n_embd: 768
+    n_head: 12
+    n_layer: 12

+ 6 - 0
training/configs/model/gpt2model/gpt2-xlarge.yaml

@@ -0,0 +1,6 @@
+# @package _global_
+model:
+  config:
+    n_embd: 1600
+    n_head: 25
+    n_layer: 48

+ 2 - 0
training/configs/optimizer/adam.yaml

@@ -0,0 +1,2 @@
+# @package train.optimizer
+_target_: torch.optim.Adam

+ 3 - 0
training/configs/optimizer/adamw-apex-distributed.yaml

@@ -0,0 +1,3 @@
+# @package train.optimizer
+_target_: apex.contrib.optimizers.distributed_fused_adam.DistributedFusedAdam
+adam_w_mode: True

+ 7 - 0
training/configs/optimizer/adamw-apex-zero.yaml

@@ -0,0 +1,7 @@
+# @package train.optimizer
+_target_: torch.distributed.optim.ZeroRedundancyOptimizer
+_recursive_: True
+optimizer_class:
+  _target_: apex.optimizers.FusedAdam
+  _partial_: True
+  adam_w_mode: True

+ 3 - 0
training/configs/optimizer/adamw-apex.yaml

@@ -0,0 +1,3 @@
+# @package train.optimizer
+_target_: apex.optimizers.FusedAdam
+adam_w_mode: True

+ 7 - 0
training/configs/optimizer/adamw-zero.yaml

@@ -0,0 +1,7 @@
+# @package train.optimizer
+_target_: torch.distributed.optim.ZeroRedundancyOptimizer
+_recursive_: True
+optimizer_class:
+  _target_: torch.optim.__getattribute__
+  _args_:
+    - "AdamW"

+ 2 - 0
training/configs/optimizer/adamw.yaml

@@ -0,0 +1,2 @@
+# @package train.optimizer
+_target_: torch.optim.AdamW

+ 2 - 0
training/configs/optimizer/fusedlamb-ds.yaml

@@ -0,0 +1,2 @@
+# @package train.optimizer
+_target_: deepspeed.ops.lamb.FusedLamb

+ 2 - 0
training/configs/optimizer/fusedlamb.yaml

@@ -0,0 +1,2 @@
+# @package train.optimizer
+_target_: apex.optimizers.FusedLAMB

+ 2 - 0
training/configs/optimizer/sgd.yaml

@@ -0,0 +1,2 @@
+# @package train.optimizer
+_target_: torch.optim.SGD

+ 2 - 0
training/configs/scheduler/cosine-warmup-timm.yaml

@@ -0,0 +1,2 @@
+# @package train.scheduler
+_target_: src.optim.timm_lr_scheduler.TimmCosineLRScheduler

+ 2 - 0
training/configs/scheduler/cosine-warmup.yaml

@@ -0,0 +1,2 @@
+# @package train.scheduler
+_target_: transformers.get_cosine_schedule_with_warmup

+ 3 - 0
training/configs/scheduler/invsqrt.yaml

@@ -0,0 +1,3 @@
+# @package train.scheduler
+_target_: src.optim.lr_scheduler.InvSqrt
+num_warmup_steps: ???

+ 2 - 0
training/configs/scheduler/linear-warmup.yaml

@@ -0,0 +1,2 @@
+# @package train.scheduler
+_target_: transformers.get_linear_schedule_with_warmup

+ 2 - 0
training/configs/scheduler/multi-step.yaml

@@ -0,0 +1,2 @@
+# @package train.scheduler
+_target_: torch.optim.lr_scheduler.MultiStepLR

+ 9 - 0
training/configs/scheduler/plateau.yaml

@@ -0,0 +1,9 @@
+# @package _global_
+train:
+  scheduler_interval: epoch
+  scheduler_monitor: ???
+  scheduler:
+    _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+    factor: 0.2  # Decay factor when ReduceLROnPlateau is used
+    patience: 20
+    min_lr: 0.0  # Minimum learning rate during annealing

+ 2 - 0
training/configs/scheduler/poly-warmup.yaml

@@ -0,0 +1,2 @@
+# @package train.scheduler
+_target_: transformers.get_polynomial_decay_schedule_with_warmup

+ 3 - 0
training/configs/scheduler/step.yaml

@@ -0,0 +1,3 @@
+# @package train.scheduler
+_target_: torch.optim.lr_scheduler.StepLR
+step_size: ???

+ 1 - 0
training/configs/task/sequence-model.yaml

@@ -0,0 +1 @@
+_target_: src.tasks.seq.SequenceModel

+ 49 - 0
training/configs/trainer/all_params.yaml

@@ -0,0 +1,49 @@
+_target_: pytorch_lightning.Trainer
+
+# default values for all trainer parameters
+checkpoint_callback: True
+default_root_dir: null
+gradient_clip_val: 0.0
+process_position: 0
+num_nodes: 1
+num_processes: 1
+gpus: null
+auto_select_gpus: False
+tpu_cores: null
+log_gpu_memory: null
+overfit_batches: 0.0
+track_grad_norm: -1
+check_val_every_n_epoch: 1
+fast_dev_run: False
+accumulate_grad_batches: 1
+max_epochs: 1
+min_epochs: 1
+max_steps: null
+min_steps: null
+limit_train_batches: 1.0
+limit_val_batches: 1.0
+limit_test_batches: 1.0
+val_check_interval: 1.0
+flush_logs_every_n_steps: 100
+log_every_n_steps: 50
+accelerator: null
+sync_batchnorm: False
+precision: 32
+weights_summary: "top"
+weights_save_path: null
+num_sanity_val_steps: 2
+truncated_bptt_steps: null
+resume_from_checkpoint: null
+profiler: null
+benchmark: False
+deterministic: False
+reload_dataloaders_every_epoch: False
+auto_lr_find: False
+replace_sampler_ddp: True
+terminate_on_nan: False
+auto_scale_batch_size: False
+prepare_data_per_node: True
+plugins: null
+amp_backend: "native"
+amp_level: "O2"
+move_metrics_to_cpu: False

+ 6 - 0
training/configs/trainer/ddp.yaml

@@ -0,0 +1,6 @@
+defaults:
+  - default.yaml
+
+accelerator: gpu
+devices: 4
+strategy: ddp

+ 21 - 0
training/configs/trainer/debug.yaml

@@ -0,0 +1,21 @@
+defaults:
+  - default.yaml
+
+gpus: 0
+
+min_epochs: 1
+max_epochs: 2
+
+# prints
+weights_summary: "full"
+profiler: null
+
+# debugs
+fast_dev_run: true
+num_sanity_val_steps: 2
+overfit_batches: 0
+limit_train_batches: 1.0
+limit_val_batches: 1.0
+limit_test_batches: 1.0
+track_grad_norm: -1
+terminate_on_nan: true

+ 7 - 0
training/configs/trainer/default.yaml

@@ -0,0 +1,7 @@
+_target_: pytorch_lightning.Trainer
+
+# set `gpu` to train on GPU, null to train on CPU only
+accelerator: null
+
+min_epochs: 1
+max_epochs: 1000

部分文件因为文件数量过多而无法显示