gpu_info_tab.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import gradio as gr
  2. import torch
  3. def gpu_info_tab():
  4. with gr.Tab("GPU Info") as gpu_info_tab:
  5. gpu_info = gr.Markdown("")
  6. gr.Button("Refresh").click(
  7. fn=refresh_gpu_info,
  8. outputs=gpu_info,
  9. api_name="refresh_gpu_info", # , every=1
  10. )
  11. gpu_info_tab.select(
  12. fn=refresh_gpu_info,
  13. outputs=gpu_info,
  14. )
  15. gr.Button("API_GET_GPU_INFO", visible=False).click(
  16. fn=get_gpu_info,
  17. outputs=[gr.JSON(None, visible=False)],
  18. api_name="get_gpu_info",
  19. )
  20. def get_gpu_info():
  21. if not torch.cuda.is_available():
  22. return []
  23. device_count = torch.cuda.device_count()
  24. return [get_gpu_info_idx(idx) for idx in range(device_count)]
  25. def get_pynvml_fields(idx=0):
  26. # check if pynvml is installed
  27. try:
  28. # import pynvml
  29. return {
  30. "temperature": torch.cuda.temperature(idx),
  31. "power_draw": torch.cuda.power_draw(idx) / 1000,
  32. "utilization": torch.cuda.utilization(idx),
  33. }
  34. # except ImportError:
  35. except:
  36. return {
  37. "temperature": 0,
  38. "power_draw": 0,
  39. "utilization": 0,
  40. }
  41. def get_gpu_info_idx(idx=0):
  42. return {
  43. "torch_version": torch.__version__,
  44. "cuda_version": torch.version.cuda,
  45. "vram": torch.cuda.get_device_properties(idx).total_memory / 1024**2,
  46. "name": torch.cuda.get_device_properties(idx).name,
  47. "cuda_capabilities": torch.cuda.get_device_capability(idx),
  48. "used_vram": torch.cuda.memory_allocated(idx) / 1024**2,
  49. "used_vram_total": (
  50. torch.cuda.mem_get_info(idx)[1] - torch.cuda.mem_get_info(idx)[0]
  51. )
  52. / 1024**2,
  53. "cached_vram": torch.cuda.memory_reserved(idx) / 1024**2,
  54. "idx": idx,
  55. "multi_processor_count": torch.cuda.get_device_properties(
  56. idx
  57. ).multi_processor_count,
  58. **get_pynvml_fields(idx),
  59. }
  60. def render_gpu_info(gpu_info):
  61. if isinstance(gpu_info, dict):
  62. return f"""VRAM: {gpu_info['vram']} MB
  63. Used VRAM: {gpu_info['used_vram']} MB
  64. Total Used VRAM: {gpu_info['used_vram_total']} MB
  65. Name: {gpu_info['name']}
  66. CUDA Capabilities: {gpu_info['cuda_capabilities']}
  67. Cached VRAM: {gpu_info['cached_vram']} MB
  68. Torch Version: {gpu_info['torch_version']}"""
  69. else:
  70. return gpu_info
  71. def refresh_gpu_info():
  72. return "".join([render_gpu_info(x) for x in get_gpu_info()])
  73. if __name__ == "__main__":
  74. if "demo" in locals():
  75. demo.close() # type: ignore
  76. with gr.Blocks() as demo:
  77. gpu_info_tab()
  78. demo.launch(
  79. server_port=7770,
  80. )