Chirag Jain 50896ec574 Make nvcc threads configurable via environment variable (#885) hai 10 meses
..
README.md abbc131173 [LayerNorm] Switch from CUDA to Triton implementation hai 1 ano
ln.h 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_api.cpp 767b71ccf0 Fix random state for dropout_layer_norm (#315) hai 1 ano
ln_bwd_1024.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_1280.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_1536.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_2048.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_256.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_2560.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_3072.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_4096.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_512.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_5120.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_6144.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_7168.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_bwd_768.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_bwd_8192.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_bwd_kernels.cuh eb33e587e9 [LayerNorm] Rename x1 -> residual %!s(int64=2) %!d(string=hai) anos
ln_fwd_1024.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_1280.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_1536.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_2048.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_256.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_2560.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_3072.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_4096.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_512.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_5120.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_6144.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_7168.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_fwd_768.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) %!s(int64=2) %!d(string=hai) anos
ln_fwd_8192.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_fwd_kernels.cuh eb33e587e9 [LayerNorm] Rename x1 -> residual %!s(int64=2) %!d(string=hai) anos
ln_kernel_traits.h ae137ed17a [LayerNorm] Fuse LayerScale %!s(int64=2) %!d(string=hai) anos
ln_parallel_bwd_1024.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_1280.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_1536.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_2048.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_256.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_2560.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_3072.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_4096.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_512.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_5120.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_6144.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_7168.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_768.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_bwd_8192.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_1024.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_1280.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_1536.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_2048.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_256.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_2560.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_3072.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_4096.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_512.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_5120.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_6144.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_7168.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_768.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_fwd_8192.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_residual_bwd_kernels.cuh 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_parallel_residual_fwd_kernels.cuh 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
ln_utils.cuh 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k hai 1 ano
setup.py 50896ec574 Make nvcc threads configurable via environment variable (#885) hai 10 meses
static_switch.h fa6d1ce44f Add fused_dense and dropout_add_layernorm CUDA extensions %!s(int64=2) %!d(string=hai) anos

README.md

This CUDA extension implements fused dropout + residual + LayerNorm, building on Apex's FastLayerNorm. Major changes:

  • Add dropout and residual.
  • Make it work for both pre-norm and post-norm architecture.
  • Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
  • Implement RMSNorm as an option.
  • Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).

If you want to use it for dimensions larger than 8k, please file an issue.

This extension has only been tested on A100s.

cd csrc/layer_norm && pip install .

As of 2024-01-05, this extension is no longer used in the FlashAttention repo. We've instead switched to a Triton-based implementation.