tpu.py 196 B

1234567891011
  1. import torch
  2. from .interface import Platform, PlatformEnum
  3. class TpuPlatform(Platform):
  4. _enum = PlatformEnum.TPU
  5. @staticmethod
  6. def inference_mode():
  7. return torch.no_grad()