setup.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron
  2. # We add the case where seqlen = 4k and seqlen = 8k
  3. import os
  4. import subprocess
  5. import torch
  6. from setuptools import setup
  7. from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
  8. def get_cuda_bare_metal_version(cuda_dir):
  9. raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
  10. output = raw_output.split()
  11. release_idx = output.index("release") + 1
  12. release = output[release_idx].split(".")
  13. bare_metal_major = release[0]
  14. bare_metal_minor = release[1][0]
  15. return raw_output, bare_metal_major, bare_metal_minor
  16. def append_nvcc_threads(nvcc_extra_args):
  17. _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
  18. if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
  19. nvcc_threads = os.getenv("NVCC_THREADS") or "4"
  20. return nvcc_extra_args + ["--threads", nvcc_threads]
  21. return nvcc_extra_args
  22. cc_flag = []
  23. cc_flag.append("-gencode")
  24. cc_flag.append("arch=compute_70,code=sm_70")
  25. cc_flag.append("-gencode")
  26. cc_flag.append("arch=compute_80,code=sm_80")
  27. setup(
  28. name='fused_softmax_lib',
  29. ext_modules=[
  30. CUDAExtension(
  31. name='fused_softmax_lib',
  32. sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'],
  33. extra_compile_args={
  34. 'cxx': ['-O3',],
  35. 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag)
  36. }
  37. )
  38. ],
  39. cmdclass={
  40. 'build_ext': BuildExtension
  41. })