1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- # Copied from https://github.com/NVIDIA/apex/tree/master/csrc/megatron
- # We add the case where seqlen = 4k and seqlen = 8k
- import os
- import subprocess
- import torch
- from setuptools import setup
- from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
- def get_cuda_bare_metal_version(cuda_dir):
- raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
- output = raw_output.split()
- release_idx = output.index("release") + 1
- release = output[release_idx].split(".")
- bare_metal_major = release[0]
- bare_metal_minor = release[1][0]
- return raw_output, bare_metal_major, bare_metal_minor
- def append_nvcc_threads(nvcc_extra_args):
- _, bare_metal_major, bare_metal_minor = get_cuda_bare_metal_version(CUDA_HOME)
- if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 2:
- nvcc_threads = os.getenv("NVCC_THREADS") or "4"
- return nvcc_extra_args + ["--threads", nvcc_threads]
- return nvcc_extra_args
- cc_flag = []
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_70,code=sm_70")
- cc_flag.append("-gencode")
- cc_flag.append("arch=compute_80,code=sm_80")
- setup(
- name='fused_softmax_lib',
- ext_modules=[
- CUDAExtension(
- name='fused_softmax_lib',
- sources=['fused_softmax.cpp', 'scaled_masked_softmax_cuda.cu', 'scaled_upper_triang_masked_softmax_cuda.cu'],
- extra_compile_args={
- 'cxx': ['-O3',],
- 'nvcc': append_nvcc_threads(['-O3', '--use_fast_math'] + cc_flag)
- }
- )
- ],
- cmdclass={
- 'build_ext': BuildExtension
- })
|