cuda_wrapper.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """This file is a pure Python wrapper for the cudart library.
  2. It avoids the need to compile a separate shared library, and is
  3. convenient for use when we just need to call a few functions.
  4. """
  5. import ctypes
  6. import glob
  7. import os
  8. import sys
  9. from dataclasses import dataclass
  10. from typing import Any, Dict, List, Optional
  11. # this line makes it possible to directly load `libcudart.so` using `ctypes`
  12. import torch # noqa
  13. # === export types and functions from cudart to Python ===
  14. # for the original cudart definition, please check
  15. # https://docs.nvidia.com/cuda/cuda-runtime-api/index.html
  16. cudaError_t = ctypes.c_int
  17. cudaMemcpyKind = ctypes.c_int
  18. class cudaIpcMemHandle_t(ctypes.Structure):
  19. _fields_ = [("internal", ctypes.c_byte * 128)]
  20. @dataclass
  21. class Function:
  22. name: str
  23. restype: Any
  24. argtypes: List[Any]
  25. def get_pytorch_default_cudart_library_path() -> str:
  26. # code borrowed from https://github.com/pytorch/pytorch/blob/1cae60a87e5bdda8bcf55724a862eeed98a9747e/torch/__init__.py#L284 # noqa
  27. lib_folder = "cuda_runtime"
  28. lib_name = "libcudart.so.*[0-9]"
  29. lib_path = None
  30. for path in sys.path:
  31. nvidia_path = os.path.join(path, "nvidia")
  32. if not os.path.exists(nvidia_path):
  33. continue
  34. candidate_lib_paths = glob.glob(
  35. os.path.join(nvidia_path, lib_folder, "lib", lib_name))
  36. if candidate_lib_paths and not lib_path:
  37. lib_path = candidate_lib_paths[0]
  38. if lib_path:
  39. break
  40. if not lib_path:
  41. raise ValueError(f"{lib_name} not found in the system path {sys.path}")
  42. return lib_path
  43. class CudaRTLibrary:
  44. exported_functions = [
  45. # ​cudaError_t cudaSetDevice ( int device )
  46. Function("cudaSetDevice", cudaError_t, [ctypes.c_int]),
  47. # cudaError_t cudaDeviceSynchronize ( void )
  48. Function("cudaDeviceSynchronize", cudaError_t, []),
  49. # ​cudaError_t cudaDeviceReset ( void )
  50. Function("cudaDeviceReset", cudaError_t, []),
  51. # const char* cudaGetErrorString ( cudaError_t error )
  52. Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]),
  53. # ​cudaError_t cudaMalloc ( void** devPtr, size_t size )
  54. Function("cudaMalloc", cudaError_t,
  55. [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]),
  56. # ​cudaError_t cudaFree ( void* devPtr )
  57. Function("cudaFree", cudaError_t, [ctypes.c_void_p]),
  58. # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count )
  59. Function("cudaMemset", cudaError_t,
  60. [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]),
  61. # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa
  62. Function("cudaMemcpy", cudaError_t, [
  63. ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind
  64. ]),
  65. # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa
  66. Function("cudaIpcGetMemHandle", cudaError_t,
  67. [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]),
  68. # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa
  69. Function("cudaIpcOpenMemHandle", cudaError_t, [
  70. ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint
  71. ]),
  72. ]
  73. # class attribute to store the mapping from the path to the library
  74. # to avoid loading the same library multiple times
  75. path_to_library_cache: Dict[str, Any] = {}
  76. # class attribute to store the mapping from library path
  77. # to the corresponding dictionary
  78. path_to_dict_mapping: Dict[str, Dict[str, Any]] = {}
  79. def __init__(self, so_file: Optional[str] = None):
  80. if so_file is None:
  81. so_file = get_pytorch_default_cudart_library_path()
  82. if so_file not in CudaRTLibrary.path_to_library_cache:
  83. lib = ctypes.CDLL(so_file)
  84. CudaRTLibrary.path_to_library_cache[so_file] = lib
  85. self.lib = CudaRTLibrary.path_to_library_cache[so_file]
  86. if so_file not in CudaRTLibrary.path_to_dict_mapping:
  87. _funcs = {}
  88. for func in CudaRTLibrary.exported_functions:
  89. f = getattr(self.lib, func.name)
  90. f.restype = func.restype
  91. f.argtypes = func.argtypes
  92. _funcs[func.name] = f
  93. CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs
  94. self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file]
  95. def CUDART_CHECK(self, result: cudaError_t) -> None:
  96. if result != 0:
  97. error_str = self.cudaGetErrorString(result)
  98. raise RuntimeError(f"CUDART error: {error_str}")
  99. def cudaGetErrorString(self, error: cudaError_t) -> str:
  100. return self.funcs["cudaGetErrorString"](error).decode("utf-8")
  101. def cudaSetDevice(self, device: int) -> None:
  102. self.CUDART_CHECK(self.funcs["cudaSetDevice"](device))
  103. def cudaDeviceSynchronize(self) -> None:
  104. self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]())
  105. def cudaDeviceReset(self) -> None:
  106. self.CUDART_CHECK(self.funcs["cudaDeviceReset"]())
  107. def cudaMalloc(self, size: int) -> ctypes.c_void_p:
  108. devPtr = ctypes.c_void_p()
  109. self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size))
  110. return devPtr
  111. def cudaFree(self, devPtr: ctypes.c_void_p) -> None:
  112. self.CUDART_CHECK(self.funcs["cudaFree"](devPtr))
  113. def cudaMemset(self, devPtr: ctypes.c_void_p, value: int,
  114. count: int) -> None:
  115. self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count))
  116. def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p,
  117. count: int) -> None:
  118. cudaMemcpyDefault = 4
  119. kind = cudaMemcpyDefault
  120. self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind))
  121. def cudaIpcGetMemHandle(self,
  122. devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t:
  123. handle = cudaIpcMemHandle_t()
  124. self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"](
  125. ctypes.byref(handle), devPtr))
  126. return handle
  127. def cudaIpcOpenMemHandle(self,
  128. handle: cudaIpcMemHandle_t) -> ctypes.c_void_p:
  129. cudaIpcMemLazyEnablePeerAccess = 1
  130. devPtr = ctypes.c_void_p()
  131. self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"](
  132. ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess))
  133. return devPtr