123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- """This file is a pure Python wrapper for the cudart library.
- It avoids the need to compile a separate shared library, and is
- convenient for use when we just need to call a few functions.
- """
- import ctypes
- import glob
- import os
- import sys
- from dataclasses import dataclass
- from typing import Any, Dict, List, Optional
- # this line makes it possible to directly load `libcudart.so` using `ctypes`
- import torch # noqa
- # === export types and functions from cudart to Python ===
- # for the original cudart definition, please check
- # https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
- cudaError_t = ctypes.c_int
- cudaMemcpyKind = ctypes.c_int
- class cudaIpcMemHandle_t(ctypes.Structure):
- _fields_ = [("internal", ctypes.c_byte * 128)]
- @dataclass
- class Function:
- name: str
- restype: Any
- argtypes: List[Any]
- def get_pytorch_default_cudart_library_path() -> str:
- # code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
- lib_folder = "cuda_runtime"
- lib_name = "libcudart.so.*[0-9]"
- lib_path = None
- for path in sys.path:
- nvidia_path = os.path.join(path, "nvidia")
- if not os.path.exists(nvidia_path):
- continue
- candidate_lib_paths = glob.glob(
- os.path.join(nvidia_path, lib_folder, "lib", lib_name))
- if candidate_lib_paths and not lib_path:
- lib_path = candidate_lib_paths[0]
- if lib_path:
- break
- if not lib_path:
- raise ValueError(f"{lib_name} not found in the system path {sys.path}")
- return lib_path
- class CudaRTLibrary:
- exported_functions = [
- # cudaError_t cudaSetDevice ( int device )
- Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
- # cudaError_t cudaDeviceSynchronize ( void )
- Function("cudaDeviceSynchronize", cudaError_t, []),
- # cudaError_t cudaDeviceReset ( void )
- Function("cudaDeviceReset", cudaError_t, []),
- # const char* cudaGetErrorString ( cudaError_t error )
- Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
- # cudaError_t cudaMalloc ( void** devPtr, size_t size )
- Function("cudaMalloc", cudaError_t,
- [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
- # cudaError_t cudaFree ( void* devPtr )
- Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
- # cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
- Function("cudaMemset", cudaError_t,
- [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
- # cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
- Function("cudaMemcpy", cudaError_t, [
- ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
- ]),
- # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
- Function("cudaIpcGetMemHandle", cudaError_t,
- [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
- # cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
- Function("cudaIpcOpenMemHandle", cudaError_t, [
- ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
- ]),
- ]
- # class attribute to store the mapping from the path to the library
- # to avoid loading the same library multiple times
- path_to_library_cache: Dict[str, Any] = {}
- # class attribute to store the mapping from library path
- # to the corresponding dictionary
- path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
- def __init__(self, so_file: Optional[str] = None):
- if so_file is None:
- so_file = get_pytorch_default_cudart_library_path()
- if so_file not in CudaRTLibrary.path_to_library_cache:
- lib = ctypes.CDLL(so_file)
- CudaRTLibrary.path_to_library_cache[so_file] = lib
- self.lib = CudaRTLibrary.path_to_library_cache[so_file]
- if so_file not in CudaRTLibrary.path_to_dict_mapping:
- _funcs = {}
- for func in CudaRTLibrary.exported_functions:
- f = getattr(self.lib, func.name)
- f.restype = func.restype
- f.argtypes = func.argtypes
- _funcs[func.name] = f
- CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
- self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
- def CUDART_CHECK(self, result: cudaError_t) -> None:
- if result != 0:
- error_str = self.cudaGetErrorString(result)
- raise RuntimeError(f"CUDART error: {error_str}")
- def cudaGetErrorString(self, error: cudaError_t) -> str:
- return self.funcs["cudaGetErrorString"](error).decode("utf-8")
- def cudaSetDevice(self, device: int) -> None:
- self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
- def cudaDeviceSynchronize(self) -> None:
- self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
- def cudaDeviceReset(self) -> None:
- self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
- def cudaMalloc(self, size: int) -> ctypes.c_void_p:
- devPtr = ctypes.c_void_p()
- self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
- return devPtr
- def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
- self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
- def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
- count: int) -> None:
- self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
- def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
- count: int) -> None:
- cudaMemcpyDefault = 4
- kind = cudaMemcpyDefault
- self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
- def cudaIpcGetMemHandle(self,
- devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
- handle = cudaIpcMemHandle_t()
- self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
- ctypes.byref(handle), devPtr))
- return handle
- def cudaIpcOpenMemHandle(self,
- handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
- cudaIpcMemLazyEnablePeerAccess = 1
- devPtr = ctypes.c_void_p()
- self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
- ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
- return devPtr
|