cpu.py 286 B

12345678910111213
  1. import torch
  2. from .interface import Platform, PlatformEnum
  3. class CpuPlatform(Platform):
  4. _enum = PlatformEnum.CPU
  5. @staticmethod
  6. def get_device_name(device_id: int = 0) -> str:
  7. return "cpu"
  8. @staticmethod
  9. def inference_mode():
  10. return torch.no_grad()