"""This file is a pure Python wrapper for the cudart library. It avoids the need to compile a separate shared library, and is convenient for use when we just need to call a few functions. """ import ctypes import glob import os import sys from dataclasses import dataclass from typing import Any, Dict, List, Optional # this line makes it possible to directly load `libcudart.so` using `ctypes` import torch # noqa # === export types and functions from cudart to Python === # for the original cudart definition, please check # https://docs.nvidia.com/cuda/cuda-runtime-api/index.html cudaError_t = ctypes.c_int cudaMemcpyKind = ctypes.c_int class cudaIpcMemHandle_t(ctypes.Structure): _fields_ = [("internal", ctypes.c_byte * 128)] @dataclass class Function: name: str restype: Any argtypes: List[Any] def get_pytorch_default_cudart_library_path() -> str: # code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa lib_folder = "cuda_runtime" lib_name = "libcudart.so.*[0-9]" lib_path = None for path in sys.path: nvidia_path = os.path.join(path, "nvidia") if not os.path.exists(nvidia_path): continue candidate_lib_paths = glob.glob( os.path.join(nvidia_path, lib_folder, "lib", lib_name)) if candidate_lib_paths and not lib_path: lib_path = candidate_lib_paths[0] if lib_path: break if not lib_path: raise ValueError(f"{lib_name} not found in the system path {sys.path}") return lib_path class CudaRTLibrary: exported_functions = [ # ​cudaError_t cudaSetDevice ( int device ) Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), # cudaError_t cudaDeviceSynchronize ( void ) Function("cudaDeviceSynchronize", cudaError_t, []), # ​cudaError_t cudaDeviceReset ( void ) Function("cudaDeviceReset", cudaError_t, []), # const char* cudaGetErrorString ( cudaError_t error ) Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) Function("cudaMalloc", cudaError_t, [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), # ​cudaError_t cudaFree ( void* devPtr ) Function("cudaFree", cudaError_t, [ctypes.c_void_p]), # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) Function("cudaMemset", cudaError_t, [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa Function("cudaMemcpy", cudaError_t, [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind ]), # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa Function("cudaIpcGetMemHandle", cudaError_t, [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa Function("cudaIpcOpenMemHandle", cudaError_t, [ ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint ]), ] # class attribute to store the mapping from the path to the library # to avoid loading the same library multiple times path_to_library_cache: Dict[str, Any] = {} # class attribute to store the mapping from library path # to the corresponding dictionary path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} def __init__(self, so_file: Optional[str] = None): if so_file is None: assert torch.version.cuda is not None major_version = torch.version.cuda.split(".")[0] so_file = f"libcudart.so.{major_version}" if so_file not in CudaRTLibrary.path_to_library_cache: lib = ctypes.CDLL(so_file) CudaRTLibrary.path_to_library_cache[so_file] = lib self.lib = CudaRTLibrary.path_to_library_cache[so_file] if so_file not in CudaRTLibrary.path_to_dict_mapping: _funcs = {} for func in CudaRTLibrary.exported_functions: f = getattr(self.lib, func.name) f.restype = func.restype f.argtypes = func.argtypes _funcs[func.name] = f CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] def CUDART_CHECK(self, result: cudaError_t) -> None: if result != 0: error_str = self.cudaGetErrorString(result) raise RuntimeError(f"CUDART error: {error_str}") def cudaGetErrorString(self, error: cudaError_t) -> str: return self.funcs["cudaGetErrorString"](error).decode("utf-8") def cudaSetDevice(self, device: int) -> None: self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) def cudaDeviceSynchronize(self) -> None: self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) def cudaDeviceReset(self) -> None: self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) def cudaMalloc(self, size: int) -> ctypes.c_void_p: devPtr = ctypes.c_void_p() self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) return devPtr def cudaFree(self, devPtr: ctypes.c_void_p) -> None: self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, count: int) -> None: self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, count: int) -> None: cudaMemcpyDefault = 4 kind = cudaMemcpyDefault self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) def cudaIpcGetMemHandle(self, devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: handle = cudaIpcMemHandle_t() self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( ctypes.byref(handle), devPtr)) return handle def cudaIpcOpenMemHandle(self, handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: cudaIpcMemLazyEnablePeerAccess = 1 devPtr = ctypes.c_void_p() self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) return devPtr