12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import torch.nn as nn
- from aphrodite.common.utils import is_cpu, is_hip, is_tpu, is_xpu
- class CustomOp(nn.Module):
- def __init__(self, *args, **kwargs):
- super().__init__()
- self._forward_method = self.dispatch_forward()
- def forward(self, *args, **kwargs):
- return self._forward_method(*args, **kwargs)
- def forward_native(self, *args, **kwargs):
- """PyTorch-native implementation of the forward method.
- This method is optional. If implemented, it can be used with compilers
- such as torch.compile or PyTorch XLA. Also, it can be used for testing
- purposes.
- """
- raise NotImplementedError
- def forward_cuda(self, *args, **kwargs):
- raise NotImplementedError
- def forward_hip(self, *args, **kwargs):
-
- return self.forward_cuda(*args, **kwargs)
- def forward_xpu(self, *args, **kwargs):
- raise NotImplementedError
- def forward_cpu(self, *args, **kwargs):
-
- return self.forward_cuda(*args, **kwargs)
- def forward_tpu(self, *args, **kwargs):
-
-
-
- return self.forward_native(*args, **kwargs)
- def forward_gaudi(self, *args, **kwargs):
-
-
-
- return self.forward_native(*args, **kwargs)
- def dispatch_forward(self):
-
-
- if is_hip():
- return self.forward_hip
- elif is_cpu():
- return self.forward_cpu
- elif is_tpu():
- return self.forward_tpu
- elif is_xpu():
- return self.forward_xpu
- else:
- return self.forward_cuda
|