_custom_op.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import torch.nn as nn
  2. from aphrodite.common.utils import is_cpu, is_hip, is_xpu
  3. from aphrodite.platforms import current_platform
  4. class CustomOp(nn.Module):
  5. def __init__(self, *args, **kwargs):
  6. super().__init__()
  7. self._forward_method = self.dispatch_forward()
  8. def forward(self, *args, **kwargs):
  9. return self._forward_method(*args, **kwargs)
  10. def forward_native(self, *args, **kwargs):
  11. """PyTorch-native implementation of the forward method.
  12. This method is optional. If implemented, it can be used with compilers
  13. such as torch.compile or PyTorch XLA. Also, it can be used for testing
  14. purposes.
  15. """
  16. raise NotImplementedError
  17. def forward_cuda(self, *args, **kwargs):
  18. raise NotImplementedError
  19. def forward_hip(self, *args, **kwargs):
  20. # By default, we assume that HIP ops are compatible with CUDA ops.
  21. return self.forward_cuda(*args, **kwargs)
  22. def forward_xpu(self, *args, **kwargs):
  23. raise NotImplementedError
  24. def forward_cpu(self, *args, **kwargs):
  25. # By default, we assume that CPU ops are compatible with CUDA ops.
  26. return self.forward_cuda(*args, **kwargs)
  27. def forward_tpu(self, *args, **kwargs):
  28. # By default, we assume that TPU ops are compatible with the
  29. # PyTorch-native implementation.
  30. # NOTE: This is a placeholder for future extensions.
  31. return self.forward_native(*args, **kwargs)
  32. def forward_gaudi(self, *args, **kwargs):
  33. # By default, we assume that Gaudi ops are compatible with the
  34. # PyTorch-native implementation.
  35. # NOTE: This is a placeholder for future extensions.
  36. return self.forward_native(*args, **kwargs)
  37. def dispatch_forward(self):
  38. # NOTE: Here we assume that Aphrodite was built for only one
  39. # specific backend. Currently, we do not support dynamic dispatching.
  40. if is_hip():
  41. return self.forward_hip
  42. elif is_cpu():
  43. return self.forward_cpu
  44. elif current_platform.is_tpu():
  45. return self.forward_tpu
  46. elif is_xpu():
  47. return self.forward_xpu
  48. else:
  49. return self.forward_cuda