Lu Fang 74aed78373 Replace c10::optional with std::optional in flash_attn il y a 5 jours
..
README.md abbc131173 [LayerNorm] Switch from CUDA to Triton implementation il y a 1 an
ln.h 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_api.cpp 74aed78373 Replace c10::optional with std::optional in flash_attn il y a 5 jours
ln_bwd_1024.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_1280.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_1536.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_2048.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_256.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_2560.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_3072.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_4096.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_512.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_5120.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_6144.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_7168.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_bwd_768.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_bwd_8192.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_bwd_kernels.cuh eb33e587e9 [LayerNorm] Rename x1 -> residual il y a 2 ans
ln_fwd_1024.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_1280.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_1536.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_2048.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_256.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_2560.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_3072.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_4096.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_512.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_5120.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_6144.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_7168.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_fwd_768.cu 8c6609ae1a [LayerNorm] Support all dimensions up to 6k (if divisible by 8) il y a 2 ans
ln_fwd_8192.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_fwd_kernels.cuh eb33e587e9 [LayerNorm] Rename x1 -> residual il y a 2 ans
ln_kernel_traits.h ae137ed17a [LayerNorm] Fuse LayerScale il y a 2 ans
ln_parallel_bwd_1024.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_1280.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_1536.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_2048.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_256.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_2560.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_3072.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_4096.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_512.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_5120.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_6144.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_7168.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_768.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_bwd_8192.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_1024.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_1280.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_1536.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_2048.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_256.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_2560.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_3072.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_4096.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_512.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_5120.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_6144.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_7168.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_768.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_fwd_8192.cu 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_residual_bwd_kernels.cuh 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_parallel_residual_fwd_kernels.cuh 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
ln_utils.cuh 393882bc08 [LayerNorm] Implement LN with parallel residual, support dim 8k il y a 1 an
setup.py 50896ec574 Make nvcc threads configurable via environment variable (#885) il y a 10 mois
static_switch.h fa6d1ce44f Add fused_dense and dropout_add_layernorm CUDA extensions il y a 2 ans

README.md

This CUDA extension implements fused dropout + residual + LayerNorm, building on Apex's FastLayerNorm. Major changes:

  • Add dropout and residual.
  • Make it work for both pre-norm and post-norm architecture.
  • Support more hidden dimensions (all dimensions divisible by 8, up to 8192).
  • Implement RMSNorm as an option.
  • Support layer norm with parallel residual (e.g., GPT-J, GPT-NeoX, PaLM).

If you want to use it for dimensions larger than 8k, please file an issue.

This extension has only been tested on A100s.

cd csrc/layer_norm && pip install .

As of 2024-01-05, this extension is no longer used in the FlashAttention repo. We've instead switched to a Triton-based implementation.